Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions burr/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,31 @@ def _render_graphviz(
pathlib.Path(f"{path_without_suffix}.{fmt}").write_bytes(graphviz_obj.pipe(format=fmt))


def _get_state_label(action: Action) -> str:
"""Builds a graphviz node label displaying the action's name with its
reads and writes each on their own line, e.g.::

action_name
reads:
read_1
writes:
write_1

Uses graphviz's ``\\l`` escape so lines are left-justified, keeping the
indentation visible.
"""
if not action.reads and not action.writes:
return action.name
lines = [action.name]
if action.reads:
lines.append("reads:")
lines.extend(f" {read}" for read in action.reads)
if action.writes:
lines.append("writes:")
lines.extend(f" {write}" for write in action.writes)
return "\\l".join(lines) + "\\l"


@dataclasses.dataclass
class Graph:
"""Graph class allows you to specify actions and transitions between them.
Expand Down Expand Up @@ -229,11 +254,7 @@ def visualize(
digraph_attr[g_key] = g_value
digraph = graphviz.Digraph(**digraph_attr)
for action in self.actions:
label = (
action.name
if not include_state
else f"{action.name}({', '.join(action.reads)}): {', '.join(action.writes)}"
)
label = action.name if not include_state else _get_state_label(action)
digraph.node(action.name, label=label, shape="box", style="rounded,filled")
required_inputs, optional_inputs = action.optional_and_required_inputs
for input_ in required_inputs | optional_inputs:
Expand Down
31 changes: 31 additions & 0 deletions tests/core/test_graphviz_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,34 @@ def test_visualize_no_dot_output(graph, tmp_path: pathlib.Path):
graph.visualize(output_file_path=None)

assert not dot_file_path.exists()


@pytest.mark.parametrize(
"reads, writes, expected_label",
[
([], [], "label=counter"),
(["count"], [], r"counter\lreads:\l count\l"),
([], ["count"], r"counter\lwrites:\l count\l"),
(
["count", "user_input"],
["count"],
r"counter\lreads:\l count\l user_input\lwrites:\l count\l",
),
],
)
def test_visualize_include_state_multiline_label(reads: list, writes: list, expected_label: str):
"""Reads and writes should each be displayed on their own line (issue #402)"""
action = PassedInAction(
reads=reads,
writes=writes,
fn=lambda state: {"count": state.get("count", 0) + 1},
update_fn=lambda result, state: state.update(**result),
inputs=[],
)
graph = (
GraphBuilder().with_actions(counter=action).with_transitions(("counter", "counter")).build()
)

digraph = graph.visualize(include_state=True)

assert expected_label in digraph.source
Loading