diff --git a/burr/core/graph.py b/burr/core/graph.py index fcbc9c5cb..5889c41f6 100644 --- a/burr/core/graph.py +++ b/burr/core/graph.py @@ -110,6 +110,31 @@ def _render_graphviz( pathlib.Path(f"{path_without_suffix}.{fmt}").write_bytes(graphviz_obj.pipe(format=fmt)) +def _get_state_label(action: Action) -> str: + """Builds a graphviz node label displaying the action's name with its + reads and writes each on their own line, e.g.:: + + action_name + reads: + read_1 + writes: + write_1 + + Uses graphviz's ``\\l`` escape so lines are left-justified, keeping the + indentation visible. + """ + if not action.reads and not action.writes: + return action.name + lines = [action.name] + if action.reads: + lines.append("reads:") + lines.extend(f" {read}" for read in action.reads) + if action.writes: + lines.append("writes:") + lines.extend(f" {write}" for write in action.writes) + return "\\l".join(lines) + "\\l" + + @dataclasses.dataclass class Graph: """Graph class allows you to specify actions and transitions between them. @@ -229,11 +254,7 @@ def visualize( digraph_attr[g_key] = g_value digraph = graphviz.Digraph(**digraph_attr) for action in self.actions: - label = ( - action.name - if not include_state - else f"{action.name}({', '.join(action.reads)}): {', '.join(action.writes)}" - ) + label = action.name if not include_state else _get_state_label(action) digraph.node(action.name, label=label, shape="box", style="rounded,filled") required_inputs, optional_inputs = action.optional_and_required_inputs for input_ in required_inputs | optional_inputs: diff --git a/tests/core/test_graphviz_display.py b/tests/core/test_graphviz_display.py index 40c016089..7eda47cba 100644 --- a/tests/core/test_graphviz_display.py +++ b/tests/core/test_graphviz_display.py @@ -68,3 +68,34 @@ def test_visualize_no_dot_output(graph, tmp_path: pathlib.Path): graph.visualize(output_file_path=None) assert not dot_file_path.exists() + + +@pytest.mark.parametrize( + "reads, writes, expected_label", + [ + ([], [], "label=counter"), + (["count"], [], r"counter\lreads:\l count\l"), + ([], ["count"], r"counter\lwrites:\l count\l"), + ( + ["count", "user_input"], + ["count"], + r"counter\lreads:\l count\l user_input\lwrites:\l count\l", + ), + ], +) +def test_visualize_include_state_multiline_label(reads: list, writes: list, expected_label: str): + """Reads and writes should each be displayed on their own line (issue #402)""" + action = PassedInAction( + reads=reads, + writes=writes, + fn=lambda state: {"count": state.get("count", 0) + 1}, + update_fn=lambda result, state: state.update(**result), + inputs=[], + ) + graph = ( + GraphBuilder().with_actions(counter=action).with_transitions(("counter", "counter")).build() + ) + + digraph = graph.visualize(include_state=True) + + assert expected_label in digraph.source