Skip to content

Commit

Permalink
Enable functions to take generic inputs and improve test coverage (#56)
Browse files Browse the repository at this point in the history
* Add converters for wrapped functions

* Add more converts for Dataset and InferenceData

* Use new convert overloads

* Make consistent with arviz

* Add converters for stats and plots functions

* Use default converts in quickstart

* Increment version

* Add converter for waic

* Add tests

* Add diagnostic tests

* Don't gridplot if already a single BokehPlot

* Fix some conversions

* Test all plots

* Run JuliaFormatter

* Fix test for current ArviZ version

* Wrap all tests in one test set

* Test vector overload
  • Loading branch information
sethaxen authored Feb 26, 2020
1 parent 94cd5c0 commit e037516
Show file tree
Hide file tree
Showing 15 changed files with 479 additions and 71 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ArviZ"
uuid = "131c737c-5715-5e2e-ad31-c244f01c1dc7"
authors = ["Seth Axen <[email protected]>"]
version = "0.3.3-DEV"
version = "0.3.3"

[deps]
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
Expand Down
18 changes: 15 additions & 3 deletions docs/src/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ using PyCall
np = pyimport_conda("numpy", "numpy")
np.seterr(divide="ignore", invalid="ignore")
turing_chns = read("../src/assets/turing_centered_eight_chains.jls", MCMCChains.Chains)
turing_chns = read(
"../src/assets/turing_centered_eight_chains.jls",
MCMCChains.Chains,
)
# use fancy HTML for xarray.Dataset if available
try
Expand Down Expand Up @@ -124,7 +127,7 @@ turing_chns = psample(
Most ArviZ functions work fine with `Chains` objects from Turing:

```@example quickstart
plot_autocorr(convert_to_inference_data(turing_chns); var_names = ["μ", "τ"]);
plot_autocorr(turing_chns; var_names = ["μ", "τ"]);
savefig("quick_turingautocorr.svg"); nothing # hide
```

Expand Down Expand Up @@ -238,7 +241,7 @@ nothing # hide
```

```@example quickstart
plot_density(convert_to_inference_data(stan_chns); var_names=["mu", "tau"]);
plot_density(stan_chns; var_names=["mu", "tau"]);
savefig("quick_cmdstandensity.svg"); nothing # hide
```

Expand Down Expand Up @@ -336,6 +339,15 @@ nothing # hide
```

Each Soss draw is a `NamedTuple`.
We can plot the rank order statistics of the posterior to identify poor convergence:

```@example quickstart
plot_rank(post; var_names = ["μ", "τ"]);
savefig("quick_sossrank.png"); nothing # hide
```

![](quick_sossrank.png)

Now we combine all of the samples to an `InferenceData`:

```@example quickstart
Expand Down
4 changes: 3 additions & 1 deletion src/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ InferenceData(; kwargs...) = reorder_groups!(arviz.InferenceData(; kwargs...))

@inline PyObject(data::InferenceData) = getfield(data, :o)

Base.convert(::Type{InferenceData}, o::PyObject) = InferenceData(o)
Base.convert(::Type{InferenceData}, obj::PyObject) = InferenceData(obj)
Base.convert(::Type{InferenceData}, obj) = convert_to_inference_data(obj)
Base.convert(::Type{InferenceData}, obj::InferenceData) = obj

Base.hash(data::InferenceData) = hash(PyObject(data))

Expand Down
4 changes: 3 additions & 1 deletion src/dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ Dataset(; kwargs...) = xarray.Dataset(; kwargs...)

@inline PyObject(data::Dataset) = getfield(data, :o)

Base.convert(::Type{Dataset}, o::PyObject) = Dataset(o)
Base.convert(::Type{Dataset}, obj::PyObject) = Dataset(obj)
Base.convert(::Type{Dataset}, obj::Dataset) = obj
Base.convert(::Type{Dataset}, obj) = convert_to_dataset(obj)

Base.hash(data::Dataset) = hash(PyObject(data))

Expand Down
24 changes: 22 additions & 2 deletions src/diagnostics.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,25 @@
@forwardfun bfmi
@forwardfun geweke
@forwardfun ess
@forwardfun rhat
@forwardfun geweke
@forwardfun mcse
@forwardfun rhat

function convert_arguments(::typeof(bfmi), data, args...; kwargs...)
dataset = convert_to_dataset(data; group = :sample_stats)
return tuple(dataset, args...), kwargs
end
function convert_arguments(::typeof(bfmi), data::AbstractArray, args...; kwargs...)
return tuple(data, args...), kwargs
end

for f in (:ess, :mcse, :rhat)
@eval begin
function convert_arguments(::typeof($(f)), data, args...; kwargs...)
dataset = convert_to_dataset(data; group = :posterior)
return tuple(dataset, args...), kwargs
end
function convert_arguments(::typeof($(f)), data::AbstractArray, args...; kwargs...)
return tuple(data, args...), kwargs
end
end
end
85 changes: 85 additions & 0 deletions src/plots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,88 @@
@forwardplotfun plot_rank
@forwardplotfun plot_trace
@forwardplotfun plot_violin

# TODO: Add conversions for plot_compare, plot_elpd, and plot_khat

for f in (
:plot_autocorr,
:plot_ess,
:plot_joint,
:plot_mcse,
:plot_pair,
:plot_posterior,
:plot_trace,
:plot_violin,
)
@eval begin
function convert_arguments(::typeof($(f)), data, args...; kwargs...)
idata = convert_to_inference_data(data; group = :posterior)
return tuple(idata, args...), kwargs
end
end
end

for f in (:plot_autocorr, :plot_ess, :plot_mcse, :plot_posterior, :plot_violin)
@eval begin
function convert_arguments(::typeof($(f)), data::AbstractArray, args...; kwargs...)
return tuple(data, args...), kwargs
end
end
end

for f in (:plot_energy, :plot_parallel)
@eval begin
function convert_arguments(::typeof($(f)), data, args...; kwargs...)
dataset = convert_to_dataset(data; group = :sample_stats)
return tuple(dataset, args...), kwargs
end
end
end

for f in (:plot_density, :plot_forest, :plot_rank)
@eval begin
function convert_arguments(
::typeof($(f)),
data,
args...;
transform = identity,
group = :posterior,
kwargs...,
)
tdata = transform(data)
dataset = convert_to_dataset(tdata; group = group)
return tuple(dataset, args...), kwargs
end
end
end

for f in (:plot_density, :plot_forest)
@eval begin
function convert_arguments(
::typeof($(f)),
data::Union{AbstractVector,Tuple},
transform = identity,
group = :posterior,
args...;
kwargs...,
)
tdata = transform(data)
datasets = map(tdata) do datum
return convert_to_dataset(datum; group = group)
end
return tuple(datasets, args...), kwargs
end
function convert_arguments(
::typeof($(f)),
data::AbstractVector{<:Real},
transform = identity,
group = :posterior,
args...;
kwargs...,
)
tdata = transform(data)
dataset = convert_to_dataset(tdata; group = group)
return tuple(dataset, args...), kwargs
end
end
end
48 changes: 25 additions & 23 deletions src/stats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,27 @@ const sample_stats_types = Dict(
"diverging" => Bool,
)

@doc forwarddoc(:compare) compare(args...; kwargs...) =
arviz.compare(args...; kwargs...) |> Pandas.DataFrame

Docs.getdoc(::typeof(compare)) = forwardgetdoc(:compare)

@forwardfun compare
@forwardfun hpd

@doc forwarddoc(:loo) loo(args...; kwargs...) =
arviz.loo(args...; kwargs...) |> Pandas.Series

Docs.getdoc(::typeof(loo)) = forwardgetdoc(:loo)

@forwardfun loo
@forwardfun loo_pit

@forwardfun psislw
@forwardfun r2_score
@forwardfun waic

for f in (:loo, :waic)
@eval begin
function convert_arguments(::typeof($(f)), data, args...; kwargs...)
idata = convert_to_inference_data(data)
return tuple(idata, args...), kwargs
end
end
end

@doc forwarddoc(:r2_score) r2_score(args...; kwargs...) =
arviz.r2_score(args...; kwargs...) |> Pandas.Series

Docs.getdoc(::typeof(r2_score)) = forwardgetdoc(:r2_score)

@doc forwarddoc(:waic) waic(args...; kwargs...) =
arviz.waic(args...; kwargs...) |> Pandas.Series

Docs.getdoc(::typeof(waic)) = forwardgetdoc(:waic)
convert_result(::typeof(loo), result) = Pandas.Series(result)
convert_result(::typeof(waic), result) = Pandas.Series(result)
convert_result(::typeof(r2_score), result) = Pandas.Series(result)
convert_result(::typeof(compare), result) = Pandas.DataFrame(result)

"""
summarystats(data::Dataset; kwargs...) -> Union{Pandas.DataFrame,Dataset}
Expand Down Expand Up @@ -75,6 +71,12 @@ Compute summary statistics on `data`.
- `order::String="C"`: If `fmt` is "wide", use either "C" or "F" unpacking order.
- `index_origin::Int=1`: If `fmt` is "wide", select 𝑛-based indexing for multivariate
parameters.
- `skipna::Bool=false`: If `true`, ignores `NaN` values when computing the summary
statistics. It does not affect the behaviour of the functions passed to `stat_funcs`.
- `coords::Dict{String,Vector}=Dict()`: Coordinates specification to be used if the `fmt`
is `"xarray"`.
- `dims::Dict{String,Vector}=Dict()`: Dimensions specification for the variables to be used
if the `fmt` is `"xarray"`.
# Returns
Expand Down Expand Up @@ -150,6 +152,6 @@ Compute summary statistics on any object that can be passed to [`convert_to_data
- `kwargs`: Keyword arguments passed to [`summarystats`](@ref).
"""
function summary(data; group = :posterior, coords = nothing, dims = nothing, kwargs...)
idata = convert_to_inference_data(data; group = group, coords = coords, dims = dims)
return summarystats(idata; group = group, kwargs...)
dataset = convert_to_dataset(data; group = group, coords = coords, dims = dims)
return summarystats(dataset; kwargs...)
end
45 changes: 41 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,25 @@ function with_interactive_backend(f; backend = nothing)
return ret
end

"""
convert_arguments(f, args...; kwargs...) -> NTuple{2}
Convert arguments to the function `f` before calling.
This function is used primarily for pre-processing arguments within macros before sending
to arviz.
"""
convert_arguments(::Any, args...; kwargs...) = args, kwargs

"""
convert_result(f, result)
Convert result of the function `f` before returning.
This function is used primarily for post-processing outputs of arviz before returning.
"""
convert_result(::Any, result) = result

forwarddoc(f::Symbol) =
"See documentation for [`arviz.$(f)`](https://arviz-devs.github.io/arviz/generated/arviz.$(f).html)."

Expand All @@ -101,12 +120,19 @@ forwardgetdoc(f::Symbol) = Docs.getdoc(getproperty(arviz, f))
@forwardfun(f)
Wrap a function `arviz.f` in `f`, forwarding its docstrings.
Use [`convert_arguments`](@ref) and [`convert_result`](@ref) to customize what is passed to
and returned from `f`.
"""
macro forwardfun(f)
fdoc = forwarddoc(f)
quote
@doc $fdoc
$(f)(args...; kwargs...) = arviz.$(f)(args...; kwargs...)
function $(f)(args...; kwargs...)
args, kwargs = convert_arguments($(f), args...; kwargs...)
result = arviz.$(f)(args...; kwargs...)
return convert_result($(f), result)
end

Docs.getdoc(::typeof($(f))) = forwardgetdoc(Symbol($(f)))
end |> esc
Expand All @@ -119,6 +145,8 @@ end
Wrap a plotting function `arviz.f` in `f`, forwarding its docstrings.
This macro also ensures that outputs for the different backends are correctly handled.
Use [`convert_arguments`](@ref) and [`convert_result`](@ref) to customize what is passed to
and returned from `f`.
"""
macro forwardplotfun(f)
fdoc = forwarddoc(f)
Expand All @@ -133,22 +161,31 @@ macro forwardplotfun(f)
return $(f)(Val(Symbol(backend)), args...; kwargs...)
end

$(f)(::Val, args...; kwargs...) = arviz.$(f)(args...; kwargs...)
function $(f)(::Val, args...; kwargs...)
args, kwargs = convert_arguments($(f), args...; kwargs...)
result = arviz.$(f)(args...; kwargs...)
return convert_result($(f), result)
end

function $(f)(::Val{:matplotlib}, args...; kwargs...)
args, kwargs = convert_arguments($(f), args...; kwargs...)
kwargs = merge(kwargs, Dict(:backend => "matplotlib"))
try
return arviz.$(f)(args...; kwargs...)
result = arviz.$(f)(args...; kwargs...)
return convert_result($(f), result)
catch e
e isa PyCall.PyError || rethrow(e)
pop!(kwargs, :backend)
return arviz.$(f)(args...; kwargs...)
result = arviz.$(f)(args...; kwargs...)
return convert_result($(f), result)
end
end

function $(f)(::Val{:bokeh}, args...; kwargs...)
args, kwargs = convert_arguments($(f), args...; kwargs...)
kwargs = merge(kwargs, Dict(:backend => "bokeh", :show => false))
plots = arviz.$(f)(args...; kwargs...)
plots isa BokehPlot && return plots
return bokeh.plotting.gridplot(plots)
end

Expand Down
6 changes: 6 additions & 0 deletions test/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@ using Random
using PyCall
using ArviZ: attributes

py"""
class PyNullObject(object):
def __init__(self):
pass
"""

function create_model(seed = 10)
rng = MersenneTwister(seed)
J = 8
Expand Down
20 changes: 12 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
using ArviZ
using Test

include("helpers.jl")
include("test_rcparams.jl")
include("test_backend.jl")
include("test_dataset.jl")
include("test_data.jl")
include("test_stats.jl")
include("test_namedtuple.jl")
include("test_mcmcchains.jl")
@testset "ArviZ" begin
include("helpers.jl")
include("test_rcparams.jl")
include("test_backend.jl")
include("test_dataset.jl")
include("test_data.jl")
include("test_diagnostics.jl")
include("test_stats.jl")
include("test_plots.jl")
include("test_namedtuple.jl")
include("test_mcmcchains.jl")
end
Loading

2 comments on commit e037516

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/10127

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if Julia TagBot is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.3 -m "<description of version>" e037516f712091c90e85603b96b4de1b20dbeeef
git push origin v0.3.3

Please sign in to comment.