From 667baaa945877a7d4333d892a5e5d57e1e99a6b9 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 27 Feb 2024 11:24:37 +0000 Subject: [PATCH] refactor with method hierarchy --- EpiAware/src/EpiAware.jl | 26 +++++++++++++------------- EpiAware/src/latent-processes.jl | 31 ++++++++++++++++++++++--------- EpiAware/src/models.jl | 3 ++- EpiAware/test/test_utilities.jl | 10 +++++----- 4 files changed, 42 insertions(+), 28 deletions(-) diff --git a/EpiAware/src/EpiAware.jl b/EpiAware/src/EpiAware.jl index c89d41b7b..67af2a6aa 100644 --- a/EpiAware/src/EpiAware.jl +++ b/EpiAware/src/EpiAware.jl @@ -20,28 +20,28 @@ This module provides functionality for calculating Rt (effective reproduction nu module EpiAware using Distributions, - Turing, - LogExpFunctions, - LinearAlgebra, - SparseArrays, - Random, - ReverseDiff, - Optim, - Parameters, - QuadGK, - DataFramesMeta + Turing, + LogExpFunctions, + LinearAlgebra, + SparseArrays, + Random, + ReverseDiff, + Optim, + Parameters, + QuadGK, + DataFramesMeta # Exported utilities export create_discrete_pmf, - default_rw_priors, default_delay_obs_priors, - default_initialisation_prior, spread_draws + default_rw_priors, default_delay_obs_priors, + default_initialisation_prior, spread_draws # Exported types export EpiData, Renewal, ExpGrowthRate, DirectInfections # Exported Turing model constructors export make_epi_inference_model, delay_observations_model, - initialize_incidence + initialize_incidence include("epimodel.jl") include("utilities.jl") diff --git a/EpiAware/src/latent-processes.jl b/EpiAware/src/latent-processes.jl index 95f8a8fa1..700c9f676 100644 --- a/EpiAware/src/latent-processes.jl +++ b/EpiAware/src/latent-processes.jl @@ -1,24 +1,41 @@ abstract type AbstractLatentProcess end +abstract type AbstractLatentProcessArg end struct RandomWalkLatentProcess{D <: Sampleable, S <: Sampleable} <: AbstractLatentProcess init_prior::D var_prior::S end +struct RandomWalkLatentProcessArg <: AbstractLatentProcessArg end + function default_rw_priors() return (:var_RW_prior => truncated(Normal(0.0, 0.05), 0.0, Inf), :init_rw_value_prior => Normal()) |> Dict end -function generate_latent_process(latent_process::AbstractLatentProcess, n; kwargs...) - @info "No concrete implementation for generate_latent_process is defined." +function latent_process(lp::AbstractLatentProcess, n; + kwargs...) + return latent_process( + AbstractLatentProcessArg(), n; var_prior = lp.var_prior, init_prior = lp.init_prior) +end + +function latent_process(lp::AbstractLatentProcess, n; kwargs...) + @info "No concrete implementation for latent_process is defined." return nothing end -@model function generate_latent_process(latent_process::RandomWalkLatentProcess, n) +function latent_process(lp::RandomWalkLatentProcess, n) + return latent_process( + RandomWalkLatentProcessArg(), n; var_prior = lp.var_prior, + init_prior = lp.init_prior + ) +end + +@model function latent_process(lp::RandomWalkLatentProcessArg, n; + var_prior::ContinuousDistribution, init_prior::ContinuousDistribution) ϵ_t ~ MvNormal(ones(n)) - σ²_RW ~ latent_process.var_prior - rw_init ~ latent_process.init_prior + σ²_RW ~ var_prior + rw_init ~ init_prior σ_RW = sqrt(σ²_RW) rw = Vector{eltype(ϵ_t)}(undef, n) @@ -28,7 +45,3 @@ end end return rw, (; σ_RW, rw_init) end - -# function random_walk_process(; latent_process_priors = default_rw_priors()) -# LatentProcess(random_walk, latent_process_priors) -# end diff --git a/EpiAware/src/models.jl b/EpiAware/src/models.jl index e02242c76..73a54e23a 100644 --- a/EpiAware/src/models.jl +++ b/EpiAware/src/models.jl @@ -5,7 +5,8 @@ observation_model::AbstractObservationModel, pos_shift = 1e-6) #Latent process - @submodel latent_process, latent_process_aux = generate_latent_process(latent_process_model, + @submodel latent_process, latent_process_aux = generate_latent_process( + latent_process_model, time_steps) #Transform into infections diff --git a/EpiAware/test/test_utilities.jl b/EpiAware/test/test_utilities.jl index 5764b05dc..8ae9b58d2 100644 --- a/EpiAware/test/test_utilities.jl +++ b/EpiAware/test/test_utilities.jl @@ -64,7 +64,7 @@ end @testset "Test case 5" begin dist = Exponential(1.0) expected_pmf_uncond = [exp(-1) - [(1 - exp(-1)) * (exp(1) - 1) * exp(-s) for s in 1:9]] + [(1 - exp(-1)) * (exp(1) - 1) * exp(-s) for s in 1:9]] expected_pmf = expected_pmf_uncond ./ sum(expected_pmf_uncond) pmf = create_discrete_pmf(dist; Δd = 1.0, D = 10.0) @test expected_pmf≈pmf atol=1e-15 @@ -100,10 +100,10 @@ end delay_int = [0.2, 0.5, 0.3] time_horizon = 5 expected_K = SparseMatrixCSC([0.2 0 0 0 0 - 0.5 0.2 0 0 0 - 0.3 0.5 0.2 0 0 - 0 0.3 0.5 0.2 0 - 0 0 0.3 0.5 0.2]) + 0.5 0.2 0 0 0 + 0.3 0.5 0.2 0 0 + 0 0.3 0.5 0.2 0 + 0 0 0.3 0.5 0.2]) K = EpiAware.generate_observation_kernel(delay_int, time_horizon) @test K == expected_K end