diff --git a/docs/howto/segmentation.md b/docs/howto/segmentation.md index 84bcbf2..5a4eee1 100644 --- a/docs/howto/segmentation.md +++ b/docs/howto/segmentation.md @@ -10,10 +10,12 @@ flowchart LR B --> C{seg_method} C -->|threshold| D[Threshold] C -->|otsu+k3i2| E[Multi-Otsu + k-means] - D --> F[Binary Mask] - E --> F - F --> G[Clean: remove edge / small objects] - G --> H[Segmentation Output] + C -->|groupotsu+k3i2| F[Group Multi-Otsu] + D --> G[Binary Mask] + E --> G + F --> G + G --> H[Clean: remove edge / small objects] + H --> I[Segmentation Output] ``` After segmentation the binary mask is used to compute [field fraction](../reference/outputs.md#field-fraction-map-seg), [object count](../reference/outputs.md#object-count-map-seg), and [per-object region properties](../reference/outputs.md#region-properties-statistics-table-tabular). @@ -81,6 +83,45 @@ stain_defaults: **Limitations:** Can fail on images with unusual histograms (e.g. very sparse pathology that does not form a distinct peak) or when the background is very noisy. +### Group Multi-Otsu (`seg_method: groupotsu+k3i2`) + +A variant of Multi-Otsu that derives a **single shared threshold from the aggregate histogram of all subjects** rather than computing a threshold independently per image. This is preferred when subjects were acquired with common acquisition settings and you want to ensure consistent, comparable quantification across the cohort. + +The workflow is a two-step process: + +**Step 1 — compute group threshold** (run once for the whole cohort): + +```bash +spimquant /bids /output participant \ + --targets all_group_otsu \ + --seg_method groupotsu+k3i2 +``` + +This triggers: + +1. For each subject: compute a percentile-clipped intensity histogram from the bias-field corrected image and save it as an NPZ file. +2. Aggregate all subject histograms onto a common intensity grid, apply multi-level Otsu thresholding, and save the resulting thresholds as a JSON file in `{output}/group/`. + +**Step 2 — segment each subject using the group threshold**: + +```bash +spimquant /bids /output participant \ + --seg_method groupotsu+k3i2 +``` + +Each subject's binary mask is produced by applying the group-level threshold from the JSON file. A per-subject PNG is also generated showing the group threshold overlaid on the individual histogram, useful for visual quality control. + +**Config key:** + +```yaml +seg_method: + - groupotsu+k3i2 +``` + +**When to use:** Preferred when a batch of subjects shares the same acquisition protocol and you want consistent thresholding across subjects. Reduces subject-to-subject variability in the segmentation boundary that can occur with per-subject Otsu. + +**Limitations:** Less adaptive than per-subject Otsu — if staining intensity varies substantially across subjects (e.g. due to different batches of antibody or tissue preparation), a single group threshold may over- or under-segment some subjects. + --- ## Post-Segmentation Cleaning diff --git a/spimquant/workflow/Snakefile b/spimquant/workflow/Snakefile index 0a123fa..b7fcb1c 100644 --- a/spimquant/workflow/Snakefile +++ b/spimquant/workflow/Snakefile @@ -85,6 +85,9 @@ stains_for_seg = list(set(config["stains_for_seg"]).intersection(set(stains))) # seg methods that use multi-Otsu thresholding (otsu+k{}i{} pattern) otsu_seg_methods = [m for m in config["seg_method"] if m.startswith("otsu+")] +# seg methods that use group-level multi-Otsu thresholding (groupotsu+k{}i{} pattern) +groupotsu_seg_methods = [m for m in config["seg_method"] if m.startswith("groupotsu+")] + if len(stains_for_seg) == 0 or config["no_segmentation"]: do_seg = False do_coloc = False @@ -908,6 +911,43 @@ rule all_group: rules.all_group_stats_coloc.input if do_coloc else [], +rule all_group_otsu: + """Target rule for group-level Otsu threshold computation. + + Aggregates intensity histograms across all subjects for each stain and + computes a single set of Otsu thresholds that can be applied consistently + to every subject. Run this rule before participant-level segmentation + when ``groupotsu+k{}i{}`` is included in ``seg_method``. + + Example:: + + spimquant /bids /output participant \\ + --targets all_group_otsu \\ + --seg_method groupotsu+k3i2 + + Then run segmentation using the computed group thresholds:: + + spimquant /bids /output participant \\ + --seg_method groupotsu+k3i2 + """ + input: + expand( + bids( + root=root, + datatype="group", + stain="{stain}", + level="{level}", + desc="{desc}", + suffix="thresholds.json", + ), + stain=stains_for_seg, + level=config["segmentation_level"], + desc=groupotsu_seg_methods, + ) + if (do_seg and groupotsu_seg_methods) + else [], + + include: "rules/import.smk" include: "rules/masking.smk" include: "rules/templatereg.smk" diff --git a/spimquant/workflow/rules/groupstats.smk b/spimquant/workflow/rules/groupstats.smk index 3e41c6f..84603bf 100644 --- a/spimquant/workflow/rules/groupstats.smk +++ b/spimquant/workflow/rules/groupstats.smk @@ -4,9 +4,67 @@ Group-level statistical analysis rules for SPIMquant. This module performs group-based statistical tests on segmentation statistics (e.g., fieldfrac, density, volume) across participants, using metadata from participants.tsv to define contrasts. + +It also provides the ``group_otsu`` rule which aggregates per-subject intensity +histograms (produced by ``compute_subject_histogram``) to compute a single set +of Otsu thresholds shared across all subjects. """ +rule group_otsu: + """Compute group-level Otsu thresholds from aggregated per-subject histograms. + + Collects the intensity histogram NPZ files computed by + ``compute_subject_histogram`` for all subjects, merges them onto a + common intensity grid, and applies multi-level Otsu thresholding to the + aggregate histogram. The resulting thresholds are saved as a JSON file + (consumed by ``multiotsu_group`` during participant-level segmentation) + and as a PNG figure for visual inspection. + + This rule is the target of ``all_group_otsu`` and should be run before + participant-level segmentation when ``groupotsu+k{}i{}`` is used as the + segmentation method. + """ + input: + histogram_npz=lambda wildcards: inputs["spim"].expand( + bids( + root=work, + datatype="seg", + stain=wildcards.stain, + level=wildcards.level, + desc="groupotsu+k{k}i{i}".format(k=wildcards.k, i=wildcards.i), + suffix="histogram.npz", + **inputs["spim"].wildcards, + ) + ), + params: + otsu_k=lambda wildcards: int(wildcards.k), + otsu_threshold_index=lambda wildcards: int(wildcards.i), + output: + thresholds_json=bids( + root=root, + datatype="group", + stain="{stain}", + level="{level}", + desc="groupotsu+k{k,[0-9]+}i{i,[0-9]+}", + suffix="thresholds.json", + ), + thresholds_png=bids( + root=root, + datatype="group", + stain="{stain}", + level="{level}", + desc="groupotsu+k{k,[0-9]+}i{i,[0-9]+}", + suffix="thresholds.png", + ), + threads: 4 + resources: + mem_mb=8000, + runtime=10, + script: + "../scripts/group_otsu.py" + + rule perform_group_stats: """Perform group-based statistical tests on segmentation statistics. diff --git a/spimquant/workflow/rules/segmentation.smk b/spimquant/workflow/rules/segmentation.smk index 70904b6..c29c9ab 100644 --- a/spimquant/workflow/rules/segmentation.smk +++ b/spimquant/workflow/rules/segmentation.smk @@ -78,19 +78,14 @@ rule n4_biasfield: shrink_factor=16 if config["sloppy"] else 1, target_chunk_size=512, #this sets the chunk size for this and downstream masks output: - corrected=temp( - directory( - bids( - root=work, + corrected=bids( + root=root, datatype="seg", stain="{stain}", level="{level}", desc="correctedn4", - suffix="SPIM.ome.zarr", + suffix="SPIM.ozx", **inputs["spim"].wildcards, - ) - ), - group_jobs=True, ), threads: 128 if config["dask_scheduler"] == "distributed" else 32 resources: @@ -110,12 +105,12 @@ rule multiotsu: """ input: corrected=bids( - root=work, + root=root, datatype="seg", stain="{stain}", level="{level}", desc="corrected{method}".format(method=config["correction_method"]), - suffix="SPIM.ome.zarr", + suffix="SPIM.ozx", **inputs["spim"].wildcards, ), params: @@ -152,6 +147,109 @@ rule multiotsu: "../scripts/multiotsu.py" +rule compute_subject_histogram: + """Compute intensity histogram for a single subject for group-level Otsu thresholding. + + Calculates a percentile-clipped histogram of the bias-field corrected image + and saves it as an NPZ file (hist_counts + bin_edges). These per-subject + histogram files are later aggregated by the ``group_otsu`` rule to derive + a single set of thresholds that can be applied consistently across the whole + cohort with the ``multiotsu_group`` rule. + """ + input: + corrected=bids( + root=root, + datatype="seg", + stain="{stain}", + level="{level}", + desc="corrected{method}".format(method=config["correction_method"]), + suffix="SPIM.ozx", + **inputs["spim"].wildcards, + ), + params: + hist_bin_width=float(config["seg_hist_bin_width"]), + hist_percentile_range=[float(x) for x in config["seg_hist_percentile_range"]], + zarrnii_kwargs={"orientation": config["orientation"]}, + output: + histogram_npz=bids( + root=work, + datatype="seg", + stain="{stain}", + level="{level}", + desc="groupotsu+k{k,[0-9]+}i{i,[0-9]+}", + suffix="histogram.npz", + **inputs["spim"].wildcards, + ), + threads: 128 if config["dask_scheduler"] == "distributed" else 32 + resources: + mem_mb=500000 if config["dask_scheduler"] == "distributed" else 250000, + runtime=90, + script: + "../scripts/compute_subject_histogram.py" + + +rule multiotsu_group: + """Apply group-level Otsu threshold for segmentation. + + Uses a pre-computed group-level Otsu threshold (derived from an aggregate + histogram across all subjects) to create a binary mask for the current + subject. This ensures consistent thresholding across the whole cohort. + A per-subject PNG is also produced showing the group threshold overlaid + on this subject's own intensity histogram for visual quality control. + + Run ``all_group_otsu`` before using this rule so that the group threshold + JSON file is available. + """ + input: + corrected=bids( + root=root, + datatype="seg", + stain="{stain}", + level="{level}", + desc="corrected{method}".format(method=config["correction_method"]), + suffix="SPIM.ozx", + **inputs["spim"].wildcards, + ), + thresholds_json=bids( + root=root, + datatype="group", + stain="{stain}", + level="{level}", + desc="groupotsu+k{k,[0-9]+}i{i,[0-9]+}", + suffix="thresholds.json", + ), + params: + hist_bin_width=float(config["seg_hist_bin_width"]), + hist_percentile_range=[float(x) for x in config["seg_hist_percentile_range"]], + zarrnii_kwargs={"orientation": config["orientation"]}, + output: + mask=bids( + root=root, + datatype="seg", + stain="{stain}", + level="{level}", + desc="groupotsu+k{k,[0-9]+}i{i,[0-9]+}", + suffix="mask.ozx", + **inputs["spim"].wildcards, + ), + thresholds_png=bids( + root=root, + datatype="seg", + stain="{stain}", + level="{level}", + desc="groupotsu+k{k,[0-9]+}i{i,[0-9]+}", + suffix="thresholds.png", + **inputs["spim"].wildcards, + ), + threads: 128 if config["dask_scheduler"] == "distributed" else 32 + resources: + mem_mb=500000 if config["dask_scheduler"] == "distributed" else 250000, + disk_mb=2097152, + runtime=180, + script: + "../scripts/multiotsu_group.py" + + rule threshold: """Apply simple intensity threshold for segmentation. @@ -160,12 +258,12 @@ rule threshold: """ input: corrected=bids( - root=work, + root=root, datatype="seg", stain="{stain}", level="{level}", desc="corrected{method}".format(method=config["correction_method"]), - suffix="SPIM.ome.zarr", + suffix="SPIM.ozx", **inputs["spim"].wildcards, ), params: diff --git a/spimquant/workflow/scripts/compute_subject_histogram.py b/spimquant/workflow/scripts/compute_subject_histogram.py new file mode 100644 index 0000000..d34dc68 --- /dev/null +++ b/spimquant/workflow/scripts/compute_subject_histogram.py @@ -0,0 +1,65 @@ +"""Compute and save an intensity histogram for a single subject. + +Used as the first step of group-level Otsu thresholding. Each subject +independently computes its histogram from the bias-field corrected image; +the resulting per-subject NPZ files are later aggregated by the +``group_otsu`` rule to derive a single set of thresholds shared across the +whole cohort. + +This is a Snakemake script; the ``snakemake`` object is automatically +provided when executed as part of a Snakemake workflow. +""" + +import numpy as np + +from dask_setup import get_dask_client +from zarrnii import ZarrNii + +if __name__ == "__main__": + with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): + + zarrnii_kwargs = snakemake.params.zarrnii_kwargs + pct_lo, pct_hi = snakemake.params.hist_percentile_range + bin_width = snakemake.params.hist_bin_width + + # load a downsampled version to estimate the percentile-based range + print("estimating intensity range from downsampled image...") + znimg_ds = None + for ds_level in [5, 4, 3, 2, 1]: + try: + candidate = ZarrNii.from_ome_zarr( + snakemake.input.corrected, level=ds_level, **zarrnii_kwargs + ) + znimg_ds = candidate + break + except Exception as e: + print(f" level {ds_level} not available ({e}), trying lower level") + + if znimg_ds is None: + znimg_ds = ZarrNii.from_ome_zarr( + snakemake.input.corrected, **zarrnii_kwargs + ) + + data_ds = znimg_ds.data.compute().ravel().astype(np.float32) + range_lo = float(np.percentile(data_ds, pct_lo)) + range_hi = float(np.percentile(data_ds, pct_hi)) + print( + f" 📊 percentile range [{pct_lo}%, {pct_hi}%]: [{range_lo:.3f}, {range_hi:.3f}]" + ) + + # compute number of bins from bin width + n_bins = max(2, int(np.ceil((range_hi - range_lo) / bin_width))) + print(f" 📊 bins: {n_bins} (bin width: {bin_width})") + + # compute full-resolution histogram + znimg = ZarrNii.from_ome_zarr(snakemake.input.corrected, **zarrnii_kwargs) + (hist_counts, bin_edges) = znimg.compute_histogram( + bins=n_bins, range=[range_lo, range_hi] + ) + + print(f"saving histogram to {snakemake.output.histogram_npz}") + np.savez( + snakemake.output.histogram_npz, + hist_counts=hist_counts, + bin_edges=bin_edges, + ) diff --git a/spimquant/workflow/scripts/group_otsu.py b/spimquant/workflow/scripts/group_otsu.py new file mode 100644 index 0000000..8fd6b93 --- /dev/null +++ b/spimquant/workflow/scripts/group_otsu.py @@ -0,0 +1,83 @@ +"""Aggregate per-subject histograms and compute group-level Otsu thresholds. + +Loads the intensity histogram NPZ files produced by ``compute_subject_histogram`` +for every subject, merges them onto a common intensity grid, and applies +multi-level Otsu thresholding to the aggregate histogram. The resulting +thresholds are saved as a JSON file (for downstream use by ``multiotsu_group``) +and as a PNG figure for visual inspection. + +This is a Snakemake script; the ``snakemake`` object is automatically +provided when executed as part of a Snakemake workflow. +""" + +import json + +import matplotlib + +matplotlib.use("agg") +import numpy as np + +from zarrnii.analysis import compute_otsu_thresholds + +if __name__ == "__main__": + # Load all per-subject histograms + histograms = [] + for path in snakemake.input.histogram_npz: + data = np.load(path) + histograms.append((data["hist_counts"], data["bin_edges"])) + + print(f"Loaded {len(histograms)} subject histograms") + + # Find the common intensity range spanning all subjects + overall_lo = min(float(be[0]) for _, be in histograms) + overall_hi = max(float(be[-1]) for _, be in histograms) + + # Use the first histogram's bin width as reference for the common grid + ref_bin_edges = histograms[0][1] + bin_width = float(ref_bin_edges[1] - ref_bin_edges[0]) + + n_bins = max(2, int(np.ceil((overall_hi - overall_lo) / bin_width))) + common_bin_edges = np.linspace(overall_lo, overall_hi, n_bins + 1) + + print( + f" 📊 common range: [{overall_lo:.3f}, {overall_hi:.3f}], bins: {n_bins}" + ) + + # Aggregate all subject histograms onto the common grid + aggregate_counts = np.zeros(n_bins, dtype=np.float64) + for hist_counts, bin_edges in histograms: + bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) + # Map each original bin center to the nearest common bin + common_indices = np.searchsorted(common_bin_edges[1:], bin_centers) + common_indices = np.clip(common_indices, 0, n_bins - 1) + np.add.at(aggregate_counts, common_indices, hist_counts) + + # Apply multi-level Otsu thresholding to the aggregated histogram + print("computing group Otsu thresholds") + (thresholds, fig) = compute_otsu_thresholds( + aggregate_counts, + classes=snakemake.params.otsu_k, + bin_edges=common_bin_edges, + return_figure=True, + ) + print(f" 📈 group thresholds: {[f'{t:.3f}' for t in thresholds]}") + + otsu_threshold_index = snakemake.params.otsu_threshold_index + selected_threshold = float(thresholds[otsu_threshold_index]) + print( + f" ✅ selected threshold (index {otsu_threshold_index}): {selected_threshold:.3f}" + ) + + fig.savefig(snakemake.output.thresholds_png) + + # Save thresholds as JSON for use by the per-subject segmentation rule + result = { + "thresholds": thresholds.tolist(), + "otsu_threshold_index": otsu_threshold_index, + "selected_threshold": selected_threshold, + "n_subjects": len(histograms), + } + with open(snakemake.output.thresholds_json, "w") as f: + json.dump(result, f, indent=2) + + print(f"saved group thresholds to {snakemake.output.thresholds_json}") diff --git a/spimquant/workflow/scripts/multiotsu_group.py b/spimquant/workflow/scripts/multiotsu_group.py new file mode 100644 index 0000000..a974edc --- /dev/null +++ b/spimquant/workflow/scripts/multiotsu_group.py @@ -0,0 +1,109 @@ +"""Apply a precomputed group-level Otsu threshold to a single subject. + +Reads the group threshold JSON produced by ``group_otsu``, applies it to +the bias-field corrected image for the current subject, and writes the +binary mask. A per-subject PNG is also saved that overlays the group +threshold on this subject's own intensity histogram, which is useful for +visual quality control. + +This is a Snakemake script; the ``snakemake`` object is automatically +provided when executed as part of a Snakemake workflow. +""" + +import json + +import matplotlib + +matplotlib.use("agg") +import matplotlib.pyplot as plt +import numpy as np + +from dask_setup import get_dask_client +from zarrnii import ZarrNii + +if __name__ == "__main__": + with get_dask_client(snakemake.config["dask_scheduler"], snakemake.threads): + + # Load group threshold from JSON + with open(snakemake.input.thresholds_json) as f: + group_data = json.load(f) + + all_thresholds = group_data["thresholds"] + otsu_threshold_index = group_data["otsu_threshold_index"] + selected_threshold = group_data["selected_threshold"] + + print(f" 📈 group thresholds: {[f'{t:.3f}' for t in all_thresholds]}") + print(f" ✅ applying threshold: {selected_threshold:.3f}") + + zarrnii_kwargs = snakemake.params.zarrnii_kwargs + pct_lo, pct_hi = snakemake.params.hist_percentile_range + bin_width = snakemake.params.hist_bin_width + + # Load a downsampled version to estimate the percentile-based range for + # the per-subject histogram visualisation + print("estimating intensity range from downsampled image for QC figure...") + znimg_ds = None + for ds_level in [5, 4, 3, 2, 1]: + try: + candidate = ZarrNii.from_ome_zarr( + snakemake.input.corrected, level=ds_level, **zarrnii_kwargs + ) + znimg_ds = candidate + break + except Exception as e: + print(f" level {ds_level} not available ({e}), trying lower level") + + if znimg_ds is None: + znimg_ds = ZarrNii.from_ome_zarr( + snakemake.input.corrected, **zarrnii_kwargs + ) + + data_ds = znimg_ds.data.compute().ravel().astype(np.float32) + range_lo = float(np.percentile(data_ds, pct_lo)) + range_hi = float(np.percentile(data_ds, pct_hi)) + + n_bins = max(2, int(np.ceil((range_hi - range_lo) / bin_width))) + + # Compute full-resolution histogram for this subject + znimg = ZarrNii.from_ome_zarr(snakemake.input.corrected, **zarrnii_kwargs) + (hist_counts, bin_edges) = znimg.compute_histogram( + bins=n_bins, range=[range_lo, range_hi] + ) + + # Generate per-subject visualisation: subject histogram + group threshold + bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) + fig, ax = plt.subplots(figsize=(8, 4)) + ax.bar( + bin_centers, + hist_counts, + width=bin_width, + color="steelblue", + alpha=0.7, + label="Subject histogram", + ) + for j, t in enumerate(all_thresholds): + linestyle = "-" if j == otsu_threshold_index else "--" + label = f"Group threshold[{j}]={t:.3f}" + if j == otsu_threshold_index: + label += " (selected)" + ax.axvline(t, color="red", linestyle=linestyle, label=label) + ax.set_xlabel("Intensity") + ax.set_ylabel("Count") + ax.set_title( + f"Subject histogram with group Otsu threshold " + f"(selected: {selected_threshold:.3f})" + ) + ax.legend() + fig.tight_layout() + fig.savefig(snakemake.output.thresholds_png) + plt.close(fig) + + # Apply the group threshold to create the binary mask + print("thresholding image with group threshold, saving as ome zarr") + znimg_mask = znimg.segment_threshold(selected_threshold) + + # Multiply by 100 (values 0 and 100) to enable field-fraction + # calculation by subsequent local-mean downsampling + znimg_mask = znimg_mask * 100 + + znimg_mask.to_ome_zarr(snakemake.output.mask, max_layer=5)