From bdae9d45fbfd164a12fae185fb7d2c79ecf80c7f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 18 May 2026 23:19:15 +0000 Subject: [PATCH 1/3] Add vessel graph connected-components annotation rule Agent-Logs-Url: https://github.com/khanlab/SPIMquant/sessions/1da3e99d-a7e3-4503-9a58-4625aeaefe07 Co-authored-by: akhanf <11492701+akhanf@users.noreply.github.com> --- spimquant/workflow/rules/vessels.smk | 43 ++++++- ...otate_vessel_graph_connected_components.py | 109 +++++++++++++++++ .../convert_vessel_graph_to_nodes_edges.py | 11 +- ...otate_vessel_graph_connected_components.py | 114 ++++++++++++++++++ 4 files changed, 272 insertions(+), 5 deletions(-) create mode 100644 spimquant/workflow/scripts/annotate_vessel_graph_connected_components.py create mode 100644 tests/test_annotate_vessel_graph_connected_components.py diff --git a/spimquant/workflow/rules/vessels.smk b/spimquant/workflow/rules/vessels.smk index 4528b96..522ecdc 100644 --- a/spimquant/workflow/rules/vessels.smk +++ b/spimquant/workflow/rules/vessels.smk @@ -166,13 +166,13 @@ rule vessel_graph_to_nodes_edges: **inputs["spim"].wildcards, ), output: - nodes_parquet=bids( + nodes_raw_parquet=bids( root=root, datatype="vessels", stain="{stain}", level="{level}", desc="{desc}+skeleton", - suffix="nodes.parquet", + suffix="nodes_raw.parquet", **inputs["spim"].wildcards, ), edges_parquet=bids( @@ -190,3 +190,42 @@ rule vessel_graph_to_nodes_edges: runtime=360, script: "../scripts/convert_vessel_graph_to_nodes_edges.py" + + +rule vessel_graph_connected_components: + """Annotate vessel graph nodes with ranked connected-component labels.""" + input: + nodes_parquet=bids( + root=root, + datatype="vessels", + stain="{stain}", + level="{level}", + desc="{desc}+skeleton", + suffix="nodes_raw.parquet", + **inputs["spim"].wildcards, + ), + edges_parquet=bids( + root=root, + datatype="vessels", + stain="{stain}", + level="{level}", + desc="{desc}+skeleton", + suffix="edges.parquet", + **inputs["spim"].wildcards, + ), + output: + nodes_parquet=bids( + root=root, + datatype="vessels", + stain="{stain}", + level="{level}", + desc="{desc}+skeleton", + suffix="nodes.parquet", + **inputs["spim"].wildcards, + ), + threads: 16 + resources: + mem_mb=64000, + runtime=360, + script: + "../scripts/annotate_vessel_graph_connected_components.py" diff --git a/spimquant/workflow/scripts/annotate_vessel_graph_connected_components.py b/spimquant/workflow/scripts/annotate_vessel_graph_connected_components.py new file mode 100644 index 0000000..98652c1 --- /dev/null +++ b/spimquant/workflow/scripts/annotate_vessel_graph_connected_components.py @@ -0,0 +1,109 @@ +"""Annotate vessel graph nodes with ranked connected-component labels. + +Expected Snakemake I/O +---------------------- +Inputs: + snakemake.input.nodes_parquet + snakemake.input.edges_parquet +Output: + snakemake.output.nodes_parquet +""" + +import pandas as pd + +COMPONENT_COLUMN = "component_label" + + +def annotate_nodes_with_connected_components(nodes_df, edges_df): + """Annotate each node with a ranked connected-component label. + + Components are ranked by descending size (largest first). Ties are resolved + deterministically by the minimum node_id in each component. + Labels start at 1. + """ + annotated = nodes_df.copy() + + if annotated.empty: + annotated[COMPONENT_COLUMN] = pd.Series(dtype="int64") + return annotated + + if "node_id" not in annotated.columns: + raise ValueError("nodes table must contain a node_id column") + + node_ids = annotated["node_id"].astype("int64").tolist() + node_to_idx = {node_id: idx for idx, node_id in enumerate(node_ids)} + + parent = list(range(len(node_ids))) + rank = [0] * len(node_ids) + + def find(i): + while parent[i] != i: + parent[i] = parent[parent[i]] + i = parent[i] + return i + + def union(i, j): + root_i = find(i) + root_j = find(j) + if root_i == root_j: + return + if rank[root_i] < rank[root_j]: + parent[root_i] = root_j + elif rank[root_i] > rank[root_j]: + parent[root_j] = root_i + else: + parent[root_j] = root_i + rank[root_i] += 1 + + if not edges_df.empty: + required_edge_cols = {"src_node_id", "dst_node_id"} + missing = required_edge_cols.difference(edges_df.columns) + if missing: + raise ValueError( + "edges table is missing expected columns: " + ", ".join(sorted(missing)) + ) + + src_nodes = edges_df["src_node_id"].astype("int64") + dst_nodes = edges_df["dst_node_id"].astype("int64") + + missing_node_ids = sorted( + (set(src_nodes.unique()) | set(dst_nodes.unique())).difference(node_to_idx) + ) + if missing_node_ids: + raise ValueError( + "edges reference node_ids not present in nodes table: " + + ", ".join(map(str, missing_node_ids[:10])) + ) + + for src_id, dst_id in zip(src_nodes.tolist(), dst_nodes.tolist()): + union(node_to_idx[src_id], node_to_idx[dst_id]) + + components_by_root = {} + for idx, node_id in enumerate(node_ids): + root = find(idx) + components_by_root.setdefault(root, []).append(node_id) + + ranked_components = sorted( + components_by_root.values(), + key=lambda comp_node_ids: (-len(comp_node_ids), min(comp_node_ids)), + ) + + component_label_by_node_id = {} + for label, component_node_ids in enumerate(ranked_components, start=1): + for node_id in component_node_ids: + component_label_by_node_id[node_id] = label + + annotated[COMPONENT_COLUMN] = ( + annotated["node_id"] + .astype("int64") + .map(component_label_by_node_id) + .astype("int64") + ) + return annotated + + +if __name__ == "__main__": + nodes_df = pd.read_parquet(snakemake.input.nodes_parquet) + edges_df = pd.read_parquet(snakemake.input.edges_parquet) + annotated_nodes = annotate_nodes_with_connected_components(nodes_df, edges_df) + annotated_nodes.to_parquet(snakemake.output.nodes_parquet, index=False) diff --git a/spimquant/workflow/scripts/convert_vessel_graph_to_nodes_edges.py b/spimquant/workflow/scripts/convert_vessel_graph_to_nodes_edges.py index 0558b6d..c245f41 100644 --- a/spimquant/workflow/scripts/convert_vessel_graph_to_nodes_edges.py +++ b/spimquant/workflow/scripts/convert_vessel_graph_to_nodes_edges.py @@ -312,9 +312,9 @@ def write_edges_table_from_parquet( graph_parquet, nodes_parquet, edges_parquet, batch_size=PARQUET_BATCH_SIZE ): """Map edges to node IDs with Dask and write output parquet.""" + import dask.dataframe as dd import pyarrow as pa import pyarrow.parquet as pq - import dask.dataframe as dd edges_ddf = dd.read_parquet( graph_parquet, @@ -405,14 +405,19 @@ def write_edges_table_from_parquet( graph_parquet = snakemake.input.graph_parquet _validate_input_parquet_columns(graph_parquet) + nodes_output = getattr( + snakemake.output, + "nodes_raw_parquet", + getattr(snakemake.output, "nodes_parquet"), + ) scheduler = snakemake.config.get("dask_scheduler", "threads") with get_dask_client(scheduler, snakemake.threads): write_nodes_table_from_parquet( graph_parquet, - snakemake.output.nodes_parquet, + nodes_output, ) write_edges_table_from_parquet( graph_parquet, - snakemake.output.nodes_parquet, + nodes_output, snakemake.output.edges_parquet, ) diff --git a/tests/test_annotate_vessel_graph_connected_components.py b/tests/test_annotate_vessel_graph_connected_components.py new file mode 100644 index 0000000..dee4523 --- /dev/null +++ b/tests/test_annotate_vessel_graph_connected_components.py @@ -0,0 +1,114 @@ +from importlib.util import module_from_spec, spec_from_file_location +from pathlib import Path + +import pandas as pd +import pytest + + +def _find_repo_root(start: Path) -> Path: + current = start.resolve() + for candidate in [current, *current.parents]: + if (candidate / "pyproject.toml").exists(): + return candidate + raise RuntimeError("Could not locate repository root from test path") + + +REPO_ROOT = _find_repo_root(Path(__file__).parent) +SCRIPT_PATH = ( + REPO_ROOT + / "spimquant/workflow/scripts/annotate_vessel_graph_connected_components.py" +) + +spec = spec_from_file_location( + "annotate_vessel_graph_connected_components", SCRIPT_PATH +) +mod = module_from_spec(spec) +spec.loader.exec_module(mod) + + +def _sample_nodes(node_ids): + return pd.DataFrame( + { + "node_id": node_ids, + "channel": [0] * len(node_ids), + "vox_x": node_ids, + "vox_y": [0] * len(node_ids), + "vox_z": [0] * len(node_ids), + "x": [float(n) for n in node_ids], + "y": [0.0] * len(node_ids), + "z": [0.0] * len(node_ids), + "radius": [1.0] * len(node_ids), + } + ) + + +def test_connected_components_ranked_by_size(): + nodes_df = _sample_nodes([0, 1, 2, 3, 4, 5]) + edges_df = pd.DataFrame( + { + "edge_id": [0, 1, 2], + "channel": [0, 0, 0], + "src_node_id": [0, 1, 3], + "dst_node_id": [1, 2, 4], + } + ) + + out = mod.annotate_nodes_with_connected_components(nodes_df, edges_df) + + labels = out.set_index("node_id")[mod.COMPONENT_COLUMN].to_dict() + assert labels[0] == labels[1] == labels[2] == 1 + assert labels[3] == labels[4] == 2 + assert labels[5] == 3 + + +def test_component_tie_breaks_by_min_node_id(): + nodes_df = _sample_nodes([0, 1, 2, 3]) + edges_df = pd.DataFrame( + { + "edge_id": [0, 1], + "channel": [0, 0], + "src_node_id": [0, 2], + "dst_node_id": [1, 3], + } + ) + + out = mod.annotate_nodes_with_connected_components(nodes_df, edges_df) + + labels = out.set_index("node_id")[mod.COMPONENT_COLUMN].to_dict() + assert labels[0] == labels[1] == 1 + assert labels[2] == labels[3] == 2 + + +def test_empty_nodes_keeps_schema_and_is_empty(): + nodes_df = pd.DataFrame( + columns=[ + "node_id", + "channel", + "vox_x", + "vox_y", + "vox_z", + "x", + "y", + "z", + "radius", + ] + ) + edges_df = pd.DataFrame(columns=["src_node_id", "dst_node_id"]) + + out = mod.annotate_nodes_with_connected_components(nodes_df, edges_df) + + assert out.empty + assert mod.COMPONENT_COLUMN in out.columns + + +def test_edges_with_unknown_node_raise_error(): + nodes_df = _sample_nodes([0, 1]) + edges_df = pd.DataFrame( + { + "src_node_id": [0], + "dst_node_id": [99], + } + ) + + with pytest.raises(ValueError, match="edges reference node_ids"): + mod.annotate_nodes_with_connected_components(nodes_df, edges_df) From d3c0409c9f4b95849183eaa14c07fe3c1b5a720d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 18 May 2026 23:20:16 +0000 Subject: [PATCH 2/3] Document union-find helpers in vessel component annotation script Agent-Logs-Url: https://github.com/khanlab/SPIMquant/sessions/1da3e99d-a7e3-4503-9a58-4625aeaefe07 Co-authored-by: akhanf <11492701+akhanf@users.noreply.github.com> --- .../scripts/annotate_vessel_graph_connected_components.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/spimquant/workflow/scripts/annotate_vessel_graph_connected_components.py b/spimquant/workflow/scripts/annotate_vessel_graph_connected_components.py index 98652c1..c04c4b8 100644 --- a/spimquant/workflow/scripts/annotate_vessel_graph_connected_components.py +++ b/spimquant/workflow/scripts/annotate_vessel_graph_connected_components.py @@ -37,12 +37,14 @@ def annotate_nodes_with_connected_components(nodes_df, edges_df): rank = [0] * len(node_ids) def find(i): + """Return root index with path compression for near-constant lookups.""" while parent[i] != i: parent[i] = parent[parent[i]] i = parent[i] return i def union(i, j): + """Merge two sets using union-by-rank to keep trees shallow.""" root_i = find(i) root_j = find(j) if root_i == root_j: From 39404521e5272be385cb3d422b8ac97b9a8195ca Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 18 May 2026 23:21:01 +0000 Subject: [PATCH 3/3] Clarify output fallback logic and refine CC docstring wording Agent-Logs-Url: https://github.com/khanlab/SPIMquant/sessions/1da3e99d-a7e3-4503-9a58-4625aeaefe07 Co-authored-by: akhanf <11492701+akhanf@users.noreply.github.com> --- .../annotate_vessel_graph_connected_components.py | 2 +- .../scripts/convert_vessel_graph_to_nodes_edges.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/spimquant/workflow/scripts/annotate_vessel_graph_connected_components.py b/spimquant/workflow/scripts/annotate_vessel_graph_connected_components.py index c04c4b8..ac7384f 100644 --- a/spimquant/workflow/scripts/annotate_vessel_graph_connected_components.py +++ b/spimquant/workflow/scripts/annotate_vessel_graph_connected_components.py @@ -37,7 +37,7 @@ def annotate_nodes_with_connected_components(nodes_df, edges_df): rank = [0] * len(node_ids) def find(i): - """Return root index with path compression for near-constant lookups.""" + """Return root index with path compression for amortized near-constant lookups.""" while parent[i] != i: parent[i] = parent[parent[i]] i = parent[i] diff --git a/spimquant/workflow/scripts/convert_vessel_graph_to_nodes_edges.py b/spimquant/workflow/scripts/convert_vessel_graph_to_nodes_edges.py index c245f41..a1edf0c 100644 --- a/spimquant/workflow/scripts/convert_vessel_graph_to_nodes_edges.py +++ b/spimquant/workflow/scripts/convert_vessel_graph_to_nodes_edges.py @@ -405,11 +405,12 @@ def write_edges_table_from_parquet( graph_parquet = snakemake.input.graph_parquet _validate_input_parquet_columns(graph_parquet) - nodes_output = getattr( - snakemake.output, - "nodes_raw_parquet", - getattr(snakemake.output, "nodes_parquet"), - ) + if hasattr(snakemake.output, "nodes_raw_parquet"): + # Supports the newer vessel rule that writes intermediate raw nodes. + nodes_output = snakemake.output.nodes_raw_parquet + else: + # Backward compatibility for direct/script usage expecting nodes_parquet. + nodes_output = snakemake.output.nodes_parquet scheduler = snakemake.config.get("dask_scheduler", "threads") with get_dask_client(scheduler, snakemake.threads): write_nodes_table_from_parquet(