Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update groups, i.e. domain and observation localization #380

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
bb98266
Update groups object, consistency and unit tests
odunbar May 23, 2024
0be1f93
working example
odunbar Aug 29, 2024
4049936
format
odunbar Aug 29, 2024
a149214
update groups runs with minibatching
odunbar Aug 30, 2024
e021ff2
basic update following u->update groups
odunbar Aug 30, 2024
e969040
better error catch for typo
odunbar Aug 30, 2024
f5655b2
organize some update cases better
odunbar Aug 31, 2024
f06aba3
format
odunbar Oct 3, 2024
9889b93
pass more tests
odunbar Oct 3, 2024
bc8c327
pass more tests
odunbar Oct 3, 2024
8c1dc36
use API in ETKI
odunbar Oct 3, 2024
cab1519
resolved ETKI runs but is still slow
odunbar Nov 13, 2024
679b137
index and exploit diag for speed
odunbar Nov 14, 2024
a47cb7f
compatible with GNKI
odunbar Nov 14, 2024
adf315c
bugfix in tests
odunbar Nov 14, 2024
a694c3f
refactored localization input args, now indep of dimension until runtime
odunbar Nov 16, 2024
c1e85c3
formatting
odunbar Nov 16, 2024
9bceb0b
rm transform
odunbar Nov 16, 2024
14e3c8d
try fail-fast to stop Mac instability killing all jobs
odunbar Nov 16, 2024
8243365
docs
odunbar Nov 16, 2024
c31636e
fail-fast in strategies
odunbar Nov 16, 2024
3205d3a
rm fail-fast
odunbar Nov 16, 2024
a7bc440
try again - indenting changed
odunbar Nov 17, 2024
d1d0c18
docs cleaning
odunbar Nov 17, 2024
36be62d
typo
odunbar Nov 17, 2024
a9482c5
typo
odunbar Nov 17, 2024
eb73ba4
format
odunbar Nov 17, 2024
86aa1dd
typo
odunbar Nov 17, 2024
7512aec
add proper args
odunbar Nov 18, 2024
3ff54be
codecov and == override
odunbar Nov 18, 2024
5136f2c
formatting
odunbar Nov 19, 2024
765418e
docs typos
odunbar Nov 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/Tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ jobs:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
strategy:
matrix:
fail-fast: false #don't cancel all jobs if one fails
matrix:
version:
- 'lts' # Long-Term Support release
- '1' # Latest 1.x release of julia
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ pages = [
"Learning rate schedulers" => "learning_rate_scheduler.md",
"Prior distributions" => "parameter_distributions.md",
"Observations and Minibatching" => "observations.md",
"Update Groups" => "update_groups.md",
"Localization and SEC" => "localization.md",
"Inflation" => "inflation.md",
"Parallelism and HPC" => "parallel_hpc.md",
Expand Down
81 changes: 81 additions & 0 deletions docs/src/update_groups.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# [Update Groups] (@id update-groups)

The `UpdateGroup` object facilitates blocked EKP updates, based on a provided updating a series user-defined pairs of parameters and data. This allows users to enforce any *known* (in)dependences between different groups of parameters during the update.

!!! note "This improves scaling at the cost of user-imposed structure"
As many of the `Process` updates scale say with ``d^\alpha``, in the data dimension ``d`` and ``\alpha > 1`` (super-linearly), update groups with ``K`` groups of equal size will improving this scaling to ``K (\frac{d}{K})^\alpha``.

## Recommended construction - shown by example

The key component to construct update groups starts with constructing the prior, and the observations. Parameter distributions and observations may be constructed in units and given names, and these names are utilized to build the update groups with a convenient constructor `create_update_groups`.

For illustration, we take code snippets from the example found in [examples/](https://github.com/CliMA/EnsembleKalmanProcesses.jl/blob/main/examples/UpdateGroups/). This example is concerned with learning several parameters in a coupled two-scale Lorenz 96 system:
```math
\begin{aligned}
\frac{\partial X_i}{\partial t} & = -X_{i-1}(X_{i-2} - X{i+1}) - X_i - GY_i + F_1 + F_2*\sin(2\pi t F_3)\\
\frac{\partial Y_i}{\partial t} & = -cbY_{i+1}(Y_{i+2} - X{i-1}) - cY_i + \frac{hc}{b} X_i
\end{aligned}
```
Parameters are learnt by fitting moments of a realized `X` and `Y` system, to some target moments.

We create a prior by combining several *named* `ParameterDistribution`s.
```julia
param_names = ["F", "G", "h", "c", "b"]

prior_F = ParameterDistribution(
Dict(
"name" => param_names[1],
"distribution" => Parameterized(MvNormal([1.0, 0.0, -2.0], I)),
"constraint" => repeat([bounded_below(0)], 3),
),
) # gives 3-D dist
prior_G = constrained_gaussian(param_names[2], 5.0, 4.0, 0, Inf)
prior_h = constrained_gaussian(param_names[3], 5.0, 4.0, 0, Inf)
prior_c = constrained_gaussian(param_names[4], 5.0, 4.0, 0, Inf)
prior_b = constrained_gaussian(param_names[5], 5.0, 4.0, 0, Inf)
priors = combine_distributions([prior_F, prior_G, prior_h, prior_c, prior_b])
```
Now we likewise construct observed moments by combining several *named* `Observation`s
```julia
# given a list of vector statistics y and their covariances Γ
data_block_names = ["<X>", "<Y>", "<X^2>", "<Y^2>", "<XY>"]

observation_vec = []
for i in 1:length(data_block_names)
push!(
observation_vec,
Observation(Dict("samples" => y[i], "covariances" => Γ[i], "names" => data_block_names[i])),
)
end
observation = combine_observations(observation_vec)
```
We define the update groups of our choice by partitioning the parameter names as keys of a dictionary, and their paired data names as values. Here we create two groups:
```julia
# update parameters F,G with data <X>, <X^2>, <XY>
# update parameters h, c, b with data <Y>, <Y^2>, <XY>
group_identifiers = Dict(
["F", "G"] => ["<X>", "<X^2>", "<XY>"],
["h", "c", "b"] => ["<Y>", "<Y^2>", "<XY>"],
)
```
We then create the update groups with our convenient constructor
```julia
update_groups = create_update_groups(prior, observation, group_identifiers)
```
and this can then be entered into the `EnsembleKalmanProcess` object as a keyword argument
```julia
# ... initial_params = construct_initial_ensemble(rng, priors, N_ens)
ekiobj = EnsembleKalmanProcess(initial_params, observation, Inversion(), update_groups = update_groups)
```

## Advice for constructing blocks
1. A parameter cannot appear in more than one block (i.e. parameters cannot be updated more than once)
2. The block structure is user-defined, and directly assumes that there is no correlation between blocks. It is up to the user to confirm if there truly is independence between different blocks. Otherwise convergence properties may suffer.
3. This can be used in conjunction with minibatching, so long as the defined data objects are available in all `Observation`s in the series.

## What happens internally?

We simply perform an independent `update_ensemble!` for each provided pairing and combine model output and updated parameters afterwards. Note that even without specifying an update group, by default EKP will always be construct one under-the-hood.

!!! note "In future..."
In theory this opens up the possibility to have different configurations, or even processes, in different groups. This could be useful when parameter-data pairings are highly heterogeneous and so the user may wish to exploit, for example, the different processes scaling properties. However this has not yet been implemented.
52 changes: 43 additions & 9 deletions examples/Localization/localization_example_lorenz96.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,10 @@ prior = combine_distributions(priors)

initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens)


# Solve problem without localization
ekiobj_vanilla = EKP.EnsembleKalmanProcess(initial_ensemble, y, Γ, Inversion(); rng = rng)
ekiobj_vanilla =
EKP.EnsembleKalmanProcess(initial_ensemble, y, Γ, Inversion(); rng = rng, scheduler = DefaultScheduler())
for i in 1:N_iter
g_ens_vanilla = G(get_ϕ_final(prior, ekiobj_vanilla))
EKP.update_ensemble!(ekiobj_vanilla, g_ens_vanilla, deterministic_forward_map = true)
Expand All @@ -91,6 +93,7 @@ ekiobj_inflated = EKP.EnsembleKalmanProcess(
Γ,
Inversion();
rng = rng,
scheduler = DefaultScheduler(),
# localization_method = BernoulliDropout(0.98),
)

Expand All @@ -108,7 +111,15 @@ end
@info "EKI (inflated) - complete"

# Test SEC
ekiobj_sec = EKP.EnsembleKalmanProcess(initial_ensemble, y, Γ, Inversion(); rng = rng, localization_method = SEC(1.0))
ekiobj_sec = EKP.EnsembleKalmanProcess(
initial_ensemble,
y,
Γ,
Inversion();
rng = rng,
localization_method = SEC(1.0),
scheduler = DefaultScheduler(),
)

for i in 1:N_iter
g_ens = G(get_ϕ_final(prior, ekiobj_sec))
Expand All @@ -117,8 +128,15 @@ end
@info "EKI (SEC) - complete"

# Test SEC with cutoff
ekiobj_sec_cutoff =
EKP.EnsembleKalmanProcess(initial_ensemble, y, Γ, Inversion(); rng = rng, localization_method = SEC(1.0, 0.1))
ekiobj_sec_cutoff = EKP.EnsembleKalmanProcess(
initial_ensemble,
y,
Γ,
Inversion();
rng = rng,
localization_method = SEC(1.0, 0.1),
scheduler = DefaultScheduler(),
)

for i in 1:N_iter
g_ens = G(get_ϕ_final(prior, ekiobj_sec_cutoff))
Expand All @@ -127,8 +145,15 @@ end
@info "EKI (SEC cut-off) - complete"

# Test SECFisher
ekiobj_sec_fisher =
EKP.EnsembleKalmanProcess(initial_ensemble, y, Γ, Inversion(); rng = rng, localization_method = SECFisher())
ekiobj_sec_fisher = EKP.EnsembleKalmanProcess(
initial_ensemble,
y,
Γ,
Inversion();
rng = rng,
localization_method = SECFisher(),
scheduler = DefaultScheduler(),
)

for i in 1:N_iter
g_ens = G(get_ϕ_final(prior, ekiobj_sec_fisher))
Expand All @@ -137,8 +162,15 @@ end
@info "EKI (SEC Fisher) - complete"

# Test SECNice
ekiobj_sec_nice =
EKP.EnsembleKalmanProcess(initial_ensemble, y, Γ, Inversion(); rng = rng, localization_method = SECNice())
ekiobj_sec_nice = EKP.EnsembleKalmanProcess(
initial_ensemble,
y,
Γ,
Inversion();
rng = rng,
localization_method = SECNice(),
scheduler = DefaultScheduler(),
)

for i in 1:N_iter
g_ens = G(get_ϕ_final(prior, ekiobj_sec_nice))
Expand All @@ -150,7 +182,9 @@ end
u_final = get_u_final(ekiobj_sec)
g_final = get_g_final(ekiobj_sec)
cov_est = cov([u_final; g_final], [u_final; g_final], dims = 2, corrected = false)
cov_localized = get_localizer(ekiobj_sec).localize(cov_est)
# need dimension args too
cov_localized =
get_localizer(ekiobj_sec).localize(cov_est, eltype(g_final), size(u_final, 1), size(g_final, 1), size(u_final, 2))

fig = plot(
get_error(ekiobj_vanilla),
Expand Down
7 changes: 7 additions & 0 deletions examples/UpdateGroups/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[deps]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Loading
Loading