Skip to content

Commit 190a38d

Browse files
committed
Rel 0.0.5 - Added cars based test case (also for waic)
1 parent 55276e4 commit 190a38d

4 files changed

Lines changed: 184 additions & 7 deletions

File tree

examples/cars_waic/cars.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
using StatisticalRethinking, StanSample, PSIS, RDatasets
2+
3+
df = RDatasets.dataset("datasets", "cars")
4+
5+
cars_stan = "
6+
data {
7+
int N;
8+
vector[N] speed;
9+
vector[N] dist;
10+
}
11+
parameters {
12+
real a;
13+
real b;
14+
real sigma;
15+
}
16+
transformed parameters{
17+
vector[N] mu;
18+
mu = a + b * speed;
19+
}
20+
model {
21+
a ~ normal(0, 100);
22+
b ~ normal(0, 10);
23+
sigma ~ exponential(1);
24+
dist ~ normal(mu, sigma) ;
25+
}
26+
generated quantities {
27+
vector[N] log_lik;
28+
for (i in 1:N)
29+
log_lik[i] = normal_lpdf(dist[i] | mu[i], sigma);
30+
}
31+
"
32+
33+
cars_stan_model = SampleModel("cars.model", cars_stan)
34+
data = (N = size(df, 1), speed = df.Speed, dist = df.Dist)
35+
rc = stan_sample(cars_stan_model; data)
36+
37+
if success(rc)
38+
cars_df = read_samples(cars_stan_model; output_format=:dataframe)
39+
precis(cars_df[:, [:a, :b, :sigma]])
40+
41+
nt_cars = read_samples(cars_stan_model);
42+
end
43+
44+
log_lik = nt_cars.log_lik'
45+
ns, n = size(log_lik)
46+
lppd = [log_sum_exp(log_lik[:, i] .- log(ns)) for i in 1:n]
47+
pwaic = [var(log_lik[:, i]) for i in 1:n]
48+
-2(sum(lppd) - sum(pwaic)) |> display
49+
50+
waic(log_lik) |> display

notebooks/cars.jl

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
### A Pluto.jl notebook ###
2+
# v0.12.18
3+
4+
using Markdown
5+
using InteractiveUtils
6+
7+
# ╔═╡ b9fc511c-6007-11eb-0c6b-1f6871a40710
8+
using Pkg, DrWatson, PSIS
9+
10+
# ╔═╡ 20d3ad36-6008-11eb-2f2a-d379b234b0e9
11+
begin
12+
using StatisticalRethinking
13+
using StanSample
14+
using RDatasets
15+
end;
16+
17+
# ╔═╡ af6b0b20-6008-11eb-2fa1-2f61145ab7db
18+
md"
19+
!!! note
20+
21+
This script assumes that RDatasets.jl is loaded. RDatasets.jl is not included in the dependencies of PSIS.jl (as it would require R to be installed)."
22+
23+
# ╔═╡ 20d377b2-6008-11eb-364a-617b6934ecb2
24+
begin
25+
cd(psis_path)
26+
@quickactivate "PSIS"
27+
pkg"instantiate"
28+
end
29+
30+
# ╔═╡ 20d43cda-6008-11eb-09f0-53489a26110d
31+
df = RDatasets.dataset("datasets", "cars");
32+
33+
# ╔═╡ 20e22570-6008-11eb-1565-f13541b41861
34+
cars_stan = "
35+
data {
36+
int N;
37+
vector[N] speed;
38+
vector[N] dist;
39+
}
40+
parameters {
41+
real a;
42+
real b;
43+
real sigma;
44+
}
45+
transformed parameters{
46+
vector[N] mu;
47+
mu = a + b * speed;
48+
}
49+
model {
50+
a ~ normal(0, 100);
51+
b ~ normal(0, 10);
52+
sigma ~ exponential(1);
53+
dist ~ normal(mu, sigma) ;
54+
}
55+
generated quantities {
56+
vector[N] log_lik;
57+
for (i in 1:N)
58+
log_lik[i] = normal_lpdf(dist[i] | mu[i], sigma);
59+
}
60+
";
61+
62+
# ╔═╡ 20e2ebb6-6008-11eb-34f5-61f1ec3d024c
63+
begin
64+
cars_stan_model = SampleModel("cars.model", cars_stan)
65+
data = (N = size(df, 1), speed = df.Speed, dist = df.Dist)
66+
rc = stan_sample(cars_stan_model; data)
67+
68+
if success(rc)
69+
cars_df = read_samples(cars_stan_model; output_format=:dataframe)
70+
PRECIS(cars_df[:, [:a, :b, :sigma]])
71+
end
72+
end
73+
74+
# ╔═╡ 5fc59200-6008-11eb-3e06-1d0bcdf11d7d
75+
if success(rc)
76+
nt_cars = read_samples(cars_stan_model);
77+
log_lik = nt_cars.log_lik'
78+
ns, n = size(log_lik)
79+
end
80+
81+
# ╔═╡ 20ed768a-6008-11eb-13f4-458ca1a29592
82+
begin
83+
lppd = [log_sum_exp(log_lik[:, i] .- log(ns)) for i in 1:n]
84+
pwaic = [var(log_lik[:, i]) for i in 1:n]
85+
-2(sum(lppd) - sum(pwaic))
86+
end
87+
88+
# ╔═╡ efe960f0-600a-11eb-1df4-5be83899715a
89+
begin
90+
waic_vec = -2(lppd - pwaic)
91+
sqrt(n*var(waic_vec))
92+
end
93+
94+
# ╔═╡ 20f5a31e-6008-11eb-2ce1-075893273872
95+
waic(log_lik)
96+
97+
# ╔═╡ 6b67e38e-6009-11eb-3d9f-afc517b7d9fb
98+
begin
99+
loo, loos, pk = psisloo(log_lik)
100+
loo
101+
end
102+
103+
# ╔═╡ 82232dfe-6009-11eb-3b32-dbc487c6e4e7
104+
pk_qualify(pk)
105+
106+
# ╔═╡ 8ddf22b0-6009-11eb-08da-5198ba046628
107+
pk_plot(pk)
108+
109+
# ╔═╡ Cell order:
110+
# ╟─af6b0b20-6008-11eb-2fa1-2f61145ab7db
111+
# ╠═b9fc511c-6007-11eb-0c6b-1f6871a40710
112+
# ╠═20d377b2-6008-11eb-364a-617b6934ecb2
113+
# ╠═20d3ad36-6008-11eb-2f2a-d379b234b0e9
114+
# ╠═20d43cda-6008-11eb-09f0-53489a26110d
115+
# ╠═20e22570-6008-11eb-1565-f13541b41861
116+
# ╠═20e2ebb6-6008-11eb-34f5-61f1ec3d024c
117+
# ╠═5fc59200-6008-11eb-3e06-1d0bcdf11d7d
118+
# ╠═20ed768a-6008-11eb-13f4-458ca1a29592
119+
# ╠═efe960f0-600a-11eb-1df4-5be83899715a
120+
# ╠═20f5a31e-6008-11eb-2ce1-075893273872
121+
# ╠═6b67e38e-6009-11eb-3d9f-afc517b7d9fb
122+
# ╠═82232dfe-6009-11eb-3b32-dbc487c6e4e7
123+
# ╠═8ddf22b0-6009-11eb-08da-5198ba046628

src/pk_utilities.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1+
using StatsPlots
2+
13
function pk_qualify(pk::Vector{Float64})
24
pk_good = sum(pk .<= 0.5)
35
pk_ok = length(pk[pk .<= 0.7]) - pk_good
46
pk_bad = length(pk[pk .<= 1]) - pk_good - pk_ok
57
(good=pk_good, ok=pk_ok, bad=pk_bad, very_bad=sum(pk .> 1))
68
end
79

8-
function pk_plot(pk::Vector{Float64}, title="PSIS diagnostic plot.")
9-
scatter(pk1, xlab="Datapoint", ylab="Pareto shape k",
10-
marker=2.5, lab="Pk points", leg=:topleft)
10+
function pk_plot(pk::Vector{Float64}, title="PSIS diagnostic plot.",
11+
leg=:topleft; kwargs...)
12+
scatter(pk, xlab="Datapoint", ylab="Pareto shape k",
13+
marker=2.5, lab="Pk points", leg=leg)
1114
hline!([0.5], lab="pk = 0.5");hline!([0.7], lab="pk = 0.7")
1215
hline!([1], lab="pk = 1.0")
1316
title!(title)

src/waic.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,27 @@ function log_sum_exp(x)
77
xmax + log(xsum)
88
end
99

10-
function waic( ll::AbstractArray; pointwise=FALSE , log_lik="log_lik" , kwargs... )
10+
function waic( ll::AbstractArray; pointwise=false , log_lik="log_lik" , kwargs... )
1111

1212
n_samples, n_obs = size(ll)
13-
lpd <- zeros(n_obs)
14-
pD <- zeros(n_obs)
13+
lpd = zeros(n_obs)
14+
pD = zeros(n_obs)
1515

1616
for i in 1:n_obs
1717
lpd[i] = log_sum_exp(ll[:,i]) .- log(n_samples)
1818
pD[i] = var2(ll[:,i])
1919
end
2020

2121
waic_vec = (-2) .* ( lpd - pD )
22-
if pointwise==FALSE
22+
if pointwise == false
2323
waic = sum(waic_vec)
2424
lpd = sum(lpd)
2525
pD = sum(pD)
2626
else
2727
waic = waic_vec
2828
end
2929

30+
local se
3031
try
3132
se = sqrt( n_obs*var2(waic_vec) )
3233
catch e

0 commit comments

Comments
 (0)