Skip to content

Commit 5f876a2

Browse files
committed
Rel 0.0.1 - Fixing psislw
1 parent cc1ddd2 commit 5f876a2

4 files changed

Lines changed: 124 additions & 435 deletions

File tree

Example/cvit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ function cvit(n, k=10, rsubstream=false)
2929
itr = Any[]
3030

3131
for cvi in 1:a
32-
push!(itst, collect(1:b) + (cvi-1) * b)
32+
push!(itst, collect(1:b) .+ (cvi-1) * b)
3333
push!(itr, setdiff(1:n,itst[cvi]))
3434
end
3535
for cvi in (a+1):k

Example/demo_wells.jl

Lines changed: 69 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ using JSON
33
using StanSample
44
using PSIS
55
using Statistics
6+
using Printf
7+
using StatsPlots
68

79
ProjDir = @__DIR__
810

@@ -17,94 +19,97 @@ n, m = size(x)
1719
# Model
1820
model_str = read(open(joinpath(ProjDir, "arsenic_logistic.stan")), String)
1921
tmpdir = joinpath(ProjDir, "tmp")
20-
sm = SampleModel("arsenic_logistic", model_str; tmpdir)
22+
sm1 = SampleModel("arsenic_logistic", model_str; tmpdir)
23+
println("\n-----------------------\n")
2124

22-
data = (p = m, N = n, y = Int.(y), x = x)
25+
data1 = (p = m, N = n, y = Int.(y), x = x)
2326
# Fit the model in Stan
24-
rc = stan_sample(sm; data)
25-
nt = read_samples(sm)
26-
27-
# Compute LOO and standard error
28-
log_lik = nt.log_lik'
29-
loo, loos, pk = psisloo(log_lik)
30-
elpd_loo = sum(loos)
31-
se_elpd_loo = std(loos) * sqrt(n)
32-
@printf(">> elpd_loo = %.1f, SE(elpd_loo) = %.1f\n", elpd_loo, se_elpd_loo)
33-
34-
# Check the shape parameter k of the generalized Pareto distribution
35-
if all(pk .< 0.5)
36-
println("All Pareto k estimates OK (k < 0.5)")
37-
else
38-
pkn1 = sum((pk .>= 0.5) & (pk .< 1))
39-
pkn2 = sum(pk .>= 1)
40-
@printf(">> %d (%.0f%%) PSIS Pareto k estimates between 0.5 and 1\n", pkn1, pkn1/n*100)
41-
@printf(">> %d (%.0f%%) PSIS Pareto k estimates greater than 1\n", pkn2, pkn2/n*100)
27+
rc1 = stan_sample(sm1; data=data1)
28+
if success(rc1)
29+
nt1 = read_samples(sm)
30+
31+
# Compute LOO and standard error
32+
log_lik = nt1.log_lik'
33+
loo, loos, pk = psisloo(log_lik)
34+
elpd_loo = sum(loos)
35+
se_elpd_loo = std(loos) * sqrt(n)
36+
@printf(">> elpd_loo = %.1f, SE(elpd_loo) = %.1f\n", elpd_loo, se_elpd_loo)
37+
38+
# Check the shape parameter k of the generalized Pareto distribution
39+
if all(pk .< 0.5)
40+
println("All Pareto k estimates OK (k < 0.5)")
41+
else
42+
pkn1 = sum((pk .>= 0.5) & (pk .< 1))
43+
pkn2 = sum(pk .>= 1)
44+
@printf(">> %d (%.0f%%) PSIS Pareto k estimates between 0.5 and 1\n", pkn1, pkn1/n*100)
45+
@printf(">> %d (%.0f%%) PSIS Pareto k estimates greater than 1\n", pkn2, pkn2/n*100)
46+
end
4247
end
48+
println("\n-----------------------\n")
4349

44-
exit()
4550
# Fit a second model, using log(arsenic) instead of arsenic
4651
x2 = Float64[log.(data["arsenic"]) data["dist"]]
4752

4853
# Model
49-
if isfile("sim2.jls")
50-
sim2 = read("sim2.jls", Chains)
51-
else
52-
standata2 = [Dict("p" => m, "N" => n, "y" => y, "x" => x2)]
53-
# Fit the model in Stan
54-
sim2 = stan(stanmodel, standata2, '.', CmdStanDir=CMDSTAN_HOME, summary=false)
55-
write("sim2.jls", sim2)
54+
data2 = (p = m, N = n, y = Int.(y), x = x2)
55+
# Fit the model in Stan
56+
rc2 = stan_sample(sm1; data=data2)
57+
58+
if success(rc2)
59+
nt2 = read_samples(sm)
60+
# Compute LOO and standard error
61+
log_lik = nt2.log_lik'
62+
loo2, loos2, pk2 = psisloo(log_lik)
63+
elpd_loo = sum(loos2)
64+
se_elpd_loo = std(loos2) * sqrt(n)
65+
@printf(">> elpd_loo = %.1f, SE(elpd_loo) = %.1f\n", elpd_loo, se_elpd_loo)
66+
67+
# Check the shape parameter k of the generalized Pareto distribution
68+
if all(pk .< 0.5)
69+
println("All Pareto k estimates OK (k < 0.5)")
70+
else
71+
pkn1 = sum((pk .>= 0.5) & (pk .< 1))
72+
pkn2 = sum(pk .>= 1)
73+
@printf(">> %d (%.0f%%) PSIS Pareto k estimates between 0.5 and 1\n", pkn1, pkn1/n*100)
74+
@printf(">> %d (%.0f%%) PSIS Pareto k estimates greater than 1\n", pkn2, pkn2/n*100)
75+
end
5676
end
5777

58-
# Compute LOO and standard error
59-
log_lik = sim2[:, names_sim, :]
60-
loo2, loos2, pk2 = psisloo(log_lik)
61-
elpd_loo = sum(loos2)
62-
se_elpd_loo = std(loos2) * sqrt(n)
63-
@printf(">> elpd_loo = %.1f, SE(elpd_loo) = %.1f\n", elpd_loo, se_elpd_loo)
64-
65-
# Check the shape parameter k of the generalized Pareto distribution
66-
if all(pk .< 0.5)
67-
println("All Pareto k estimates OK (k < 0.5)")
68-
else
69-
pkn1 = sum((pk .>= 0.5) & (pk .< 1))
70-
pkn2 = sum(pk .>= 1)
71-
@printf(">> %d (%.0f%%) PSIS Pareto k estimates between 0.5 and 1\n", pkn1, pkn1/n*100)
72-
@printf(">> %d (%.0f%%) PSIS Pareto k estimates greater than 1\n", pkn2, pkn2/n*100)
78+
if success(rc1) && success(rc2)
79+
## Compare the models
80+
loodiff = loos - loos2
81+
@printf("elpd_diff = %.1f, SE(elpd_diff) = %.1f\n",sum(loodiff), std(loodiff) * sqrt(n))
7382
end
74-
75-
## Compare the models
76-
loodiff = loos - loos2
77-
@printf("elpd_diff = %.1f, SE(elpd_diff) = %.1f\n",sum(loodiff), std(loodiff) * sqrt(n))
78-
83+
println("\n-----------------------\n")
7984

8085
## k-fold-CV
8186
# k-fold-CV should be used if several khats>0.5
8287
# in this case it is not needed, but provided as an example
83-
model_str = readstring(open("arsenic_logistic_t.stan"))
84-
stanmodel = Stanmodel(name="arsenic_logistic_t", adapt=500, update=500, model=model_str);
88+
89+
model_str = read(open(joinpath(ProjDir, "arsenic_logistic_t.stan")), String)
90+
sm3 = SampleModel("arsenic_logistic_t", model_str; tmpdir=tmpdir);
8591

8692
cvitr, cvitst = cvit(n, 10, true)
8793
kfcvs = similar(loos)
8894
for cvi in 1:10
8995
@printf("%d\n", cvi)
9096

91-
standatacv = [Dict("p" => m, "N" => length(cvitr[cvi]), "Nt" => length(cvitst[cvi]),
92-
"x" => x[cvitr[cvi],:], "y" => y[cvitr[cvi]],
93-
"xt" => x[cvitst[cvi],:], "yt" => y[cvitst[cvi]])
94-
]
97+
standatacv = (p = m, N = length(cvitr[cvi]), Nt = length(cvitst[cvi]),
98+
x = x[cvitr[cvi],:], y = Int.(y[cvitr[cvi]]),
99+
xt = x[cvitst[cvi],:], yt = Int.(y[cvitst[cvi]]))
100+
95101
# Fit the model in Stan
96-
simcv = stan(stanmodel, standatacv, '.',
97-
CmdStanDir=CMDSTAN_HOME, summary=false)
98-
ns = filter(x->startswith(x,"log_likt"), simcv.names)
99-
log_likt = Mamba.combine(simcv[:, ns, :])
100-
kfcvs[cvitst[cvi]]= PSIS.logsumexp(log_likt) - log(size(log_likt,1))
102+
rc3 = stan_sample(sm3; data=standatacv)
103+
if success(rc3)
104+
nt3 = read_samples(sm3)
105+
# Compute LOO and standard error
106+
log_likt = nt3.log_likt'
107+
kfcvs[cvitst[cvi]] = PSIS.logsumexp(log_likt) .- log(size(log_likt, 1))
108+
end
101109
end
102110

103111
# compare PSIS-LOO and k-fold-CV
104-
p = plot(layer(x = loos, y = kfcvs, Geom.point),
105-
layer(x = [-3.5,0] ,y=[-3.5,0], Geom.line, style(default_color=colorant"red")),
106-
Guide.xlabel("PSIS-LOO"),
107-
Guide.ylabel("10-fold-CV"))
112+
plot(x = loos, y = kfcvs, xlab = "PSIS-LOO", ylab = "10-fold-CV")
108113

109-
draw(PDF("Compare.pdf", 210mm, 210mm),p)
114+
#savefig(PDF(joinpath(ProjDir, "Compare.pdf"), 210mm, 210mm),p)
110115

0 commit comments

Comments
 (0)