From 496c6e9bf80dbd7d63b02168701895f5cdeb5d6b Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Wed, 29 May 2024 12:28:16 +0100 Subject: [PATCH] Write and test a forecast function which projects inference forwards in time (#239) * new constructor function for InferenceConfig * fix constructor unit test * Update test_InferenceConfig.jl * define_epiprob function and unit test * reformat * add get_param_array function * Update get_param_array.jl * Simpler utility using `rowtable` * expanding test (not working yet) * forecast method template (not workign yet) * remove scrap code * forecasting function first pass * Update test_infer_and_forecast.jl * Add MCMCChains dep * rework generate_forecasts to sample white noise forwards in time * end-to-end test: generate, infer, forecast, viz * generate predictions * reorganise tests * generate forecast function --------- Co-authored-by: Sam Abbott --- pipeline/Project.toml | 8 +- pipeline/src/EpiAwarePipeline.jl | 8 +- .../src/forecast/define_forecast_epiprob.jl | 25 ++++++ pipeline/src/forecast/forecast.jl | 2 + pipeline/src/forecast/generate_forecasts.jl | 29 +++++++ pipeline/src/infer/InferenceConfig.jl | 36 ++++----- pipeline/src/infer/define_epiprob.jl | 31 +++++++ .../src/infer/generate_inference_results.jl | 16 +--- pipeline/src/infer/infer.jl | 1 + .../{ => end-to-end}/test_full_inference.jl | 0 .../end-to-end/test_infer_and_forecast.jl | 80 +++++++++++++++++++ .../{ => end-to-end}/test_prior_predictive.jl | 0 pipeline/test/forecast/test_forecast.jl | 19 +++++ pipeline/test/infer/test_InferenceConfig.jl | 28 ++++--- pipeline/test/infer/test_define_epiprob.jl | 15 ++++ pipeline/test/runtests.jl | 2 + 16 files changed, 248 insertions(+), 52 deletions(-) create mode 100644 pipeline/src/forecast/define_forecast_epiprob.jl create mode 100644 pipeline/src/forecast/forecast.jl create mode 100644 pipeline/src/forecast/generate_forecasts.jl create mode 100644 pipeline/src/infer/define_epiprob.jl rename pipeline/test/{ => end-to-end}/test_full_inference.jl (100%) create mode 100644 pipeline/test/end-to-end/test_infer_and_forecast.jl rename pipeline/test/{ => end-to-end}/test_prior_predictive.jl (100%) create mode 100644 pipeline/test/forecast/test_forecast.jl create mode 100644 pipeline/test/infer/test_define_epiprob.jl diff --git a/pipeline/Project.toml b/pipeline/Project.toml index 745f57f32..764a5b874 100644 --- a/pipeline/Project.toml +++ b/pipeline/Project.toml @@ -14,19 +14,23 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" DrWatson = "634d3b9d-ee7a-5ddf-bec9-22491ea816e1" EpiAware = "b2eeebe4-5992-4301-9193-7ebc9f62c855" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -RCall = "6f49c342-dc21-5d91-9882-a32aef131414" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] +ADTypes = "0.2" +AbstractMCMC = "5.2" CSV = "0.10" Dagger = "0.18" DataFramesMeta = "0.14" +Dates = "1.10" Distributions = "0.25" DocStringExtensions = "0.9" DrWatson = "2.15" +JLD2 = "0.4" Plots = "1.40" -RCall = "0.14" Statistics = "1.10" julia = "1.10" diff --git a/pipeline/src/EpiAwarePipeline.jl b/pipeline/src/EpiAwarePipeline.jl index a53c2229a..fd0dbb118 100644 --- a/pipeline/src/EpiAwarePipeline.jl +++ b/pipeline/src/EpiAwarePipeline.jl @@ -11,7 +11,7 @@ with execution determined by available computational resources. module EpiAwarePipeline using CSV, Dagger, DataFramesMeta, Dates, Distributions, DocStringExtensions, DrWatson, - EpiAware, Plots, Statistics, ADTypes, AbstractMCMC, Plots, JLD2 + EpiAware, Plots, Statistics, ADTypes, AbstractMCMC, Plots, JLD2, MCMCChains, Turing # Exported pipeline types export AbstractEpiAwarePipeline, EpiAwarePipeline, RtwithoutRenewalPipeline, @@ -32,7 +32,10 @@ export do_truthdata, do_inference, do_pipeline export simulate, generate_truthdata # Exported functions: infer functions -export infer, generate_inference_results, map_inference_results +export infer, generate_inference_results, map_inference_results, define_epiprob + +# Exported functions: forecast functions +export define_forecast_epiprob, generate_forecasts # Exported functions: plot functions export plot_truth_data, plot_Rt @@ -42,5 +45,6 @@ include("pipeline/pipeline.jl") include("constructors/constructors.jl") include("simulate/simulate.jl") include("infer/infer.jl") +include("forecast/forecast.jl") include("plot_functions.jl") end diff --git a/pipeline/src/forecast/define_forecast_epiprob.jl b/pipeline/src/forecast/define_forecast_epiprob.jl new file mode 100644 index 000000000..79eab10c5 --- /dev/null +++ b/pipeline/src/forecast/define_forecast_epiprob.jl @@ -0,0 +1,25 @@ +""" +Create a forecast EpiProblem by extending the given EpiProblem with additional +forecast steps. + +# Arguments +- `epiprob::EpiProblem`: The original EpiProblem to be extended. +- `n::Integer`: The number of forecast steps to be added. + +# Returns +- `forecast_epiprob::EpiProblem`: The forecast EpiProblem with extended time +span. + +""" +function define_forecast_epiprob(epiprob::EpiProblem, n::Integer) + @assert n>0 "number of forecast steps n must be positive" + + forecast_epiprob = EpiProblem( + epi_model = epiprob.epi_model, + latent_model = epiprob.latent_model, + observation_model = epiprob.observation_model, + tspan = (epiprob.tspan[1], epiprob.tspan[2] + n) + ) + + return forecast_epiprob +end diff --git a/pipeline/src/forecast/forecast.jl b/pipeline/src/forecast/forecast.jl new file mode 100644 index 000000000..75dcd2c91 --- /dev/null +++ b/pipeline/src/forecast/forecast.jl @@ -0,0 +1,2 @@ +include("define_forecast_epiprob.jl") +include("generate_forecasts.jl") diff --git a/pipeline/src/forecast/generate_forecasts.jl b/pipeline/src/forecast/generate_forecasts.jl new file mode 100644 index 000000000..52842bd75 --- /dev/null +++ b/pipeline/src/forecast/generate_forecasts.jl @@ -0,0 +1,29 @@ +""" +Generate forecasts for `n` time steps above based on the given inference results. + +# Arguments +- `inference_results`: The results of the inference process. +- `n`: The number of forecasts to generate. + +# Returns +- `forecast_quantities`: The generated forecast quantities. + +""" +function generate_forecasts(inference_results, n::Integer) + inference_chn = inference_results["inference_results"].samples + data = inference_results["inference_results"].data + epiprob = inference_results["epiprob"] + forecast_epiprob = define_forecast_epiprob(epiprob, n) + forecast_mdl = generate_epiaware(forecast_epiprob, (y_t = missing,)) + + # Add forward generation of latent variables using `predict` + pred_chn = mapreduce(chainscat, 1:size(inference_chn, 3)) do c + mapreduce(vcat, 1:size(inference_chn, 1)) do i + fwd_chn = predict(forecast_mdl, inference_chn[i, :, c]; include_all = true) + setrange(fwd_chn, i:i) + end + end + + forecast_quantities = generated_observables(forecast_mdl, data, pred_chn) + return forecast_quantities +end diff --git a/pipeline/src/infer/InferenceConfig.jl b/pipeline/src/infer/InferenceConfig.jl index 9333b541f..e2bbd2f69 100644 --- a/pipeline/src/infer/InferenceConfig.jl +++ b/pipeline/src/infer/InferenceConfig.jl @@ -50,6 +50,17 @@ struct InferenceConfig{T, F, I, L, E} gi_mean, gi_std, igp, latent_model, case_data, tspan, epimethod, D_gen, transformation, delay_distribution, D_obs, log_I0_prior, cluster_factor_prior) end + + function InferenceConfig(inference_config::Dict; case_data, tspan, epimethod) + InferenceConfig( + inference_config["igp"], inference_config["latent_namemodels"].second; + gi_mean = inference_config["gi_mean"], + gi_std = inference_config["gi_std"], + case_data = case_data, + tspan = tspan, + epimethod = epimethod + ) + end end """ @@ -65,34 +76,15 @@ to make inference on and model configuration. """ function infer(config::InferenceConfig) - #Define infection-generating model - shape = (config.gi_mean / config.gi_std)^2 - scale = config.gi_std^2 / config.gi_mean - gen_distribution = Gamma(shape, scale) - - #Define the infection-generating process - model_data = EpiData(gen_distribution = gen_distribution, D_gen = config.D_gen, - transformation = config.transformation) - - epi = config.igp(model_data, config.log_I0_prior) - - #Define the infection conditional observation distribution - obs = LatentDelay( - NegativeBinomialError(cluster_factor_prior = config.cluster_factor_prior), - config.delay_distribution; D = config.D_obs) - #Define the EpiProblem - epi_prob = EpiProblem(epi_model = epi, - latent_model = config.latent_model, - observation_model = obs, - tspan = config.tspan) - + epi_prob = define_epiprob(config) idxs = config.tspan[1]:config.tspan[2] + #Return the sampled infections and observations y_t = ismissing(config.case_data) ? missing : config.case_data[idxs] inference_results = apply_method(epi_prob, config.epimethod, (y_t = y_t,) ) - return Dict("inference_results" => inference_results) + return Dict("inference_results" => inference_results, "epiprob" => epi_prob) end diff --git a/pipeline/src/infer/define_epiprob.jl b/pipeline/src/infer/define_epiprob.jl new file mode 100644 index 000000000..42a9242f6 --- /dev/null +++ b/pipeline/src/infer/define_epiprob.jl @@ -0,0 +1,31 @@ +""" +Create an `EpiProblem` object based on the provided `InferenceConfig`. + +# Arguments +- `config::InferenceConfig`: An instance of the `InferenceConfig` type. + +# Returns +- `epi_prob::EpiProblem`: An `EpiProblem` object representing the defined epidemiological problem. + +""" +function define_epiprob(config::InferenceConfig) + shape = (config.gi_mean / config.gi_std)^2 + scale = config.gi_std^2 / config.gi_mean + gen_distribution = Gamma(shape, scale) + + model_data = EpiData(gen_distribution = gen_distribution, D_gen = config.D_gen, + transformation = config.transformation) + + epi = config.igp(model_data, config.log_I0_prior) + + obs = LatentDelay( + NegativeBinomialError(cluster_factor_prior = config.cluster_factor_prior), + config.delay_distribution; D = config.D_obs) + + epi_prob = EpiProblem(epi_model = epi, + latent_model = config.latent_model, + observation_model = obs, + tspan = config.tspan) + + return epi_prob +end diff --git a/pipeline/src/infer/generate_inference_results.jl b/pipeline/src/infer/generate_inference_results.jl index be12bbea3..a878334fa 100644 --- a/pipeline/src/infer/generate_inference_results.jl +++ b/pipeline/src/infer/generate_inference_results.jl @@ -18,13 +18,7 @@ function generate_inference_results( tspan, inference_method, prfix_name = "observables", datadir_name = "epiaware_observables") config = InferenceConfig( - inference_config["igp"], inference_config["latent_namemodels"].second; - gi_mean = inference_config["gi_mean"], - gi_std = inference_config["gi_std"], - case_data = truthdata["y_t"], - tspan = tspan, - epimethod = inference_method - ) + inference_config; case_data = truthdata["y_t"], tspan, epimethod = inference_method) # produce or load inference results prfx = prfix_name * "_igp_" * string(inference_config["igp"]) * "_latentmodel_" * @@ -44,13 +38,7 @@ function generate_inference_results( tspan, inference_method, prfix_name = "prior_observables", datadir_name = "epiaware_observables") config = InferenceConfig( - inference_config["igp"], inference_config["latent_namemodels"].second; - gi_mean = inference_config["gi_mean"], - gi_std = inference_config["gi_std"], - case_data = missing, - tspan = tspan, - epimethod = inference_method - ) + inference_config; case_data = missing, tspan, epimethod = inference_method) # produce or load inference results prfx = prfix_name * "_igp_" * string(inference_config["igp"]) * "_latentmodel_" * diff --git a/pipeline/src/infer/infer.jl b/pipeline/src/infer/infer.jl index 557ca59f4..f8b9feb9d 100644 --- a/pipeline/src/infer/infer.jl +++ b/pipeline/src/infer/infer.jl @@ -1,3 +1,4 @@ include("InferenceConfig.jl") include("generate_inference_results.jl") include("map_inference_results.jl") +include("define_epiprob.jl") diff --git a/pipeline/test/test_full_inference.jl b/pipeline/test/end-to-end/test_full_inference.jl similarity index 100% rename from pipeline/test/test_full_inference.jl rename to pipeline/test/end-to-end/test_full_inference.jl diff --git a/pipeline/test/end-to-end/test_infer_and_forecast.jl b/pipeline/test/end-to-end/test_infer_and_forecast.jl new file mode 100644 index 000000000..6d08e656e --- /dev/null +++ b/pipeline/test/end-to-end/test_infer_and_forecast.jl @@ -0,0 +1,80 @@ +using Test + +@testset "run inference and forecast for generated data" begin + using DrWatson + quickactivate(@__DIR__(), "EpiAwarePipeline") + + using EpiAwarePipeline, EpiAware, Plots, Statistics + pipeline = RtwithoutRenewalPipeline() + prior = RtwithoutRenewalPriorPipeline() + + # Set up data generation on a random scenario + lookahead = 60 + n_observation_steps = 28 + tspan_gen = (1, n_observation_steps + lookahead) + tspan_inf = (1, n_observation_steps) + inference_method = make_inference_method(pipeline) + truth_data_config = make_truth_data_configs(pipeline)[1] + inference_configs = make_inference_configs(pipeline) + inference_config = rand(inference_configs) + + # Generate truth data and plot + truth_sampling = generate_inference_results( + Dict("y_t" => missing, "truth_gi_mean" => 1.5), inference_config, + prior; tspan = tspan_gen, inference_method = make_inference_method(prior)) |> + d -> d["inference_results"] + + #Choose first sample to represent truth data + truthdata = truth_sampling.generated[1].generated_y_t + + plt = scatter(truthdata, xlabel = "t", ylabel = "y_t", label = "truth data") + vline!(plt, [n_observation_steps], label = "forecast start") + + # Run inference + obs_truthdata = truthdata[tspan_inf[1]:tspan_inf[2]] + inference_results = generate_inference_results( + Dict("y_t" => obs_truthdata, "truth_gi_mean" => 1.5), + inference_config, pipeline; tspan = tspan_inf, inference_method) + + @test inference_results["inference_results"] isa EpiAwareObservables + + # Make 21-day forecast + forecast_quantities = generate_forecasts(inference_results, lookahead) + @test forecast_quantities isa EpiAwareObservables + + # Make forecast spaghetti plot + forecast_y_t = mapreduce(hcat, forecast_quantities.generated) do gen + gen.generated_y_t + end + forecast_qs = mapreduce(hcat, [0.025, 0.25, 0.5, 0.75, 0.975]) do q + map(eachrow(forecast_y_t)) do row + quantile(row, q) + end + end + plot!(plt, forecast_qs, label = "forecast quantiles", + color = :grey, lw = [0.5 1.5 3 1.5 0.5]) + plot!(plt, ylims = (-0.5, maximum(truthdata) * 1.25)) + savefig(plt, + joinpath(@__DIR__(), "forecast_y_t.png") + ) + + # Make forecast plot for Z_t + infer_Z_t = mapreduce(hcat, inference_results["inference_results"].generated) do gen + gen.Z_t + end + forecast_Z_t = mapreduce(hcat, forecast_quantities.generated) do gen + gen.Z_t + end + plt_Zt = plot( + truth_sampling.generated[1].Z_t, lw = 3, color = :black, label = "truth Z_t") + plot!(plt_Zt, infer_Z_t, xlabel = "t", ylabel = "Z_t", + label = "", color = :grey, alpha = 0.05) + plot!((n_observation_steps + 1):size(forecast_Z_t, 1), + forecast_Z_t[(n_observation_steps + 1):end, :], + label = "", color = :red, alpha = 0.05) + vline!(plt_Zt, [n_observation_steps], label = "forecast start") + + savefig(plt_Zt, + joinpath(@__DIR__(), "forecast_Z_t.png") + ) +end diff --git a/pipeline/test/test_prior_predictive.jl b/pipeline/test/end-to-end/test_prior_predictive.jl similarity index 100% rename from pipeline/test/test_prior_predictive.jl rename to pipeline/test/end-to-end/test_prior_predictive.jl diff --git a/pipeline/test/forecast/test_forecast.jl b/pipeline/test/forecast/test_forecast.jl new file mode 100644 index 000000000..c428b9232 --- /dev/null +++ b/pipeline/test/forecast/test_forecast.jl @@ -0,0 +1,19 @@ +@testset "define_forecast_epiprob" begin + using EpiAwarePipeline + pipeline = RtwithoutRenewalPipeline() + + inference_configs = make_inference_configs(pipeline) + + case_data = missing + tspan = (1, 28) + epimethod = make_inference_method(pipeline) + + epiprob = InferenceConfig(rand(inference_configs); case_data, tspan, epimethod) |> + define_epiprob + + @test_throws AssertionError define_forecast_epiprob(epiprob, -1) + + n_fr = 7 + forecast_epiprob = define_forecast_epiprob(epiprob, 7) + @test forecast_epiprob.tspan == (epiprob.tspan[1], epiprob.tspan[2] + n_fr) +end diff --git a/pipeline/test/infer/test_InferenceConfig.jl b/pipeline/test/infer/test_InferenceConfig.jl index b298f292f..5b98d6aa3 100644 --- a/pipeline/test/infer/test_InferenceConfig.jl +++ b/pipeline/test/infer/test_InferenceConfig.jl @@ -1,6 +1,4 @@ - -# Test the InferenceConfig struct constructor -@testset "InferenceConfig" begin +@testset "InferenceConfig: constructor function" begin using Distributions, EpiAwarePipeline, EpiAware struct TestLatentModel <: AbstractLatentModel @@ -16,16 +14,15 @@ epimethod = TestMethod() case_data = [10, 20, 30, 40, 50] tspan = (1, 5) + @testset "config_parameters back from constructor" begin + config = InferenceConfig(igp, latent_model; + gi_mean = gi_mean, + gi_std = gi_std, + case_data = case_data, + tspan = tspan, + epimethod = epimethod + ) - config = InferenceConfig(igp, latent_model; - gi_mean = gi_mean, - gi_std = gi_std, - case_data = case_data, - tspan = tspan, - epimethod = epimethod - ) - - @testset "config_parameters" begin @test config.gi_mean == gi_mean @test config.gi_std == gi_std @test config.igp == igp @@ -34,4 +31,11 @@ @test config.tspan == tspan @test config.epimethod == epimethod end + + @testset "construct from config dictionary" begin + pipeline = RtwithoutRenewalPipeline() + inference_configs = make_inference_configs(pipeline) + @test [InferenceConfig(ic; case_data, tspan, epimethod) isa InferenceConfig + for ic in inference_configs] |> all + end end diff --git a/pipeline/test/infer/test_define_epiprob.jl b/pipeline/test/infer/test_define_epiprob.jl new file mode 100644 index 000000000..61db6bbd2 --- /dev/null +++ b/pipeline/test/infer/test_define_epiprob.jl @@ -0,0 +1,15 @@ +@testset "test define_epiprob" begin + using EpiAwarePipeline + pipeline = RtwithoutRenewalPipeline() + + inference_configs = make_inference_configs(pipeline) + + case_data = missing + tspan = (1, 28) + epimethod = make_inference_method(pipeline) + + epiprob = InferenceConfig(rand(inference_configs); case_data, tspan, epimethod) |> + define_epiprob + + @test epiprob isa EpiProblem +end diff --git a/pipeline/test/runtests.jl b/pipeline/test/runtests.jl index 371a2d1a7..c52f15c07 100644 --- a/pipeline/test/runtests.jl +++ b/pipeline/test/runtests.jl @@ -7,3 +7,5 @@ include("constructors/test_constructors.jl"); include("simulate/test_TruthSimulationConfig.jl"); include("simulate/test_SimulationConfig.jl"); include("infer/test_InferenceConfig.jl"); +include("infer/test_define_epiprob.jl"); +include("forecast/test_forecast.jl");