Skip to content

Commit

Permalink
Implement summarystats for InferenceObjects types (#294)
Browse files Browse the repository at this point in the history
* Add initial summarystats implementation

* Add summarystats to API docs

* Increment patch number

* Update src/ArviZStats/summarystats.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Fix summarystats methods

* Rename utility function

* Add utilities for formatting to strings

* Generalize interval probability variable name

* Make test pass on v1.10.x

* Implement common interfaces for SummaryStats

* Add alignment of column entries

* Use formatter utility functions

* Update keywords

* Add utility function

* Use parent

* Rename keyword to prob_interval

* Rename to compact_labels

* Fix docstring

* Update doctests

* Use new utilities in compare

* Change table format for ELPD results

* Update tests and docstrings

* Update docstrings

* Simplify formatter implementations

* Add formatter tests

* Define metric dim as constant

* Refactor utility functions

* Use refactored utilities

* Call correct version of ess_rhat

* Concretize sample dims

* Add summarystats tests

* Remove duplicate keyword doc

* Add doctest

* Add SummaryStats tests

* Add missing tests

* Unify PrettyTables code

* Add HTML show method for tables

* Fix compare tests

* Just test that HTML is returned

* Add Tables row table methods

Pluto uses these

* sReuse table utilities for AbstractELPDResult

* Fix indexing on v1.6

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
sethaxen and github-actions[bot] authored Aug 2, 2023
1 parent a572fe0 commit c31176a
Show file tree
Hide file tree
Showing 18 changed files with 971 additions and 146 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ArviZ"
uuid = "131c737c-5715-5e2e-ad31-c244f01c1dc7"
authors = ["Seth Axen <[email protected]>"]
version = "0.9.1"
version = "0.9.2"

[deps]
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
Expand Down
7 changes: 7 additions & 0 deletions docs/src/api/stats.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
Pages = ["stats.md"]
```

## Summary statistics

```@docs
SummaryStats
summarystats
```

## General statistics

```@docs
Expand Down
2 changes: 1 addition & 1 deletion src/ArviZ.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import Base:
+
import Base.Docs: getdoc
using StatsBase: StatsBase
import StatsBase: summarystats
import Markdown: @doc_str

using InferenceObjects
Expand All @@ -48,6 +47,7 @@ export PSIS, PSISResult, psis, psis!
export elpd_estimates, information_criterion, loo, waic
export AbstractModelWeightsMethod, BootstrappedPseudoBMA, PseudoBMA, Stacking, model_weights
export ModelComparisonResult, compare
export SummaryStats, summarystats
export hdi, hdi!, loo_pit, r2_score

## Diagnostics
Expand Down
7 changes: 6 additions & 1 deletion src/ArviZStats/ArviZStats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ using PSIS: PSIS, PSISResult, psis, psis!
using Random: Random
using Setfield: Setfield
using Statistics: Statistics
using StatsBase: StatsBase
using StatsBase: StatsBase, summarystats
using Tables: Tables
using TableTraits: TableTraits

Expand All @@ -33,12 +33,16 @@ export elpd_estimates, information_criterion, loo, waic
export AbstractModelWeightsMethod, BootstrappedPseudoBMA, PseudoBMA, Stacking, model_weights
export ModelComparisonResult, compare

# Summary statistics
export SummaryStats, summarystats

# Others
export hdi, hdi!, loo_pit, r2_score

# load for docstrings
using ArviZ: InferenceData, convert_to_dataset, ess

const DEFAULT_INTERVAL_PROB = 0.94
const INFORMATION_CRITERION_SCALES = (deviance=-2, log=1, negative_log=-1)

include("utils.jl")
Expand All @@ -50,5 +54,6 @@ include("model_weights.jl")
include("compare.jl")
include("loo_pit.jl")
include("r2_score.jl")
include("summarystats.jl")

end # module
77 changes: 34 additions & 43 deletions src/ArviZStats/compare.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ julia> mc = compare(models)
┌ Warning: 1 parameters had Pareto shape values 0.7 < k ≤ 1. Resulting importance sampling estimates are likely to be unstable.
└ @ PSIS ~/.julia/packages/PSIS/...
ModelComparisonResult with Stacking weights
name rank elpd elpd_mcse elpd_diff elpd_diff_mcse weight p ⋯
non_centered 1 -31 1.4 0 0 1.0 0.9 ⋯
centered 2 -31 1.4 0.06 0.067 0.0 0.9 ⋯
rank elpd elpd_mcse elpd_diff elpd_diff_mcse weight p ⋯
non_centered 1 -31 1.4 0 0.0 1.0 0.9 ⋯
centered 2 -31 1.4 0.06 0.067 0.0 0.9 ⋯
1 column omitted
```
Expand All @@ -68,9 +68,9 @@ julia> elpd_results = mc.elpd_result;
julia> compare(elpd_results; weights_method=BootstrappedPseudoBMA())
ModelComparisonResult with BootstrappedPseudoBMA weights
name rank elpd elpd_mcse elpd_diff elpd_diff_mcse weight p ⋯
non_centered 1 -31 1.4 0 0 0.5 0.9 ⋯
centered 2 -31 1.4 0.06 0.067 0.5 0.9 ⋯
rank elpd elpd_mcse elpd_diff elpd_diff_mcse weight p ⋯
non_centered 1 -31 1.4 0 0.0 0.52 0.9 ⋯
centered 2 -31 1.4 0.06 0.067 0.48 0.9 ⋯
1 column omitted
```
"""
Expand Down Expand Up @@ -144,46 +144,35 @@ struct ModelComparisonResult{E,N,R,W,ER,M}
weights_method::M
end

function _print_comparison_results(
io::IO, ::MIME"text/plain", r::ModelComparisonResult; sigdigits_se=2
)
table = Tables.columntable(r)
cols = Tables.columnnames(table)
formatters = function (v, i, j)
nm = cols[j]
if nm (:elpd, :elpd_diff, :p)
nm_se = Symbol("$(nm)_mcse")
v_se = table[nm_se][i]
sigdigits = sigdigits_matching_error(v, v_se)
return sprint(Printf.format, Printf.Format("%.$(sigdigits)g"), v)
elseif nm (:elpd_mcse, :elpd_diff_mcse, :p_mcse)
sigdigits = sigdigits_se
return sprint(Printf.format, Printf.Format("%.$(sigdigits)g"), v)
elseif nm === :rank
return string(v)
elseif nm === :weight
return sprint(Printf.format, Printf.Format("%.1f"), v)
else
return string(v)
end
end
PrettyTables.pretty_table(
#### custom tabular show methods

function Base.show(io::IO, mime::MIME"text/plain", r::ModelComparisonResult; kwargs...)
return _show(io, mime, r; kwargs...)
end
function Base.show(io::IO, mime::MIME"text/html", r::ModelComparisonResult; kwargs...)
return _show(io, mime, r; kwargs...)
end

function _show(io::IO, mime::MIME, r::ModelComparisonResult; kwargs...)
row_labels = collect(r.name)
cols = Tables.columnnames(r)[2:end]
table = NamedTuple{cols}(Tables.columntable(r))

weights_method_name = _typename(r.weights_method)
weights = table.weight
digits_weights = ceil(Int, -log10(maximum(weights))) + 1
weight_formatter = PrettyTables.ft_printf(
"%.$(digits_weights)f", findfirst(==(:weight), cols)
)
return _show_prettytable(
io,
mime,
table;
show_subheader=false,
hlines=:none,
vlines=:none,
formatters,
newline_at_end=false,
title="ModelComparisonResult with $(weights_method_name) weights",
row_labels,
extra_formatters=(weight_formatter,),
kwargs...,
)
return nothing
end

function Base.show(io::IO, mime::MIME"text/plain", result::ModelComparisonResult)
weights_method_name = _typename(result.weights_method)
println(io, "ModelComparisonResult with $(weights_method_name) weights")
_print_comparison_results(io, mime, result)
return nothing
end

function _permute(r::ModelComparisonResult, perm)
Expand Down Expand Up @@ -213,6 +202,8 @@ function Tables.getcolumn(r::ModelComparisonResult, nm::Symbol)
end
throw(ArgumentError("Unrecognized column name $nm"))
end
Tables.rowaccess(::Type{<:ModelComparisonResult}) = true
Tables.rows(r::ModelComparisonResult) = Tables.rows(Tables.columntable(r))

IteratorInterfaceExtensions.isiterable(::ModelComparisonResult) = true
function IteratorInterfaceExtensions.getiterator(r::ModelComparisonResult)
Expand Down
23 changes: 4 additions & 19 deletions src/ArviZStats/elpdresult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,12 @@ Subtypes implement the following functions:
"""
abstract type AbstractELPDResult end

function _print_elpd_estimates(
io::IO, ::MIME"text/plain", r::AbstractELPDResult; sigdigits_se=2
function _show_elpd_estimates(
io::IO, mime::MIME"text/plain", r::AbstractELPDResult; kwargs...
)
estimates = elpd_estimates(r)
elpd, elpd_mcse = estimates.elpd, estimates.elpd_mcse
p, p_mcse = estimates.p, estimates.p_mcse
table = (; Estimate=[elpd, p], SE=[elpd_mcse, p_mcse])
formatters = function (v, i, j)
sigdigits = j == 1 ? sigdigits_matching_error(v, table.SE[i]) : sigdigits_se
return sprint(Printf.format, Printf.Format("%.$(sigdigits)g"), v)
end
PrettyTables.pretty_table(
io,
table;
show_subheader=false,
row_labels=["elpd", "p"],
hlines=:none,
vlines=:none,
formatters,
newline_at_end=false,
)
table = map(Base.vect, NamedTuple{(:elpd, :elpd_mcse, :p, :p_mcse)}(estimates))
_show_prettytable(io, mime, table; kwargs...)
return nothing
end

Expand Down
15 changes: 7 additions & 8 deletions src/ArviZStats/hdi.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
const HDI_DEFAULT_PROB = 0.94
# this pattern ensures that the type is completely specified at compile time
const HDI_BOUND_DIM = Dimensions.format(
Dimensions.Dim{:hdi_bound}([:lower, :upper]), Base.OneTo(2)
)

"""
hdi(samples::AbstractArray{<:Real}; prob=$(HDI_DEFAULT_PROB)) -> (; lower, upper)
hdi(samples::AbstractArray{<:Real}; prob=$(DEFAULT_INTERVAL_PROB)) -> (; lower, upper)
Estimate the unimodal highest density interval (HDI) of `samples` for the probability `prob`.
Expand All @@ -20,8 +19,8 @@ This implementation uses the algorithm of [^ChenShao1999].
!!! note
Any default value of `prob` is arbitrary. The default value of
`prob=$(HDI_DEFAULT_PROB)` instead of a more common default like `prob=0.95` is chosen
to reminder the user of this arbitrariness.
`prob=$(DEFAULT_INTERVAL_PROB)` instead of a more common default like `prob=0.95` is
chosen to reminder the user of this arbitrariness.
[^Hyndman1996]: Rob J. Hyndman (1996) Computing and Graphing Highest Density Regions,
Amer. Stat., 50(2): 120-6.
Expand Down Expand Up @@ -66,11 +65,11 @@ function hdi(x::AbstractArray{<:Real}; kwargs...)
end

"""
hdi!(samples::AbstractArray{<:Real}; prob=$(HDI_DEFAULT_PROB)) -> (; lower, upper)
hdi!(samples::AbstractArray{<:Real}; prob=$(DEFAULT_INTERVAL_PROB)) -> (; lower, upper)
A version of [hdi](@ref) that sorts `samples` in-place while computing the HDI.
"""
function hdi!(x::AbstractArray{<:Real}; prob::Real=HDI_DEFAULT_PROB)
function hdi!(x::AbstractArray{<:Real}; prob::Real=DEFAULT_INTERVAL_PROB)
0 < prob < 1 || throw(DomainError(prob, "HDI `prob` must be in the range `(0, 1)`.]"))
return _hdi!(x, prob)
end
Expand Down Expand Up @@ -104,8 +103,8 @@ function _hdi!(x::AbstractArray{<:Real}, prob::Real)
end

"""
hdi(data::InferenceData; prob=$HDI_DEFAULT_PROB) -> Dataset
hdi(data::Dataset; prob=$HDI_DEFAULT_PROB) -> Dataset
hdi(data::InferenceData; prob=$DEFAULT_INTERVAL_PROB) -> Dataset
hdi(data::Dataset; prob=$DEFAULT_INTERVAL_PROB) -> Dataset
Calculate the highest density interval (HDI) for each parameter in the data.
Expand Down
15 changes: 6 additions & 9 deletions src/ArviZStats/loo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ function elpd_estimates(r::PSISLOOResult; pointwise::Bool=false)
return pointwise ? r.pointwise : r.estimates
end

function Base.show(io::IO, mime::MIME"text/plain", result::PSISLOOResult)
println(io, "PSISLOOResult with estimates")
_print_elpd_estimates(io, mime, result)
function Base.show(io::IO, mime::MIME"text/plain", result::PSISLOOResult; kwargs...)
_show_elpd_estimates(io, mime, result; title="PSISLOOResult with estimates", kwargs...)
println(io)
println(io)
print(io, "and ")
Expand Down Expand Up @@ -71,9 +70,8 @@ julia> reff = ess(log_like; kind=:basic, split_chains=1, relative=true);
julia> loo(log_like; reff)
PSISLOOResult with estimates
Estimate SE
elpd -31 1.4
p 0.9 0.34
elpd elpd_mcse p p_mcse
-31 1.4 0.9 0.34
and PSISResult with 500 draws, 4 chains, and 8 parameters
Pareto shape (k) diagnostic values:
Expand Down Expand Up @@ -103,9 +101,8 @@ julia> idata = load_example_data("centered_eight");
julia> loo(idata)
PSISLOOResult with estimates
Estimate SE
elpd -31 1.4
p 0.9 0.34
elpd elpd_mcse p p_mcse
-31 1.4 0.9 0.34
and PSISResult with 500 draws, 4 chains, and 8 parameters
Pareto shape (k) diagnostic values:
Expand Down
Loading

2 comments on commit c31176a

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/88874

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.9.2 -m "<description of version>" c31176a952bb6cd17554eeff3ad5d70f46675368
git push origin v0.9.2

Please sign in to comment.