Skip to content

Commit

Permalink
update groups runs with minibatching
Browse files Browse the repository at this point in the history
  • Loading branch information
odunbar committed Aug 30, 2024
1 parent 2ea9a9e commit 3372f64
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 57 deletions.
42 changes: 28 additions & 14 deletions examples/UpdateGroups/calibrate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ function run_G_ensemble(state, lsettings, ϕ_i, window, data_dim)
N_ens = size(ϕ_i, 2)
data_sample = zeros(data_dim, N_ens)

for j in 1:N_ens
Threads.@threads for j in 1:N_ens
# ϕ_i is n_params x n_ens
params_i = LParams(
ϕ_i[1:3, j], # F
Expand Down Expand Up @@ -214,12 +214,28 @@ function main()
Γ = cov(data_samples, dims = 2) # estimate covariance from samples
# add a little additive and multiplicative inflation
Γ += 1e4 * eps() * I # 10^-12 just to make things nonzero
#blocksize = Int64(size(Γ, 1) / 5) # known block structure
blocksize = Int64(size(Γ, 1) / 5) # known block structure
#meanblocks = [mean([Γ[i, i] for i in ((j - 1) * blocksize + 1):(j * blocksize)]) for j in 1:5]
#Γ += 1e-4* kron(Diagonal(meanblocks),I(blocksize)) # this will add scaled noise to the diagonal scaled by the block

y_mean = mean(data_samples, dims = 2)
y = data_samples[:, shuffle(rng, 1:n_sample_cov)[1]] # random data point as the data

# build a nice observation object from this (blocks assumption)
data_block_names = ["<X>", "<Y>", "<X^2>", "<Y^2>", "<XY>"]
observation_vec = []
for i in 1:5
idx = ((i - 1) * blocksize + 1):(i * blocksize)
push!(
observation_vec,
Observation(Dict("samples" => y[idx], "covariances" => Γ[idx, idx], "names" => data_block_names[i])),
)
end
observation = combine_observations(observation_vec)
y = get_obs(observation)



fig = Figure(size = (450, 450))
= Axis(fig[1, 1][1, 1])
adata = Axis(fig[2, 1][1, 1])
Expand All @@ -240,8 +256,8 @@ function main()
###

# EKP parameters
N_ens = 30 # number of ensemble members
N_iter = 10 # number of EKI iterations
N_ens = 50 # number of ensemble members
N_iter = 5 # number of EKI iterations
# initial parameters: N_params x N_ens
initial_params = construct_initial_ensemble(rng, priors, N_ens)

Expand Down Expand Up @@ -328,16 +344,14 @@ function main()
# F[1:3],G,h,c,b
# and the data blocks are
# <X>, <Y>, <X^2>, <Y^2>, <XY> X-slow, Y-fast
# F(3) -> <X>, <X^2>,
# Observation-based update_group
group_identifiers = Dict(["F"] => ["<X>", "<X^2>"], ["G", "h", "c", "b"] => ["<Y>", "<Y^2>", "<XY>"])
update_groups = create_update_groups(priors, observation, group_identifiers)

# F(3) -> <X>, <X^2>,
# group_slow = UpdateGroup(collect(1:4), reduce(vcat, [collect(1:bs), collect(2 * bs + 1:3 * bs), collect(4*bs+1:5*bs)]))
group_slow = UpdateGroup(collect(1:3), reduce(vcat, [collect(1:bs), collect((2 * bs + 1):(3 * bs))]))#, collect(4*bs+1:5*bs)]))
# G,h,c,b -> <Y>, <Y^2>,<XY>
#group_fast = UpdateGroup(collect(4:7), collect(1:(5 * bs)))
group_fast = UpdateGroup(collect(4:7), reduce(vcat, [collect((bs + 1):(2 * bs)), collect((3 * bs + 1):(5 * bs))]))



println("update groups:")
println("**************")
println(get_group_id.(update_groups))

ekiobj_grouped = EKP.EnsembleKalmanProcess(
initial_params,
Expand All @@ -348,7 +362,7 @@ function main()
scheduler = DataMisfitController(terminate_at = 1e4),
failure_handler_method = SampleSuccGauss(),
verbose = true,
update_groups = [group_slow, group_fast],
update_groups = update_groups,
)
@info "Built grouped EKP object"

Expand Down
38 changes: 30 additions & 8 deletions src/EnsembleKalmanProcess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ function EnsembleKalmanProcess(

obs_for_minibatch = get_obs(observation_series) # get stacked observation over minibatch
obs_size_for_minibatch = length(obs_for_minibatch) # number of dims in the stacked observation

IT = typeof(N_ens)
#store for model evaluations
g = []
Expand All @@ -225,13 +224,14 @@ function EnsembleKalmanProcess(
# timestep store
Δt = FT[]

# defined groups of parameters to be updated by groups of data
# defined groups of parameters to be updated by groups of data
obs_size = length(get_obs(get_observations(observation_series)[1])) #deduce size just from first observation
if isnothing(update_groups)
groups = [UpdateGroup(1:N_par, 1:obs_size_for_minibatch)] # vec length 1
groups = [UpdateGroup(1:N_par, 1:obs_size)] # vec length 1
else
groups = update_groups
end
update_group_consistency(groups, N_par, obs_size_for_minibatch) # consistency checks
update_group_consistency(groups, N_par, obs_size) # consistency checks
VVV = typeof(groups)

scheduler = configuration["scheduler"]
Expand Down Expand Up @@ -573,6 +573,28 @@ function get_update_groups(ekp::EnsembleKalmanProcess)
return ekp.update_groups
end

"""
list_update_groups_over_minibatch(ekp::EnsembleKalmanProcess)
Return u_groups and g_groups for the current minibatch, i.e. the subset of
"""
function list_update_groups_over_minibatch(ekp::EnsembleKalmanProcess)
os = get_observation_series(ekp)
len_mb = length(get_current_minibatch(os)) # number of obs per batch
len_obs = Int(length(get_obs(os)) / len_mb) # length of obs in a batch
update_groups = get_update_groups(ekp)
u_groups = get_u_group.(update_groups) # update_group indices
g_groups = get_g_group.(update_groups)

# extend group indices from one obs to the minibatch of obs
new_u_groups = [reduce(vcat, [(i - 1) * len_obs .+ u_group for i in 1:len_mb]) for u_group in u_groups]
new_g_groups = [reduce(vcat, [(i - 1) * len_obs .+ g_group for i in 1:len_mb]) for g_group in g_groups]

return new_u_groups, new_g_groups
end




"""
get_process(ekp::EnsembleKalmanProcess)
Return `process` field of EnsembleKalmanProcess.
Expand Down Expand Up @@ -895,12 +917,12 @@ function update_ensemble!(

terminate = calculate_timestep!(ekp, g, Δt_new)
if isnothing(terminate)
update_groups = get_update_groups(ekp)
u_groups, g_groups = list_update_groups_over_minibatch(ekp)
u = zeros(size(get_u_prior(ekp)))
# with several g_groups we want to do
# u_n+1 = u_n + sum(update{g_i})
# but get u_n+1 = sum(u_n + update{g_i}),remove the extra u_ns
n_g_groups = length(get_g_group(update_groups))
n_g_groups = length(g_groups)
u -= get_u_final(ekp) * (n_g_groups - 1)

if ekp.verbose
Expand All @@ -915,8 +937,8 @@ function update_ensemble!(

# update each u_block with every g_block
# for (u_idx,g_idx) in zip(get_u_group(update_groups),get_g_group(update_groups))
for u_idx in get_u_group(update_groups)
for g_idx in get_g_group(update_groups)
for u_idx in u_groups
for g_idx in g_groups
u[u_idx, :] += update_ensemble!(ekp, g, get_process(ekp), u_idx, g_idx; ekp_kwargs...)
end
end
Expand Down
112 changes: 77 additions & 35 deletions src/UpdateGroup.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# EKP implementation of the domain and observation localization from the literature

using ..ParameterDistributions
export UpdateGroup
export get_u_group, get_g_group, update_group_consistency
export get_u_group, get_g_group, get_group_id, update_group_consistency, create_update_groups

"""
struct UpdateGroup {VV <: AbstractVector}
Expand All @@ -16,72 +17,113 @@ $(TYPEDFIELDS)
"""
struct UpdateGroup
"vector of parameter indices, forms part(or all) of a partition of 1:input_dim with other UpdateGroups provided"
"vector of parameter indices to form a partition of 1:input_dim) with other UpdateGroups provided"
u_group::Vector{Int}
"vector of data indices, must lie within 1:output_dim"
"vector of data indices that lie within 1:output_dim)"
g_group::Vector{Int}
# process::Process # in future
# localizer::Localizer # in future
# inflation::Inflation # in future
group_id::Dict
end

function UpdateGroup(u_group, g_group)
return UpdateGroup(
u_group,
g_group,
Dict("[$(minimum(u_group)),...,$(maximum(u_group))]" => "[$(minimum(g_group)),...,$(maximum(g_group))]"),
)
end


get_u_group(group::UpdateGroup) = group.u_group
get_g_group(group::UpdateGroup) = group.g_group
get_group_id(group::UpdateGroup) = group.group_id

function get_u_group(groups::VV) where {VV <: AbstractVector}
u_group = []
for group in groups
push!(u_group, get_u_group(group))
end
return u_group
end

function get_g_group(groups::VV) where {VV <: AbstractVector}
g_group = []
for group in groups
push!(g_group, get_g_group(group))
end
return g_group
end

# check an array of update_groups are consistent, i.e. common sizing for u,g, and that u is a partition.
"""
$(TYPEDSIGNATURES)
Check the consistency of sizing and partitioning of the `UpdateGroup` array
Check the consistency of sizing and partitioning of the `UpdateGroup` array if it contains indices
No consistency check if u,g has strings internally
"""
function update_group_consistency(groups::VV, input_dim::Int, output_dim::Int) where {VV <: AbstractVector}

u_groups = get_u_group(groups)
g_groups = get_g_group(groups)
u_groups = get_u_group.(groups)
g_groups = get_g_group.(groups)

# check there is an index in each group
if any(length(group) == 0 for group in u_groups)
throw(ArgumentError("all `UpdateGroup.u_group` must contain at least one parameter index"))
throw(ArgumentError("all `UpdateGroup.u_group` must contain at least one parameter identifier"))
end
if any(length(group) == 0 for group in g_groups)
throw(ArgumentError("all `UpdateGroup.g_group` must contain at least one parameter index"))
throw(ArgumentError("all `UpdateGroup.g_group` must contain at least one data identifier"))
end

# check for partition
# check for partition (only if indices passed)
u_flat = reduce(vcat, u_groups)
if !(1:input_dim == sort(u_flat))
throw(
ArgumentError(
"The combined 'UpdateGroup.u_group's must partition the indices of the input parameters: 1:$(input_dim), received: $(sort(u_flat))",
),
)
if eltype(sort(u_flat)) == Int
throw(
ArgumentError(
"The combined 'UpdateGroup.u_group's must partition the indices of the input parameters: 1:$(input_dim), received: $(sort(u_flat))",
),
)
end
end

g_flat = reduce(vcat, g_groups)
if any(gf > output_dim for gf in g_flat) || any(gf <= 0 for gf in g_flat)
throw(
ArgumentError(
"The UpdateGroup.g_group must contains values in: 1:$(output_dim), found values outside this range",
),
)
if eltype(g_flat) == Int
if any(gf > output_dim for gf in g_flat) || any(gf <= 0 for gf in g_flat)
throw(
ArgumentError(
"The UpdateGroup.g_group must contains values in: 1:$(output_dim), found values outside this range",
),
)
end
end

# pass the tests
return true
end

## Convience constructor for update_groups
"""
$(TYPEDSIGNATURES)
To construct a list of update-groups populated by indices of parameter distributions and indices of observations, from a dictionary of `group_identifiers = Dict(..., group_k_input_names => group_k_output_names, ...)`
"""
function create_update_groups(
prior::PD,
observation::OB,
group_identifiers::Dict,
) where {PD <: ParameterDistribution, OB <: Observation}

param_names = get_name(prior)
param_indices = batch(prior, function_parameter_opt = "dof")

obs_names = get_names(observation)
obs_indices = get_indices(observation)

update_groups = []
for (key, val) in pairs(group_identifiers)
key_vec = isa(key, AbstractString) ? [key] : key # make it iterable
val_vec = isa(val, AbstractString) ? [val] : val
u_group = []
g_group = []
for pn in key_vec
pi = param_indices[pn .== param_names]
push!(u_group, isa(pi, Int) ? [pi] : pi)
end
for obn in val_vec
oi = obs_indices[obn .== obs_names]
push!(g_group, isa(oi, Int) ? [oi] : oi)
end
u_group = reduce(vcat, reduce(vcat, u_group))
g_group = reduce(vcat, reduce(vcat, g_group))
push!(update_groups, UpdateGroup(u_group, g_group, Dict(key_vec => val_vec)))
end
return update_groups

end

0 comments on commit 3372f64

Please sign in to comment.