Skip to content

Commit

Permalink
Use PSIS.jl for Pareto-smoothing (#144)
Browse files Browse the repository at this point in the history
* Add PSIS dependency

* Import psis/psis!

* Deprecate psislw

* Export psis/psis!

* Update api docs

* Point reference to psis

* Run formatter

* Increment version number

* Use latest version of PSIS

* Import PSISResult as well

* Implement wrapping psislw

* Add PSISResult to API

* Add LogExpFunctions as dependency

* Remove outdated deprecation warnings

* Add psislw test

* Run formatter

* Point to arviz issue

* Fix reference

* Also export PSIS

* Add PSIS.jl's public docstrings
  • Loading branch information
sethaxen authored Nov 22, 2021
1 parent 4a10410 commit da15978
Show file tree
Hide file tree
Showing 10 changed files with 78 additions and 10 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
name = "ArviZ"
uuid = "131c737c-5715-5e2e-ad31-c244f01c1dc7"
authors = ["Seth Axen <[email protected]>"]
version = "0.5.7"
version = "0.5.8"

[deps]
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50"
PSIS = "ce719bf2-d5d0-4fb9-925d-10a81b42ad04"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
Expand All @@ -18,9 +20,11 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
CmdStan = "5.2.3, 6.0"
Conda = "1.0"
DataFrames = "0.20, 0.21, 0.22, 1.0"
LogExpFunctions = "0.2.0, 0.3"
MCMCChains = "0.3.15, 0.4, 1.0, 2.0, 3.0, 4.0"
MonteCarloMeasurements = "0.6.4, 0.7, 0.8"
NamedTupleTools = "0.11.0, 0.12, 0.13"
PSIS = "0.2"
PyCall = "1.91.2"
PyPlot = "2.8.2"
Requires = "0.5.2, 1.0"
Expand Down
5 changes: 4 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@
| [`hdi`](@ref) | Calculate highest density interval (HDI) of array for given probability. |
| [`loo`](@ref) | Pareto-smoothed importance sampling leave-one-out (LOO) cross-validation. |
| [`loo_pit`](@ref) | Compute leave-one-out probability integral transform (PIT) values. |
| [`psislw`](@ref) | Pareto smoothed importance sampling (PSIS). |
| [`psislw`](@ref) | Pareto smoothed importance sampling (PSIS). (deprecated) |
| [`psis`](@ref) | Pareto smoothed importance sampling (PSIS). |
| [`psis!`](@ref) | Pareto smoothed importance sampling (PSIS) in-place. |
| [`PSISResult`](@ref) | Container for results of Pareto smoothed importance sampling. |
| [`r2_score`](@ref) | $R^2$ for Bayesian regression models. |
| [`waic`](@ref) | Calculate the widely available information criterion (WAIC). |

Expand Down
2 changes: 1 addition & 1 deletion docs/src/mpl_examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ plot_loo_pit(idata; y="y", ecdf=true, color="maroon")
gcf()
```

See [`psislw`](@ref), [`plot_loo_pit`](@ref)
See [`psis`](@ref), [`plot_loo_pit`](@ref)

## LOO-PIT Overlay Plot

Expand Down
12 changes: 11 additions & 1 deletion docs/src/reference.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# [API Reference](@id reference)
# [Reference](@id reference)

## Exported

```@autodocs
Modules = [ArviZ, PSIS]
Private = false
```

## Internal

```@autodocs
Modules = [ArviZ]
Public = false
```
5 changes: 4 additions & 1 deletion src/ArviZ.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ using DataFrames
using PyCall
using Conda
using PyPlot
using PSIS: PSIS, PSISResult, psis, psis!
using LogExpFunctions: logsumexp

import Base:
convert,
Expand Down Expand Up @@ -58,7 +60,8 @@ export plot_autocorr,
plot_violin

## Stats
export summarystats, compare, hdi, loo, loo_pit, psislw, r2_score, waic
export PSIS, PSISResult, psis, psis!, psislw
export summarystats, compare, hdi, loo, loo_pit, r2_score, waic

## Diagnostics
export bfmi, ess, rhat, mcse
Expand Down
2 changes: 1 addition & 1 deletion src/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ arviz_version() = VersionNumber(arviz.__version__)
function check_needs_update(; update=true)
if arviz_version() < _min_arviz_version
@warn "ArviZ.jl only officially supports arviz version $(_min_arviz_version) or " *
"greater but found version $(arviz_version())."
"greater but found version $(arviz_version())."
if update
if update_arviz()
# yay, but we still already imported the old version
Expand Down
31 changes: 30 additions & 1 deletion src/stats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
@forwardfun hdi
@forwardfun loo
@forwardfun loo_pit
@forwardfun psislw
@forwardfun r2_score
@forwardfun waic

Expand All @@ -22,6 +21,36 @@ function convert_result(::typeof(compare), result)
return todataframes(result; index_name=:name)
end

"""
psislw(log_weights, reff=1.0) -> (lw_out, kss)
Pareto smoothed importance sampling (PSIS).
!!! note
This function is deprecated and is just a thin wrapper around [`psis`](@ref).
# Arguments
- `log_weights`: Array of size `(nobs, ndraws)`
- `reff`: relative MCMC efficiency, `ess / n`
# Returns
- `lw_out`: Smoothed log weights
- `kss`: Pareto tail indices
"""
function psislw(logw, reff=1)
@warn "`psislw(logw[, reff])` is deprecated, use `psis(logw[, reff])`" maxlog = 1
result = psis(logw, reff)
log_weights = result.log_weights
d = ndims(log_weights)
dims = d == 1 ? Colon() : ntuple(Base.Fix1(+, 1), d - 1)
log_norm_exp = logsumexp(log_weights; dims=dims)
log_weights .-= log_norm_exp
return log_weights, result.pareto_shape
end

@doc doc"""
summarystats(
data::InferenceData;
Expand Down
2 changes: 1 addition & 1 deletion test/test_backend.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ if !ispynull(ArviZ.bokeh) && "plot.backend" in keys(ArviZ.rcParams)
end
bytes = read(fn)
@test bytes[1:8] ==
UInt8[0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a]
UInt8[0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a]
end

@testset "show MIME\"$(mime)\"" for mime in [
Expand Down
4 changes: 2 additions & 2 deletions test/test_samplechains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ end
end
@testset for var_name in propertynames(multichain)
@test getproperty(getproperty(idata, group), var_name).values ==
getproperty(getproperty(idata_nt, group), var_name).values
getproperty(getproperty(idata_nt, group), var_name).values
@test getproperty(getproperty(idata, group), var_name).values ==
getproperty(getproperty(idata_conv, group), var_name).values
getproperty(getproperty(idata_conv, group), var_name).values
end
end
end
Expand Down
19 changes: 19 additions & 0 deletions test/test_stats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,25 @@ using DataFrames: DataFrames
@test all(df == ArviZ.todataframes(ArviZ.arviz.loo(idata)))
end

@testset "psislw" begin
@testset for sz in ((1000,), (10, 1000))
logw = randn(sz)
logw_smoothed, k = psislw(copy(logw), 0.9)

# check against PSIS.jl
result = psis(copy(logw), 0.9)
@test exp.(logw_smoothed) result.weights
@test k result.pareto_shape

# check against Python ArviZ
# NOTE: currently these implementations disagree
# see https://github.com/arviz-devs/arviz/issues/1941
logw_smoothed2, k2 = ArviZ.arviz.psislw(copy(logw), 0.9)
@test_broken logw_smoothed logw_smoothed2
@test_broken k k2
end
end

@testset "waic" begin
df = waic(idata)
@test df isa DataFrames.DataFrame
Expand Down

2 comments on commit da15978

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/49215

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.8 -m "<description of version>" da15978dab341ab2d7161cdc34ace9492c75bb81
git push origin v0.5.8

Please sign in to comment.