Skip to content

Commit

Permalink
Write and test a forecast function which projects inference forwards …
Browse files Browse the repository at this point in the history
…in time (#239)

* new constructor function for InferenceConfig

* fix constructor unit test

* Update test_InferenceConfig.jl

* define_epiprob function and unit test

* reformat

* add get_param_array function

* Update get_param_array.jl

* Simpler utility using `rowtable`

* expanding test (not working yet)

* forecast method template (not workign yet)

* remove scrap code

* forecasting function first pass

* Update test_infer_and_forecast.jl

* Add MCMCChains dep

* rework generate_forecasts to sample white noise forwards in time

* end-to-end test: generate, infer, forecast, viz

* generate predictions

* reorganise tests

* generate forecast function

---------

Co-authored-by: Sam Abbott <[email protected]>
  • Loading branch information
SamuelBrand1 and seabbs authored May 29, 2024
1 parent 919b8a8 commit 496c6e9
Show file tree
Hide file tree
Showing 16 changed files with 248 additions and 52 deletions.
8 changes: 6 additions & 2 deletions pipeline/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,23 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
DrWatson = "634d3b9d-ee7a-5ddf-bec9-22491ea816e1"
EpiAware = "b2eeebe4-5992-4301-9193-7ebc9f62c855"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
ADTypes = "0.2"
AbstractMCMC = "5.2"
CSV = "0.10"
Dagger = "0.18"
DataFramesMeta = "0.14"
Dates = "1.10"
Distributions = "0.25"
DocStringExtensions = "0.9"
DrWatson = "2.15"
JLD2 = "0.4"
Plots = "1.40"
RCall = "0.14"
Statistics = "1.10"
julia = "1.10"
8 changes: 6 additions & 2 deletions pipeline/src/EpiAwarePipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ with execution determined by available computational resources.
module EpiAwarePipeline

using CSV, Dagger, DataFramesMeta, Dates, Distributions, DocStringExtensions, DrWatson,
EpiAware, Plots, Statistics, ADTypes, AbstractMCMC, Plots, JLD2
EpiAware, Plots, Statistics, ADTypes, AbstractMCMC, Plots, JLD2, MCMCChains, Turing

# Exported pipeline types
export AbstractEpiAwarePipeline, EpiAwarePipeline, RtwithoutRenewalPipeline,
Expand All @@ -32,7 +32,10 @@ export do_truthdata, do_inference, do_pipeline
export simulate, generate_truthdata

# Exported functions: infer functions
export infer, generate_inference_results, map_inference_results
export infer, generate_inference_results, map_inference_results, define_epiprob

# Exported functions: forecast functions
export define_forecast_epiprob, generate_forecasts

# Exported functions: plot functions
export plot_truth_data, plot_Rt
Expand All @@ -42,5 +45,6 @@ include("pipeline/pipeline.jl")
include("constructors/constructors.jl")
include("simulate/simulate.jl")
include("infer/infer.jl")
include("forecast/forecast.jl")
include("plot_functions.jl")
end
25 changes: 25 additions & 0 deletions pipeline/src/forecast/define_forecast_epiprob.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
Create a forecast EpiProblem by extending the given EpiProblem with additional
forecast steps.
# Arguments
- `epiprob::EpiProblem`: The original EpiProblem to be extended.
- `n::Integer`: The number of forecast steps to be added.
# Returns
- `forecast_epiprob::EpiProblem`: The forecast EpiProblem with extended time
span.
"""
function define_forecast_epiprob(epiprob::EpiProblem, n::Integer)
@assert n>0 "number of forecast steps n must be positive"

forecast_epiprob = EpiProblem(
epi_model = epiprob.epi_model,
latent_model = epiprob.latent_model,
observation_model = epiprob.observation_model,
tspan = (epiprob.tspan[1], epiprob.tspan[2] + n)
)

return forecast_epiprob
end
2 changes: 2 additions & 0 deletions pipeline/src/forecast/forecast.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
include("define_forecast_epiprob.jl")
include("generate_forecasts.jl")
29 changes: 29 additions & 0 deletions pipeline/src/forecast/generate_forecasts.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
Generate forecasts for `n` time steps above based on the given inference results.
# Arguments
- `inference_results`: The results of the inference process.
- `n`: The number of forecasts to generate.
# Returns
- `forecast_quantities`: The generated forecast quantities.
"""
function generate_forecasts(inference_results, n::Integer)
inference_chn = inference_results["inference_results"].samples
data = inference_results["inference_results"].data
epiprob = inference_results["epiprob"]
forecast_epiprob = define_forecast_epiprob(epiprob, n)
forecast_mdl = generate_epiaware(forecast_epiprob, (y_t = missing,))

# Add forward generation of latent variables using `predict`
pred_chn = mapreduce(chainscat, 1:size(inference_chn, 3)) do c
mapreduce(vcat, 1:size(inference_chn, 1)) do i
fwd_chn = predict(forecast_mdl, inference_chn[i, :, c]; include_all = true)
setrange(fwd_chn, i:i)
end
end

forecast_quantities = generated_observables(forecast_mdl, data, pred_chn)
return forecast_quantities
end
36 changes: 14 additions & 22 deletions pipeline/src/infer/InferenceConfig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ struct InferenceConfig{T, F, I, L, E}
gi_mean, gi_std, igp, latent_model, case_data, tspan, epimethod,
D_gen, transformation, delay_distribution, D_obs, log_I0_prior, cluster_factor_prior)
end

function InferenceConfig(inference_config::Dict; case_data, tspan, epimethod)
InferenceConfig(
inference_config["igp"], inference_config["latent_namemodels"].second;
gi_mean = inference_config["gi_mean"],
gi_std = inference_config["gi_std"],
case_data = case_data,
tspan = tspan,
epimethod = epimethod
)
end
end

"""
Expand All @@ -65,34 +76,15 @@ to make inference on and model configuration.
"""
function infer(config::InferenceConfig)
#Define infection-generating model
shape = (config.gi_mean / config.gi_std)^2
scale = config.gi_std^2 / config.gi_mean
gen_distribution = Gamma(shape, scale)

#Define the infection-generating process
model_data = EpiData(gen_distribution = gen_distribution, D_gen = config.D_gen,
transformation = config.transformation)

epi = config.igp(model_data, config.log_I0_prior)

#Define the infection conditional observation distribution
obs = LatentDelay(
NegativeBinomialError(cluster_factor_prior = config.cluster_factor_prior),
config.delay_distribution; D = config.D_obs)

#Define the EpiProblem
epi_prob = EpiProblem(epi_model = epi,
latent_model = config.latent_model,
observation_model = obs,
tspan = config.tspan)

epi_prob = define_epiprob(config)
idxs = config.tspan[1]:config.tspan[2]

#Return the sampled infections and observations
y_t = ismissing(config.case_data) ? missing : config.case_data[idxs]
inference_results = apply_method(epi_prob,
config.epimethod,
(y_t = y_t,)
)
return Dict("inference_results" => inference_results)
return Dict("inference_results" => inference_results, "epiprob" => epi_prob)
end
31 changes: 31 additions & 0 deletions pipeline/src/infer/define_epiprob.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
Create an `EpiProblem` object based on the provided `InferenceConfig`.
# Arguments
- `config::InferenceConfig`: An instance of the `InferenceConfig` type.
# Returns
- `epi_prob::EpiProblem`: An `EpiProblem` object representing the defined epidemiological problem.
"""
function define_epiprob(config::InferenceConfig)
shape = (config.gi_mean / config.gi_std)^2
scale = config.gi_std^2 / config.gi_mean
gen_distribution = Gamma(shape, scale)

model_data = EpiData(gen_distribution = gen_distribution, D_gen = config.D_gen,
transformation = config.transformation)

epi = config.igp(model_data, config.log_I0_prior)

obs = LatentDelay(
NegativeBinomialError(cluster_factor_prior = config.cluster_factor_prior),
config.delay_distribution; D = config.D_obs)

epi_prob = EpiProblem(epi_model = epi,
latent_model = config.latent_model,
observation_model = obs,
tspan = config.tspan)

return epi_prob
end
16 changes: 2 additions & 14 deletions pipeline/src/infer/generate_inference_results.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,7 @@ function generate_inference_results(
tspan, inference_method,
prfix_name = "observables", datadir_name = "epiaware_observables")
config = InferenceConfig(
inference_config["igp"], inference_config["latent_namemodels"].second;
gi_mean = inference_config["gi_mean"],
gi_std = inference_config["gi_std"],
case_data = truthdata["y_t"],
tspan = tspan,
epimethod = inference_method
)
inference_config; case_data = truthdata["y_t"], tspan, epimethod = inference_method)

# produce or load inference results
prfx = prfix_name * "_igp_" * string(inference_config["igp"]) * "_latentmodel_" *
Expand All @@ -44,13 +38,7 @@ function generate_inference_results(
tspan, inference_method,
prfix_name = "prior_observables", datadir_name = "epiaware_observables")
config = InferenceConfig(
inference_config["igp"], inference_config["latent_namemodels"].second;
gi_mean = inference_config["gi_mean"],
gi_std = inference_config["gi_std"],
case_data = missing,
tspan = tspan,
epimethod = inference_method
)
inference_config; case_data = missing, tspan, epimethod = inference_method)

# produce or load inference results
prfx = prfix_name * "_igp_" * string(inference_config["igp"]) * "_latentmodel_" *
Expand Down
1 change: 1 addition & 0 deletions pipeline/src/infer/infer.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
include("InferenceConfig.jl")
include("generate_inference_results.jl")
include("map_inference_results.jl")
include("define_epiprob.jl")
File renamed without changes.
80 changes: 80 additions & 0 deletions pipeline/test/end-to-end/test_infer_and_forecast.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
using Test

@testset "run inference and forecast for generated data" begin
using DrWatson
quickactivate(@__DIR__(), "EpiAwarePipeline")

using EpiAwarePipeline, EpiAware, Plots, Statistics
pipeline = RtwithoutRenewalPipeline()
prior = RtwithoutRenewalPriorPipeline()

# Set up data generation on a random scenario
lookahead = 60
n_observation_steps = 28
tspan_gen = (1, n_observation_steps + lookahead)
tspan_inf = (1, n_observation_steps)
inference_method = make_inference_method(pipeline)
truth_data_config = make_truth_data_configs(pipeline)[1]
inference_configs = make_inference_configs(pipeline)
inference_config = rand(inference_configs)

# Generate truth data and plot
truth_sampling = generate_inference_results(
Dict("y_t" => missing, "truth_gi_mean" => 1.5), inference_config,
prior; tspan = tspan_gen, inference_method = make_inference_method(prior)) |>
d -> d["inference_results"]

#Choose first sample to represent truth data
truthdata = truth_sampling.generated[1].generated_y_t

plt = scatter(truthdata, xlabel = "t", ylabel = "y_t", label = "truth data")
vline!(plt, [n_observation_steps], label = "forecast start")

# Run inference
obs_truthdata = truthdata[tspan_inf[1]:tspan_inf[2]]
inference_results = generate_inference_results(
Dict("y_t" => obs_truthdata, "truth_gi_mean" => 1.5),
inference_config, pipeline; tspan = tspan_inf, inference_method)

@test inference_results["inference_results"] isa EpiAwareObservables

# Make 21-day forecast
forecast_quantities = generate_forecasts(inference_results, lookahead)
@test forecast_quantities isa EpiAwareObservables

# Make forecast spaghetti plot
forecast_y_t = mapreduce(hcat, forecast_quantities.generated) do gen
gen.generated_y_t
end
forecast_qs = mapreduce(hcat, [0.025, 0.25, 0.5, 0.75, 0.975]) do q
map(eachrow(forecast_y_t)) do row
quantile(row, q)
end
end
plot!(plt, forecast_qs, label = "forecast quantiles",
color = :grey, lw = [0.5 1.5 3 1.5 0.5])
plot!(plt, ylims = (-0.5, maximum(truthdata) * 1.25))
savefig(plt,
joinpath(@__DIR__(), "forecast_y_t.png")
)

# Make forecast plot for Z_t
infer_Z_t = mapreduce(hcat, inference_results["inference_results"].generated) do gen
gen.Z_t
end
forecast_Z_t = mapreduce(hcat, forecast_quantities.generated) do gen
gen.Z_t
end
plt_Zt = plot(
truth_sampling.generated[1].Z_t, lw = 3, color = :black, label = "truth Z_t")
plot!(plt_Zt, infer_Z_t, xlabel = "t", ylabel = "Z_t",
label = "", color = :grey, alpha = 0.05)
plot!((n_observation_steps + 1):size(forecast_Z_t, 1),
forecast_Z_t[(n_observation_steps + 1):end, :],
label = "", color = :red, alpha = 0.05)
vline!(plt_Zt, [n_observation_steps], label = "forecast start")

savefig(plt_Zt,
joinpath(@__DIR__(), "forecast_Z_t.png")
)
end
File renamed without changes.
19 changes: 19 additions & 0 deletions pipeline/test/forecast/test_forecast.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
@testset "define_forecast_epiprob" begin
using EpiAwarePipeline
pipeline = RtwithoutRenewalPipeline()

inference_configs = make_inference_configs(pipeline)

case_data = missing
tspan = (1, 28)
epimethod = make_inference_method(pipeline)

epiprob = InferenceConfig(rand(inference_configs); case_data, tspan, epimethod) |>
define_epiprob

@test_throws AssertionError define_forecast_epiprob(epiprob, -1)

n_fr = 7
forecast_epiprob = define_forecast_epiprob(epiprob, 7)
@test forecast_epiprob.tspan == (epiprob.tspan[1], epiprob.tspan[2] + n_fr)
end
28 changes: 16 additions & 12 deletions pipeline/test/infer/test_InferenceConfig.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@

# Test the InferenceConfig struct constructor
@testset "InferenceConfig" begin
@testset "InferenceConfig: constructor function" begin
using Distributions, EpiAwarePipeline, EpiAware

struct TestLatentModel <: AbstractLatentModel
Expand All @@ -16,16 +14,15 @@
epimethod = TestMethod()
case_data = [10, 20, 30, 40, 50]
tspan = (1, 5)
@testset "config_parameters back from constructor" begin
config = InferenceConfig(igp, latent_model;
gi_mean = gi_mean,
gi_std = gi_std,
case_data = case_data,
tspan = tspan,
epimethod = epimethod
)

config = InferenceConfig(igp, latent_model;
gi_mean = gi_mean,
gi_std = gi_std,
case_data = case_data,
tspan = tspan,
epimethod = epimethod
)

@testset "config_parameters" begin
@test config.gi_mean == gi_mean
@test config.gi_std == gi_std
@test config.igp == igp
Expand All @@ -34,4 +31,11 @@
@test config.tspan == tspan
@test config.epimethod == epimethod
end

@testset "construct from config dictionary" begin
pipeline = RtwithoutRenewalPipeline()
inference_configs = make_inference_configs(pipeline)
@test [InferenceConfig(ic; case_data, tspan, epimethod) isa InferenceConfig
for ic in inference_configs] |> all
end
end
15 changes: 15 additions & 0 deletions pipeline/test/infer/test_define_epiprob.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
@testset "test define_epiprob" begin
using EpiAwarePipeline
pipeline = RtwithoutRenewalPipeline()

inference_configs = make_inference_configs(pipeline)

case_data = missing
tspan = (1, 28)
epimethod = make_inference_method(pipeline)

epiprob = InferenceConfig(rand(inference_configs); case_data, tspan, epimethod) |>
define_epiprob

@test epiprob isa EpiProblem
end
Loading

0 comments on commit 496c6e9

Please sign in to comment.