Skip to content

Commit c36a594

Browse files
committed
REl 0.1.0
1 parent f2369de commit c36a594

13 files changed

Lines changed: 52 additions & 158 deletions

docs/src/index.md

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,6 @@ adjacency_matrix(e::NamedArray)
5151
adjacency_matrix_to_dict(ea::NamedArray)
5252
ancester_graph(e::NamedArray)
5353
ancestral_graph(a::NamedArray{Int, 2}; m=Symbol[], c=Symbol[])
54-
DAG(name::AbstractString, d::OrderedDict)
55-
DAG(name::AbstractString, str::AbstractString, df::DataFrame)
56-
DAG(name::AbstractString, str::AbstractString)
57-
DAG(name::AbstractString, a::NamedArray, df::DataFrame)
58-
DAG(name::AbstractString, a::NamedArray)
5954
dag_show(io::IO, d::DAG)
6055
dag_vars(d::OrderedDict)
6156
edge_matrix(d::OrderedDict)

src/StructuralCausalModels.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ SymbolListOrNothing = Union{SymbolList, Nothing}
3232
OrderedDictOrNothing = Union{OrderedDict, Nothing}
3333
NamedArrayOrNothing = Union{NamedArray, Nothing}
3434
DataFrameOrNothing = Union{DataFrame, Nothing}
35+
ModelDefinition = Union{OrderedDict, AbstractString, NamedArray}
3536

3637
include("types/DAG.jl")
3738
include("types/Path.jl")

src/methods/d_separation.jl

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ d_separation(
1818
1919
### Keyword arguments
2020
```julia
21-
* `cond::SymbolListOrNoting=noting` : Conditioning set
21+
* `cond::SymbolListOrNothing=nothing` : Conditioning set
2222
* `debug=false` : Trace execution
2323
```
2424
@@ -69,14 +69,5 @@ function d_separation(d::DAG, first::SymbolList, second::SymbolList;
6969

7070
end
7171

72-
#=
73-
function d_separation(d::DAG, first::SymbolList, second::SymbolList, cond::SymbolList; debug=false)
74-
75-
e = induced_covariance_graph(d, vcat(first, second), cond; debug=debug)
76-
sum(e[first, second]) == 0
77-
78-
end
79-
=#
80-
8172
export
8273
d_separation

src/methods/dag_methods.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ Part of the API, exported.
135135
"""
136136
function adjacency_matrix_to_dict(a::NamedArray)
137137
vars = names(a, 1)
138-
dct = Dict()
138+
dct = OrderedDict()
139139
for (ind, r) in enumerate(eachrow(a))
140140
rhs = vars[findall(x -> x ==1, r)]
141141
if length(rhs) == 1

src/types/DAG.jl

Lines changed: 33 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,17 @@ $(SIGNATURES)
4444
### Required arguments
4545
```julia
4646
* `name::AbstractString` : Name for the DAG object
47-
* `d` : DAG definition as an
48-
OrderedDict (see extended help)
49-
AbstractString (as in ggm or dagitty)
50-
AdjacencyMatrix
47+
* `d::ModelDefinition` : DAG definition
5148
```
5249
53-
### Optional positional argument
50+
where
51+
```
52+
ModelDefinition = Union{OrderedDict, AbstractString, NamedArray}
53+
```
54+
55+
See the extended help for a usage example.
56+
57+
### Keyword arguments
5458
```julia
5559
* `df::DataFrame` : DataFrame with observations
5660
```
@@ -80,15 +84,15 @@ Coming from R's dagitty:
8084
amat <- dagitty("dag { {X V} -> U; S1 <- U; {Y V} -> W; S2 <- W}”)
8185
```julia
8286
dag = DAG("my_name", "dag { {X V} -> U; S1 <- U; {Y V} -> W; S2 <- W}”)
83-
display(dag.a) # Show the adjacency_matrix
87+
display(dag) # Show the DAG
8488
```
8589
8690
Coming from R's ggm:
8791
8892
amat <- DAG(U~X+V, S1~U, W~V+Y, S2~W, order=FALSE)
8993
```julia
9094
dag = DAG("my_name", "DAG(U~X+V, S1~U, W~V+Y, S2~W”)
91-
display(dag.a) # Show the adjacency_matrix
95+
display(dag) # Show the DAG
9296
```
9397
9498
### Acknowledgements
@@ -105,121 +109,36 @@ The Julia translation is licenced under: MIT.
105109
106110
Part of API, exported.
107111
"""
108-
function DAG(name::AbstractString, d::OrderedDict, df::DataFrame)
109-
110-
vars = dag_vars(d)
111-
a = adjacency_matrix(d)
112-
e = edge_matrix(d)
113-
114-
# Compute covariance matrix and store as NamedArray
115-
116-
@assert length(names(df)) == length(vars) "DataFrame has different number of columns"
117-
s = NamedArray(cov(Array(df)), (names(df), names(df)), ("Rows", "Cols"))
118-
119-
# Create object
120-
121-
DAG(name, d, a, e, s, df, vars)
122-
123-
end
124-
125-
"""
126-
127-
# `DAG`
128-
129-
$(SIGNATURES)
130-
131-
Part of API, exported.
132-
"""
133-
function DAG(name::AbstractString, d::OrderedDict)
134-
135-
vars = dag_vars(d)
136-
a = adjacency_matrix(d)
137-
e = edge_matrix(d)
138-
139-
# Create object
140-
141-
DAG(name, d, a, e, nothing, nothing, vars)
142-
end
143-
144-
"""
145-
146-
# `DAG`
147-
148-
$(SIGNATURES)
149-
150-
Part of API, exported.
151-
"""
152-
function DAG(name::AbstractString, str::AbstractString, df::DataFrame)
153-
ds = strip(str)
154-
if ds[1:3] == "DAG"
155-
d = from_ggm(ds)
156-
elseif ds[1:3] == "dag"
157-
d = from_dagitty(ds)
158-
else
159-
@error "Unrecognized model string: $(ds))"
160-
end
112+
function DAG(name::AbstractString, model::ModelDefinition; df::DataFrameOrNothing=nothing)
113+
114+
local d
115+
if typeof(model) <: OrderedDict
116+
d = model
117+
elseif typeof(model) <: AbstractString
118+
ds = strip(model)
119+
if ds[1:3] == "DAG"
120+
d = from_ggm(ds)
121+
elseif ds[1:3] == "dag"
122+
d = from_dagitty(ds)
123+
else
124+
@error "Unrecognized model string: $(ds))"
125+
end
126+
elseif typeof(model) <: NamedArray
127+
d = adjacency_matrix_to_dict(model)
128+
end
161129

162130
vars = dag_vars(d)
163131
a = adjacency_matrix(d)
164132
e = edge_matrix(d)
165133

166-
# Compute covariance matrix and store as NamedArray
167-
168-
@assert length(names(df)) == length(vars) "DataFrame has different number of columns"
169-
s = NamedArray(cov(Array(df)), (names(df), names(df)), ("Rows", "Cols"))
170-
171-
# Create object
172-
173-
DAG(name, d, a, e, s, df, vars)
174-
175-
end
176-
177-
"""
178-
179-
# `DAG`
180-
181-
$(SIGNATURES)
182-
183-
Part of API, exported.
184-
"""
185-
function DAG(name::AbstractString, str::AbstractString)
186-
ds = strip(str)
187-
if ds[1:3] == "DAG"
188-
d = from_ggm(ds)
189-
elseif ds[1:3] == "dag"
190-
d = from_dagitty(ds)
134+
if isnothing(df)
135+
s = nothing
191136
else
192-
@error "Unrecognized model string: $(ds))"
137+
# Compute covariance matrix and store as NamedArray
138+
@assert length(names(df)) == length(vars) "DataFrame has different number of columns"
139+
s = NamedArray(cov(Array(df)), (names(df), names(df)), ("Rows", "Cols"))
193140
end
194141

195-
vars = dag_vars(d)
196-
a = adjacency_matrix(d)
197-
e = edge_matrix(d)
198-
199-
# Create object
200-
201-
DAG(name, d, a, e, nothing, nothing, vars)
202-
end
203-
204-
"""
205-
206-
# `DAG`
207-
208-
$(SIGNATURES)
209-
210-
Part of API, exported.
211-
"""
212-
function DAG(name::AbstractString, a::NamedArray, df::DataFrame)
213-
214-
vars = names(a, 1)
215-
d = adjacency_matrix_to_dict(a)
216-
e = StructuralCausalModels.edge_matrix(a)
217-
218-
# Compute covariance matrix and store as NamedArray if df is present
219-
220-
@assert length(names(df)) == length(vars) "DataFrame has different number of columns"
221-
s = NamedArray(cov(Array(df)), (names(df), names(df)), ("Rows", "Cols"))
222-
223142
# Create object
224143

225144
DAG(name, d, a, e, s, df, vars)
@@ -228,26 +147,6 @@ end
228147

229148
"""
230149
231-
# `DAG`
232-
233-
$(SIGNATURES)
234-
235-
Part of API, exported.
236-
"""
237-
function DAG(name::AbstractString, a::NamedArray)
238-
239-
vars = names(a, 1)
240-
d = adjacency_matrix_to_dict(a)
241-
e = StructuralCausalModels.edge_matrix(a)
242-
243-
# Create object
244-
245-
DAG(name, d, a, e, nothing, nothing, vars)
246-
247-
end
248-
249-
"""
250-
251150
# `set_dag_df!`
252151
253152
Set or update Dataframe associated to DAG

test/runtests.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ using Test
1414
include("test_sr6_4_2a.jl")
1515
@test adjustmentsets == [[:a], [:c]]
1616
=#
17+
println("\n")
1718

1819
end
1920

@@ -28,6 +29,7 @@ end
2829
@test show_dag_path(dag, allpaths[1]) == ":w ⇒ :s ⇐ :a ⇐ :d"
2930
@test adjustmentsets == Array{Symbol,1}[]
3031
=#
32+
println("\n")
3133

3234
end
3335

@@ -40,6 +42,7 @@ end
4042
include("test_dagitty_conversion.jl")
4143
include("test_ggm_conversion.jl")
4244
include("test_graphviz_conversions.jl")
45+
println("\n")
4346

4447
end
4548

@@ -57,5 +60,6 @@ end
5760
@test show_dag_path(dag, allpaths[1]) == ":w ⇒ :s ⇐ :a ⇐ :d"
5861
@test adjustmentsets == Array{Symbol,1}[]
5962
=#
60-
end
63+
println("\n")
6164

65+
end

test/test_ag.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@ fr1 = ancestral_graph(a; m=m, c=c)
4242
@test all(fr .== fr1)
4343

4444
fr2 = test_ag(dag.a; m=m, c=c)
45-
@test all(fr .== fr2);
45+
for i in names(fr, 1)
46+
for j in names(fr, 1)
47+
@test fr[i, j] == fr2[i, j]
48+
end
49+
end
4650

4751
println()
4852
display(fr)

test/test_graphviz_conversions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ d = from_ggm("DAG(
2525
analysis ~ algebra)"
2626
);
2727

28-
dag = DAG("marks", d, df);
28+
dag = DAG("marks", d, df=df);
2929
show(dag)
3030

3131
fn = joinpath(mktempdir(), "marks.dot")

test/test_sr6_4_2.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ d = OrderedDict(
3030
:u => :a
3131
)
3232

33-
dag = DAG("sr6_4_2", d, df);
33+
dag = DAG("sr6_4_2", d, df=df);
3434

3535
fn = joinpath(mktempdir(), "sr6_4_2.dot")
3636
to_graphviz(dag, fn)

test/test_sr6_4_2a.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ d = OrderedDict(
2929
);
3030
u = [:u]
3131

32-
dag = DAG("sr6_4_2a", d, df);
32+
dag = DAG("sr6_4_2a", d; df=df);
3333
show(dag)
3434

3535
fn = joinpath(mktempdir(), "sr6_4_2a.dot")

0 commit comments

Comments
 (0)