Skip to content

Commit

Permalink
organize some update cases better
Browse files Browse the repository at this point in the history
  • Loading branch information
odunbar committed Sep 24, 2024
1 parent be26c68 commit d61f568
Showing 1 changed file with 84 additions and 26 deletions.
110 changes: 84 additions & 26 deletions examples/UpdateGroups/calibrate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,21 +218,89 @@ function main()
#meanblocks = [mean([Γ[i, i] for i in ((j - 1) * blocksize + 1):(j * blocksize)]) for j in 1:5]
#Γ += 1e-4* kron(Diagonal(meanblocks),I(blocksize)) # this will add scaled noise to the diagonal scaled by the block

y_mean = mean(data_samples, dims = 2)
y_mean = mean(data_samples, dims = 2)
y = data_samples[:, shuffle(rng, 1:n_sample_cov)[1]] # random data point as the data

# build a nice observation object from this (blocks assumption)
full_observation = Observation(Dict("samples" => y, "covariances" => Γ, "names" => "full"))

observation_cases = [
"all-block",
"fast-slow-block",
"fast-slow-block-inconsistent", # causes error when calc timestep
"fast-slow-noxy",
]
observation_case = observation_cases[1]

if observation_case == "all-block"
data_block_names = ["<X>", "<Y>", "<X^2>", "<Y^2>", "<XY>"]

observation_vec = []
for i in 1:5
for i in 1:length(data_block_names)
idx = ((i - 1) * blocksize + 1):(i * blocksize)
push!(
observation_vec,
Observation(Dict("samples" => y[idx], "covariances" => Γ[idx, idx], "names" => data_block_names[i])),
)
end
observation = combine_observations(observation_vec)
y = get_obs(observation)
elseif observation_case == "fast-slow-block"

data_block_names = ["<X><X^2>", "<Y><Y^2><XY>"]
idx_slow = reduce(vcat, [(i - 1) * blocksize + 1:i * blocksize for i [1,3]]) #slow
idx_fast = reduce(vcat, [(i - 1) * blocksize + 1:i * blocksize for i [2,4,5]]) #fast

observation_vec = [
Observation(Dict("samples" => y[idx_slow], "covariances" => Γ[idx_slow, idx_slow], "names" => data_block_names[1])),
Observation(Dict("samples" => y[idx_fast], "covariances" => Γ[idx_fast, idx_fast], "names" => data_block_names[2])),
]
elseif observation_case == "fast-slow-block-inconsistent"

data_block_names = ["<X><X^2><XY>", "<Y><Y^2><XY>"]
idx_slow = reduce(vcat, [(i - 1) * blocksize + 1:i * blocksize for i [1,3,5]]) #slow
idx_fast = reduce(vcat, [(i - 1) * blocksize + 1:i * blocksize for i [2,4,5]]) #fast

observation_vec = [
Observation(Dict("samples" => y[idx_slow], "covariances" => Γ[idx_slow, idx_slow], "names" => data_block_names[1])),
Observation(Dict("samples" => y[idx_fast], "covariances" => Γ[idx_fast, idx_fast], "names" => data_block_names[2])),
]
elseif observation_case == "fast-slow-noxy"

data_block_names = ["<X><X^2>", "<Y><Y^2>"]
idx_slow = reduce(vcat, [(i - 1) * blocksize + 1:i * blocksize for i [1,3]]) #slow
idx_fast = reduce(vcat, [(i - 1) * blocksize + 1:i * blocksize for i [2,4]]) #fast

observation_vec = [
Observation(Dict("samples" => y[idx_slow], "covariances" => Γ[idx_slow, idx_slow], "names" => data_block_names[1])),
Observation(Dict("samples" => y[idx_fast], "covariances" => Γ[idx_fast, idx_fast], "names" => data_block_names[2])),
]
end
observation = combine_observations(observation_vec)

## group parameter-observation pairs for second experiment
if observation_case == "all-block"
group_identifiers = Dict(
["F", "G"] => ["<X>", "<X^2>"],
["h", "c", "b"] => ["<Y>", "<Y^2>", "<XY>"]
)
elseif observation_case == "fast-slow-block"
group_identifiers = Dict(
["F", "G"] => ["<X><X^2>"],
["h", "c", "b"] => ["<Y><Y^2><XY>"]
)
elseif observation_case == "fast-slow-block-inconsistent"
group_identifiers = Dict(
["F", "G"] => ["<X><X^2><XY>"],
["h", "c", "b"] => ["<Y><Y^2><XY>"]
)
elseif observation_case == "fast-slow-block-noxy"
group_identifiers = Dict(
["F", "G"] => ["<X><X^2>"],
["h", "c", "b"] => ["<Y><Y^2>"]
)
end

update_groups = create_update_groups(priors, observation, group_identifiers)



Expand All @@ -256,19 +324,15 @@ function main()
###

# EKP parameters
N_ens = 50 # number of ensemble members
N_ens = 70 # number of ensemble members
N_iter = 5 # number of EKI iterations
# initial parameters: N_params x N_ens
initial_params = construct_initial_ensemble(rng, priors, N_ens)

ekiobj = EKP.EnsembleKalmanProcess(
initial_params,
y,
Γ,
full_observation,
Inversion(),
localization_method = Localizers.NoLocalization(),
scheduler = DataMisfitController(terminate_at = 1e4),
failure_handler_method = SampleSuccGauss(),
verbose = true,
)
@info "Built EKP object"
Expand Down Expand Up @@ -296,14 +360,18 @@ function main()
@save output_directory * "output_storage.jld2" g_stored Γ

# Plots
fig = Figure(size = (450, 450))
Γ = get_obs_noise_cov(full_observation)
y = get_obs(full_observation)
fig = Figure(size = (450, 450))
aprior = Axis(fig[1, 1][1, 1])
apost = Axis(fig[2, 1][1, 1])
g_prior = get_g(ekiobj, 1)
g_post = get_g_final(ekiobj)
data_std = sqrt.([Γ[i, i] for i in 1:size(Γ, 1)])
data_dim = length(y)
dplot = 1:data_dim
dplot = 1:data_dim


lines!(aprior, dplot, sqrt_inv_Γ * y, color = (:black, 0.5), label = "data") #plots each row as new plot
lines!(apost, dplot, sqrt_inv_Γ * y, color = (:black, 0.5), label = "data") #plots each row as new plot
band!(
Expand Down Expand Up @@ -340,27 +408,15 @@ function main()
# We see that
bs = Int64(size(Γ, 1) / 5) # known block structure

# recall the parameters are
# F[1:3],G,h,c,b
# and the data blocks are
# <X>, <Y>, <X^2>, <Y^2>, <XY> X-slow, Y-fast
# F(3) -> <X>, <X^2>,
# Observation-based update_group
group_identifiers = Dict(["F"] => ["<X>", "<X^2>"], ["G", "h", "c", "b"] => ["<Y>", "<Y^2>", "<XY>"])
update_groups = create_update_groups(priors, observation, group_identifiers)


println("update groups:")
println("**************")
println(get_group_id.(update_groups))

ekiobj_grouped = EKP.EnsembleKalmanProcess(
initial_params,
y,
Γ,
observation,
Inversion(),
localization_method = Localizers.NoLocalization(),
scheduler = DataMisfitController(terminate_at = 1e4),
failure_handler_method = SampleSuccGauss(),
verbose = true,
update_groups = update_groups,
)
Expand Down Expand Up @@ -390,7 +446,9 @@ function main()


# Plots
fig = Figure(size = (450, 450))
Γ = get_obs_noise_cov(full_observation) # get the data/noise from the true statistics, despite our assumptions
y = get_obs(full_observation)
fig = Figure(size = (450, 450))
aprior = Axis(fig[1, 1][1, 1])
apost = Axis(fig[2, 1][1, 1])
g_prior = get_g(ekiobj_grouped, 1)
Expand Down

0 comments on commit d61f568

Please sign in to comment.