Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor with method hierarchy #78

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions EpiAware/src/EpiAware.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,28 @@ This module provides functionality for calculating Rt (effective reproduction nu
module EpiAware

using Distributions,
Turing,
LogExpFunctions,
LinearAlgebra,
SparseArrays,
Random,
ReverseDiff,
Optim,
Parameters,
QuadGK,
DataFramesMeta
Turing,
LogExpFunctions,
LinearAlgebra,
SparseArrays,
Random,
ReverseDiff,
Optim,
Parameters,
QuadGK,
DataFramesMeta

# Exported utilities
export create_discrete_pmf,
default_rw_priors, default_delay_obs_priors,
default_initialisation_prior, spread_draws
default_rw_priors, default_delay_obs_priors,
default_initialisation_prior, spread_draws

# Exported types
export EpiData, Renewal, ExpGrowthRate, DirectInfections

# Exported Turing model constructors
export make_epi_inference_model, delay_observations_model,
initialize_incidence
initialize_incidence

include("epimodel.jl")
include("utilities.jl")
Expand Down
31 changes: 22 additions & 9 deletions EpiAware/src/latent-processes.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,41 @@
abstract type AbstractLatentProcess end
abstract type AbstractLatentProcessArg end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think defining a type hierarchy for the underlying priors and an auxiliary type hierarchy that is initimately connected to it is overkill.

For example, you can dispatch on the type of the first type without using it if you want to.


struct RandomWalkLatentProcess{D <: Sampleable, S <: Sampleable} <: AbstractLatentProcess
init_prior::D
var_prior::S
end

struct RandomWalkLatentProcessArg <: AbstractLatentProcessArg end

function default_rw_priors()
return (:var_RW_prior => truncated(Normal(0.0, 0.05), 0.0, Inf),
:init_rw_value_prior => Normal()) |> Dict
end

function generate_latent_process(latent_process::AbstractLatentProcess, n; kwargs...)
@info "No concrete implementation for generate_latent_process is defined."
function latent_process(lp::AbstractLatentProcess, n;
kwargs...)
return latent_process(
AbstractLatentProcessArg(), n; var_prior = lp.var_prior, init_prior = lp.init_prior)
end

function latent_process(lp::AbstractLatentProcess, n; kwargs...)
@info "No concrete implementation for latent_process is defined."
return nothing
end

@model function generate_latent_process(latent_process::RandomWalkLatentProcess, n)
function latent_process(lp::RandomWalkLatentProcess, n)
return latent_process(
RandomWalkLatentProcessArg(), n; var_prior = lp.var_prior,
init_prior = lp.init_prior
)
end

@model function latent_process(lp::RandomWalkLatentProcessArg, n;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like the point of the above is to lower down to this function; but you can get the same effect by specifying the fields of RandomWalkLatentProcess (i.e. not using a NamedTuple or Dict of priors of potentially any name) and dispatching on that.

I'd avoid that by having RandomWalkLatentProcess have determined typed fields and saving a bunch of code repetition.

var_prior::ContinuousDistribution, init_prior::ContinuousDistribution)
ϵ_t ~ MvNormal(ones(n))
σ²_RW ~ latent_process.var_prior
rw_init ~ latent_process.init_prior
σ²_RW ~ var_prior
rw_init ~ init_prior
σ_RW = sqrt(σ²_RW)
rw = Vector{eltype(ϵ_t)}(undef, n)

Expand All @@ -28,7 +45,3 @@ end
end
return rw, (; σ_RW, rw_init)
end

# function random_walk_process(; latent_process_priors = default_rw_priors())
# LatentProcess(random_walk, latent_process_priors)
# end
3 changes: 2 additions & 1 deletion EpiAware/src/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
observation_model::AbstractObservationModel,
pos_shift = 1e-6)
#Latent process
@submodel latent_process, latent_process_aux = generate_latent_process(latent_process_model,
@submodel latent_process, latent_process_aux = generate_latent_process(
latent_process_model,
time_steps)

#Transform into infections
Expand Down
10 changes: 5 additions & 5 deletions EpiAware/test/test_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ end
@testset "Test case 5" begin
dist = Exponential(1.0)
expected_pmf_uncond = [exp(-1)
[(1 - exp(-1)) * (exp(1) - 1) * exp(-s) for s in 1:9]]
[(1 - exp(-1)) * (exp(1) - 1) * exp(-s) for s in 1:9]]
expected_pmf = expected_pmf_uncond ./ sum(expected_pmf_uncond)
pmf = create_discrete_pmf(dist; Δd = 1.0, D = 10.0)
@test expected_pmf≈pmf atol=1e-15
Expand Down Expand Up @@ -100,10 +100,10 @@ end
delay_int = [0.2, 0.5, 0.3]
time_horizon = 5
expected_K = SparseMatrixCSC([0.2 0 0 0 0
0.5 0.2 0 0 0
0.3 0.5 0.2 0 0
0 0.3 0.5 0.2 0
0 0 0.3 0.5 0.2])
0.5 0.2 0 0 0
0.3 0.5 0.2 0 0
0 0.3 0.5 0.2 0
0 0 0.3 0.5 0.2])
K = EpiAware.generate_observation_kernel(delay_int, time_horizon)
@test K == expected_K
end
Expand Down
Loading