@@ -3,6 +3,8 @@ using JSON
33using StanSample
44using PSIS
55using Statistics
6+ using Printf
7+ using StatsPlots
68
79ProjDir = @__DIR__
810
@@ -17,94 +19,97 @@ n, m = size(x)
1719# Model
1820model_str = read (open (joinpath (ProjDir, " arsenic_logistic.stan" )), String)
1921tmpdir = joinpath (ProjDir, " tmp" )
20- sm = SampleModel (" arsenic_logistic" , model_str; tmpdir)
22+ sm1 = SampleModel (" arsenic_logistic" , model_str; tmpdir)
23+ println (" \n -----------------------\n " )
2124
22- data = (p = m, N = n, y = Int .(y), x = x)
25+ data1 = (p = m, N = n, y = Int .(y), x = x)
2326# Fit the model in Stan
24- rc = stan_sample (sm; data)
25- nt = read_samples (sm)
26-
27- # Compute LOO and standard error
28- log_lik = nt. log_lik'
29- loo, loos, pk = psisloo (log_lik)
30- elpd_loo = sum (loos)
31- se_elpd_loo = std (loos) * sqrt (n)
32- @printf (" >> elpd_loo = %.1f, SE(elpd_loo) = %.1f\n " , elpd_loo, se_elpd_loo)
33-
34- # Check the shape parameter k of the generalized Pareto distribution
35- if all (pk .< 0.5 )
36- println (" All Pareto k estimates OK (k < 0.5)" )
37- else
38- pkn1 = sum ((pk .>= 0.5 ) & (pk .< 1 ))
39- pkn2 = sum (pk .>= 1 )
40- @printf (" >> %d (%.0f%%) PSIS Pareto k estimates between 0.5 and 1\n " , pkn1, pkn1/ n* 100 )
41- @printf (" >> %d (%.0f%%) PSIS Pareto k estimates greater than 1\n " , pkn2, pkn2/ n* 100 )
27+ rc1 = stan_sample (sm1; data= data1)
28+ if success (rc1)
29+ nt1 = read_samples (sm)
30+
31+ # Compute LOO and standard error
32+ log_lik = nt1. log_lik'
33+ loo, loos, pk = psisloo (log_lik)
34+ elpd_loo = sum (loos)
35+ se_elpd_loo = std (loos) * sqrt (n)
36+ @printf (" >> elpd_loo = %.1f, SE(elpd_loo) = %.1f\n " , elpd_loo, se_elpd_loo)
37+
38+ # Check the shape parameter k of the generalized Pareto distribution
39+ if all (pk .< 0.5 )
40+ println (" All Pareto k estimates OK (k < 0.5)" )
41+ else
42+ pkn1 = sum ((pk .>= 0.5 ) & (pk .< 1 ))
43+ pkn2 = sum (pk .>= 1 )
44+ @printf (" >> %d (%.0f%%) PSIS Pareto k estimates between 0.5 and 1\n " , pkn1, pkn1/ n* 100 )
45+ @printf (" >> %d (%.0f%%) PSIS Pareto k estimates greater than 1\n " , pkn2, pkn2/ n* 100 )
46+ end
4247end
48+ println (" \n -----------------------\n " )
4349
44- exit ()
4550# Fit a second model, using log(arsenic) instead of arsenic
4651x2 = Float64[log .(data[" arsenic" ]) data[" dist" ]]
4752
4853# Model
49- if isfile (" sim2.jls" )
50- sim2 = read (" sim2.jls" , Chains)
51- else
52- standata2 = [Dict (" p" => m, " N" => n, " y" => y, " x" => x2)]
53- # Fit the model in Stan
54- sim2 = stan (stanmodel, standata2, ' .' , CmdStanDir= CMDSTAN_HOME, summary= false )
55- write (" sim2.jls" , sim2)
54+ data2 = (p = m, N = n, y = Int .(y), x = x2)
55+ # Fit the model in Stan
56+ rc2 = stan_sample (sm1; data= data2)
57+
58+ if success (rc2)
59+ nt2 = read_samples (sm)
60+ # Compute LOO and standard error
61+ log_lik = nt2. log_lik'
62+ loo2, loos2, pk2 = psisloo (log_lik)
63+ elpd_loo = sum (loos2)
64+ se_elpd_loo = std (loos2) * sqrt (n)
65+ @printf (" >> elpd_loo = %.1f, SE(elpd_loo) = %.1f\n " , elpd_loo, se_elpd_loo)
66+
67+ # Check the shape parameter k of the generalized Pareto distribution
68+ if all (pk .< 0.5 )
69+ println (" All Pareto k estimates OK (k < 0.5)" )
70+ else
71+ pkn1 = sum ((pk .>= 0.5 ) & (pk .< 1 ))
72+ pkn2 = sum (pk .>= 1 )
73+ @printf (" >> %d (%.0f%%) PSIS Pareto k estimates between 0.5 and 1\n " , pkn1, pkn1/ n* 100 )
74+ @printf (" >> %d (%.0f%%) PSIS Pareto k estimates greater than 1\n " , pkn2, pkn2/ n* 100 )
75+ end
5676end
5777
58- # Compute LOO and standard error
59- log_lik = sim2[:, names_sim, :]
60- loo2, loos2, pk2 = psisloo (log_lik)
61- elpd_loo = sum (loos2)
62- se_elpd_loo = std (loos2) * sqrt (n)
63- @printf (" >> elpd_loo = %.1f, SE(elpd_loo) = %.1f\n " , elpd_loo, se_elpd_loo)
64-
65- # Check the shape parameter k of the generalized Pareto distribution
66- if all (pk .< 0.5 )
67- println (" All Pareto k estimates OK (k < 0.5)" )
68- else
69- pkn1 = sum ((pk .>= 0.5 ) & (pk .< 1 ))
70- pkn2 = sum (pk .>= 1 )
71- @printf (" >> %d (%.0f%%) PSIS Pareto k estimates between 0.5 and 1\n " , pkn1, pkn1/ n* 100 )
72- @printf (" >> %d (%.0f%%) PSIS Pareto k estimates greater than 1\n " , pkn2, pkn2/ n* 100 )
78+ if success (rc1) && success (rc2)
79+ # # Compare the models
80+ loodiff = loos - loos2
81+ @printf (" elpd_diff = %.1f, SE(elpd_diff) = %.1f\n " ,sum (loodiff), std (loodiff) * sqrt (n))
7382end
74-
75- # # Compare the models
76- loodiff = loos - loos2
77- @printf (" elpd_diff = %.1f, SE(elpd_diff) = %.1f\n " ,sum (loodiff), std (loodiff) * sqrt (n))
78-
83+ println (" \n -----------------------\n " )
7984
8085# # k-fold-CV
8186# k-fold-CV should be used if several khats>0.5
8287# in this case it is not needed, but provided as an example
83- model_str = readstring (open (" arsenic_logistic_t.stan" ))
84- stanmodel = Stanmodel (name= " arsenic_logistic_t" , adapt= 500 , update= 500 , model= model_str);
88+
89+ model_str = read (open (joinpath (ProjDir, " arsenic_logistic_t.stan" )), String)
90+ sm3 = SampleModel (" arsenic_logistic_t" , model_str; tmpdir= tmpdir);
8591
8692cvitr, cvitst = cvit (n, 10 , true )
8793kfcvs = similar (loos)
8894for cvi in 1 : 10
8995 @printf (" %d\n " , cvi)
9096
91- standatacv = [ Dict ( " p " => m, " N " => length (cvitr[cvi]), " Nt " => length (cvitst[cvi]),
92- " x " => x[cvitr[cvi],:], " y " => y[cvitr[cvi]],
93- " xt " => x[cvitst[cvi],:], " yt " => y[cvitst[cvi]])
94- ]
97+ standatacv = (p = m, N = length (cvitr[cvi]), Nt = length (cvitst[cvi]),
98+ x = x[cvitr[cvi],:], y = Int .( y[cvitr[cvi]]) ,
99+ xt = x[cvitst[cvi],:], yt = Int .( y[cvitst[cvi]]) )
100+
95101 # Fit the model in Stan
96- simcv = stan (stanmodel, standatacv, ' .' ,
97- CmdStanDir= CMDSTAN_HOME, summary= false )
98- ns = filter (x-> startswith (x," log_likt" ), simcv. names)
99- log_likt = Mamba. combine (simcv[:, ns, :])
100- kfcvs[cvitst[cvi]]= PSIS. logsumexp (log_likt) - log (size (log_likt,1 ))
102+ rc3 = stan_sample (sm3; data= standatacv)
103+ if success (rc3)
104+ nt3 = read_samples (sm3)
105+ # Compute LOO and standard error
106+ log_likt = nt3. log_likt'
107+ kfcvs[cvitst[cvi]] = PSIS. logsumexp (log_likt) .- log (size (log_likt, 1 ))
108+ end
101109end
102110
103111# compare PSIS-LOO and k-fold-CV
104- p = plot (layer (x = loos, y = kfcvs, Geom. point),
105- layer (x = [- 3.5 ,0 ] ,y= [- 3.5 ,0 ], Geom. line, style (default_color= colorant " red" )),
106- Guide. xlabel (" PSIS-LOO" ),
107- Guide. ylabel (" 10-fold-CV" ))
112+ plot (x = loos, y = kfcvs, xlab = " PSIS-LOO" , ylab = " 10-fold-CV" )
108113
109- draw (PDF (" Compare.pdf" , 210 mm, 210 mm),p)
114+ # savefig (PDF(joinpath(ProjDir, "Compare.pdf") , 210mm, 210mm),p)
110115
0 commit comments