Skip to content

Commit 55276e4

Browse files
committed
Rel 0.0.4 - Added waic() and pk utilities
1 parent ed47a46 commit 55276e4

4 files changed

Lines changed: 66 additions & 3 deletions

File tree

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
name = "PSIS"
22
uuid = "fb740163-aa3c-59c1-9c12-c3f890714cde"
3-
authors = ["Rob J Goedman <goedman@icloud.com"]
3+
authors = ["@alvaro1101, Rob J Goedman <goedman@icloud.com"]
44
version = "0.0.3"
55

66
[deps]
77
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
88
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
99
StanSample = "c1514b29-d3a0-5178-b312-660c88baa699"
10-
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
1110
StatisticalRethinking = "2d09df54-9d0f-5258-8220-54c2a3d4fbee"
1211
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
12+
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
1313
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1414

1515
[compat]
1616
JSON = "0.21"
17-
StatsPlots = "0.14"
1817
StanSample = "3.0"
1918
StatisticalRethinking = "3.2"
19+
StatsPlots = "0.14"
2020
julia = "1"

src/PSIS.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ include("psislw.jl")
3939
include("gpdfitnew.jl")
4040
include("gpinv.jl")
4141
include("logsumexp.jl")
42+
include("waic.jl")
43+
include("pk_utilities.jl")
4244

4345
export
4446
psis_path,

src/pk_utilities.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
function pk_qualify(pk::Vector{Float64})
2+
pk_good = sum(pk .<= 0.5)
3+
pk_ok = length(pk[pk .<= 0.7]) - pk_good
4+
pk_bad = length(pk[pk .<= 1]) - pk_good - pk_ok
5+
(good=pk_good, ok=pk_ok, bad=pk_bad, very_bad=sum(pk .> 1))
6+
end
7+
8+
function pk_plot(pk::Vector{Float64}, title="PSIS diagnostic plot.")
9+
scatter(pk1, xlab="Datapoint", ylab="Pareto shape k",
10+
marker=2.5, lab="Pk points", leg=:topleft)
11+
hline!([0.5], lab="pk = 0.5");hline!([0.7], lab="pk = 0.7")
12+
hline!([1], lab="pk = 1.0")
13+
title!(title)
14+
end
15+
16+
export
17+
pk_qualify,
18+
pk_plot

src/waic.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
2+
var2(x) = mean(x.^2) .- mean(x)^2
3+
4+
function log_sum_exp(x)
5+
xmax = maximum(x)
6+
xsum = sum(exp.(x .- xmax))
7+
xmax + log(xsum)
8+
end
9+
10+
function waic( ll::AbstractArray; pointwise=FALSE , log_lik="log_lik" , kwargs... )
11+
12+
n_samples, n_obs = size(ll)
13+
lpd <- zeros(n_obs)
14+
pD <- zeros(n_obs)
15+
16+
for i in 1:n_obs
17+
lpd[i] = log_sum_exp(ll[:,i]) .- log(n_samples)
18+
pD[i] = var2(ll[:,i])
19+
end
20+
21+
waic_vec = (-2) .* ( lpd - pD )
22+
if pointwise==FALSE
23+
waic = sum(waic_vec)
24+
lpd = sum(lpd)
25+
pD = sum(pD)
26+
else
27+
waic = waic_vec
28+
end
29+
30+
try
31+
se = sqrt( n_obs*var2(waic_vec) )
32+
catch e
33+
prinrln(e)
34+
se = nothing
35+
end
36+
37+
(WAIC=waic, lppd=lpd, penalty=pD, std_err=se)
38+
end
39+
40+
export
41+
var2,
42+
log_sum_exp,
43+
waic

0 commit comments

Comments
 (0)