Skip to content

Commit e65c94b

Browse files
committed
Mamba/Stan Compatible
1 parent e389965 commit e65c94b

5 files changed

Lines changed: 48 additions & 29 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
*.jl.mem
44
*.jls
55
*.R
6+
*.bck
67
tmp/
78

Example/Compare.pdf

-92.8 KB
Binary file not shown.

Example/demo_wells.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,10 @@ else
3030
write("sim.jls", sim)
3131
end
3232

33-
r,v,c = size(sim)
34-
ns = filter(x->startswith(x,"log_lik"),sim.names)
35-
log_lik = reshape(permutedims(sim[:,ns,:].value,[3, 1 ,2]), (r*c ,length(ns)))
33+
names_sim = filter(x->startswith(x,"log_lik"), sim.names)
3634

3735
# Compute LOO and standard error
36+
log_lik = sim[:, names_sim, :]
3837
loo, loos, pk = psisloo(log_lik)
3938
elpd_loo = sum(loos)
4039
se_elpd_loo = std(loos) * sqrt(n)
@@ -50,7 +49,7 @@ else
5049
@printf(">> %d (%.0f%%) PSIS Pareto k estimates greater than 1\n", pkn2, pkn2/n*100)
5150
end
5251

53-
52+
exit()
5453
# Fit a second model, using log(arsenic) instead of arsenic
5554
x2 = Float64[log.(data["arsenic"]) data["dist"]]
5655

@@ -64,11 +63,8 @@ else
6463
write("sim2.jls", sim2)
6564
end
6665

67-
r,v,c = size(sim2)
68-
ns = filter(x->startswith(x,"log_lik"),sim2.names)
69-
log_lik = reshape(permutedims(sim2[:,ns,:].value,[3, 1 ,2]), (r*c ,length(ns)))
70-
7166
# Compute LOO and standard error
67+
log_lik = sim2[:, names_sim, :]
7268
loo2, loos2, pk2 = psisloo(log_lik)
7369
elpd_loo = sum(loos2)
7470
se_elpd_loo = std(loos2) * sqrt(n)
@@ -105,10 +101,10 @@ for cvi in 1:10
105101
"xt" => x[cvitst[cvi],:], "yt" => y[cvitst[cvi]])
106102
]
107103
# Fit the model in Stan
108-
simcv = stan(stanmodel, standatacv, '.', CmdStanDir=CMDSTAN_HOME, summary=false)
109-
r,v,c = size(sim2)
110-
ns = filter(x->startswith(x,"log_likt"),simcv.names)
111-
log_likt = reshape(permutedims(simcv[:,ns,:].value,[3, 1 ,2]), (r*c ,length(ns)))
104+
simcv = stan(stanmodel, standatacv, '.',
105+
CmdStanDir=CMDSTAN_HOME, summary=false)
106+
ns = filter(x->startswith(x,"log_likt"), simcv.names)
107+
log_likt = Mamba.combine(simcv[:, ns, :])
112108
kfcvs[cvitst[cvi]]= PSIS.logsumexp(log_likt) - log(size(log_likt,1))
113109
end
114110

REQUIRE

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
julia 0.5
22
Mamba 0.10
3-
Stan 1.0.0
43

src/PSIS.jl

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ arXiv preprint arXiv:1507.02646.
3131
"""
3232
module PSIS
3333

34+
using Mamba
35+
3436
export psislw
3537
export psisloo
3638

@@ -57,23 +59,31 @@ not exist and if tail index k>1 the mean of the raw estimate does not exist
5759
and the PSIS estimate is likely to have large variation and some bias.
5860
5961
# Arguments
60-
* `log_lik::AbstractArray`: Array of size n x m containing n posterior samples of the log likelihood terms p(y_i|\theta^s).
62+
* `log_lik::Union{AbstractArray, Mamba.Chains}`: Array of size n x m containing n posterior samples of the log likelihood terms p(y_i|\theta^s).
6163
* `wcpp::Real`: Percentage of samples used for GPD fit estimate (default is 20).
6264
* `wtrunc::Float64`: Positive parameter for truncating very large weights to n^wtrunc. Providing False or 0 disables truncation. Default values is 3/4.
6365
6466
# Returns
6567
* `loo::Real`: sum of the leave-one-out log predictive densities.
6668
* `loos::AbstractArray`: Individual leave-one-out log predictive density terms.* `ks::AbstractArray`: Estimated Pareto tail indeces.
6769
"""
68-
function psisloo(log_lik, wcpp=20, wtrunc=3/4)
70+
# Compute LOO and standard error
71+
function psisloo(log_lik::Union{Mamba.Chains, AbstractArray},
72+
wcpp::Int64=20, wtrunc::Float64=3/4)
73+
6974
# log raw weights from log_lik
70-
lw = -copy(log_lik)
75+
if isa(log_lik, Mamba.Chains)
76+
lw = Mamba.combine(log_lik)
77+
else
78+
lw = copy(log_lik)
79+
end
7180
# compute Pareto smoothed log weights given raw log weights
72-
lw, ks = psislw(lw, wcpp, wtrunc)
73-
# compute
74-
lw += log_lik
75-
loos = logsumexp(lw, 1)
81+
lwp, ks = psislw(-lw, wcpp, wtrunc)
82+
83+
lwp += lw
84+
loos = logsumexp(lwp, 1)
7685
loo = sum(loos)
86+
7787
return loo, loos, ks
7888
end
7989

@@ -83,27 +93,40 @@ end
8393
Compute the Pareto smoothed importance sampling (PSIS).
8494
8595
# Arguments
86-
* `lw::AbstractArray`: Array of size n x m containing m sets of n log weights. It is also possible to provide one dimensional array of length n.
96+
* `lw::Union{AbstractArray, Mamba.Chains}`: Array of size n x m containing m sets of n log weights. It is also possible to provide one dimensional array of length n.
8797
* `wcpp::Real`: Percentage of samples used for GPD fit estimate (default is 20).
8898
* `wtrunc::Float64`: Positive parameter for truncating very large weights to n^wtrunc. Providing False or 0 disables truncation. Default values is 3/4.
8999
90100
# Returns
91101
* `lw_out::AbstractArray`: Smoothed log weights
92102
* `kss::AbstractArray`: Pareto tail indices
93103
"""
94-
function psislw(lw, wcpp=20, wtrunc=3/4)
95-
if ndims(lw) == 2
96-
n, m = size(lw)
97-
elseif ndims(lw) == 1
98-
n = length(lw)
99-
m = 1
104+
function psislw(lw::Union{AbstractArray, Mamba.Chains},
105+
wcpp::Int64=20, wtrunc::Float64=3/4)
106+
107+
if isa(lw, Mamba.Chains)
108+
lw_out = Mamba.combine(lw)
100109
else
110+
lw_out = copy(lw)
111+
end
112+
113+
if ~(1 <= ndims(lw_out) <= 2)
101114
throw(DimensionMismatch("Argument `lw` must be 1 or 2 dimensional."))
102115
end
103-
if n <= 1
116+
if size(lw_out,1) <= 1
104117
error("More than one log-weight needed.")
105118
end
106-
lw_out = copy(lw)
119+
return _psislw(lw_out, wcpp, wtrunc)
120+
end
121+
122+
function _psislw(lw_out::Array{Float64}, wcpp::Int64, wtrunc::Float64)
123+
124+
if ndims(lw_out) == 2
125+
n, m = size(lw_out)
126+
elseif ndims(lw_out) == 1
127+
n = length(lw_out)
128+
m = 1
129+
end
107130
kss = zeros(Float64, m)
108131

109132
# precalculate constants

0 commit comments

Comments
 (0)