Skip to content

Commit e5a781d

Browse files
committed
Use vector version of logsumexp.
1 parent 21928d5 commit e5a781d

2 files changed

Lines changed: 10 additions & 14 deletions

File tree

notebooks/arsenic.jl

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
### A Pluto.jl notebook ###
2-
# v0.12.18
2+
# v0.12.19
33

44
using Markdown
55
using InteractiveUtils
@@ -71,10 +71,7 @@ if success(rc1)
7171

7272
# Check the shape parameter k of the generalized Pareto distribution
7373

74-
pk_good1 = sum(pk1 .<= 0.5)
75-
pk_ok1 = length(pk1[pk1 .<= 0.7]) - pk_good1
76-
pk_bad1 = length(pk1[pk1 .<= 1]) - pk_good1 - pk_ok1
77-
(good=pk_good1, ok=pk_ok1, bad=pk_bad1, very_bad=sum(pk1 .> 1))
74+
pk_qualify(pk1)
7875
end
7976

8077
# ╔═╡ 5a76e4aa-5ebd-11eb-2e15-6d5808ead825
@@ -105,10 +102,7 @@ if success(rc2)
105102

106103
# Check the shape parameter k of the generalized Pareto distribution
107104

108-
pk_good2 = sum(pk2 .<= 0.5)
109-
pk_ok2 = length(pk2[pk2 .<= 0.7]) - pk_good2
110-
pk_bad2 = length(pk2[pk2 .<= 1]) - pk_good2 - pk_ok2
111-
(good=pk_good2, ok=pk_ok2, bad=pk_bad2, very_bad=sum(pk2 .> 1))
105+
pk_qualify(pk2)
112106
end
113107

114108
# ╔═╡ dd28d430-5ebd-11eb-1854-ab4c13e82c34
@@ -147,7 +141,9 @@ begin
147141
nt3 = read_samples(sm3)
148142
# Compute LOO and standard error
149143
log_likt = nt3.log_likt'
150-
kfcvs[cvitst[cvi]] = PSIS.logsumexp(log_likt) .- log(size(log_likt, 1))
144+
local n_sam, n_obs = size(log_likt)
145+
kfcvs[cvitst[cvi]] .=
146+
reshape(logsumexp(log_likt .- log(n_sam), dims=1), n_obs)
151147
end
152148
end
153149
end
@@ -158,7 +154,7 @@ begin
158154
# compare PSIS-LOO and k-fold-CV
159155

160156
plot([-3.5, 0], [-3.5, 0], color=:red)
161-
scatter!(loos1[1,:], kfcvs[1,:], xlab = "PSIS-LOO", ylab = "10-fold-CV",
157+
scatter!(loos1, kfcvs, xlab = "PSIS-LOO", ylab = "10-fold-CV",
162158
leg=false, color=:darkblue)
163159
end
164160

test/test_demo_wells.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,6 @@ sm3 = SampleModel("arsenic_logistic_t", model_str);
8888
cvitr, cvitst = cvit(n, 10, true)
8989
kfcvs = similar(loos)
9090
for cvi in 1:3
91-
@printf("%d\n", cvi)
92-
9391
standatacv = (p = m, N = length(cvitr[cvi]), Nt = length(cvitst[cvi]),
9492
x = x[cvitr[cvi],:], y = Int.(y[cvitr[cvi]]),
9593
xt = x[cvitst[cvi],:], yt = Int.(y[cvitst[cvi]]))
@@ -100,6 +98,8 @@ for cvi in 1:3
10098
nt3 = read_samples(sm3)
10199
# Compute LOO and standard error
102100
log_likt = nt3.log_likt'
103-
kfcvs[cvitst[cvi]] = PSIS.logsumexp(log_likt) .- log(size(log_likt, 1))
101+
local n_sam, n_obs = size(log_likt)
102+
kfcvs[cvitst[cvi]] .=
103+
reshape(logsumexp(log_likt .- log(n_sam), dims=1), n_obs)
104104
end
105105
end

0 commit comments

Comments
 (0)