Skip to content

Commit 3cc6353

Browse files
committed
Rel 0.0.1 - Updated runtests.jl
1 parent 20d9c38 commit 3cc6353

9 files changed

Lines changed: 306 additions & 57 deletions

File tree

old.appveyor.yml

Lines changed: 0 additions & 34 deletions
This file was deleted.

old.travis.yml

Lines changed: 0 additions & 19 deletions
This file was deleted.

rest/cvit.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# CVIT - Create itr and itst indeces for k-fold-cv
2+
#
3+
# Description
4+
# [ITR,ITST]=CVITR(N,K) returns 1xK cell arrays ITR and ITST holding
5+
# cross-validation indeces for train and test sets respectively.
6+
# K-fold division is balanced with all sets having floor(N/K) or
7+
# ceil(N/K) elements.
8+
#
9+
# [ITR,ITST]=CVITR(N,K,RS) with integer RS=true also makes random
10+
# permutation, using substream RS. This way different permutations
11+
# can be produced with different RS values, but same permutation is
12+
# obtained when called again with same RS. Function restores the
13+
# previous random stream before exiting.
14+
#
15+
16+
17+
# Copyright (c) 2010 Aki Vehtari
18+
19+
# This software is distributed under the GNU General Public
20+
# License (version 2 or later); please refer to the file
21+
# License.txt, included with the software, for details.
22+
23+
function cvit(n, k=10, rsubstream=false)
24+
25+
a = k-rem(n,k)
26+
b = floor(Int, n/k);
27+
28+
itst = Any[]
29+
itr = Any[]
30+
31+
for cvi in 1:a
32+
push!(itst, collect(1:b) .+ (cvi-1) * b)
33+
push!(itr, setdiff(1:n,itst[cvi]))
34+
end
35+
for cvi in (a+1):k
36+
push!(itst, (a * b) + collect(1:(b + 1)) + (cvi - a - 1) * (b + 1))
37+
push!(itr, setdiff(1:n,itst[cvi]))
38+
end
39+
40+
if rsubstream
41+
rng = MersenneTwister()
42+
rii = randperm(rng, n)
43+
for cvi in 1:k
44+
itst[cvi] = rii[itst[cvi]]
45+
itr[cvi] = rii[itr[cvi]]
46+
end
47+
end
48+
itr, itst
49+
end
50+

test/arsenic_logistic.stan

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
data {
2+
int<lower=0> p;
3+
int<lower=0> N;
4+
int<lower=0,upper=1> y[N];
5+
matrix[N,p] x;
6+
}
7+
8+
transformed data {
9+
matrix[N,p] z;
10+
vector[p] mean_x;
11+
vector[p] sd_x;
12+
for (j in 1:p) {
13+
mean_x[j] <- mean(col(x,j));
14+
sd_x[j] <- sd(col(x,j));
15+
for (i in 1:N)
16+
z[i,j] <- (x[i,j] - mean_x[j]) / sd_x[j];
17+
}
18+
}
19+
20+
parameters {
21+
real beta0;
22+
vector[p] beta;
23+
real<lower=0> phi;
24+
}
25+
26+
model {
27+
vector[N] eta;
28+
eta <- beta0 + z*beta;
29+
beta ~ normal(0, phi);
30+
phi ~ double_exponential(0, 10);
31+
y ~ bernoulli_logit(eta);
32+
}
33+
34+
generated quantities {
35+
vector[N] log_lik;
36+
vector[N] eta;
37+
eta <- beta0 + z*beta;
38+
for (i in 1:N)
39+
log_lik[i] <- bernoulli_logit_log(y[i],eta[i]);
40+
}

test/arsenic_logistic_t.stan

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
data {
2+
int<lower=0> p;
3+
int<lower=0> N;
4+
int<lower=0,upper=1> y[N];
5+
matrix[N,p] x;
6+
int<lower=0> Nt;
7+
int<lower=0,upper=1> yt[Nt];
8+
matrix[Nt,p] xt;
9+
}
10+
transformed data {
11+
matrix[N,p] z;
12+
matrix[Nt,p] zt;
13+
vector[p] mean_x;
14+
vector[p] sd_x;
15+
for (j in 1:p) {
16+
mean_x[j] <- mean(col(x,j));
17+
sd_x[j] <- sd(col(x,j));
18+
for (i in 1:N)
19+
z[i,j] <- (x[i,j] - mean_x[j]) / sd_x[j];
20+
for (i in 1:Nt)
21+
zt[i,j] <- (xt[i,j] - mean_x[j]) / sd_x[j];
22+
}
23+
}
24+
parameters {
25+
real beta0;
26+
vector[p] beta;
27+
real<lower=0> phi;
28+
}
29+
model {
30+
vector[N] eta;
31+
eta <- beta0 + z*beta;
32+
beta ~ normal(0, phi);
33+
phi ~ double_exponential(0, 10);
34+
y ~ bernoulli_logit(eta);
35+
}
36+
37+
generated quantities {
38+
vector[N] log_lik;
39+
vector[Nt] log_likt;
40+
vector[N] eta;
41+
vector[Nt] etat;
42+
eta <- beta0 + z*beta;
43+
etat <- beta0 + zt*beta;
44+
for (i in 1:N)
45+
log_lik[i] <- bernoulli_logit_log(y[i],eta[i]);
46+
for (i in 1:Nt)
47+
log_likt[i] <- bernoulli_logit_log(yt[i],etat[i]);
48+
}
49+

test/cvit.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# CVIT - Create itr and itst indeces for k-fold-cv
2+
#
3+
# Description
4+
# [ITR,ITST]=CVITR(N,K) returns 1xK cell arrays ITR and ITST holding
5+
# cross-validation indeces for train and test sets respectively.
6+
# K-fold division is balanced with all sets having floor(N/K) or
7+
# ceil(N/K) elements.
8+
#
9+
# [ITR,ITST]=CVITR(N,K,RS) with integer RS=true also makes random
10+
# permutation, using substream RS. This way different permutations
11+
# can be produced with different RS values, but same permutation is
12+
# obtained when called again with same RS. Function restores the
13+
# previous random stream before exiting.
14+
#
15+
16+
17+
# Copyright (c) 2010 Aki Vehtari
18+
19+
# This software is distributed under the GNU General Public
20+
# License (version 2 or later); please refer to the file
21+
# License.txt, included with the software, for details.
22+
23+
function cvit(n, k=10, rsubstream=false)
24+
25+
a = k-rem(n,k)
26+
b = floor(Int, n/k);
27+
28+
itst = Any[]
29+
itr = Any[]
30+
31+
for cvi in 1:a
32+
push!(itst, collect(1:b) .+ (cvi-1) * b)
33+
push!(itr, setdiff(1:n,itst[cvi]))
34+
end
35+
for cvi in (a+1):k
36+
push!(itst, (a * b) + collect(1:(b + 1)) + (cvi - a - 1) * (b + 1))
37+
push!(itr, setdiff(1:n,itst[cvi]))
38+
end
39+
40+
if rsubstream
41+
rng = MersenneTwister()
42+
rii = randperm(rng, n)
43+
for cvi in 1:k
44+
itst[cvi] = rii[itst[cvi]]
45+
itr[cvi] = rii[itr[cvi]]
46+
end
47+
end
48+
itr, itst
49+
end
50+

test/runtests.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1-
using PSIS
2-
using Base.Test
1+
using PSIS, StanSample
2+
using Test
33

4-
# write your own tests here
5-
@test 1 == 2
4+
if haskey(ENV, "JULIA_CMDSTAN_HOME")
5+
6+
ProjDir = @__DIR__
7+
include(joinpath(ProjDir, "test_demo_wells.jl"))
8+
9+
else
10+
println("\nJULIA_CMDSTAN_HOME not set. Skipping tests")
11+
end

test/test_demo_wells.jl

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
using StatisticalRethinking
2+
using JSON
3+
using StanSample
4+
using PSIS
5+
#using Statistics
6+
using Printf
7+
#using StatsPlots
8+
9+
ProjDir = @__DIR__
10+
11+
include(joinpath(ProjDir, "cvit.jl"))
12+
13+
# Data
14+
data = JSON.parsefile(joinpath(ProjDir, "wells.data.json"))
15+
y = Float64.(data["switched"])
16+
x = Float64[data["arsenic"] data["dist"]]
17+
n, m = size(x)
18+
19+
# Model
20+
model_str = read(open(joinpath(ProjDir, "arsenic_logistic.stan")), String)
21+
tmpdir = joinpath(ProjDir, "tmp")
22+
sm1 = SampleModel("arsenic_logistic", model_str)
23+
24+
data1 = (p = m, N = n, y = Int.(y), x = x)
25+
# Fit the model in Stan
26+
rc1 = stan_sample(sm1; data=data1)
27+
if success(rc1)
28+
nt1 = read_samples(sm1)
29+
30+
# Compute LOO and standard error
31+
log_lik = nt1.log_lik'
32+
loo, loos, pk = psisloo(log_lik)
33+
elpd_loo = sum(loos)
34+
se_elpd_loo = std(loos) * sqrt(n)
35+
@printf(">> elpd_loo = %.1f, SE(elpd_loo) = %.1f\n", elpd_loo, se_elpd_loo)
36+
37+
# Check the shape parameter k of the generalized Pareto distribution
38+
if all(pk .< 0.5)
39+
println("All Pareto k estimates OK (k < 0.5)")
40+
else
41+
pkn1 = sum((pk .>= 0.5) & (pk .< 1))
42+
pkn2 = sum(pk .>= 1)
43+
@printf(">> %d (%.0f%%) PSIS Pareto k estimates between 0.5 and 1\n", pkn1, pkn1/n*100)
44+
@printf(">> %d (%.0f%%) PSIS Pareto k estimates greater than 1\n", pkn2, pkn2/n*100)
45+
end
46+
end
47+
48+
# Fit a second model, using log(arsenic) instead of arsenic
49+
x2 = Float64[log.(data["arsenic"]) data["dist"]]
50+
51+
# Model
52+
data2 = (p = m, N = n, y = Int.(y), x = x2)
53+
# Fit the model in Stan
54+
rc2 = stan_sample(sm1; data=data2)
55+
56+
if success(rc2)
57+
nt2 = read_samples(sm1)
58+
# Compute LOO and standard error
59+
log_lik = nt2.log_lik'
60+
loo2, loos2, pk2 = psisloo(log_lik)
61+
elpd_loo = sum(loos2)
62+
se_elpd_loo = std(loos2) * sqrt(n)
63+
@printf(">> elpd_loo = %.1f, SE(elpd_loo) = %.1f\n", elpd_loo, se_elpd_loo)
64+
65+
# Check the shape parameter k of the generalized Pareto distribution
66+
if all(pk .< 0.5)
67+
println("All Pareto k estimates OK (k < 0.5)")
68+
else
69+
pkn1 = sum((pk .>= 0.5) & (pk .< 1))
70+
pkn2 = sum(pk .>= 1)
71+
@printf(">> %d (%.0f%%) PSIS Pareto k estimates between 0.5 and 1\n", pkn1, pkn1/n*100)
72+
@printf(">> %d (%.0f%%) PSIS Pareto k estimates greater than 1\n", pkn2, pkn2/n*100)
73+
end
74+
end
75+
76+
if success(rc1) && success(rc2)
77+
## Compare the models
78+
loodiff = loos - loos2
79+
@printf("elpd_diff = %.1f, SE(elpd_diff) = %.1f\n",sum(loodiff), std(loodiff) * sqrt(n))
80+
end
81+
82+
## k-fold-CV
83+
# k-fold-CV should be used if several khats>0.5
84+
# in this case it is not needed, but provided as an example
85+
86+
model_str = read(open(joinpath(ProjDir, "arsenic_logistic_t.stan")), String)
87+
sm3 = SampleModel("arsenic_logistic_t", model_str);
88+
89+
cvitr, cvitst = cvit(n, 10, true)
90+
kfcvs = similar(loos)
91+
for cvi in 1:3
92+
@printf("%d\n", cvi)
93+
94+
standatacv = (p = m, N = length(cvitr[cvi]), Nt = length(cvitst[cvi]),
95+
x = x[cvitr[cvi],:], y = Int.(y[cvitr[cvi]]),
96+
xt = x[cvitst[cvi],:], yt = Int.(y[cvitst[cvi]]))
97+
98+
# Fit the model in Stan
99+
rc3 = stan_sample(sm3; data=standatacv)
100+
if success(rc3)
101+
nt3 = read_samples(sm3)
102+
# Compute LOO and standard error
103+
log_likt = nt3.log_likt'
104+
kfcvs[cvitst[cvi]] = PSIS.logsumexp(log_likt) .- log(size(log_likt, 1))
105+
end
106+
end

test/wells.data.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)