-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Write and test a forecast function which projects inference forwards …
…in time (#239) * new constructor function for InferenceConfig * fix constructor unit test * Update test_InferenceConfig.jl * define_epiprob function and unit test * reformat * add get_param_array function * Update get_param_array.jl * Simpler utility using `rowtable` * expanding test (not working yet) * forecast method template (not workign yet) * remove scrap code * forecasting function first pass * Update test_infer_and_forecast.jl * Add MCMCChains dep * rework generate_forecasts to sample white noise forwards in time * end-to-end test: generate, infer, forecast, viz * generate predictions * reorganise tests * generate forecast function --------- Co-authored-by: Sam Abbott <[email protected]>
- Loading branch information
1 parent
919b8a8
commit 496c6e9
Showing
16 changed files
with
248 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
""" | ||
Create a forecast EpiProblem by extending the given EpiProblem with additional | ||
forecast steps. | ||
# Arguments | ||
- `epiprob::EpiProblem`: The original EpiProblem to be extended. | ||
- `n::Integer`: The number of forecast steps to be added. | ||
# Returns | ||
- `forecast_epiprob::EpiProblem`: The forecast EpiProblem with extended time | ||
span. | ||
""" | ||
function define_forecast_epiprob(epiprob::EpiProblem, n::Integer) | ||
@assert n>0 "number of forecast steps n must be positive" | ||
|
||
forecast_epiprob = EpiProblem( | ||
epi_model = epiprob.epi_model, | ||
latent_model = epiprob.latent_model, | ||
observation_model = epiprob.observation_model, | ||
tspan = (epiprob.tspan[1], epiprob.tspan[2] + n) | ||
) | ||
|
||
return forecast_epiprob | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
include("define_forecast_epiprob.jl") | ||
include("generate_forecasts.jl") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
""" | ||
Generate forecasts for `n` time steps above based on the given inference results. | ||
# Arguments | ||
- `inference_results`: The results of the inference process. | ||
- `n`: The number of forecasts to generate. | ||
# Returns | ||
- `forecast_quantities`: The generated forecast quantities. | ||
""" | ||
function generate_forecasts(inference_results, n::Integer) | ||
inference_chn = inference_results["inference_results"].samples | ||
data = inference_results["inference_results"].data | ||
epiprob = inference_results["epiprob"] | ||
forecast_epiprob = define_forecast_epiprob(epiprob, n) | ||
forecast_mdl = generate_epiaware(forecast_epiprob, (y_t = missing,)) | ||
|
||
# Add forward generation of latent variables using `predict` | ||
pred_chn = mapreduce(chainscat, 1:size(inference_chn, 3)) do c | ||
mapreduce(vcat, 1:size(inference_chn, 1)) do i | ||
fwd_chn = predict(forecast_mdl, inference_chn[i, :, c]; include_all = true) | ||
setrange(fwd_chn, i:i) | ||
end | ||
end | ||
|
||
forecast_quantities = generated_observables(forecast_mdl, data, pred_chn) | ||
return forecast_quantities | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
""" | ||
Create an `EpiProblem` object based on the provided `InferenceConfig`. | ||
# Arguments | ||
- `config::InferenceConfig`: An instance of the `InferenceConfig` type. | ||
# Returns | ||
- `epi_prob::EpiProblem`: An `EpiProblem` object representing the defined epidemiological problem. | ||
""" | ||
function define_epiprob(config::InferenceConfig) | ||
shape = (config.gi_mean / config.gi_std)^2 | ||
scale = config.gi_std^2 / config.gi_mean | ||
gen_distribution = Gamma(shape, scale) | ||
|
||
model_data = EpiData(gen_distribution = gen_distribution, D_gen = config.D_gen, | ||
transformation = config.transformation) | ||
|
||
epi = config.igp(model_data, config.log_I0_prior) | ||
|
||
obs = LatentDelay( | ||
NegativeBinomialError(cluster_factor_prior = config.cluster_factor_prior), | ||
config.delay_distribution; D = config.D_obs) | ||
|
||
epi_prob = EpiProblem(epi_model = epi, | ||
latent_model = config.latent_model, | ||
observation_model = obs, | ||
tspan = config.tspan) | ||
|
||
return epi_prob | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
include("InferenceConfig.jl") | ||
include("generate_inference_results.jl") | ||
include("map_inference_results.jl") | ||
include("define_epiprob.jl") |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
using Test | ||
|
||
@testset "run inference and forecast for generated data" begin | ||
using DrWatson | ||
quickactivate(@__DIR__(), "EpiAwarePipeline") | ||
|
||
using EpiAwarePipeline, EpiAware, Plots, Statistics | ||
pipeline = RtwithoutRenewalPipeline() | ||
prior = RtwithoutRenewalPriorPipeline() | ||
|
||
# Set up data generation on a random scenario | ||
lookahead = 60 | ||
n_observation_steps = 28 | ||
tspan_gen = (1, n_observation_steps + lookahead) | ||
tspan_inf = (1, n_observation_steps) | ||
inference_method = make_inference_method(pipeline) | ||
truth_data_config = make_truth_data_configs(pipeline)[1] | ||
inference_configs = make_inference_configs(pipeline) | ||
inference_config = rand(inference_configs) | ||
|
||
# Generate truth data and plot | ||
truth_sampling = generate_inference_results( | ||
Dict("y_t" => missing, "truth_gi_mean" => 1.5), inference_config, | ||
prior; tspan = tspan_gen, inference_method = make_inference_method(prior)) |> | ||
d -> d["inference_results"] | ||
|
||
#Choose first sample to represent truth data | ||
truthdata = truth_sampling.generated[1].generated_y_t | ||
|
||
plt = scatter(truthdata, xlabel = "t", ylabel = "y_t", label = "truth data") | ||
vline!(plt, [n_observation_steps], label = "forecast start") | ||
|
||
# Run inference | ||
obs_truthdata = truthdata[tspan_inf[1]:tspan_inf[2]] | ||
inference_results = generate_inference_results( | ||
Dict("y_t" => obs_truthdata, "truth_gi_mean" => 1.5), | ||
inference_config, pipeline; tspan = tspan_inf, inference_method) | ||
|
||
@test inference_results["inference_results"] isa EpiAwareObservables | ||
|
||
# Make 21-day forecast | ||
forecast_quantities = generate_forecasts(inference_results, lookahead) | ||
@test forecast_quantities isa EpiAwareObservables | ||
|
||
# Make forecast spaghetti plot | ||
forecast_y_t = mapreduce(hcat, forecast_quantities.generated) do gen | ||
gen.generated_y_t | ||
end | ||
forecast_qs = mapreduce(hcat, [0.025, 0.25, 0.5, 0.75, 0.975]) do q | ||
map(eachrow(forecast_y_t)) do row | ||
quantile(row, q) | ||
end | ||
end | ||
plot!(plt, forecast_qs, label = "forecast quantiles", | ||
color = :grey, lw = [0.5 1.5 3 1.5 0.5]) | ||
plot!(plt, ylims = (-0.5, maximum(truthdata) * 1.25)) | ||
savefig(plt, | ||
joinpath(@__DIR__(), "forecast_y_t.png") | ||
) | ||
|
||
# Make forecast plot for Z_t | ||
infer_Z_t = mapreduce(hcat, inference_results["inference_results"].generated) do gen | ||
gen.Z_t | ||
end | ||
forecast_Z_t = mapreduce(hcat, forecast_quantities.generated) do gen | ||
gen.Z_t | ||
end | ||
plt_Zt = plot( | ||
truth_sampling.generated[1].Z_t, lw = 3, color = :black, label = "truth Z_t") | ||
plot!(plt_Zt, infer_Z_t, xlabel = "t", ylabel = "Z_t", | ||
label = "", color = :grey, alpha = 0.05) | ||
plot!((n_observation_steps + 1):size(forecast_Z_t, 1), | ||
forecast_Z_t[(n_observation_steps + 1):end, :], | ||
label = "", color = :red, alpha = 0.05) | ||
vline!(plt_Zt, [n_observation_steps], label = "forecast start") | ||
|
||
savefig(plt_Zt, | ||
joinpath(@__DIR__(), "forecast_Z_t.png") | ||
) | ||
end |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
@testset "define_forecast_epiprob" begin | ||
using EpiAwarePipeline | ||
pipeline = RtwithoutRenewalPipeline() | ||
|
||
inference_configs = make_inference_configs(pipeline) | ||
|
||
case_data = missing | ||
tspan = (1, 28) | ||
epimethod = make_inference_method(pipeline) | ||
|
||
epiprob = InferenceConfig(rand(inference_configs); case_data, tspan, epimethod) |> | ||
define_epiprob | ||
|
||
@test_throws AssertionError define_forecast_epiprob(epiprob, -1) | ||
|
||
n_fr = 7 | ||
forecast_epiprob = define_forecast_epiprob(epiprob, 7) | ||
@test forecast_epiprob.tspan == (epiprob.tspan[1], epiprob.tspan[2] + n_fr) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
@testset "test define_epiprob" begin | ||
using EpiAwarePipeline | ||
pipeline = RtwithoutRenewalPipeline() | ||
|
||
inference_configs = make_inference_configs(pipeline) | ||
|
||
case_data = missing | ||
tspan = (1, 28) | ||
epimethod = make_inference_method(pipeline) | ||
|
||
epiprob = InferenceConfig(rand(inference_configs); case_data, tspan, epimethod) |> | ||
define_epiprob | ||
|
||
@test epiprob isa EpiProblem | ||
end |
Oops, something went wrong.