Skip to content

Commit 6b6a713

Browse files
committed
Rel 0.1.0 - Expanded BasisSet
1 parent 72f8e00 commit 6b6a713

9 files changed

Lines changed: 136 additions & 122 deletions

src/methods/basis_set.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function basis_set(dag::DAG; debug=false)
3030
append!(ind, [ed])
3131
end
3232
end
33-
BasisSet(ind)
33+
BasisSet(dag, ind)
3434
end
3535

3636
export

src/methods/dag_methods.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,38 @@ function topological_sort(a::NamedArray)
215215
a[ord, ord]
216216
end
217217

218+
"""
219+
220+
# `topological_sort`
221+
222+
$(SIGNATURES)
223+
224+
Part of the API, exported
225+
"""
226+
function topological_sort(dag::DAG)
227+
new_dag = deepcopy(dag)
228+
ord = topological_order(new_dag.a)
229+
new_dag.a = new_dag.a[ord, ord]
230+
new_dag.e = new_dag.a[ord, ord]
231+
new_dag.vars = new_dag.vars[ord]
232+
new_dag
233+
end
234+
235+
"""
236+
237+
# `topological_sort!`
238+
239+
$(SIGNATURES)
240+
241+
Part of the API, exported
242+
"""
243+
function topological_sort!(dag::DAG)
244+
ord = topological_order(dag.a)
245+
dag.a = dag.a[ord, ord]
246+
dag.e = dag.a[ord, ord]
247+
dag.vars = dag.vars[ord]
248+
end
249+
218250
export
219251
dag_vars,
220252
adjacency_matrix,

src/methods/induced_covariance_graph.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ end
1515

1616
"""
1717
18-
# indicator_matrix
18+
## indicator_matrix
1919
2020
$(SIGNATURES)
2121

src/types/BasisSet.jl

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,86 @@
1-
import Base: show, getindex, iterate, HasLength, HasEltype, length
1+
import Base: show, getindex, iterate, HasLength, HasEltype, length, sort
2+
3+
4+
5+
function sort(dag::DAG, bs::Vector{Vector{Symbol}})
6+
7+
#topological_sort!(dag)
8+
9+
dct = OrderedDict()
10+
for (i, s) in enumerate(dag.vars)
11+
dct[s] = i
12+
end
13+
14+
# Arrange first sym < second sym
15+
16+
for (i, s) in enumerate(bs)
17+
if isless(dct[s[2]], dct[s[1]])
18+
tmp = s[1]
19+
bs[i][1] = s[2]
20+
bs[i][2] = tmp
21+
end
22+
end
23+
fset = Symbol[]
24+
for (i, s) in enumerate(bs)
25+
push!(fset, s[1])
26+
end
27+
sort!(unique!(fset))
28+
bs2 = sort(bs; by = x -> x[1])
29+
30+
# Now sort on second symbol within first symbol
31+
32+
bs3 = Vector{Symbol}[]
33+
for s in fset
34+
indx = filter(x -> x[1] == s, bs2)
35+
sort!(indx; by = x -> dct[x[2]])
36+
for i in indx
37+
push!(bs3, i)
38+
end
39+
end
40+
bs3
41+
end
42+
43+
function d_sep_combinations(dag::DAG, bs::Vector{Symbol})
44+
csets = Vector{Symbol}[Symbol[]]
45+
if length(bs) > 2
46+
for c in combinations(bs[3:end])
47+
push!(csets, c)
48+
end
49+
end
50+
bs_new = Vector{Symbol}[]
51+
for cset in csets
52+
if length(cset) == 0
53+
if d_separation(dag, bs[1], bs[2])
54+
push!(bs_new, [bs[1], bs[2]])
55+
end
56+
else
57+
if d_separation(dag, bs[1], bs[2]; cset=cset)
58+
push!(bs_new, [bs[1], bs[2], cset...])
59+
end
60+
end
61+
end
62+
bs_new
63+
end
64+
265

366
struct BasisSet
67+
dag::DAG
468
bs::Vector{Vector{Symbol}}
5-
BasisSet(bs) = new(sort(bs; by = x -> x[1]))
69+
70+
# Inner constructor
71+
#BasisSet(dag, bs) = new(dag, sort(bs; by = x -> x[1]))
72+
73+
74+
function BasisSet(dag, bs)
75+
# Update,create and return the sorted BasisSet
76+
bs2 = sort(dag, bs)
77+
bs3 = Vector{Symbol}[]
78+
for b in bs2
79+
bs3 = vcat(bs3, d_sep_combinations(dag, b))
80+
end
81+
new(dag, bs3)
82+
end
83+
684
end
785

886
iterate(b::BasisSet, state=1) =

test/test_basis_set_01.jl

Lines changed: 0 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -18,86 +18,3 @@ display(dag)
1818

1919
bs = basis_set(dag)
2020
bs |> display
21-
22-
show(ConditionalIndependency(bs[1]))
23-
@show d_separation(dag, :k1, :y, cset=[:w, :k2, :x2, :x3])
24-
@show d_separation(dag, :k1, :y, cset=[:k2, :x1, :x3])
25-
@show d_separation(dag, :k1, :y, cset=[:k2, :x2, :x3])
26-
@show d_separation(dag, :k1, :y, cset=[:k2, :x2, :x3])
27-
@show d_separation(dag, :k1, :y, cset=[:k2, :v, :x3])
28-
@show d_separation(dag, :k1, :y, cset=[:k2, :x1, :x3])
29-
@show d_separation(dag, :k1, :y, cset=[:k2, :x1])
30-
@show d_separation(dag, :k1, :y, cset=[:k2, :x3])
31-
@show d_separation(dag, :k1, :y, cset=[:k2, :w])
32-
@show d_separation(dag, :k1, :y, cset=[:w])
33-
@show d_separation(dag, :k1, :y, cset=[:k2])
34-
println()
35-
36-
#=
37-
@show d_separation(dag, :k1, :v, cset=[:x1])
38-
@show d_separation(dag, :k1, :v, cset=[:w])
39-
@show d_separation(dag, :k1, :x1, cset=[:w])
40-
@show d_separation(dag, :k1, :x2, cset=[:v])
41-
@show d_separation(dag, :k1, :x2, cset=[:x1])
42-
@show d_separation(dag, :k1, :x2, cset=[:w])
43-
@show d_separation(dag, :k1, :x3, cset=[:w])
44-
45-
bs[9] |> display
46-
@show d_separation(dag, :x1, :x2; cset=[:v])
47-
println()
48-
49-
bs[10] |> display
50-
@show d_separation(dag, :x1, :x3)
51-
println()
52-
53-
bs[11] |> display
54-
@show d_separation(dag, :x1, :k1; cset=[:w])
55-
println()
56-
57-
bs[12] |> display
58-
@show d_separation(dag, :x1, :k2; cset=[:k2])
59-
println()
60-
61-
bs[13] |> display
62-
@show d_separation(dag, :x1, :y; cset=[:x2, :x3, :k2])
63-
println()
64-
65-
@show d_separation(dag, :k2, :v, cset=[:x1])
66-
@show d_separation(dag, :k2, :v, cset=[:w])
67-
@show d_separation(dag, :k2, :v, cset=[:k1])
68-
@show d_separation(dag, :k2, :w, cset=[:k1])
69-
@show d_separation(dag, :k2, :x1, cset=[:w])
70-
@show d_separation(dag, :k2, :x1, cset=[:k1])
71-
@show d_separation(dag, :k2, :x2, cset=[:v])
72-
@show d_separation(dag, :k2, :x2, cset=[:x1])
73-
@show d_separation(dag, :k2, :x2, cset=[:w])
74-
@show d_separation(dag, :k2, :x2, cset=[:k1])
75-
@show d_separation(dag, :k2, :x3, cset=[:w])
76-
@show d_separation(dag, :k2, :x3, cset=[:k1])
77-
println()
78-
79-
@show d_separation(dag, :v, :w; cset=:x1)
80-
@show d_separation(dag, :v, :x3)
81-
@show d_separation(dag, :v, :y; cset=[:k2, :x2, :x3])
82-
@show d_separation(dag, :v, :y; cset=[:k1, :x2, :x3])
83-
@show d_separation(dag, :v, :y; cset=[:w, :x2, :x3])
84-
@show d_separation(dag, :v, :y; cset=[:x2, :x1])
85-
println()
86-
87-
@show d_separation(dag, :w, :x2; cset=[:v])
88-
@show d_separation(dag, :w, :x2; cset=[:x1])
89-
@show d_separation(dag, :w, :y; cset=[:k2, :x2, :x3])
90-
@show d_separation(dag, :w, :y; cset=[:k2, :v, :x3])
91-
@show d_separation(dag, :w, :y; cset=[:k1, :x2, :x3])
92-
@show d_separation(dag, :w, :y; cset=[:k1, :v, :x3])
93-
@show d_separation(dag, :w, :y; cset=[:k1, :x1, :x3])
94-
@show d_separation(dag, :w, :y; cset=[:k2, :x2, :x3])
95-
println()
96-
97-
@show d_separation(dag, :x1, :y; cset=[:k1, :v, :x3])
98-
@show d_separation(dag, :x1, :y; cset=[:v, :w, :x3])
99-
println()
100-
101-
@show d_separation(dag, :x2, :x3)
102-
println()
103-
=#

test/test_open_paths_01.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,6 @@ using StructuralCausalModels, Test
33
ProjDir = @__DIR__
44
cd(ProjDir) #do
55

6-
function syms_in_paths(paths, f, l)
7-
thepaths = deepcopy(paths)
8-
syms = Symbol[]
9-
for p in thepaths
10-
setdiff!(p, [f, l])
11-
append!(syms, p)
12-
unique!(syms)
13-
end
14-
syms
15-
end
16-
17-
function sym_in_all_paths(paths, sym)
18-
all([sym in p for p in paths])
19-
end
20-
216
d = OrderedDict(
227
:w => :s,
238
:d => [:a, :w, :m],

test/test_open_paths_02.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using StructuralCausalModels
1+
using StructuralCausalModels, Test
22

33
ProjDir = @__DIR__
44
cd(ProjDir) #do
@@ -9,19 +9,21 @@ dag = DAG("test_open_paths_02", d_string);
99
@test to_dagitty(dag.d) == "dag { D <- E; D <- Z; E <- Z; D <- B; Z <- B; E <- A; Z <- A }"
1010

1111
bs = basis_set(dag)
12-
@test bs[1][1] == :A
13-
@test bs[1][2] == :D
14-
@test bs[1][3:end] == [:B, :Z, :E]
15-
@test bs[2][1] == :B
12+
@test bs[1][1] == :B
13+
@test bs[1][2] == :A
14+
@test bs[1][3:end] == Symbol[]
15+
@test bs[2][1] == :D
1616
@test bs[2][2] == :A
17-
@test length(bs[2]) == 2
18-
@test bs[3][1] == :B
19-
@test bs[3][2] == :E
17+
@test length(bs[2]) == 5
18+
@test bs[3][1] == :E
19+
@test bs[3][2] == :B
2020
@test bs[3][3:end] == [:A, :Z]
2121

22+
#=
2223
fname = joinpath(ProjDir, "test_open_paths_02.dot")
2324
to_graphviz(dag, fname)
24-
#Sys.isapple() && run(`open -a GraphViz.app $(fname)`)
25+
Sys.isapple() && run(`open -a GraphViz.app $(fname)`)
26+
=#
2527

2628
ap = all_paths(dag, :D, :E)
2729
bp = backdoor_paths(dag, ap, :D)

test/test_open_paths_04.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ to_graphviz(dag, fname)
2323
:h => :i, :g => :h, :f => :g, :c => :f, :b => :c, :a => :b,
2424
:q => :r, :p => :q, :o => :p, :n => :o, :m => :n, :k => :l,
2525
:l => :m, :i => :k)
26-
@test length(bs) == 117
27-
@test bs[3][1] == :b
28-
@test bs[3][2] == :D
29-
@test bs[3][3:end] == [:c, :r, :k, :E]
26+
@test length(bs) == 272
27+
@test bs[3][1] == :D
28+
@test bs[3][2] == :a
29+
@test bs[3][3:end] == [:k, :E]
3030
@test adjs == [[:i], [:k]]
3131

3232
end

test/test_sr6_4_2.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@ adjustmentsets = adjustment_sets(dag, :x, :y)
4343

4444
@testset "sr6_4_2" begin
4545

46-
@test length(basisset) == 8
47-
@test basisset[6][1] == :u
48-
@test basisset[6][2] == :y
49-
@test basisset[6][3:end] == [:a, :c, :x]
50-
@test basisset[8][1] == :y
51-
@test basisset[8][2] == :b
52-
@test basisset[8][3:end] == [:c, :x, :u]
46+
@test length(basisset) == 15
47+
@test basisset[6][1] == :x
48+
@test basisset[6][2] == :b
49+
@test basisset[6][3:end] == [:u]
50+
@test basisset[8][1] == :x
51+
@test basisset[8][2] == :a
52+
@test basisset[8][3:end] == [:u]
5353

5454
@test allpaths[1] == [:x, :u, :b, :c, :y]
5555
@test allpaths[2] == [:x, :u, :a, :c, :y]

0 commit comments

Comments
 (0)