Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions spimquant/workflow/rules/vessels.smk
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,13 @@ rule vessel_graph_to_nodes_edges:
**inputs["spim"].wildcards,
),
output:
nodes_parquet=bids(
nodes_raw_parquet=bids(
root=root,
datatype="vessels",
stain="{stain}",
level="{level}",
desc="{desc}+skeleton",
suffix="nodes.parquet",
suffix="nodes_raw.parquet",
**inputs["spim"].wildcards,
),
edges_parquet=bids(
Expand All @@ -190,3 +190,42 @@ rule vessel_graph_to_nodes_edges:
runtime=360,
script:
"../scripts/convert_vessel_graph_to_nodes_edges.py"


rule vessel_graph_connected_components:
"""Annotate vessel graph nodes with ranked connected-component labels."""
input:
nodes_parquet=bids(
root=root,
datatype="vessels",
stain="{stain}",
level="{level}",
desc="{desc}+skeleton",
suffix="nodes_raw.parquet",
**inputs["spim"].wildcards,
),
edges_parquet=bids(
root=root,
datatype="vessels",
stain="{stain}",
level="{level}",
desc="{desc}+skeleton",
suffix="edges.parquet",
**inputs["spim"].wildcards,
),
output:
nodes_parquet=bids(
root=root,
datatype="vessels",
stain="{stain}",
level="{level}",
desc="{desc}+skeleton",
suffix="nodes.parquet",
**inputs["spim"].wildcards,
),
threads: 16
resources:
mem_mb=64000,
runtime=360,
script:
"../scripts/annotate_vessel_graph_connected_components.py"
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Annotate vessel graph nodes with ranked connected-component labels.

Expected Snakemake I/O
----------------------
Inputs:
snakemake.input.nodes_parquet
snakemake.input.edges_parquet
Output:
snakemake.output.nodes_parquet
"""

import pandas as pd

COMPONENT_COLUMN = "component_label"


def annotate_nodes_with_connected_components(nodes_df, edges_df):
"""Annotate each node with a ranked connected-component label.

Components are ranked by descending size (largest first). Ties are resolved
deterministically by the minimum node_id in each component.
Labels start at 1.
"""
annotated = nodes_df.copy()

if annotated.empty:
annotated[COMPONENT_COLUMN] = pd.Series(dtype="int64")
return annotated

if "node_id" not in annotated.columns:
raise ValueError("nodes table must contain a node_id column")

node_ids = annotated["node_id"].astype("int64").tolist()
node_to_idx = {node_id: idx for idx, node_id in enumerate(node_ids)}

parent = list(range(len(node_ids)))
rank = [0] * len(node_ids)

def find(i):
"""Return root index with path compression for amortized near-constant lookups."""
while parent[i] != i:
parent[i] = parent[parent[i]]
i = parent[i]
return i

def union(i, j):
"""Merge two sets using union-by-rank to keep trees shallow."""
root_i = find(i)
root_j = find(j)
if root_i == root_j:
return
if rank[root_i] < rank[root_j]:
parent[root_i] = root_j
elif rank[root_i] > rank[root_j]:
parent[root_j] = root_i
else:
parent[root_j] = root_i
rank[root_i] += 1

if not edges_df.empty:
required_edge_cols = {"src_node_id", "dst_node_id"}
missing = required_edge_cols.difference(edges_df.columns)
if missing:
raise ValueError(
"edges table is missing expected columns: " + ", ".join(sorted(missing))
)

src_nodes = edges_df["src_node_id"].astype("int64")
dst_nodes = edges_df["dst_node_id"].astype("int64")

missing_node_ids = sorted(
(set(src_nodes.unique()) | set(dst_nodes.unique())).difference(node_to_idx)
)
if missing_node_ids:
raise ValueError(
"edges reference node_ids not present in nodes table: "
+ ", ".join(map(str, missing_node_ids[:10]))
)

for src_id, dst_id in zip(src_nodes.tolist(), dst_nodes.tolist()):
union(node_to_idx[src_id], node_to_idx[dst_id])

components_by_root = {}
for idx, node_id in enumerate(node_ids):
root = find(idx)
components_by_root.setdefault(root, []).append(node_id)

ranked_components = sorted(
components_by_root.values(),
key=lambda comp_node_ids: (-len(comp_node_ids), min(comp_node_ids)),
)

component_label_by_node_id = {}
for label, component_node_ids in enumerate(ranked_components, start=1):
for node_id in component_node_ids:
component_label_by_node_id[node_id] = label

annotated[COMPONENT_COLUMN] = (
annotated["node_id"]
.astype("int64")
.map(component_label_by_node_id)
.astype("int64")
)
return annotated


if __name__ == "__main__":
nodes_df = pd.read_parquet(snakemake.input.nodes_parquet)
edges_df = pd.read_parquet(snakemake.input.edges_parquet)
annotated_nodes = annotate_nodes_with_connected_components(nodes_df, edges_df)
annotated_nodes.to_parquet(snakemake.output.nodes_parquet, index=False)
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,9 @@ def write_edges_table_from_parquet(
graph_parquet, nodes_parquet, edges_parquet, batch_size=PARQUET_BATCH_SIZE
):
"""Map edges to node IDs with Dask and write output parquet."""
import dask.dataframe as dd
import pyarrow as pa
import pyarrow.parquet as pq
import dask.dataframe as dd

edges_ddf = dd.read_parquet(
graph_parquet,
Expand Down Expand Up @@ -405,14 +405,20 @@ def write_edges_table_from_parquet(

graph_parquet = snakemake.input.graph_parquet
_validate_input_parquet_columns(graph_parquet)
if hasattr(snakemake.output, "nodes_raw_parquet"):
# Supports the newer vessel rule that writes intermediate raw nodes.
nodes_output = snakemake.output.nodes_raw_parquet
else:
# Backward compatibility for direct/script usage expecting nodes_parquet.
nodes_output = snakemake.output.nodes_parquet
scheduler = snakemake.config.get("dask_scheduler", "threads")
with get_dask_client(scheduler, snakemake.threads):
write_nodes_table_from_parquet(
graph_parquet,
snakemake.output.nodes_parquet,
nodes_output,
)
write_edges_table_from_parquet(
graph_parquet,
snakemake.output.nodes_parquet,
nodes_output,
snakemake.output.edges_parquet,
)
114 changes: 114 additions & 0 deletions tests/test_annotate_vessel_graph_connected_components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path

import pandas as pd
import pytest


def _find_repo_root(start: Path) -> Path:
current = start.resolve()
for candidate in [current, *current.parents]:
if (candidate / "pyproject.toml").exists():
return candidate
raise RuntimeError("Could not locate repository root from test path")


REPO_ROOT = _find_repo_root(Path(__file__).parent)
SCRIPT_PATH = (
REPO_ROOT
/ "spimquant/workflow/scripts/annotate_vessel_graph_connected_components.py"
)

spec = spec_from_file_location(
"annotate_vessel_graph_connected_components", SCRIPT_PATH
)
mod = module_from_spec(spec)
spec.loader.exec_module(mod)


def _sample_nodes(node_ids):
return pd.DataFrame(
{
"node_id": node_ids,
"channel": [0] * len(node_ids),
"vox_x": node_ids,
"vox_y": [0] * len(node_ids),
"vox_z": [0] * len(node_ids),
"x": [float(n) for n in node_ids],
"y": [0.0] * len(node_ids),
"z": [0.0] * len(node_ids),
"radius": [1.0] * len(node_ids),
}
)


def test_connected_components_ranked_by_size():
nodes_df = _sample_nodes([0, 1, 2, 3, 4, 5])
edges_df = pd.DataFrame(
{
"edge_id": [0, 1, 2],
"channel": [0, 0, 0],
"src_node_id": [0, 1, 3],
"dst_node_id": [1, 2, 4],
}
)

out = mod.annotate_nodes_with_connected_components(nodes_df, edges_df)

labels = out.set_index("node_id")[mod.COMPONENT_COLUMN].to_dict()
assert labels[0] == labels[1] == labels[2] == 1
assert labels[3] == labels[4] == 2
assert labels[5] == 3


def test_component_tie_breaks_by_min_node_id():
nodes_df = _sample_nodes([0, 1, 2, 3])
edges_df = pd.DataFrame(
{
"edge_id": [0, 1],
"channel": [0, 0],
"src_node_id": [0, 2],
"dst_node_id": [1, 3],
}
)

out = mod.annotate_nodes_with_connected_components(nodes_df, edges_df)

labels = out.set_index("node_id")[mod.COMPONENT_COLUMN].to_dict()
assert labels[0] == labels[1] == 1
assert labels[2] == labels[3] == 2


def test_empty_nodes_keeps_schema_and_is_empty():
nodes_df = pd.DataFrame(
columns=[
"node_id",
"channel",
"vox_x",
"vox_y",
"vox_z",
"x",
"y",
"z",
"radius",
]
)
edges_df = pd.DataFrame(columns=["src_node_id", "dst_node_id"])

out = mod.annotate_nodes_with_connected_components(nodes_df, edges_df)

assert out.empty
assert mod.COMPONENT_COLUMN in out.columns


def test_edges_with_unknown_node_raise_error():
nodes_df = _sample_nodes([0, 1])
edges_df = pd.DataFrame(
{
"src_node_id": [0],
"dst_node_id": [99],
}
)

with pytest.raises(ValueError, match="edges reference node_ids"):
mod.annotate_nodes_with_connected_components(nodes_df, edges_df)