Skip to content

Commit 5e942e6

Browse files
committed
BasisSet updates
1 parent 7c5a66a commit 5e942e6

8 files changed

Lines changed: 73 additions & 60 deletions

src/StructuralCausalModels.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ ModelDefinition = Union{OrderedDict, AbstractString, NamedArray}
3737
include("types/DAG.jl")
3838
include("types/Path.jl")
3939
include("types/ConditionalIndependence.jl")
40+
include("types/BasisSet.jl")
4041

4142
include("methods/dag_methods.jl")
4243
include("methods/basis_set.jl")
@@ -50,7 +51,6 @@ include("methods/open_paths.jl")
5051
include("methods/backdoor_paths.jl")
5152
include("methods/adjustment_sets.jl")
5253
include("methods/ancestral_graph.jl")
53-
#include("methods/implied_conditional_independencies.jl")
5454

5555
include("utils/show_dag_path.jl")
5656
include("utils/ggm_conversions.jl")

src/types/BasisSet.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import Base: show, getindex, iterate, HasLength, HasEltype, length
2+
3+
struct BasisSet
4+
bs::Vector{Vector{Symbol}}
5+
BasisSet(bs) = new(sort(bs; by = x -> x[1]))
6+
end
7+
8+
iterate(b::BasisSet, state=1) =
9+
state > length(b.bs) ? nothing : (b.bs[state], state+1)
10+
11+
getindex(b::BasisSet, i::Int) = b.bs[i]
12+
13+
HasLength(b::BasisSet) = length(b.bs)
14+
15+
HasEltype(b::BasisSet) = eltype(b.bs)
16+
17+
length(b::BasisSet) = length(b.bs)
18+
19+
function bs_show(io::IO, bs::BasisSet)
20+
println("BasisSet[")
21+
for ci in bs.bs
22+
show(ConditionalIndependency(ci))
23+
end
24+
println("]")
25+
end
26+
27+
show(io::IO, bs::BasisSet) = bs_show(io, bs)
28+
29+
export
30+
BasisSet

src/types/ConditionalIndependence.jl

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,49 +19,13 @@ end
1919

2020
function ci_show(io::IO, ci::ConditionalIndependency)
2121
if isnothing(ci.c)
22-
println(" $(ci.f) \u2210 $(ci.s)")
22+
println(" :$(ci.f) \u2210 :$(ci.s)")
2323
else
24-
println(" $(ci.f) \u2210 $(ci.s) | $(ci.c)")
24+
println(" :$(ci.f) \u2210 :$(ci.s) | $(ci.c)")
2525
end
2626
end
2727

2828
show(io::IO, ci::ConditionalIndependency) = ci_show(io, ci)
2929

30-
function append(c::ConditionalIndependency)
31-
v = [c.f, c.s]
32-
!isnothing(c.c) && length(c.c) > 0 && push!(v, c.c...)
33-
v
34-
end
35-
36-
struct BasisSet
37-
bs::Vector{ConditionalIndependency}
38-
end
39-
40-
function BasisSet(b::Array{Array{Symbol,1},1})
41-
BasisSet(ConditionalIndependency.(b))
42-
end
43-
44-
iterate(b::BasisSet, state=1) =
45-
state > length(b.bs) ? nothing : (b.bs[state], state+1)
46-
47-
getindex(b::BasisSet, i::Int) = b.bs[i]
48-
49-
HasLength(b::BasisSet) = length(b.bs)
50-
51-
HasEltype(b::BasisSet) = eltype(b.bs)
52-
53-
length(b::BasisSet) = length(b.bs)
54-
55-
function bs_show(io::IO, bs::BasisSet)
56-
println("BasisSet[")
57-
for ci in bs.bs
58-
show(ci)
59-
end
60-
println("]")
61-
end
62-
63-
show(io::IO, bs::BasisSet) = bs_show(io, bs)
64-
6530
export
66-
BasisSet,
6731
ConditionalIndependency

test/test_bases_set_01.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using StructuralCausalModels, Test
2+
3+
ProjDir = @__DIR__
4+
5+
d_str = "dag{k1 -> k2;k2 -> y;v -> x2;w -> k1;x1 -> v;x1 -> w;x2 -> y;x3 -> w;x3 -> y}"
6+
7+
dag = DAG("test_desc", d_str)
8+
9+
to_dagitty(dag) |> display
10+
11+
fname = joinpath(ProjDir, "test_descendents_03.dot")
12+
to_graphviz(dag, fname)
13+
Sys.isapple() && run(`open -a GraphViz.app $(fname)`)
14+
15+
display(dag)
16+
17+
bs = basis_set(dag)
18+
bs |> display
19+

test/test_open_paths_02.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ dag = DAG("test_open_paths_02", d_string);
99
@test to_dagitty(dag.d) == "dag { D <- E; D <- Z; E <- Z; D <- B; Z <- B; E <- A; Z <- A }"
1010

1111
bs = basis_set(dag)
12-
@test bs[1].f == :B
13-
@test bs[1].s == :A
14-
@test bs[1].c == nothing
15-
@test bs[2].f == :B
16-
@test bs[2].s == :E
17-
@test bs[2].c == [:A, :Z]
18-
@test bs[3].f == :A
19-
@test bs[3].s == :D
20-
@test bs[3].c == [:B, :Z, :E]
12+
@test bs[1][1] == :A
13+
@test bs[1][2] == :D
14+
@test bs[1][3:end] == [:B, :Z, :E]
15+
@test bs[2][1] == :B
16+
@test bs[2][2] == :A
17+
@test length(bs[2]) == 2
18+
@test bs[3][1] == :B
19+
@test bs[3][2] == :E
20+
@test bs[3][3:end] == [:A, :Z]
2121

2222
fname = joinpath(ProjDir, "test_open_paths_02.dot")
2323
to_graphviz(dag, fname)

test/test_open_paths_03.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ adjs = adjustment_sets(dag, :X, :Y)
1919

2020
@testset "Open_path_03" begin
2121

22-
@test bs[1].f == :Z
23-
@test bs[1].s == :Y
24-
@test bs[1].c == [:X, :I]
22+
@test bs[1][1] == :Z
23+
@test bs[1][2] == :Y
24+
@test bs[1][3:end] == [:X, :I]
2525
@test bp[1] == [:X, :Z, :I, :Y]
2626
@test adjs == [[:Z], [:I]]
2727

test/test_open_paths_04.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ to_graphviz(dag, fname)
2424
:q => :r, :p => :q, :o => :p, :n => :o, :m => :n, :k => :l,
2525
:l => :m, :i => :k)
2626
@test length(bs) == 117
27-
@test bs[116].f == :b
28-
@test bs[116].s == :D
29-
@test bs[116].c == [:c, :r, :k, :E]
27+
@test bs[3][1] == :b
28+
@test bs[3][2] == :D
29+
@test bs[3][3:end] == [:c, :r, :k, :E]
3030
@test adjs == [[:i], [:k]]
3131

3232
end

test/test_sr6_4_2.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@ adjustmentsets = adjustment_sets(dag, :x, :y)
4444
@testset "sr6_4_2" begin
4545

4646
@test length(basisset) == 8
47-
@test basisset[6].f == :u
48-
@test basisset[6].s == :y
49-
@test basisset[6].c == [:a, :c, :x]
50-
@test basisset[8].f == :y
51-
@test basisset[8].s == :b
52-
@test basisset[8].c == [:c, :x, :u]
47+
@test basisset[6][1] == :u
48+
@test basisset[6][2] == :y
49+
@test basisset[6][3:end] == [:a, :c, :x]
50+
@test basisset[8][1] == :y
51+
@test basisset[8][2] == :b
52+
@test basisset[8][3:end] == [:c, :x, :u]
5353

5454
@test allpaths[1] == [:x, :u, :b, :c, :y]
5555
@test allpaths[2] == [:x, :u, :a, :c, :y]

0 commit comments

Comments
 (0)