Skip to content

Commit 21928d5

Browse files
committed
Rel 0.0.7 - Rely in StatsFuns' logsumexp
1 parent 440d7d1 commit 21928d5

7 files changed

Lines changed: 13 additions & 16 deletions

File tree

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
99
StanSample = "c1514b29-d3a0-5178-b312-660c88baa699"
1010
StatisticalRethinking = "2d09df54-9d0f-5258-8220-54c2a3d4fbee"
1111
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
12+
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1213
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
1314
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1415

1516
[compat]
1617
JSON = "0.21"
1718
StanSample = "3.0"
1819
StatisticalRethinking = "3.2"
20+
StatsFuns = "0.9"
1921
StatsPlots = "0.14"
2022
julia = "1"

README.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,6 @@ The cars WAIC example requires RDatasets.jl to be installed and functioning.
5050
`gpinv()` -
5151
Inverse Generalised Pareto distribution function.
5252

53-
`logsumexp()` -
54-
Sum of a vector where numbers are represented by their logarithms.
55-
5653
`waic()` -
5754
Compute WAIC for a loglikelihood matrix.
5855

@@ -65,9 +62,6 @@ The cars WAIC example requires RDatasets.jl to be installed and functioning.
6562
`var2()` -
6663
Uncorrected variance.
6764

68-
`log_sum_exp()` -
69-
Compute logarithmic sum of a vector.
70-
7165
### Corresponding R code
7266

7367
The corresponding R code can be found in [R package called

examples/cars_waic/cars.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,13 @@ rc = stan_sample(cars_stan_model; data)
3737
if success(rc)
3838
cars_df = read_samples(cars_stan_model; output_format=:dataframe)
3939
precis(cars_df[:, [:a, :b, :sigma]])
40-
4140
nt_cars = read_samples(cars_stan_model);
4241
end
4342

4443
log_lik = nt_cars.log_lik'
45-
ns, n = size(log_lik)
46-
lppd = [log_sum_exp(log_lik[:, i] .- log(ns)) for i in 1:n]
47-
pwaic = [var(log_lik[:, i]) for i in 1:n]
44+
n_sam, n_obs = size(log_lik)
45+
lppd = reshape(logsumexp(log_lik .- log(n_sam); dims=1), n_obs)
46+
pwaic = [var(log_lik[:, i]) for i in 1:n_obs]
4847
-2(sum(lppd) - sum(pwaic)) |> display
4948

5049
waic(log_lik) |> display

notebooks/cars.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,12 @@ end
7575
if success(rc)
7676
nt_cars = read_samples(cars_stan_model);
7777
log_lik = nt_cars.log_lik'
78-
ns, n = size(log_lik)
7978
end
8079

8180
# ╔═╡ 20ed768a-6008-11eb-13f4-458ca1a29592
8281
begin
83-
lppd = [StatsFuns.logsumexp(log_lik[:, i] .- log(ns)) for i in 1:n]
82+
ns, n = size(log_lik)
83+
lppd = reshape(logsumexp(log_lik .- log(ns); dims=1), n)
8484
pwaic = [var(log_lik[:, i]) for i in 1:n]
8585
-2(sum(lppd) - sum(pwaic))
8686
end

src/PSIS.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,15 @@ arXiv preprint arXiv:1507.02646.
3232

3333
module PSIS
3434

35+
using StatsFuns
36+
3537
psis_path = @__DIR__
3638

3739
include("psisloo.jl")
3840
include("psislw.jl")
3941
include("gpdfitnew.jl")
4042
include("gpinv.jl")
41-
include("logsumexp.jl")
43+
#include("logsumexp.jl")
4244
include("waic.jl")
4345
include("pk_utilities.jl")
4446

src/psisloo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ function psisloo(log_lik::AbstractArray, wcpp::Int64=20, wtrunc::Float64=3/4)
3838
lwp, ks = psislw(-lw, wcpp, wtrunc)
3939

4040
lwp += lw
41-
loos = logsumexp(lwp, 1)
41+
loos = reshape(logsumexp(lwp, dims=1), size(lw, 2))
4242
loo = sum(loos)
4343

4444
return loo, loos, ks

src/waic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ var2(x) = mean(x.^2) .- mean(x)^2
44
function waic( ll::AbstractArray; pointwise=false , log_lik="log_lik" , kwargs... )
55

66
n_samples, n_obs = size(ll)
7-
lpd = zeros(n_obs)
7+
#lpd = zeros(n_obs)
88
pD = zeros(n_obs)
99

10+
lpd = reshape(logsumexp(ll .- log(n_samples); dims=1), n_obs)
1011
for i in 1:n_obs
11-
lpd[i] = logsumexp(ll[:,i]) .- log(n_samples)
1212
pD[i] = var2(ll[:,i])
1313
end
1414

0 commit comments

Comments
 (0)