diff --git a/README.md b/README.md index 87a1775..7374ad3 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,7 @@ +[![arXiv](https://img.shields.io/badge/Paper-10.1007-blue)](https://arxiv.org/abs/2605.03098) +[![Python Versions](https://img.shields.io/pypi/pyversions/spineps)](https://pypi.org/project/spineps/) + [![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0) + # AugLab This repository investigates the influence of different data augmentation strategies on MRI training performance. @@ -120,3 +124,16 @@ Scripts developped in this repository use JSON files to specify image and segmen To track parameters used during data augmentation, JSON files are also used: see this [example](https://github.com/neuropoly/AugLab/blob/16653a84e031c40e25a72e946c2724494606b21c/auglab/configs/transform_params.json) + +## Citation + +If you use AugLab, please make sure to cite the following paper: + +``` +@article{molinier2026one, + title={One Sequence to Segment Them All: Efficient Data Augmentation for CT and MRI Cross-Domain 3D Spine Segmentation}, + author={Molinier, Nathan and M{\"o}ller, Hendrik and Dagonneau, Thomas and Curto-Vilalta, Anna and Graf, Robert and Atad, Matan and Rueckert, Daniel and Kirschke, Jan S and Cohen-Adad, Julien}, + journal={arXiv preprint arXiv:2605.03098}, + year={2026} +} +``` \ No newline at end of file diff --git a/auglab/configs/transform_params_gpu.json b/auglab/configs/transform_params_gpu.json index 291b652..d1fd109 100644 --- a/auglab/configs/transform_params_gpu.json +++ b/auglab/configs/transform_params_gpu.json @@ -1,4 +1,9 @@ { + "ImageContrastGPUTransform": { + "label_classes": [1,2,3], + "num_bins": 32, + "probability": 0.25 + }, "ScharrTransform": { "kernel_type": "Scharr", "absolute": true, @@ -43,6 +48,8 @@ "RedistributeSegTransform": { "in_seg": 0.25, "retain_stats": true, + "std_noise_range": [0.1, 0.3], + "dilation_iterations_range": [1, 5], "probability": 0.4 }, "GaussianNoiseTransform": { diff --git a/auglab/configs/transform_params_gpu_default01-23_ImageContrastV26_6_2GPUTransform.json b/auglab/configs/transform_params_gpu_default01-23_ImageContrastV26_6_2GPUTransform.json new file mode 100644 index 0000000..0afb60a --- /dev/null +++ b/auglab/configs/transform_params_gpu_default01-23_ImageContrastV26_6_2GPUTransform.json @@ -0,0 +1,179 @@ +{ + "ImageContrastV26_6_2GPUTransform": { + "probability": 0.5, + "c_choices": [2, 3, 4, 5, 6], + "s_choices": [2, 3, 4, 5, 6, 7, 8, 9, 10], + "blur_sigmas": [0.0, 0.0, 0.0, 0.3, 0.5, 0.8], + "dark_threshold": 0.01, + "n_kmeans_subsample": 10000, + "skip_parcellation_prob": 0.10, + "skip_sub_parc_prob": 0.40, + "alpha_magnitude_range": [0.5, 2.0], + "label_remap_prob": 0.5, + "min_label_voxels": 4, + "label_classes": null + }, + "ScharrTransform": { + "kernel_type": "Scharr", + "absolute": true, + "retain_stats": true, + "mix_prob": 0.50, + "in_seg": 0.0, + "out_seg": 0.0, + "mix_in_out": true, + "probability": 0.0 + }, + "GaussianBlurTransform": { + "kernel_type": "GaussianBlur", + "sigma": 1.0, + "retain_stats": false, + "mix_prob": 0.00, + "in_seg": 0.0, + "out_seg": 0.0, + "mix_in_out": true, + "probability": 0.20 + }, + "UnsharpMaskTransform": { + "kernel_type": "UnsharpMask", + "sigma": 1.0, + "unsharp_amount": 1.5, + "retain_stats": false, + "in_seg": 0.0, + "out_seg": 0.0, + "mix_in_out": true, + "mix_prob": 0.50, + "probability": 0.0 + }, + "RandomConvTransform": { + "kernel_type": "RandConv", + "kernel_sizes": [3, 5, 7], + "retain_stats": true, + "in_seg": 0.0, + "out_seg": 0.0, + "mix_in_out": true, + "mix_prob": 0.50, + "probability": 0.0 + }, + "RedistributeSegTransform": { + "in_seg": 0.25, + "retain_stats": true, + "probability": 0.0 + }, + "GaussianNoiseTransform": { + "mean": 0.0, + "std": 1.0, + "in_seg": 0.0, + "out_seg": 0.0, + "mix_in_out": true, + "probability": 0.1 + }, + "ClampTransform": { + "max_clamp_amount": 0.2, + "retain_stats": true, + "in_seg": 0.0, + "out_seg": 0.0, + "mix_in_out": true, + "probability": 0.0 + }, + "BrightnessTransform": { + "brightness_range": [0.75, 1.25], + "in_seg": 0.0, + "out_seg": 0.0, + "mix_in_out": true, + "probability": 0.15 + }, + "GammaTransform": { + "gamma_range": [0.7, 1.5], + "retain_stats": true, + "in_seg": 0.0, + "out_seg": 0.0, + "mix_in_out": true, + "probability": 0.30 + }, + "InvGammaTransform": { + "gamma_range": [0.7, 1.5], + "retain_stats": true, + "in_seg": 0.0, + "out_seg": 0.0, + "mix_in_out": true, + "probability": 0.10 + }, + "ContrastTransform": { + "contrast_range": [0.75, 1.25], + "retain_stats": false, + "in_seg": 0.0, + "out_seg": 0.0, + "mix_in_out": true, + "probability": 0.15 + }, + "FunctionTransform": { + "retain_stats": true, + "in_seg": 0, + "out_seg": 0, + "mix_in_out": false, + "probability": 0.1 + }, + "InverseTransform": { + "retain_stats": true, + "in_seg": 0.0, + "out_seg": 0.0, + "mix_in_out": true, + "mix_prob": 0.00, + "probability": 0.0 + }, + "HistogramEqualizationTransform": { + "retain_stats": true, + "in_seg": 0, + "out_seg": 0, + "mix_in_out": true, + "mix_prob": 0.60, + "probability": 0.0 + }, + "SimulateLowResTransform": { + "scale": [0.5, 1.0], + "crop": [1.0, 1.0], + "same_on_batch": false, + "probability": 0.25 + }, + "AcqTransform": { + "scale": [0.6, 1.0], + "crop": [1.0, 1.0], + "same_on_batch": false, + "probability": 0.00 + }, + "BiasFieldTransform": { + "retain_stats": true, + "coefficients": 0.2, + "in_seg": 0.0, + "out_seg": 0.0, + "mix_in_out": true, + "probability": 0.15 + }, + "FlipTransform": { + "flip_axis": [0], + "same_on_batch": false, + "keepdim": true, + "probability": 0.5 + }, + "AffineTransform": { + "degrees": 5, + "translate": [0.1, 0.1, 0.1], + "scale": [0.7, 1.4], + "shear": [-5, 5, -5, 5, -5, 5], + "resample": "bilinear", + "probability": 0.0 + }, + "nnUNetSpatialTransform": { + "patch_center_dist_from_border": 80, + "random_crop": false, + "p_elastic_deform": 0.0, + "p_rotation": 0.2, + "p_scaling": 0.2, + "scaling": [0.7, 1.4], + "p_synchronize_scaling_across_axes": 1, + "bg_style_seg_sampling": false + }, + "ZscoreNormalizationTransform": { + "probability": 0.00 + } +} diff --git a/auglab/transforms/gpu/fromSeg.py b/auglab/transforms/gpu/fromSeg.py index 83b3b90..dc2c4b4 100644 --- a/auglab/transforms/gpu/fromSeg.py +++ b/auglab/transforms/gpu/fromSeg.py @@ -1,12 +1,88 @@ +import random + import torch +from torch import nn from torch.nn import functional as F -from typing import Any, Dict, Optional, Tuple, Union, List +from typing import Any, Dict, Optional, Tuple, Union, List, Protocol from kornia.core import Tensor +import torch.distributed as dist from auglab.transforms.gpu.base import ImageOnlyTransform +# ── V26_6 / V26_6_2 helpers ────────────────────────────────────────────────── + +def _kmeans_1d(values: torch.Tensor, C: int, n_iter: int = 10) -> torch.Tensor: + """1-D K-means on foreground values. Returns (C,) centroids.""" + centroids = torch.linspace(values.min().item(), values.max().item(), C, device=values.device) + for _ in range(n_iter): + d = torch.abs(values.unsqueeze(1) - centroids.unsqueeze(0)) + lbl = torch.argmin(d, dim=1) + s = torch.zeros(C, device=values.device).scatter_add_(0, lbl, values) + n = torch.zeros(C, device=values.device).scatter_add_(0, lbl, torch.ones_like(values)) + new_c = torch.where(n > 0, s / n, centroids) + if torch.allclose(centroids, new_c): + break + centroids = new_c + return centroids + + +def _gaussian_blur_3d(x: torch.Tensor, sigma: float) -> torch.Tensor: + """Separable 3D Gaussian blur. x: (B, 1, D, H, W).""" + k_r = max(1, int(3.0 * sigma + 0.5)) + k1d = torch.arange(-k_r, k_r + 1, dtype=x.dtype, device=x.device) + k1d = torch.exp(-0.5 * (k1d / sigma) ** 2) + k1d = k1d / k1d.sum() + pad = len(k1d) // 2 + y = F.conv3d(x, k1d.view(1, 1, -1, 1, 1), padding=(pad, 0, 0)) + y = F.conv3d(y, k1d.view(1, 1, 1, -1, 1), padding=(0, pad, 0)) + y = F.conv3d(y, k1d.view(1, 1, 1, 1, -1), padding=(0, 0, pad)) + return y.clamp(0, 1) + + +def _voronoi_region_ids( + coords: torch.Tensor, + lbl_l: torch.Tensor, + fg: torch.Tensor, + C: int, + device: torch.device, + s_choices: List[int], + skip_sub_parc_prob: float, +) -> tuple[torch.Tensor, int]: + """Spatially subdivide each K-means cluster into Voronoi sub-regions. + + Returns (rid, R): per-voxel sub-region id and total region count. + """ + N = lbl_l.shape[0] + rid = torch.zeros(N, dtype=torch.long, device=device) + offset = 0 + for c in range(C): + c_mask = lbl_l == c + c_fg_mask = c_mask & (fg > 0) + n_fg = int(c_fg_mask.sum().item()) + if n_fg == 0: + continue + if n_fg < 2 or torch.rand(1, device=device).item() < skip_sub_parc_prob: + S = 1 + else: + s_idx = int(torch.rand(1, device=device).item() * len(s_choices)) + S = min(s_choices[s_idx], n_fg) + if S <= 1: + rid = torch.where(c_mask, torch.full_like(rid, offset), rid) + offset += 1 + continue + seed_idx = torch.multinomial(c_fg_mask.float(), S, replacement=False) + centroids_v = coords.index_select(0, seed_idx) + d = torch.cdist(coords, centroids_v) + sub = torch.argmin(d, dim=1) + rid = torch.where(c_mask, offset + sub, rid) + offset += S + return rid, offset + + +# ───────────────────────────────────────────────────────────────────────────── + def _normal_pdf(x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: inv = 1.0 / (std + 1e-6) return (inv / (torch.sqrt(torch.tensor(2.0 * 3.141592653589793, device=x.device, dtype=x.dtype)))) * torch.exp( @@ -29,12 +105,16 @@ def __init__( same_on_batch: bool = False, p: float = 1.0, keepdim: bool = True, + std_noise_range: list[float] = [0.1, 0.3], + dilation_iterations_range: list[int] = [1, 3], **kwargs, ) -> None: super().__init__(p=p, same_on_batch=same_on_batch, keepdim=keepdim) self.in_seg = in_seg self.apply_to_channel = apply_to_channel self.retain_stats = retain_stats + self.std_noise_range = std_noise_range + self.dilation_iterations_range = dilation_iterations_range @torch.no_grad() def apply_transform( @@ -96,7 +176,8 @@ def apply_transform( # Vectorized dilation for all regions (3 iterations) dilated = masks.float() - for _ in range(3): + dilation_iterations = torch.randint(self.dilation_iterations_range[0], self.dilation_iterations_range[1]+1, (1,), device=input.device)[0].item() + for _ in range(dilation_iterations): if spatial_dims == 3: dilated = F.max_pool3d(dilated.unsqueeze(0), 3, 1, 1).squeeze(0) else: @@ -125,8 +206,9 @@ def apply_transform( dil_stds = dil_vars.sqrt() # redist_std per region + std_noise_range = torch.rand(1, device=input.device)[0] * (self.std_noise_range[1] - self.std_noise_range[0]) + self.std_noise_range[0] redist_std = torch.maximum( - torch.rand(R, device=input.device) * 0.2 + 0.4 * torch.abs((means - dil_means) * stds / (dil_stds + 1e-6)), + torch.rand(R, device=input.device) * std_noise_range + 0.4 * torch.abs((means - dil_means) * stds / (dil_stds + 1e-6)), torch.full((R,), 0.01, device=input.device, dtype=input.dtype) ) @@ -171,3 +253,522 @@ def apply_transform( input[b, c] = x return input + + + +class RandomV19ContrastGPU(ImageOnlyTransform): + """ + AugLab-compatible GPU augmentation that applies V19 Stochastic Semantic + Decoupling to produce a stochastic guidance map in place of the input. + + Args: + label_classes: Integer label indices to decouple stochastically. + Defaults to BraTS convention [1, 2, 3] (NCR, ED, ET). + Pass your dataset's foreground class indices if they differ. + num_bins: Histogram bins for the internal DifferentiableHistogram3D. + p: Probability of applying the transform (standard Kornia convention). + """ + + def __init__( + self, + label_classes: Optional[List[int]] = None, + num_bins: int = 64, + p: float = 0.5, + **kwargs: Any, + ) -> None: + super().__init__(p=p, **kwargs) + self._hist = DifferentiableHistogram3D(num_bins=num_bins, value_range=(0.0, 1.0)) + self._generator = V19LabelConditionedTextureGenerator(label_classes=label_classes) + + # ------------------------------------------------------------------ + # Core AugLab contract + # ------------------------------------------------------------------ + + def apply_transform( + self, + input: torch.Tensor, + params: Dict[str, Any], + flags: Dict[str, Any], + transform: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + input: Image tensor [B, C, D, H, W], float32, values in [0, 1]. + params: AugLab parameter dict. If a segmentation mask was registered + via AugLab's DataKey.MASK / 'seg' entry, it appears here. + Supported formats: + - One-hot [B, C_seg, D, H, W] with C_seg > 1 + - Integer index [B, 1, D, H, W] (passed through directly) + flags: AugLab flags dict (unused but required by the interface). + + Returns: + guidance_map: [B, C, D, H, W] stochastic synthesis output, same + shape and dtype as input. + """ + seg_raw: Optional[torch.Tensor] = params.get("seg", None) + + labels: Optional[torch.Tensor] = None + if seg_raw is not None and seg_raw.ndim == 5 and seg_raw.shape[1] > 1: + # One-hot [B, C_seg, D, H, W] → integer index [B, 1, D, H, W] + labels = collapse_onehot_to_index(seg_raw) + elif seg_raw is not None and seg_raw.ndim == 5 and seg_raw.shape[1] == 1: + # Already a single-channel integer index mask — use directly. + labels = seg_raw.long() + + # normalize image to [0,1] + data_norm, vmin, vmax = _minmax_norm(input) + + _target_hist, guidance_map, _dup = self._generator( + input_images=data_norm, + hist_module=self._hist, + labels=labels, + ) + # normalize back with z-score normalization + guidance_map = _zscore_renorm(_minmax_denorm(guidance_map, vmin, vmax)) + return guidance_map + + +class RandomV26_6_2ContrastGPU(ImageOnlyTransform): + """ + AugLab GPU augmentation implementing V26_6_2 synthesis. + + Pipeline (mirrors src/synthesis/v26_6_2_synthesis.py, self-contained): + 1. Min-max normalise input to [0, 1] per sample. + 2. V26_6 whole-image synthesis: 1-D K-means intensity parcellation → + Voronoi spatial sub-parcellation → per-region signed-alpha affine remap + y = μ + α·(x − mean_region), optional Gaussian blur. + 3. Per-anatomical-label affine remap (V26_6_2 step): for each foreground + label, with probability `label_remap_prob`, independently remap that + label's voxels with a fresh (μ, α). + 4. Optional second Gaussian blur, then foreground z-score. + + Segmentation is read from params['seg'] (injected by AugLab's pipeline). + Supported formats: one-hot [B, C_seg, D, H, W] or index [B, 1, D, H, W]. + + Args: + c_choices: candidate K-means cluster counts. + s_choices: candidate Voronoi sub-region counts per cluster. + blur_sigmas: Gaussian blur sigma options (first pass and second pass each + sample independently; weighted toward 0 = no blur). + dark_threshold: voxels below this are treated as background. + n_kmeans_subsample: max foreground voxels used to fit K-means. + skip_parcellation_prob: probability of single global remap (no K-means). + skip_sub_parc_prob: per-cluster probability of skipping Voronoi sub-split. + alpha_magnitude_range: [min, max] for |alpha| in every affine remap. + label_remap_prob: per-label per-sample probability of applying label remap. + min_label_voxels: minimum voxels in a label to attempt the remap. + label_classes: if set, restrict label remap to these class indices + (None = all foreground classes present in the batch). + p: probability of applying the transform. + """ + + def __init__( + self, + c_choices: List[int] = [2, 3, 4, 5, 6], + s_choices: List[int] = [2, 3, 4, 5, 6, 7, 8, 9, 10], + blur_sigmas: List[float] = [0.0, 0.0, 0.0, 0.3, 0.5, 0.8], + dark_threshold: float = 0.01, + n_kmeans_subsample: int = 10_000, + skip_parcellation_prob: float = 0.10, + skip_sub_parc_prob: float = 0.40, + alpha_magnitude_range: List[float] = [0.5, 2.0], + label_remap_prob: float = 0.5, + min_label_voxels: int = 4, + label_classes: Optional[List[int]] = None, + p: float = 1.0, + **kwargs: Any, + ) -> None: + super().__init__(p=p, **kwargs) + self.c_choices = c_choices + self.s_choices = s_choices + self.blur_sigmas = blur_sigmas + self.dark_threshold = dark_threshold + self.n_kmeans_subsample = n_kmeans_subsample + self.skip_parcellation_prob = skip_parcellation_prob + self.skip_sub_parc_prob = skip_sub_parc_prob + self.alpha_magnitude_range = alpha_magnitude_range + self.label_remap_prob = label_remap_prob + self.min_label_voxels = min_label_voxels + self.label_classes = label_classes + + @torch.no_grad() + def apply_transform( + self, + input: Tensor, + params: Dict[str, Any], + flags: Dict[str, Any], + transform: Optional[Tensor] = None, + ) -> Tensor: + seg_raw: Optional[torch.Tensor] = params.get("seg", None) + + labels: Optional[torch.Tensor] = None + if seg_raw is not None and seg_raw.ndim == 5 and seg_raw.shape[1] > 1: + labels = collapse_onehot_to_index(seg_raw) + elif seg_raw is not None and seg_raw.ndim == 5 and seg_raw.shape[1] == 1: + labels = seg_raw.long() + + B, _C, D, H, W = input.shape + N = D * H * W + device = input.device + eps = 1e-7 + alpha_lo, alpha_hi = self.alpha_magnitude_range + + # Min-max normalise channel-0 to [0, 1] per sample + flat_all = input[:, 0].float().reshape(B, N) + v_min = flat_all.min(dim=1).values.view(B, 1) + v_max = flat_all.max(dim=1).values.view(B, 1) + images_01 = ((flat_all - v_min) / (v_max - v_min + eps)).clamp(0, 1) + + flat_m_all = (images_01 > self.dark_threshold).float() # foreground mask + + # Voxel coordinates (shared — same spatial dims for every sample) + coords = torch.stack(torch.meshgrid( + torch.arange(D, device=device, dtype=torch.float32), + torch.arange(H, device=device, dtype=torch.float32), + torch.arange(W, device=device, dtype=torch.float32), + indexing="ij"), dim=-1).reshape(N, 3) + + # ── Step 1: V26_6 K-means + Voronoi per-region affine remap ────────── + synth_list = [] + for i in range(B): + flat = images_01[i] + flat_m = flat_m_all[i] + n_fg = flat_m.sum() + + if n_fg < 4 or torch.rand(1, device=device).item() < self.skip_parcellation_prob: + # Single global remap + b_mean_i = (flat * flat_m).sum() / n_fg.clamp(min=1) + mu = torch.rand(1, device=device).item() + sign = (torch.rand(1, device=device) > 0.5).float() * 2 - 1 + mag = torch.rand(1, device=device) * (alpha_hi - alpha_lo) + alpha_lo + alpha = (sign * mag).item() + synth_i = (mu + alpha * (flat - b_mean_i)).clamp(0, 1) * flat_m + else: + C_k = self.c_choices[int(torch.rand(1, device=device).item() * len(self.c_choices))] + idx = torch.randint(0, N, (min(N, 40_000),), device=device) + samp = flat[idx] + sub_fg = samp[samp > self.dark_threshold][:self.n_kmeans_subsample] + if sub_fg.numel() < 4: + sub_fg = samp[:self.n_kmeans_subsample] + + centroids = _kmeans_1d(sub_fg, C_k) + sorted_c, sort_idx = torch.sort(centroids) + boundaries = (sorted_c[:-1] + sorted_c[1:]) / 2.0 + lbl_s = torch.bucketize(flat, boundaries) + lbl_l = sort_idx[lbl_s].long() + + rid, R = _voronoi_region_ids( + coords, lbl_l, flat_m, C_k, device, + self.s_choices, self.skip_sub_parc_prob, + ) + + s_c = torch.zeros(R, device=device).scatter_add_(0, rid, flat * flat_m) + n_c = torch.zeros(R, device=device).scatter_add_(0, rid, flat_m) + mean_c = s_c / n_c.clamp(min=eps) + + mu_c = torch.rand(R, device=device) + mag_c = torch.rand(R, device=device) * (alpha_hi - alpha_lo) + alpha_lo + sign_c = (torch.rand(R, device=device) > 0.5).float() * 2 - 1 + alp_c = mag_c * sign_c + + synth_i = (mu_c[rid] + alp_c[rid] * (flat - mean_c[rid])).clamp(0, 1) * flat_m + + synth_list.append(synth_i) + + synth = torch.stack(synth_list) # (B, N) + synth_01 = synth.reshape(B, 1, D, H, W) + + sigma = random.choice(self.blur_sigmas) + if sigma > 0.0: + synth_01 = _gaussian_blur_3d(synth_01, sigma) + synth = synth_01.reshape(B, N) + + # ── Step 2: per-anatomical-label affine remap (V26_6_2) ─────────────── + if labels is not None: + if labels.shape[2:] != (D, H, W): + labels = F.interpolate(labels.float(), size=(D, H, W), mode="nearest").long() + lbl = labels[:, 0].reshape(B, N).clamp(min=0) + + unique_classes = lbl.unique() + unique_classes = unique_classes[unique_classes > 0] + if self.label_classes is not None: + keep = torch.tensor(self.label_classes, device=device) + unique_classes = unique_classes[torch.isin(unique_classes, keep)] + + for c in unique_classes: + c_val = int(c.item()) + c_mask = (lbl == c_val).float() # (B, N) + c_cnt = c_mask.sum(dim=1, keepdim=True) # (B, 1) + + apply = ( + (torch.rand(B, 1, device=device) < self.label_remap_prob) + & (c_cnt >= self.min_label_voxels) + ).float() + + if apply.sum() == 0: + continue + + c_mean = (synth * c_mask).sum(dim=1, keepdim=True) / c_cnt.clamp(min=1) + + mu_c = torch.rand(B, 1, device=device) + mag_c = torch.rand(B, 1, device=device) * (alpha_hi - alpha_lo) + alpha_lo + sign_c = (torch.rand(B, 1, device=device) > 0.5).float() * 2 - 1 + alp_c = mag_c * sign_c + + new_vals = (mu_c + alp_c * (synth - c_mean)).clamp(0, 1) + write_mask = c_mask * apply + synth = synth * (1.0 - write_mask) + new_vals * write_mask + + # ── Step 3: optional second blur, then foreground z-score ───────────── + synth_01 = synth.reshape(B, 1, D, H, W) + sigma2 = random.choice(self.blur_sigmas) + if sigma2 > 0.0: + synth_01 = _gaussian_blur_3d(synth_01, sigma2) + synth = synth_01.reshape(B, N) + + b_sum = (synth * flat_m_all).sum(dim=1, keepdim=True) + b_cnt = flat_m_all.sum(dim=1, keepdim=True).clamp(min=1) + b_mean = b_sum / b_cnt + b_sq = ((synth - b_mean) * flat_m_all).pow(2).sum(dim=1, keepdim=True) + b_std = (b_sq / b_cnt + eps).sqrt() + synth_z = ((synth - b_mean) / b_std * flat_m_all).reshape(B, 1, D, H, W) + + out = input.clone() + out[:, 0:1] = synth_z.to(input.dtype) + return out + + +_SHARED_RNG_COUNTER = 0 + + +def _next_shared_seed() -> int: + global _SHARED_RNG_COUNTER + _SHARED_RNG_COUNTER += 1 + seed = (int(torch.initial_seed()) + _SHARED_RNG_COUNTER) % (2**63 - 1) + if dist.is_available() and dist.is_initialized(): + seed_tensor = torch.tensor([seed], dtype=torch.long) + dist.broadcast(seed_tensor, src=0) + seed = int(seed_tensor.item()) + return seed + + +@staticmethod +def _minmax_norm(x: torch.Tensor, eps: float = 1e-8 + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Per-sample min-max normalise to [0, 1]. Returns (normed, min, max).""" + B = x.shape[0] + x_flat = x.view(B, -1) + vmin = x_flat.min(dim=1).values.view(B, 1, 1, 1, 1) + vmax = x_flat.max(dim=1).values.view(B, 1, 1, 1, 1) + return (x - vmin) / (vmax - vmin + eps), vmin, vmax + +@staticmethod +def _minmax_denorm(x_norm: torch.Tensor, vmin: torch.Tensor, + vmax: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: + return x_norm * (vmax - vmin + eps) + vmin + +@staticmethod +def _zscore_renorm(x: torch.Tensor, bg_threshold: float = 1e-6) -> torch.Tensor: + """Per-sample foreground-masked z-score. Mirrors nnUNet's use_mask_for_norm=True. + + Background voxels (abs ≈ 0, zeroed by nnUNet masking) stay at 0. + Eliminates the train/inference distribution mismatch that would occur + because nnUNet always z-scores at inference time. + """ + fg = x.abs() > bg_threshold + fg_f = fg.float() + n = fg_f.sum(dim=(2, 3, 4), keepdim=True).clamp(min=1) + mean = (x * fg_f).sum(dim=(2, 3, 4), keepdim=True) / n + var = ((x - mean).pow(2) * fg_f).sum(dim=(2, 3, 4), keepdim=True) / n + std = var.sqrt().clamp(min=1e-8) + return torch.where(fg, (x - mean) / std, torch.zeros_like(x)) + +def _shared_cpu_generator() -> torch.Generator: + generator = torch.Generator(device="cpu") + generator.manual_seed(_next_shared_seed()) + return generator + +def _shared_rand(shape: tuple[int, ...], device: torch.device, dtype: torch.dtype) -> torch.Tensor: + if not (dist.is_available() and dist.is_initialized()): + return torch.rand(shape, device=device, dtype=dtype) + rand_cpu = torch.rand(shape, generator=_shared_cpu_generator(), device="cpu", dtype=dtype) + return rand_cpu.to(device=device, dtype=dtype) + + +def collapse_onehot_to_index(seg_raw: torch.Tensor) -> torch.Tensor: + """ + Convert a one-hot segmentation mask to a single-channel integer index mask. + + Args: + seg_raw: One-hot tensor [B, C_seg, D, H, W], bool or float. + Channel 0 is assumed to be *absent* (background is implicit). + Each foreground channel c encodes class index (c + 1). + + Returns: + labels: Integer index tensor [B, 1, D, H, W]. + Background voxels (all-zero across channels) map to 0. + Foreground voxels map to argmax(seg_raw, dim=1) + 1. + """ + foreground_mask = seg_raw.any(dim=1, keepdim=True) # [B,1,D,H,W] bool + labels = torch.argmax(seg_raw, dim=1, keepdim=True).long() + 1 # 0-based → 1-based + labels = torch.where(foreground_mask, labels, torch.zeros_like(labels)) + return labels + +class HistogramModuleLike(Protocol): + num_bins: int + min_value: float + max_value: float + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + ... + +class BaseTargetGenerator(nn.Module): + """Strategy interface for guidance and target histogram generation.""" + + def forward( + self, + input_images: torch.Tensor, + num_bins: int, + num_chunks: int, + dark_threshold: float, + hist_module: HistogramModuleLike, + return_guidance_map: bool = True, + labels: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + raise NotImplementedError + + +class V19LabelConditionedTextureGenerator(BaseTargetGenerator): + """ + V19 Stochastic Semantic Decoupling: Merges geometric label-priors with + texture-preserving latent space. + """ + + def __init__(self, label_classes: Optional[List[int]] = None): + super().__init__() + self.label_classes: List[int] = label_classes if label_classes is not None else [1, 2, 3] + + def __call__( + self, + input_images: torch.Tensor, + hist_module: nn.Module, + labels: Optional[torch.Tensor] = None, + **kwargs + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + images = input_images + B, C, D, H, W = images.shape + device = images.device + dtype = images.dtype + + # Step A: Base v18_6 Background Synthesis + mask = images > 0.01 + + y = images.clone() + + with torch.autocast(device_type="cuda" if device.type == "cuda" else "cpu", enabled=False): + images_f = images.float() + + mu_base = _shared_rand((B, 8), device=device, dtype=torch.float32) + alpha_base = _shared_rand((B, 8), device=device, dtype=torch.float32) * 1.5 + 0.5 + + q_edges = torch.linspace(0, 1, 9, device=device) + + c_i = torch.bucketize(images_f, q_edges) - 1 + c_i = torch.clamp(c_i, 0, 7) + + mu_c = mu_base.view(B, 8, 1, 1, 1).expand(B, 8, D, H, W).gather(1, c_i) + alpha_c = alpha_base.view(B, 8, 1, 1, 1).expand(B, 8, D, H, W).gather(1, c_i) + # v19_c bias fix: center the affine shift on the chunk midpoint rather than + # the lower edge, so alpha × offset is symmetric around 0 within each chunk. + q_c_lower = q_edges[:-1].view(1, 8, 1, 1, 1).expand(B, 8, D, H, W).gather(1, c_i) + q_c_upper = q_edges[1:].view(1, 8, 1, 1, 1).expand(B, 8, D, H, W).gather(1, c_i) + q_c_center = (q_c_lower + q_c_upper) * 0.5 + + y_base = mu_c + alpha_c * (images_f - q_c_center) + y = torch.where(mask, y_base.to(dtype), y) + + # Step B: Stochastic Semantic Decoupling + if labels is not None: + with torch.autocast(device_type="cuda" if device.type == "cuda" else "cpu", enabled=False): + if labels.dim() == 4: + labels = labels.unsqueeze(1) + if any(dim <= 0 for dim in labels.shape[2:]): + labels = None + else: + if labels.shape[2:] != images.shape[2:]: + labels = F.interpolate(labels.float(), size=images.shape[2:], mode="nearest") + labels = labels.to(device=device) + y_f = y.float() + images_f = images.float() + + for c in self.label_classes: + decouple = _shared_rand((B, 1, 1, 1, 1), device=device, dtype=torch.float32) > 0.5 + + mu_path = _shared_rand((B, 1, 1, 1, 1), device=device, dtype=torch.float32) + alpha_path = _shared_rand((B, 1, 1, 1, 1), device=device, dtype=torch.float32) * 1.5 + 0.5 + + class_mask = (labels == c) + + class_sum = (images_f * class_mask).sum(dim=(1, 2, 3, 4), keepdim=True) + class_count = class_mask.sum(dim=(1, 2, 3, 4), keepdim=True) + class_count_safe = torch.clamp(class_count, min=1.0) + mean_c = class_sum / class_count_safe + + y_override = mu_path + alpha_path * (images_f - mean_c) + + valid_override = class_mask & decouple & (class_count > 0) + y_f = torch.where(valid_override, y_override, y_f) + + y = y_f.to(dtype) + + # Step C: Masking & Clamping + y = torch.clamp(y, 0.0, 1.0) + y = torch.where(mask, y, torch.zeros_like(y)) + + target_hist = hist_module(y) + return target_hist, y, y + + +class DifferentiableHistogram3D(nn.Module): + """Differentiable soft histogram for 3D volumes returning RAW VOXEL COUNTS.""" + + def __init__(self, num_bins: int = 64, value_range: tuple[float, float] = (0.0, 1.0), eps: float = 1e-8): + super().__init__() + self.num_bins = num_bins + self.min_value = float(value_range[0]) + self.max_value = float(value_range[1]) + self.eps = eps + + bin_centers = torch.linspace(self.min_value, self.max_value, num_bins) + self.register_buffer("bin_centers", bin_centers.view(1, 1, num_bins, 1), persistent=False) + self.bin_width = (self.max_value - self.min_value) / max(num_bins - 1, 1) + + def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: + if x.ndim != 5: + raise ValueError(f"Expected a 5D tensor (B, C, D, H, W), got shape {tuple(x.shape)}") + + b, c, *_ = x.shape + flat_x = x.reshape(b, c, -1) + + scaled = (flat_x - self.min_value) / (self.bin_width + self.eps) + left_idx = torch.floor(scaled).to(torch.long) + right_idx = left_idx + 1 + + wl = (right_idx.to(flat_x.dtype) - scaled).clamp(0.0, 1.0) + wr = (scaled - left_idx.to(flat_x.dtype)).clamp(0.0, 1.0) + + left_idx = left_idx.clamp(0, self.num_bins - 1) + right_idx = right_idx.clamp(0, self.num_bins - 1) + + if mask is not None: + if mask.shape != x.shape: + raise ValueError("Mask shape must match the input tensor shape.") + flat_mask = mask.reshape(b, c, -1).to(dtype=flat_x.dtype) + wl = wl * flat_mask + wr = wr * flat_mask + + hist = torch.zeros((b, c, self.num_bins), device=x.device, dtype=x.dtype) + hist.scatter_add_(2, left_idx, wl) + hist.scatter_add_(2, right_idx, wr) + + return hist diff --git a/auglab/transforms/gpu/transforms.py b/auglab/transforms/gpu/transforms.py index f471370..7cb7ce2 100644 --- a/auglab/transforms/gpu/transforms.py +++ b/auglab/transforms/gpu/transforms.py @@ -11,7 +11,9 @@ from auglab.transforms.gpu.contrast import RandomConvTransformGPU, RandomGaussianNoiseGPU, RandomBrightnessGPU, RandomGammaGPU, RandomFunctionGPU, \ RandomHistogramEqualizationGPU, RandomInverseGPU, RandomBiasFieldGPU, RandomContrastGPU, ZscoreNormalizationGPU, RandomClampGPU from auglab.transforms.gpu.spatial import RandomAffine3DCustom, RandomLowResTransformGPU, RandomFlipTransformGPU, RandomAcqTransformGPU, RandomCropTransformGPU -from auglab.transforms.gpu.fromSeg import RandomRedistributeSegGPU +from auglab.transforms.gpu.fromSeg import RandomRedistributeSegGPU, RandomV19ContrastGPU, RandomV26_6_2ContrastGPU +from auglab.transforms.gpu.domain_transfer import RandomDomainTransferGPU +from auglab.transforms.synthseg.transforms import RandomSynthSegGPU from auglab.transforms.gpu.base import AugmentationSequentialCustom class AugTransformsGPU(AugmentationSequentialCustom): @@ -57,7 +59,80 @@ def _build_transforms(self) -> list[nn.Module]: p=affine_params.get('probability', 0) )) + # SynthSeg generative augmentation: replace the image with a GMM synthesis + # of the segmentation (intensity-only here, so the mask stays consistent; + # geometric transforms above deform the labels first). All SynthSeg + # generator parameters are read straight from the config block. + synthseg_params = self.transform_params.get('SynthSeg') + if synthseg_params is not None: + synthseg_kwargs = {k: v for k, v in synthseg_params.items() if k != 'probability'} + transforms.append(RandomSynthSegGPU( + p=synthseg_params.get('probability', 1.0), + **synthseg_kwargs, + )) + ## Transfer augmentations (TA) + ######################### + # Replace Image with V19Contrast (Paul) + img_contrast_params = self.transform_params.get('ImageContrastGPUTransform') + if img_contrast_params is not None: + transforms.append( + RandomV19ContrastGPU( + p=img_contrast_params.get("probability", 0), + label_classes=img_contrast_params.get("label_classes", [1]), + num_bins=img_contrast_params.get("num_bins", 32), + ) + ) + + # Replace image with V26_6_2 contrast (K-means + Voronoi + per-label remap) + v26_6_2_params = self.transform_params.get('ImageContrastV26_6_2GPUTransform') + if v26_6_2_params is not None: + transforms.append( + RandomV26_6_2ContrastGPU( + p=v26_6_2_params.get("probability", 1.0), + c_choices=v26_6_2_params.get("c_choices", [2, 3, 4, 5, 6]), + s_choices=v26_6_2_params.get("s_choices", [2, 3, 4, 5, 6, 7, 8, 9, 10]), + blur_sigmas=v26_6_2_params.get("blur_sigmas", [0.0, 0.0, 0.0, 0.3, 0.5, 0.8]), + dark_threshold=v26_6_2_params.get("dark_threshold", 0.01), + n_kmeans_subsample=v26_6_2_params.get("n_kmeans_subsample", 10000), + skip_parcellation_prob=v26_6_2_params.get("skip_parcellation_prob", 0.10), + skip_sub_parc_prob=v26_6_2_params.get("skip_sub_parc_prob", 0.40), + alpha_magnitude_range=v26_6_2_params.get("alpha_magnitude_range", [0.5, 2.0]), + label_remap_prob=v26_6_2_params.get("label_remap_prob", 0.5), + min_label_voxels=v26_6_2_params.get("min_label_voxels", 4), + label_classes=v26_6_2_params.get("label_classes", None), + ) + ) + + # Domain transfer: randomly re-render the image as another sequence/cluster (TA) + # Accept either the class-name key or the descriptive key. + domain_params = self.transform_params.get('RandomDomainTransferGPU') \ + or self.transform_params.get('DomainTransferTransform') + if domain_params is not None: + transforms.append( + RandomDomainTransferGPU( + bank_path=domain_params.get("bank_path", None), + source_label=domain_params["source_label"], + targets=domain_params.get("targets", None), + include_self=domain_params.get("include_self", False), + any_source=domain_params.get("any_source", False), + sigma=domain_params.get("sigma", 2.0), + apply_to_channel=domain_params.get("apply_to_channel", [0]), + zscore_io=domain_params.get("zscore_io", "auto"), + pct=domain_params.get("pct", 1.0), + blend_targets=domain_params.get("blend_targets", 1), + blend_concentration=domain_params.get("blend_concentration", 1.0), + p_class_mix=domain_params.get("p_class_mix", 0.0), + bias_field_std=domain_params.get("bias_field_std", 0.0), + bias_scale=domain_params.get("bias_scale", 0.03), + p_spatial_mix=domain_params.get("p_spatial_mix", 0.0), + spatial_mix_scale=domain_params.get("spatial_mix_scale", 0.03), + spatial_mix_gain=domain_params.get("spatial_mix_gain", 3.0), + p=domain_params.get("probability", 0.0), + same_on_batch=domain_params.get("same_on_batch", False), + ) + ) + # Inverse transform (max - pixel_value) inverse_params = self.transform_params.get('InverseTransform') if inverse_params is not None: @@ -93,6 +168,8 @@ def _build_transforms(self) -> list[nn.Module]: in_seg=redistribute_params.get('in_seg', 0.2), retain_stats=redistribute_params.get('retain_stats', False), p=redistribute_params.get('probability', 0), + std_noise_range=redistribute_params.get('std_noise_range', [0.1, 0.3]), + dilation_iterations_range=redistribute_params.get('dilation_iterations_range', [1, 3]), )) # Scharr filter @@ -395,22 +472,22 @@ def pad_numpy_array(arr, shape): augmentor = AugTransformsGPU(json_path) # Load images and masks tensors - img_path = '/home/GRAMES.POLYMTL.CA/p118739/data_nvme_p118739/data/datasets/data-multi-subject/sub-amu02/anat/sub-amu02_T1w.nii.gz' + img_path = '/home/ge.polymtl.ca/p118739/data/datasets/data-multi-subject/sub-amu02/anat/sub-amu02_T1w.nii.gz' img = Image(img_path).change_orientation('RSP') img = resample_nib(img, new_size=[1,1,1], new_size_type='mm', interpolation='linear') img_tensor = torch.from_numpy(img.data.copy()).to(torch.float32) - seg_path = '/home/GRAMES.POLYMTL.CA/p118739/data_nvme_p118739/data/datasets/data-multi-subject/derivatives/labels/sub-amu02/anat/sub-amu02_T1w_label-spine_dseg.nii.gz' + seg_path = '/home/ge.polymtl.ca/p118739/data/datasets/data-multi-subject/derivatives/labels/sub-amu02/anat/sub-amu02_T1w_label-spine_dseg.nii.gz' seg = Image(seg_path).change_orientation('RSP') seg = resample_nib(seg, new_size=[1,1,1], new_size_type='mm', interpolation='nn') seg_tensor_all = torch.from_numpy(seg.data.copy()) - img2_path = '/home/GRAMES.POLYMTL.CA/p118739/data_nvme_p118739/data/datasets/spider-challenge-2023/sub-002/anat/sub-002_acq-lowresSag_T2w.nii.gz' + img2_path = '/home/ge.polymtl.ca/p118739/data/datasets/spider-challenge-2023/sub-002/anat/sub-002_acq-lowresSag_T2w.nii.gz' img2 = Image(img2_path).change_orientation('RSP') img2 = resample_nib(img2, new_size=[1,1,1], new_size_type='mm', interpolation='linear') img2_tensor = torch.from_numpy(img2.data.copy()).to(torch.float32) - seg2_path = '/home/GRAMES.POLYMTL.CA/p118739/data_nvme_p118739/data/datasets/spider-challenge-2023/derivatives/labels/sub-002/anat/sub-002_acq-lowresSag_T2w_label-spine_dseg.nii.gz' + seg2_path = '/home/ge.polymtl.ca/p118739/data/datasets/spider-challenge-2023/derivatives/labels/sub-002/anat/sub-002_acq-lowresSag_T2w_label-spine_dseg.nii.gz' seg2 = Image(seg2_path).change_orientation('RSP') seg2 = resample_nib(seg2, new_size=[1,1,1], new_size_type='mm', interpolation='nn') seg2_tensor_all = torch.from_numpy(seg2.data.copy()) diff --git a/auglab/transforms/synthseg/README.md b/auglab/transforms/synthseg/README.md new file mode 100644 index 0000000..0d56058 --- /dev/null +++ b/auglab/transforms/synthseg/README.md @@ -0,0 +1,204 @@ +# SynthSeg generative augmentation + +A faithful [torch] re-implementation of the **SynthSeg** "brain generator" as an +AugLab augmentation. Unlike every other transform in AugLab — which perturbs a +*real* image — SynthSeg **ignores the input image entirely and synthesises a new +image from a label map**, using domain randomisation (a per-label Gaussian +mixture model plus random spatial, bias, intensity and resolution corruptions). +A network trained on these synthetic images becomes agnostic to MRI contrast and +resolution. + +Because the method is fundamentally different from the intensity/geometry +transforms in `gpu/` and `cpu/`, it lives in its own package. + +References: +- B. Billot et al., *SynthSeg: Segmentation of brain MRI scans of any contrast + and resolution without retraining*, Medical Image Analysis, 2023. +- B. Billot et al., *A Learning Strategy for Contrast-agnostic MRI Segmentation* + (MICCAI 2020) and *Partial Volume Segmentation of Brain MRI Scans of any + Resolution and Contrast* (MedIA 2021). +- Reference code: [`BBillot/SynthSeg`](https://github.com/BBillot/SynthSeg), + [`BBillot/lab2im`](https://github.com/BBillot/lab2im). + +## Pipeline + +`SynthSegGenerator` reproduces the exact order of `labels_to_image_model`: + +``` +label map + └─ spatial deformation (affine + diffeomorphic SVF, on labels, nearest) + └─ [optional random crop to output_shape] + └─ left/right flip with anatomical label swap + └─ GMM intensity sampling image[v] = mean[label_v] + std[label_v]·N(0,1) + └─ bias field × exp(smooth Gaussian field) + └─ intensity augmentation clip[0,300] → min-max to [0,1] → image^exp(N(0,γ)) + └─ resolution randomisation blur → subsample (nearest) → resample (linear), per channel + └─ map generation labels → output labels +=> (synthetic image, deformed label map) +``` + +Each step is a small function in [`functional.py`](functional.py), each citing +the corresponding `lab2im`/`SynthSeg` layer. + +## Default hyper-parameters + +Defaults match `BrainGenerator.__init__` (which overrides several +`labels_to_image_model` signature defaults). Notable, easy-to-miss values: + +| Parameter | Default | Notes | +|---|---|---| +| `prior_means` / `prior_stds` | `None` → `U(0, 250)` / `U(0, 30)` | full domain randomisation. The often-quoted `[25,225]`/`[5,25]` are a stale docstring; the code uses `centre=125,range=125` and `centre=15,range=15`. | +| `scaling_bounds` | `0.2` | per-axis `U(0.8, 1.2)` | +| `rotation_bounds` | `15` | per-axis `U(-15°, 15°)` | +| `shearing_bounds` | `0.012` | per off-diagonal `U(-0.012, 0.012)` | +| `translation_bounds` | `false` | off | +| `nonlin_std` / `nonlin_scale` | `4.0` / `0.04` | SVF std `~U(0,4)`; coarse grid `ceil(shape·0.04)` | +| `svf_integration_steps` | `7` | scaling-and-squaring (`VecInt`, `ss`) | +| `bias_field_std` / `bias_scale` | `0.7` / `0.025` | std `~U(0,0.7)`; coarse grid `ceil(shape·0.025)` | +| `gamma_std` / `clip` | `0.5` / `300` | hard-coded in SynthSeg's `IntensityAugmentation` | +| `randomise_res` | `true` | per-channel random acquisition resolution | +| `max_res_iso` / `max_res_aniso` | `4.0` / `8.0` | mm ceilings | +| `blur_range` | `1.03` | sigma jitter `U(1/1.03, 1.03)` (`1.15` in the 2020 model) | +| `atlas_res` | `1.0` | native resolution of the input label map (mm) | + +## Implementation notes / deviations + +- **3D only** (5D `(B, C, D, H, W)` tensors), matching AugLab's GPU transforms. +- Affine transforms are applied **about the volume centre** (like AugLab's + `RandomAffine3DCustom`), rather than the corner-origin used by neuron's + `affine_to_shift`. This keeps the anatomy in frame and is the standard choice; + the visual augmentation is equivalent. +- The SVF is integrated at full resolution after upsampling the coarse velocity + field (lab2im integrates at half resolution then upsamples — equivalent in + effect for a smooth field, simpler and less error-prone here). +- The background special-casing in GMM parameter sampling (5% black / 25% + dark-low-variance / 70% normal) is reproduced. The internal 0.95 "apply" prob + of the bias-field layer is folded into the transform-level probability. + +## ⚠️ Label maps must be dense for realistic synthesis + +SynthSeg's realism comes from a **dense anatomical label map** (e.g. a +FreeSurfer/SAMSEG segmentation covering every tissue: WM, GM, CSF, ventricles, +sub-cortical structures, extra-cerebral tissue, ...). If you feed it a *sparse* +target segmentation (a few foreground structures over a `0` background — as is +common for spinal-cord/lesion tasks), it still runs, but without the EM +completion below the synthetic image will only contain those structures over a +single-Gaussian background. + +### EM label completion for sparse maps (`em_label_completion`) + +This is exactly the situation the SynthSeg paper addresses in §5.4: + +> *"we enhance the training segmentations by subdividing all their labels +> (background and foreground) into finer subregions. This is achieved by +> clustering the intensities of the associated image with the Expectation +> Maximisation algorithm."* + +Enable `em_label_completion=True` and the generator will, **using the paired real +image** that is already available on-the-fly (the `data` / `input` tensor): + +- split every **foreground** label into `em_n_foreground_clusters` subregions (2 in the paper); +- split the **background** label into a random `N ∈ em_background_clusters_range` subregions ([3, 10] in the paper); +- give each subregion its own generation Gaussian (so the formerly single-Gaussian + background becomes an intensity-coherent patchwork — realistic extra-cerebral / unlabelled tissue); +- **merge the subregions back** to their parent labels for the output segmentation, + so the training target is unchanged. + +The EM fit per region is sub-sampled to `em_max_fit_voxels` voxels for speed (the +full region is still assigned). Unlike the paper — which precomputes these maps +offline — this runs on the fly from the real image, so no preprocessing is needed. + +```python +gen = SynthSegGenerator(em_label_completion=True).to("cuda") +image, label = gen(sparse_label_map, image=real_image) # real_image drives the EM clustering +# or via the driver / config: set "em_label_completion": true and call synth(data, target) +``` + +> The EM path needs the real image: `SynthSegTransformsGPU(data, target)` and +> `RandomSynthSegGPU` (which reads the image being augmented) both supply it +> automatically; the bare `SynthSegGenerator.forward` needs `image=...`. + +## Usage + +### 1. Full end-to-end generator (faithful SynthSeg) + +`SynthSegTransformsGPU` mirrors `AugTransformsGPU`: build from JSON, move to the +device, and call `transforms(data, target) → (image, target)`. The `data` tensor +is ignored; `target` is the label map. + +```python +import importlib, torch +import auglab.configs as configs +from auglab.transforms.synthseg import SynthSegTransformsGPU + +cfg = importlib.resources.files(configs) / "synthseg_params.json" +synth = SynthSegTransformsGPU(json_path=str(cfg)).to("cuda") + +# data: (B, 1, D, H, W) image (ignored), target: (B, 1, D, H, W) label map +image, label = synth(data, target) # image is fully synthetic, label is deformed/aligned +``` + +Or directly with the module API: + +```python +from auglab.transforms.synthseg import SynthSegGenerator +gen = SynthSegGenerator(generation_labels=[0, 2, 3, 41, 42, ...], + n_neutral_labels=1, n_channels=1).to("cuda") +image, label = gen(label_map) # label_map: (B, 1, D, H, W) +``` + +### 2. As an `ImageOnlyTransform` in an existing GPU pipeline + +`RandomSynthSegGPU` replaces the image with a GMM synthesis of `params['seg']` +(intensity-only: GMM → bias → intensity → resolution). Put AugLab's geometric +transforms *before* it so the mask is deformed first and SynthSeg generates from +the deformed labels: + +```python +from auglab.transforms.gpu.base import AugmentationSequentialCustom +from auglab.transforms.gpu.spatial import RandomAffine3DCustom +from auglab.transforms.synthseg import RandomSynthSegGPU + +aug = AugmentationSequentialCustom( + RandomAffine3DCustom(degrees=15, scale=[0.8, 1.2], p=1.0), + RandomSynthSegGPU(generation_labels=None, n_channels=1, + bias_field_std=0.7, gamma_std=0.5, randomise_res=True, p=1.0), + data_keys=["input", "mask"], same_on_batch=True, +) +image, seg = aug(image, seg) +``` + +> **Note (kornia 0.7.4 quirk):** when an `AugmentationSequentialCustom` is called +> with more than one data key, kornia detaches the returned **input image** to the +> CPU (the segmentation/mask stays on the GPU). This affects *every* AugLab GPU +> transform identically, not just SynthSeg — re-`.to(device)` the returned image +> if you need it back on the GPU. The full-pipeline `SynthSegTransformsGPU` driver +> (section 1) does **not** go through kornia's sequential and is unaffected. + +### 3. Via a `SynthSeg` key in an `AugTransformsGPU` config + +`AugTransformsGPU` recognises a top-level `"SynthSeg"` config block and appends a +`RandomSynthSegGPU` built from it (mapping `"probability"` → `p`). This lets any +harness that already constructs `AugTransformsGPU(json_path)` and calls +`augmentor(img, seg)` use SynthSeg with no code changes: + +```jsonc +{ "SynthSeg": { "probability": 1.0, "n_channels": 1, "bias_field_std": 0.7, + "gamma_std": 0.5, "randomise_res": true, "em_label_completion": false } } +``` + +Because this goes through an `ImageOnlyTransform`, it is **intensity-only**: the +image is synthesised and the segmentation is returned unchanged. The spatial keys +in the block (`flipping`, `scaling_bounds`, `rotation_bounds`, `nonlin_std`, ...) +are therefore inert in this path — add an `AffineTransform`/`FlipTransform` block +(which run *before* SynthSeg) for geometry, or use `SynthSegTransformsGPU` +(section 1) for the full pipeline with a deformed label map. + +## Smoke test + +Both modules are runnable and self-contained (no data files, CPU-friendly): + +```bash +python -m auglab.transforms.synthseg.generator +python -m auglab.transforms.synthseg.transforms +``` diff --git a/auglab/transforms/synthseg/__init__.py b/auglab/transforms/synthseg/__init__.py new file mode 100644 index 0000000..f947b0b --- /dev/null +++ b/auglab/transforms/synthseg/__init__.py @@ -0,0 +1,27 @@ +"""SynthSeg generative augmentation for AugLab. + +A faithful torch re-implementation of the SynthSeg "brain generator" +(Billot et al., Medical Image Analysis 2023; BBillot/SynthSeg, BBillot/lab2im): +synthesise a randomly-contrasted, randomly-resolved image from an anatomical +label map, together with the matching ground-truth segmentation. + +Public API: + SynthSegGenerator -- the full generative model as an nn.Module + (``forward(label_map) -> (image, label)``). + SynthSegTransformsGPU -- config-driven driver, ``forward(data, target) -> + (image, target)`` (AugLab calling convention). + RandomSynthSegGPU -- ImageOnlyTransform that replaces the image with a + GMM synthesis of ``params['seg']`` (composes inside + AugmentationSequentialCustom pipelines). +""" + +from auglab.transforms.synthseg.generator import SynthSegGenerator +from auglab.transforms.synthseg.transforms import RandomSynthSegGPU, SynthSegTransformsGPU +from auglab.transforms.synthseg import functional + +__all__ = [ + "SynthSegGenerator", + "SynthSegTransformsGPU", + "RandomSynthSegGPU", + "functional", +] diff --git a/auglab/transforms/synthseg/functional.py b/auglab/transforms/synthseg/functional.py new file mode 100644 index 0000000..697308c --- /dev/null +++ b/auglab/transforms/synthseg/functional.py @@ -0,0 +1,876 @@ +"""Functional building blocks for the SynthSeg generative model (GPU / torch). + +This module re-implements, as standalone differentiable-free torch ops, the +individual layers of the SynthSeg "brain generator" described in: + + B. Billot et al., "SynthSeg: Segmentation of brain MRI scans of any contrast + and resolution without retraining", Medical Image Analysis, 2023. + (and the earlier MICCAI-2020 contrast-agnostic / PV-segmentation papers) + +The reference TensorFlow implementation lives in ``BBillot/SynthSeg`` and +``BBillot/lab2im``. Each function below cites the corresponding reference layer. +Everything operates on 3D volumes stored as ``(B, C, D, H, W)`` torch tensors +(label maps as ``(B, 1, D, H, W)`` integer tensors), which is the convention +used throughout AugLab's GPU transforms. + +Spatial conventions +-------------------- +* Spatial axes ``(D, H, W)`` map to torch dims ``(2, 3, 4)``. Internally we work + with voxel coordinates in ``(i, j, k) = (D, H, W)`` order and only convert to + the ``(x, y, z) = (W, H, D)`` order expected by ``F.grid_sample`` at the very + end, with ``align_corners=True`` so that integer voxel indices map exactly. +* Affine transforms are applied about the volume centre (standard practice and + matching AugLab's existing ``RandomAffine3DCustom``), so small rotations / + scalings keep the anatomy in frame. +""" + +from __future__ import annotations + +import math +from typing import List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn.functional as F + +Number = Union[int, float] + +__all__ = [ + "to_label_map", + "infer_label_values", + "sample_gmm_parameters", + "labels_to_image_gmm", + "sample_affine_matrices", + "random_svf_field", + "warp_volume", + "bias_field", + "intensity_augmentation", + "blurring_sigma_for_downsampling", + "gaussian_blur_3d", + "sample_resolution", + "mimic_acquisition", + "em_subdivide_labels", + "flip_lr_with_swap", + "convert_labels", +] + + +# --------------------------------------------------------------------------- +# Label map helpers +# --------------------------------------------------------------------------- +def to_label_map(seg: torch.Tensor) -> torch.Tensor: + """Coerce a segmentation tensor into a single-channel integer label map. + + Accepts: + * ``(B, 1, D, H, W)`` -> rounded to integer labels (used directly). + * ``(B, C, D, H, W)`` one-hot (C > 1) -> argmax + 1, background (all-zero + across channels) stays 0. This matches AugLab's + :func:`collapse_onehot_to_index` convention where channel ``c`` encodes + label ``c + 1``. + * ``(B, D, H, W)`` -> unsqueezed to ``(B, 1, D, H, W)``. + + Returns a ``(B, 1, D, H, W)`` ``long`` tensor. + """ + if seg.dim() == 4: + seg = seg.unsqueeze(1) + if seg.dim() != 5: + raise ValueError(f"Expected a 4D or 5D segmentation tensor, got {seg.dim()}D.") + + if seg.shape[1] == 1: + return seg.round().long() + + # One-hot -> integer index (channel c -> label c + 1, background -> 0). + foreground = seg.any(dim=1, keepdim=True) + labels = torch.argmax(seg, dim=1, keepdim=True).long() + 1 + return torch.where(foreground, labels, torch.zeros_like(labels)) + + +def infer_label_values(label_map: torch.Tensor) -> torch.Tensor: + """Return the sorted unique label values present in ``label_map``.""" + return torch.unique(label_map).long() + + +# --------------------------------------------------------------------------- +# GMM intensity model (lab2im.layers.SampleConditionalGMM +# + SynthSeg.model_inputs.build_model_inputs) +# --------------------------------------------------------------------------- +def _draw_value( + prior: Optional[Union[Number, Sequence[Number], torch.Tensor]], + size: Tuple[int, int], + distribution: str, + centre: float, + default_range: float, + device: torch.device, + positive_only: bool = False, +) -> torch.Tensor: + """Port of ``lab2im.utils.draw_value_from_distribution``. + + ``size`` is ``(batch, n_classes)``. Returns a tensor of that shape. + + ``prior`` interpretations: + * ``None`` -> ``uniform``: U(centre-range, centre+range); + ``normal``: N(centre, range). + * scalar ``s`` -> ``uniform``: U(centre-s, centre+s); + ``normal``: N(centre, s). + * length-2 ``[a, b]``-> ``uniform``: U(a, b); ``normal``: N(a, b). + (shared across classes) + * array ``(2, K)`` -> per-class ``[a, b]`` rows. + """ + batch, n_classes = size + + # Resolve the two distribution parameters (a, b) of shape ``size``. + # uniform -> (low, high); normal -> (mean, std). + if prior is None: + if distribution == "uniform": + a = torch.full(size, centre - default_range, device=device) + b = torch.full(size, centre + default_range, device=device) + else: + a = torch.full(size, centre, device=device) + b = torch.full(size, default_range, device=device) + elif isinstance(prior, (int, float)): + if distribution == "uniform": + a = torch.full(size, centre - float(prior), device=device) + b = torch.full(size, centre + float(prior), device=device) + else: + a = torch.full(size, centre, device=device) + b = torch.full(size, float(prior), device=device) + else: + prior_t = torch.as_tensor(prior, dtype=torch.float32, device=device) + if prior_t.numel() == 2 and prior_t.dim() == 1: + a = prior_t[0].expand(size).clone() + b = prior_t[1].expand(size).clone() + elif prior_t.dim() == 2 and prior_t.shape[0] == 2: + if prior_t.shape[1] != n_classes: + raise ValueError( + f"Prior array has {prior_t.shape[1]} classes, expected {n_classes}." + ) + a = prior_t[0].unsqueeze(0).expand(size).clone() + b = prior_t[1].unsqueeze(0).expand(size).clone() + else: + raise ValueError(f"Unsupported prior shape {tuple(prior_t.shape)}.") + + out = _sample(distribution, a, b, device) + return out.clamp_min(0.0) if positive_only else out + + +def _sample(distribution: str, a: torch.Tensor, b: torch.Tensor, device: torch.device) -> torch.Tensor: + if distribution == "uniform": + return a + (b - a) * torch.rand(a.shape, device=device) + if distribution == "normal": + return a + b * torch.randn(a.shape, device=device) + raise ValueError(f"Unknown distribution '{distribution}' (use 'uniform' or 'normal').") + + +def sample_gmm_parameters( + n_labels: int, + n_channels: int, + batch: int, + device: torch.device, + prior_means=None, + prior_stds=None, + prior_distributions: str = "uniform", + generation_classes: Optional[Sequence[int]] = None, + background_label_index: Optional[int] = 0, + randomise_background: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Draw per-label Gaussian means/stds for one minibatch. + + Mirrors ``SynthSeg.model_inputs.build_model_inputs``. With the default + ``prior_means=None`` / ``prior_stds=None`` the *effective* priors are + ``means ~ U(0, 250)`` and ``stds ~ U(0, 30)`` per class (the code uses + ``centre=125, range=125`` and ``centre=15, range=15``; the often-quoted + ``[25, 225]`` / ``[5, 25]`` come from a stale docstring). + + ``generation_classes`` lets several labels share the same Gaussian (e.g. to + tie left/right homologues). When ``None`` every label is independent. + + Returns ``means, stds`` of shape ``(batch, n_labels, n_channels)``. + """ + if generation_classes is None: + classes = torch.arange(n_labels, device=device) + else: + classes = torch.as_tensor(generation_classes, device=device, dtype=torch.long) + if classes.numel() != n_labels: + raise ValueError("generation_classes must have one entry per generation label.") + n_classes = int(classes.max().item()) + 1 + + means = torch.empty(batch, n_classes, n_channels, device=device) + stds = torch.empty(batch, n_classes, n_channels, device=device) + for ch in range(n_channels): + means[:, :, ch] = _draw_value( + prior_means, (batch, n_classes), prior_distributions, 125.0, 125.0, device, positive_only=True + ) + stds[:, :, ch] = _draw_value( + prior_stds, (batch, n_classes), prior_distributions, 15.0, 15.0, device, positive_only=True + ) + + # Scatter class parameters to per-label parameters. + means_lab = means[:, classes, :] + stds_lab = stds[:, classes, :] + + # Background special-casing (build_model_inputs): per subject, 5% pure black, + # 25% very dark/low-variance, 70% normal draw. + if randomise_background and background_label_index is not None and 0 <= background_label_index < n_labels: + for b in range(batch): + r = float(torch.rand((), device=device)) + if r > 0.95: + means_lab[b, background_label_index, :] = 0.0 + stds_lab[b, background_label_index, :] = 0.0 + elif r > 0.70: + means_lab[b, background_label_index, :] = torch.rand(n_channels, device=device) * 15.0 + stds_lab[b, background_label_index, :] = torch.rand(n_channels, device=device) * 5.0 + # else: keep the normal draw + + return means_lab, stds_lab + + +def labels_to_image_gmm( + label_map: torch.Tensor, + label_values: torch.Tensor, + means: torch.Tensor, + stds: torch.Tensor, +) -> torch.Tensor: + """Render an image from a label map with a per-label Gaussian mixture. + + ``image[v] = mean[label_v] + std[label_v] * N(0, 1)`` (independent noise per + voxel and channel). Mirrors ``lab2im.layers.SampleConditionalGMM``. + + Args: + label_map: ``(B, 1, D, H, W)`` integer labels. + label_values: ``(K,)`` the generation label values (sorted unique). + means, stds: ``(B, K, C)`` per-label parameters. + + Returns ``(B, C, D, H, W)`` float image. + """ + B = label_map.shape[0] + spatial = label_map.shape[2:] + C = means.shape[-1] + device = label_map.device + dtype = means.dtype + + max_label = int(label_values.max().item()) + # LUT: label value -> contiguous class index 0..K-1. + lut = torch.zeros(max_label + 1, dtype=torch.long, device=device) + lut[label_values] = torch.arange(label_values.numel(), device=device) + idx = lut[label_map.clamp(min=0, max=max_label)].squeeze(1) # (B, D, H, W) + + image = torch.empty((B, C, *spatial), device=device, dtype=dtype) + for b in range(B): + idx_b = idx[b] + for ch in range(C): + mean_map = means[b, :, ch][idx_b] + std_map = stds[b, :, ch][idx_b] + image[b, ch] = mean_map + std_map * torch.randn(spatial, device=device, dtype=dtype) + return image + + +# --------------------------------------------------------------------------- +# Spatial deformation (lab2im.layers.RandomSpatialDeformation +# + utils.sample_affine_transform + neuron VecInt) +# --------------------------------------------------------------------------- +def _as_3vec(value, device: torch.device, default: float = 0.0) -> torch.Tensor: + if value is None or value is False: + return torch.full((3,), float(default), device=device) + if isinstance(value, (int, float)): + return torch.full((3,), float(value), device=device) + return torch.as_tensor(value, dtype=torch.float32, device=device) + + +def sample_affine_matrices( + batch: int, + device: torch.device, + scaling_bounds: Union[bool, Number, Sequence[Number]] = 0.2, + rotation_bounds: Union[bool, Number, Sequence[Number]] = 15.0, + shearing_bounds: Union[bool, Number, Sequence[Number]] = 0.012, + translation_bounds: Union[bool, Number, Sequence[Number]] = False, +) -> torch.Tensor: + """Sample a batch of 3D affine matrices. + + Port of ``lab2im.utils.sample_affine_transform`` / + ``create_affine_transformation_matrix`` for ``n_dims=3``. Each parameter is + drawn independently per axis: + + * scaling ``~ U(1 - s, 1 + s)`` (``scaling_bounds=0.2`` -> U(0.8, 1.2)) + * rotation ``~ U(-r, r)`` degrees (``rotation_bounds=15``) + * shearing ``~ U(-sh, sh)`` (``shearing_bounds=0.012``) + * translation ``~ U(-t, t)`` voxels (``False`` -> disabled) + + The composition is ``T = T_scaling @ T_shearing @ T_rotation`` with the + translation placed in the last column. Returns ``(B, 4, 4)``. The matrix is + applied about the volume centre by :func:`warp_volume`. + """ + def draw(bounds, centre): + vec = _as_3vec(bounds, device, default=0.0) + return centre + (2.0 * torch.rand(batch, 3, device=device) - 1.0) * vec.view(1, 3) + + scaling = draw(scaling_bounds, 1.0) if scaling_bounds not in (False, None) else torch.ones(batch, 3, device=device) + rotation_deg = draw(rotation_bounds, 0.0) if rotation_bounds not in (False, None) else torch.zeros(batch, 3, device=device) + # 6 shear parameters for the off-diagonal entries of the 3x3 linear part. + shear_vec = _as_3vec(shearing_bounds, device, default=0.0) + shear_vec6 = torch.cat([shear_vec, shear_vec]) if shearing_bounds not in (False, None) else torch.zeros(6, device=device) + shearing = (2.0 * torch.rand(batch, 6, device=device) - 1.0) * shear_vec6.view(1, 6) + translation = draw(translation_bounds, 0.0) if translation_bounds not in (False, None) else torch.zeros(batch, 3, device=device) + + rot = math.pi / 180.0 * rotation_deg + cx, cy, cz = torch.cos(rot[:, 0]), torch.cos(rot[:, 1]), torch.cos(rot[:, 2]) + sx, sy, sz = torch.sin(rot[:, 0]), torch.sin(rot[:, 1]), torch.sin(rot[:, 2]) + + zeros = torch.zeros(batch, device=device) + ones = torch.ones(batch, device=device) + + def stack3(rows): + return torch.stack([torch.stack(r, dim=-1) for r in rows], dim=-2) # (B, 3, 3) + + rx = stack3([[ones, zeros, zeros], [zeros, cx, -sx], [zeros, sx, cx]]) + ry = stack3([[cy, zeros, sy], [zeros, ones, zeros], [-sy, zeros, cy]]) + rz = stack3([[cz, -sz, zeros], [sz, cz, zeros], [zeros, zeros, ones]]) + rotation_m = torch.bmm(torch.bmm(rx, ry), rz) + + shear_m = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1) + # off-diagonal positions (0,1),(0,2),(1,0),(1,2),(2,0),(2,1) + pos = [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)] + for n, (i, j) in enumerate(pos): + shear_m[:, i, j] = shearing[:, n] + + scale_m = torch.zeros(batch, 3, 3, device=device) + scale_m[:, 0, 0] = scaling[:, 0] + scale_m[:, 1, 1] = scaling[:, 1] + scale_m[:, 2, 2] = scaling[:, 2] + + linear = torch.bmm(torch.bmm(scale_m, shear_m), rotation_m) # (B, 3, 3) + + affine = torch.zeros(batch, 4, 4, device=device) + affine[:, :3, :3] = linear + affine[:, :3, 3] = translation + affine[:, 3, 3] = 1.0 + return affine + + +def _identity_grid(shape: Tuple[int, int, int], device: torch.device) -> torch.Tensor: + """Voxel-coordinate identity grid in ``(i, j, k)`` order, shape ``(3, D, H, W)``.""" + d, h, w = shape + zs = torch.arange(d, device=device, dtype=torch.float32) + ys = torch.arange(h, device=device, dtype=torch.float32) + xs = torch.arange(w, device=device, dtype=torch.float32) + ii, jj, kk = torch.meshgrid(zs, ys, xs, indexing="ij") + return torch.stack([ii, jj, kk], dim=0) + + +def _coords_to_grid_sample(coords: torch.Tensor, shape: Tuple[int, int, int]) -> torch.Tensor: + """Convert ``(B, 3, D, H, W)`` voxel coords (i,j,k) to a grid_sample grid. + + Output ``(B, D, H, W, 3)`` with last-dim order ``(x, y, z) = (k, j, i)`` + normalised to ``[-1, 1]`` for ``align_corners=True``. + """ + d, h, w = shape + i = coords[:, 0] + j = coords[:, 1] + k = coords[:, 2] + norm_k = 2.0 * k / max(w - 1, 1) - 1.0 + norm_j = 2.0 * j / max(h - 1, 1) - 1.0 + norm_i = 2.0 * i / max(d - 1, 1) - 1.0 + return torch.stack([norm_k, norm_j, norm_i], dim=-1) + + +def warp_volume( + volume: torch.Tensor, + affine: Optional[torch.Tensor] = None, + displacement: Optional[torch.Tensor] = None, + interp: str = "linear", + center: bool = True, + padding_mode: str = "zeros", +) -> torch.Tensor: + """Resample ``volume`` by an affine matrix and/or a dense displacement field. + + Mirrors ``ext.neuron.layers.SpatialTransformer`` composing + ``[affine, dense_field]``: the sampling location of each output voxel is + ``affine(identity) + displacement``. + + Args: + volume: ``(B, C, D, H, W)``. + affine: ``(B, 4, 4)`` linear+translation applied about the centre (voxel + units, ``(i, j, k)`` order). ``None`` -> identity. + displacement: ``(B, 3, D, H, W)`` per-voxel shift in ``(i, j, k)`` voxel + units. ``None`` -> no elastic term. + interp: ``"linear"`` (trilinear) for images / fields, ``"nearest"`` for + label maps. + center: apply the affine about the volume centre. + padding_mode: ``grid_sample`` padding (``"zeros"`` for label/image, + ``"border"`` for field composition). + """ + B, C, D, H, W = volume.shape + shape = (D, H, W) + device = volume.device + + grid = _identity_grid(shape, device).unsqueeze(0).expand(B, -1, -1, -1, -1) # (B,3,D,H,W) + coords = grid + + if affine is not None: + flat = coords.reshape(B, 3, -1) # (B, 3, N) + if center: + centre = torch.tensor( + [(D - 1) / 2.0, (H - 1) / 2.0, (W - 1) / 2.0], device=device + ).view(1, 3, 1) + flat = flat - centre + linear = affine[:, :3, :3] + translation = affine[:, :3, 3:4] + flat = torch.bmm(linear, flat) + translation + if center: + flat = flat + centre + coords = flat.reshape(B, 3, D, H, W) + + if displacement is not None: + coords = coords + displacement + + sample_grid = _coords_to_grid_sample(coords, shape) + mode = "nearest" if interp == "nearest" else "bilinear" # 3D 'bilinear' == trilinear + return F.grid_sample( + volume, sample_grid, mode=mode, align_corners=True, padding_mode=padding_mode + ) + + +def _integrate_velocity(velocity: torch.Tensor, int_steps: int = 7) -> torch.Tensor: + """Scaling-and-squaring integration of a stationary velocity field. + + Port of ``ext.neuron.layers.VecInt`` (``method='ss'``, ``int_steps=7``): + ``phi = v / 2**N`` then ``phi <- phi + warp(phi, phi)`` repeated ``N`` times, + yielding a diffeomorphic displacement field. ``velocity`` and the returned + displacement are ``(B, 3, D, H, W)`` in voxel units. + """ + disp = velocity / (2 ** int_steps) + for _ in range(int_steps): + disp = disp + warp_volume(disp, displacement=disp, interp="linear", padding_mode="border") + return disp + + +def random_svf_field( + batch: int, + shape: Tuple[int, int, int], + device: torch.device, + nonlin_std: float = 4.0, + nonlin_scale: float = 0.04, + int_steps: int = 7, +) -> torch.Tensor: + """Sample a smooth diffeomorphic displacement field. + + Port of ``RandomSpatialDeformation``'s nonlinear branch: draw a small + stationary velocity field ``~ N(0, U(0, nonlin_std))`` on a coarse grid + (``ceil(shape * nonlin_scale)``), upsample it (trilinear) to full resolution, + then integrate it by scaling-and-squaring. Returns ``(B, 3, D, H, W)``. + """ + if nonlin_std <= 0: + return torch.zeros(batch, 3, *shape, device=device) + + small = [max(2, int(math.ceil(s * nonlin_scale))) for s in shape] + std = torch.rand(batch, 1, 1, 1, 1, device=device) * nonlin_std + velocity = torch.randn(batch, 3, *small, device=device) * std + velocity = F.interpolate(velocity, size=shape, mode="trilinear", align_corners=True) + return _integrate_velocity(velocity, int_steps=int_steps) + + +# --------------------------------------------------------------------------- +# Bias field (lab2im.layers.BiasFieldCorruption) +# --------------------------------------------------------------------------- +def bias_field( + image: torch.Tensor, + bias_field_std: float = 0.7, + bias_scale: float = 0.025, +) -> torch.Tensor: + """Apply a smooth multiplicative bias field. + + Port of ``BiasFieldCorruption``: sample a small Gaussian field + ``~ N(0, U(0, bias_field_std))`` on a coarse grid (``ceil(shape*bias_scale)``), + upsample (trilinear) to full resolution, exponentiate, and multiply. The + field is Gaussian in log space -> positive and multiplicative in intensity + space. A separate field is drawn per channel. + """ + if bias_field_std <= 0: + return image + B, C, D, H, W = image.shape + device = image.device + small = [max(2, int(math.ceil(s * bias_scale))) for s in (D, H, W)] + std = torch.rand(B, 1, 1, 1, 1, device=device) * bias_field_std + field = torch.randn(B, C, *small, device=device) * std + field = F.interpolate(field, size=(D, H, W), mode="trilinear", align_corners=True) + return image * torch.exp(field) + + +# --------------------------------------------------------------------------- +# Intensity augmentation (lab2im.layers.IntensityAugmentation) +# --------------------------------------------------------------------------- +def intensity_augmentation( + image: torch.Tensor, + clip: float = 300.0, + gamma_std: float = 0.5, + normalise: bool = True, +) -> torch.Tensor: + """Clip -> per-channel min-max normalise to [0, 1] -> gamma. + + Port of ``IntensityAugmentation(clip=300, normalise=True, gamma_std=.5, + separate_channels=True)``. Gamma is log-normal: + ``image <- image ** exp(N(0, gamma_std))`` (one exponent per channel). + """ + B, C = image.shape[:2] + reduce_dims = tuple(range(2, image.dim())) + + if clip and clip > 0: + image = image.clamp(0.0, clip) + + if normalise: + m = image.amin(dim=reduce_dims, keepdim=True) + M = image.amax(dim=reduce_dims, keepdim=True) + image = (image - m) / (M - m + 1e-7) + + if gamma_std and gamma_std > 0: + gamma = torch.exp(torch.randn(B, C, *([1] * (image.dim() - 2)), device=image.device) * gamma_std) + image = image.clamp_min(0.0) ** gamma + + return image + + +# --------------------------------------------------------------------------- +# Resolution randomisation (lab2im.edit_tensors + layers.GaussianBlur / +# DynamicGaussianBlur / SampleResolution / +# MimicAcquisition) +# --------------------------------------------------------------------------- +def blurring_sigma_for_downsampling( + current_res: torch.Tensor, + downsample_res: torch.Tensor, + thickness: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Per-axis Gaussian blur sigma for a target acquisition resolution. + + Port of ``edit_tensors.blurring_sigma_for_downsampling`` with + ``mult_coef=None``: ``sigma = 0.75 * min(downsample_res, thickness) / + current_res``; ``sigma = 0.5`` where ``downsample_res == current_res``; + ``sigma = 0`` where ``downsample_res == 0``. All inputs are ``(3,)`` tensors. + """ + if thickness is None: + thickness = downsample_res + effective = torch.minimum(downsample_res, thickness) + sigma = 0.75 * effective / current_res + sigma = torch.where(downsample_res == current_res, torch.full_like(sigma, 0.5), sigma) + sigma = torch.where(downsample_res == 0, torch.zeros_like(sigma), sigma) + return sigma + + +def _gaussian_kernel1d(sigma: float, device: torch.device) -> torch.Tensor: + if sigma <= 0: + return torch.tensor([1.0], device=device) + radius = max(1, int(math.ceil(3.0 * sigma))) + x = torch.arange(-radius, radius + 1, device=device, dtype=torch.float32) + k = torch.exp(-0.5 * (x / sigma) ** 2) + return k / k.sum() + + +def gaussian_blur_3d( + image: torch.Tensor, + sigma: torch.Tensor, + blur_range: float = 1.03, +) -> torch.Tensor: + """Separable anisotropic Gaussian blur with random sigma jitter. + + Port of ``GaussianBlur`` / ``DynamicGaussianBlur``: the per-axis ``sigma`` is + multiplied by ``U(1/blur_range, blur_range)`` (``blur_range=1.03`` in + SynthSeg, ``1.15`` in the 2020 lab2im model) and applied as three 1D + convolutions (reflect padding). ``sigma`` is a ``(3,)`` tensor. + """ + B, C, D, H, W = image.shape + device = image.device + sigma = sigma.clone().float() + if blur_range and blur_range > 1.0: + jitter = (1.0 / blur_range) + torch.rand(3, device=device) * (blur_range - 1.0 / blur_range) + sigma = sigma * jitter + + out = image + for axis, s in enumerate(sigma.tolist()): + if s <= 0: + continue + kernel = _gaussian_kernel1d(s, device) + ksize = kernel.numel() + pad = ksize // 2 + # shape the separable kernel for conv3d along the given spatial axis + shape = [1, 1, 1, 1, 1] + shape[2 + axis] = ksize + weight = kernel.view(shape).repeat(C, 1, 1, 1, 1) + padding = [0, 0, 0] + padding[axis] = pad + pad_full = (padding[2], padding[2], padding[1], padding[1], padding[0], padding[0]) + out = F.pad(out, pad_full, mode="reflect") + out = F.conv3d(out, weight, groups=C) + return out + + +def sample_resolution( + min_res: torch.Tensor, + max_res_iso: float = 4.0, + max_res_aniso: float = 8.0, + prob_iso: float = 0.1, + prob_min: float = 0.05, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Sample a random target acquisition resolution and slice thickness. + + Port of ``lab2im.layers.SampleResolution`` (with ``return_thickness=True``): + + * with prob ``prob_iso``: isotropic, ``res ~ U(min_res, max_res_iso)`` (same + on all axes); + * else: anisotropic, one random axis gets ``U(min_res, max_res_aniso)`` and + the others stay at ``min_res``; + * with prob ``prob_min``: override to ``min_res`` (no downsampling); + * thickness ``~ U(min_res, res)`` per axis. + + ``min_res`` is a ``(3,)`` tensor (the native/atlas resolution). Returns + ``(resolution, thickness)`` as ``(3,)`` tensors. + """ + device = min_res.device + if float(torch.rand((), device=device)) < prob_iso: + r = float(torch.rand((), device=device)) + res = min_res + (max_res_iso - min_res) * r # same scalar factor -> isotropic-ish + else: + res = min_res.clone() + axis = int(torch.randint(0, 3, (1,), device=device)) + res[axis] = min_res[axis] + (max_res_aniso - float(min_res[axis])) * float(torch.rand((), device=device)) + + if float(torch.rand((), device=device)) < prob_min: + res = min_res.clone() + + thickness = min_res + (res - min_res) * torch.rand(3, device=device) + return res, thickness + + +def mimic_acquisition( + image: torch.Tensor, + current_res: torch.Tensor, + downsample_res: torch.Tensor, + output_shape: Tuple[int, int, int], +) -> torch.Tensor: + """Downsample to a target resolution, then resample to the output grid. + + Port of ``lab2im.layers.MimicAcquisition``: nearest-neighbour downsampling to + the sampled ``downsample_res`` grid (the partial-volume step) followed by + trilinear resampling to ``output_shape``. + """ + B, C, D, H, W = image.shape + in_shape = (D, H, W) + factor = (current_res / downsample_res).tolist() + down_shape = [max(1, int(round(in_shape[i] * factor[i]))) for i in range(3)] + x = F.interpolate(image, size=down_shape, mode="nearest") + x = F.interpolate(x, size=tuple(output_shape), mode="trilinear", align_corners=True) + return x + + +# --------------------------------------------------------------------------- +# EM label completion for sparse label maps (SynthSeg paper, Sec. 5.4) +# --------------------------------------------------------------------------- +def _em_gmm_1d( + x_fit: torch.Tensor, n_components: int, n_iters: int, eps: float +) -> Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + """Fit a 1D Gaussian mixture by Expectation-Maximization. + + ``x_fit`` is a 1D tensor of intensities. Returns ``(means, vars, weights)`` + each of shape ``(K,)`` (``K = min(n_components, len(x_fit))``), or ``None`` if + ``x_fit`` is empty. Responsibilities use a numerically-stable log-sum-exp. + """ + n = x_fit.numel() + if n == 0: + return None + k = max(1, min(int(n_components), n)) + device = x_fit.device + xmin, xmax = float(x_fit.min()), float(x_fit.max()) + means = torch.linspace(xmin, xmax, k, device=device) + if xmax <= xmin: # degenerate (constant region): spread the means slightly + means = means + torch.arange(k, device=device, dtype=means.dtype) * eps + var = x_fit.var(unbiased=False).clamp_min(eps).repeat(k) + weights = torch.full((k,), 1.0 / k, device=device) + + x = x_fit.view(n, 1) + log2pi = math.log(2.0 * math.pi) + for _ in range(int(n_iters)): + logp = ( + torch.log(weights.clamp_min(eps)).view(1, k) + - 0.5 * (log2pi + torch.log(var).view(1, k)) + - 0.5 * (x - means.view(1, k)) ** 2 / var.view(1, k) + ) + logp = logp - torch.logsumexp(logp, dim=1, keepdim=True) + resp = logp.exp() # (n, k) + nk = resp.sum(0).clamp_min(eps) # (k,) + weights = nk / n + means = (resp * x).sum(0) / nk + var = (resp * (x - means.view(1, k)) ** 2).sum(0) / nk + var = var.clamp_min(eps) + return means, var, weights + + +def _assign_gmm( + x_full: torch.Tensor, means: torch.Tensor, var: torch.Tensor, weights: torch.Tensor, + eps: float, chunk: int = 2_000_000, +) -> torch.Tensor: + """Hard-assign each value in ``x_full`` to its most likely mixture component.""" + n = x_full.numel() + k = means.numel() + out = torch.empty(n, dtype=torch.long, device=x_full.device) + logw = torch.log(weights.clamp_min(eps)).view(1, k) + half_logvar = 0.5 * torch.log(var).view(1, k) + m = means.view(1, k) + v = var.view(1, k) + for s in range(0, n, chunk): + xc = x_full[s:s + chunk].view(-1, 1) + logp = logw - half_logvar - 0.5 * (xc - m) ** 2 / v + out[s:s + chunk] = logp.argmax(dim=1) + return out + + +def em_subdivide_labels( + image: torch.Tensor, + label_map: torch.Tensor, + n_foreground_clusters: int = 2, + background_clusters_range: Sequence[int] = (3, 10), + background_label: int = 0, + n_iters: int = 20, + max_fit_voxels: int = 100000, + channel: int = 0, + same_on_batch: bool = False, + eps: float = 1e-6, +) -> Tuple[torch.Tensor, List[int], List[int]]: + """Subdivide each label into intensity-coherent subregions via EM (SynthSeg §5.4). + + Reproduces SynthSeg's handling of sparse / incomplete label maps: "we enhance + the training segmentations by subdividing all their labels (background and + foreground) into finer subregions [...] by clustering the intensities of the + associated image with the Expectation Maximisation algorithm." Each foreground + label is split into ``n_foreground_clusters`` (2 in the paper); the background + label is split into a random ``N`` in ``background_clusters_range`` ([3, 10]). + The resulting fine labels are the *generation* labels (each gets its own + Gaussian), while the original labels are recovered for the segmentation target + via the returned merge map. + + Args: + image: ``(B, C, D, H, W)`` paired real intensities (channel ``channel`` is + used as the clustering reference). + label_map: ``(B, 1, D, H, W)`` integer parent labels. + n_foreground_clusters: sub-clusters per non-background label. + background_clusters_range: ``(min, max)`` for the random background split. + background_label: the label value treated as background. + n_iters: EM iterations. + max_fit_voxels: subsample this many voxels to *fit* each EM (the full + region is still assigned); ``0`` / ``None`` -> use all voxels. + same_on_batch: draw a single background ``N`` shared across the batch. + eps: numerical floor. + + Returns: + ``(fine_labels, generation_labels, output_labels)`` where ``fine_labels`` + is ``(B, 1, D, H, W)`` long, ``generation_labels`` is the sorted list of + sub-label values, and ``output_labels[i]`` is the parent label that + ``generation_labels[i]`` merges back to. + """ + B = image.shape[0] + device = image.device + ref = image[:, channel] # (B, D, H, W) + parents = torch.unique(label_map).long().tolist() # sorted, batch-wide + parent_to_idx = {p: i for i, p in enumerate(parents)} + lo, hi = int(background_clusters_range[0]), int(background_clusters_range[1]) + mult = max(hi, int(n_foreground_clusters)) + 1 # collision-free encoding + + # If the configured background label is absent (e.g. a *complete* one-hot whose + # decoding shifted every label by +1, so the real background is no longer 0), + # fall back to the largest-area label so it still receives the richer + # [min, max] background split rather than the 2-cluster foreground split. + if background_label not in parents and len(parents) > 0: + background_label = int(torch.bincount(label_map.flatten().clamp_min(0)).argmax()) + + bg_n_shared = int(torch.randint(lo, hi + 1, (1,), device=device)) if same_on_batch else None + + fine = torch.zeros_like(label_map) + for b in range(B): + lab_b = label_map[b, 0] + ref_b = ref[b] + for p in parents: + pi = parent_to_idx[p] + mask = lab_b == p + cnt = int(mask.sum()) + if cnt == 0: + continue + if p == background_label: + k = bg_n_shared if bg_n_shared is not None else int(torch.randint(lo, hi + 1, (1,), device=device)) + else: + k = int(n_foreground_clusters) + k = max(1, min(k, cnt)) + + x = ref_b[mask] + if k == 1: + assign = torch.zeros(cnt, dtype=torch.long, device=device) + else: + if max_fit_voxels and cnt > max_fit_voxels: + sel = torch.randint(0, cnt, (int(max_fit_voxels),), device=device) + x_fit = x[sel] + else: + x_fit = x + fit = _em_gmm_1d(x_fit, k, n_iters, eps) + assign = ( + torch.zeros(cnt, dtype=torch.long, device=device) + if fit is None else _assign_gmm(x, *fit, eps=eps) + ) + fine[b, 0][mask] = pi * mult + assign + + gen_values = torch.unique(fine).long().tolist() + out_values = [parents[g // mult] for g in gen_values] + return fine, gen_values, out_values + + +# --------------------------------------------------------------------------- +# Label utilities (lab2im.layers.RandomFlip / ConvertLabels) +# --------------------------------------------------------------------------- +def flip_lr_with_swap( + label_map: torch.Tensor, + flip_axis: int, + label_values: Optional[torch.Tensor] = None, + n_neutral_labels: Optional[int] = None, +) -> torch.Tensor: + """Flip the label map along ``flip_axis`` and (optionally) swap L/R labels. + + Port of ``lab2im.layers.RandomFlip(swap_labels=True)``. When + ``n_neutral_labels`` is provided, ``label_values`` is assumed ordered as + ``[neutral..., left-hemisphere..., right-hemisphere...]`` and the two + hemispheres are relabelled into each other after flipping (so anatomical + left/right stays correct). When ``n_neutral_labels`` is ``None`` a plain flip + is performed (no relabelling). + + ``flip_axis`` is the spatial axis index in ``(0, 1, 2) = (D, H, W)``. + """ + flipped = torch.flip(label_map, dims=(2 + flip_axis,)) + + if n_neutral_labels is None or label_values is None: + return flipped + + values = label_values.tolist() + n_labels = len(values) + if n_neutral_labels >= n_labels: + return flipped + n_sided = (n_labels - n_neutral_labels) // 2 + if n_sided == 0: + return flipped + + neutral = values[:n_neutral_labels] + left = values[n_neutral_labels:n_neutral_labels + n_sided] + right = values[n_neutral_labels + n_sided:n_neutral_labels + 2 * n_sided] + source = neutral + left + right + dest = neutral + right + left + return convert_labels(flipped, source, dest) + + +def convert_labels( + label_map: torch.Tensor, + source_values: Sequence[int], + dest_values: Sequence[int], +) -> torch.Tensor: + """Relabel ``label_map`` mapping ``source_values[i] -> dest_values[i]``. + + Port of ``lab2im.layers.ConvertLabels``. Labels not present in + ``source_values`` are left unchanged. + """ + device = label_map.device + source = torch.as_tensor(source_values, dtype=torch.long, device=device) + dest = torch.as_tensor(dest_values, dtype=torch.long, device=device) + max_label = int(max(int(label_map.max().item()), int(source.max().item()))) + lut = torch.arange(max_label + 1, dtype=torch.long, device=device) + lut[source] = dest + return lut[label_map.clamp(min=0, max=max_label)] diff --git a/auglab/transforms/synthseg/generator.py b/auglab/transforms/synthseg/generator.py new file mode 100644 index 0000000..b7cba44 --- /dev/null +++ b/auglab/transforms/synthseg/generator.py @@ -0,0 +1,409 @@ +"""SynthSeg generative model (``BrainGenerator``) as a torch module. + +``SynthSegGenerator`` faithfully reproduces the order of operations and the +default hyper-parameters of the SynthSeg "brain generator" +(``SynthSeg/brain_generator.py`` -> ``labels_to_image_model``), turning an +anatomical *label map* into a randomly-synthesised image together with the +matching (spatially-deformed, possibly relabelled) ground-truth label map: + + spatial deform (affine + diffeomorphic SVF, on labels, nearest) + -> [optional random crop] + -> left/right flip with label swap + -> GMM intensity sampling (labels -> image) + -> bias field + -> intensity augmentation (clip -> min-max norm -> gamma) + -> resolution randomisation (blur -> subsample -> resample), per channel + -> map generation labels to output/segmentation labels + +The default hyper-parameters match ``BrainGenerator`` (which overrides several +``labels_to_image_model`` signature defaults). See ``README.md`` for the table +and source citations. + +Note on label maps: SynthSeg derives its realism from a *dense* anatomical label +map (e.g. a FreeSurfer/SAMSEG segmentation covering every tissue). When fed a +sparse target segmentation (only a few foreground structures over a 0 +background) it still runs correctly, but the synthetic image will only contain +those structures over a single-Gaussian background. +""" + +from __future__ import annotations + +from typing import List, Optional, Sequence, Tuple, Union + +import torch +from torch import nn + +from auglab.transforms.synthseg import functional as FN + +Number = Union[int, float] + + +class SynthSegGenerator(nn.Module): + """Differentiable-free SynthSeg image generator for 3D label maps. + + Args: + generation_labels: the label values the GMM generates intensities for. + If ``None`` they are inferred from each input as the sorted unique + values (re-inferred per call). + output_labels: label values written to the output segmentation, aligned + with ``generation_labels`` (``generation_labels[i] -> output_labels[i]``). + ``None`` -> identity (output labels == generation labels). + n_neutral_labels: number of non-lateralised labels at the start of + ``generation_labels``; enables anatomically-correct L/R label swapping + on flip. ``None`` -> plain flip without relabelling. + generation_classes: per-label class index so several labels can share one + Gaussian (e.g. tie left/right). ``None`` -> every label independent. + n_channels: number of synthesised image channels (modalities). + prior_distributions: ``"uniform"`` or ``"normal"`` GMM parameter priors. + prior_means / prior_stds: priors for the GMM means/stds. ``None`` -> + ``U(0, 250)`` / ``U(0, 30)`` (full domain randomisation). May be a + scalar, ``[a, b]``, or a ``(2, n_classes)`` array. + flipping: enable random left/right flipping. + scaling_bounds / rotation_bounds / shearing_bounds / translation_bounds: + affine augmentation ranges (see :func:`functional.sample_affine_matrices`). + nonlin_std / nonlin_scale / svf_integration_steps: diffeomorphic elastic + deformation controls. + bias_field_std / bias_scale: multiplicative bias-field controls. + gamma_std / clip: intensity-augmentation controls. + normalise: min-max normalise to [0, 1] before gamma (SynthSeg always does). + randomise_res: randomise the acquisition resolution per channel. + max_res_iso / max_res_aniso: resolution ceilings for ``randomise_res``. + data_res / thickness: fixed acquisition resolution(s) when + ``randomise_res=False`` (``(3,)`` or per-channel ``(n_channels, 3)``). + blur_range: random jitter factor on the blur sigma (1.03 in SynthSeg). + atlas_res: native resolution of the input label map (mm). + output_shape: spatial size of the output (random crop of the label map + before generation). ``None`` -> keep the input shape. + em_label_completion: for sparse/incomplete label maps, subdivide every + label into intensity-coherent generation sub-labels by clustering the + paired real image with Expectation-Maximization (SynthSeg §5.4). + Requires ``image`` to be passed to :meth:`forward`. The sub-labels are + merged back to their parent labels in the output segmentation. + em_n_foreground_clusters: EM sub-clusters per foreground label (2 in the + paper). + em_background_clusters_range: ``(min, max)`` random number of EM clusters + for the background label ([3, 10] in the paper). + em_background_label: label value treated as background for EM. + em_n_iters: EM iterations. + em_max_fit_voxels: subsample size for fitting each EM (the full region is + still assigned); ``0`` -> use all voxels. + em_same_on_batch: share the random background cluster count across the batch. + apply_intensity_augmentation: toggle clip/normalise/gamma. + """ + + def __init__( + self, + generation_labels: Optional[Sequence[int]] = None, + output_labels: Optional[Sequence[int]] = None, + n_neutral_labels: Optional[int] = None, + generation_classes: Optional[Sequence[int]] = None, + n_channels: int = 1, + prior_distributions: str = "uniform", + prior_means=None, + prior_stds=None, + flipping: bool = True, + flip_axis: int = 2, + scaling_bounds: Union[bool, Number, Sequence[Number]] = 0.2, + rotation_bounds: Union[bool, Number, Sequence[Number]] = 15.0, + shearing_bounds: Union[bool, Number, Sequence[Number]] = 0.012, + translation_bounds: Union[bool, Number, Sequence[Number]] = False, + nonlin_std: float = 4.0, + nonlin_scale: float = 0.04, + svf_integration_steps: int = 7, + bias_field_std: float = 0.7, + bias_scale: float = 0.025, + gamma_std: float = 0.5, + clip: float = 300.0, + normalise: bool = True, + randomise_res: bool = True, + max_res_iso: float = 4.0, + max_res_aniso: float = 8.0, + data_res=None, + thickness=None, + blur_range: float = 1.03, + atlas_res: float = 1.0, + output_shape: Optional[Sequence[int]] = None, + em_label_completion: bool = False, + em_n_foreground_clusters: int = 2, + em_background_clusters_range: Sequence[int] = (3, 10), + em_background_label: int = 0, + em_n_iters: int = 20, + em_max_fit_voxels: int = 100000, + em_same_on_batch: bool = False, + apply_affine: bool = True, + apply_nonlinear: bool = True, + apply_bias_field: bool = True, + apply_intensity_augmentation: bool = True, + apply_resolution: bool = True, + ) -> None: + super().__init__() + self.generation_labels = list(generation_labels) if generation_labels is not None else None + self.output_labels = list(output_labels) if output_labels is not None else None + self.n_neutral_labels = n_neutral_labels + self.generation_classes = list(generation_classes) if generation_classes is not None else None + self.n_channels = int(n_channels) + self.prior_distributions = prior_distributions + self.prior_means = prior_means + self.prior_stds = prior_stds + + self.flipping = flipping + self.flip_axis = flip_axis + self.scaling_bounds = scaling_bounds + self.rotation_bounds = rotation_bounds + self.shearing_bounds = shearing_bounds + self.translation_bounds = translation_bounds + + self.nonlin_std = nonlin_std + self.nonlin_scale = nonlin_scale + self.svf_integration_steps = int(svf_integration_steps) + + self.bias_field_std = bias_field_std + self.bias_scale = bias_scale + + self.gamma_std = gamma_std + self.clip = clip + self.normalise = normalise + + self.randomise_res = randomise_res + self.max_res_iso = max_res_iso + self.max_res_aniso = max_res_aniso + self.data_res = data_res + self.thickness = thickness + self.blur_range = blur_range + self.atlas_res = float(atlas_res) + self.output_shape = list(output_shape) if output_shape is not None else None + + self.em_label_completion = em_label_completion + self.em_n_foreground_clusters = int(em_n_foreground_clusters) + self.em_background_clusters_range = tuple(em_background_clusters_range) + self.em_background_label = int(em_background_label) + self.em_n_iters = int(em_n_iters) + self.em_max_fit_voxels = int(em_max_fit_voxels) + self.em_same_on_batch = em_same_on_batch + self._warned_em = False + + self.apply_affine = apply_affine + self.apply_nonlinear = apply_nonlinear + self.apply_bias_field = apply_bias_field + self.apply_intensity_augmentation = apply_intensity_augmentation + self.apply_resolution = apply_resolution + + # ------------------------------------------------------------------ + @torch.no_grad() + def forward( + self, label_map: torch.Tensor, image: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Generate an image and its label map from an input label map. + + Args: + label_map: ``(B, 1, D, H, W)`` integer / one-hot / ``(B, D, H, W)`` + segmentation. Coerced via :func:`functional.to_label_map`. + image: ``(B, C, D, H, W)`` paired *real* image. Only used (and only + required) when ``em_label_completion=True``, to cluster unlabelled + tissue into generation sub-labels via EM (SynthSeg §5.4). + + Returns: + ``(image, labels)`` where ``image`` is ``(B, n_channels, *out)`` float + in [0, 1] (if normalised) and ``labels`` is ``(B, 1, *out)`` int. + """ + labels = FN.to_label_map(label_map) + device = labels.device + batch = labels.shape[0] + + # Label bookkeeping (defaults from config; overridden by EM completion). + gen_labels = self.generation_labels # list[int] or None + out_labels_cfg = self.output_labels # list[int] or None + gen_classes = self.generation_classes # list[int] or None + n_neutral = self.n_neutral_labels + randomise_bg = True + + # 0. EM label completion for sparse maps (uses the paired real image) -- + if self.em_label_completion: + if image is None: + if not self._warned_em: + print(f"{type(self).__name__}: em_label_completion is enabled but no " + f"image was provided; falling back to plain generation.", flush=True) + self._warned_em = True + else: + ref = image if image.dim() == 5 else image.unsqueeze(1) + labels, gen_labels, out_labels_cfg = FN.em_subdivide_labels( + ref.float(), labels, + n_foreground_clusters=self.em_n_foreground_clusters, + background_clusters_range=self.em_background_clusters_range, + background_label=self.em_background_label, + n_iters=self.em_n_iters, + max_fit_voxels=self.em_max_fit_voxels, + same_on_batch=self.em_same_on_batch, + ) + gen_classes = None # each sub-label gets its own Gaussian + n_neutral = None # plain flip (sub-labels carry no L/R structure) + randomise_bg = False # background is now modelled by its clusters + + # 1. random crop to output_shape (label space) ------------------------ + if self.output_shape is not None and tuple(self.output_shape) != tuple(labels.shape[2:]): + labels = self._random_crop(labels, self.output_shape) + + # 2. spatial deformation of the LABEL MAP (nearest) ------------------- + affine = None + if self.apply_affine and self._affine_active(): + affine = FN.sample_affine_matrices( + batch, device, + scaling_bounds=self.scaling_bounds, + rotation_bounds=self.rotation_bounds, + shearing_bounds=self.shearing_bounds, + translation_bounds=self.translation_bounds, + ) + displacement = None + if self.apply_nonlinear and self.nonlin_std and self.nonlin_std > 0: + displacement = FN.random_svf_field( + batch, tuple(labels.shape[2:]), device, + nonlin_std=self.nonlin_std, + nonlin_scale=self.nonlin_scale, + int_steps=self.svf_integration_steps, + ) + if affine is not None or displacement is not None: + labels = FN.warp_volume( + labels.float(), affine=affine, displacement=displacement, + interp="nearest", padding_mode="zeros", + ).round().long() + + # 3. left/right flipping (with optional label swap) ------------------- + if self.flipping and float(torch.rand((), device=device)) < 0.5: + label_values_flip = ( + torch.as_tensor(gen_labels, dtype=torch.long, device=device) + if gen_labels is not None else FN.infer_label_values(labels) + ) + labels = FN.flip_lr_with_swap( + labels, self.flip_axis, + label_values=label_values_flip, + n_neutral_labels=n_neutral, + ) + + # The generation labels (after potential relabelling) used by the GMM. + gen_values = ( + torch.as_tensor(gen_labels, dtype=torch.long, device=device) + if gen_labels is not None else FN.infer_label_values(labels) + ) + n_labels = gen_values.numel() + bg_index = None + if randomise_bg and (gen_values == 0).any(): + bg_index = int((gen_values == 0).nonzero(as_tuple=True)[0].item()) + + # 4. GMM intensity sampling ------------------------------------------- + means, stds = FN.sample_gmm_parameters( + n_labels, self.n_channels, batch, device, + prior_means=self.prior_means, + prior_stds=self.prior_stds, + prior_distributions=self.prior_distributions, + generation_classes=gen_classes, + background_label_index=bg_index, + ) + synth = FN.labels_to_image_gmm(labels, gen_values, means, stds) + + # 5. bias field ------------------------------------------------------- + if self.apply_bias_field and self.bias_field_std and self.bias_field_std > 0: + synth = FN.bias_field(synth, self.bias_field_std, self.bias_scale) + + # 6. intensity augmentation (clip -> normalise -> gamma) -------------- + if self.apply_intensity_augmentation: + synth = FN.intensity_augmentation( + synth, clip=self.clip, gamma_std=self.gamma_std, normalise=self.normalise + ) + + # 7. resolution randomisation, per channel ---------------------------- + if self.apply_resolution: + synth = self._simulate_resolution(synth) + + # 8. map generation labels -> output/segmentation labels -------------- + out_labels = labels + if out_labels_cfg is not None and gen_labels is not None: + out_labels = FN.convert_labels(labels, gen_labels, out_labels_cfg) + + return synth, out_labels + + # ------------------------------------------------------------------ + def _affine_active(self) -> bool: + return any( + b not in (False, None, 0, 0.0) + for b in (self.scaling_bounds, self.rotation_bounds, self.shearing_bounds, self.translation_bounds) + ) + + @staticmethod + def _random_crop(labels: torch.Tensor, output_shape: Sequence[int]) -> torch.Tensor: + B, _, D, H, W = labels.shape + out = [] + sizes = (D, H, W) + starts = [] + for dim, target in zip(sizes, output_shape): + target = min(int(target), dim) + start = int(torch.randint(0, dim - target + 1, (1,))) if dim > target else 0 + starts.append((start, target)) + (sd, td), (sh, th), (sw, tw) = starts + return labels[:, :, sd:sd + td, sh:sh + th, sw:sw + tw] + + def _simulate_resolution(self, image: torch.Tensor) -> torch.Tensor: + device = image.device + out_shape = tuple(image.shape[2:]) + atlas_res = torch.full((3,), self.atlas_res, device=device) + + channels = [] + for c in range(image.shape[1]): + ch = image[:, c:c + 1] + if self.randomise_res: + res, thickness = FN.sample_resolution( + atlas_res, self.max_res_iso, self.max_res_aniso + ) + else: + res = self._fixed_res(c, device) + thickness = self._fixed_thickness(c, device, res) + sigma = FN.blurring_sigma_for_downsampling(atlas_res, res, thickness) + ch = FN.gaussian_blur_3d(ch, sigma, blur_range=self.blur_range) + ch = FN.mimic_acquisition(ch, atlas_res, res, out_shape) + channels.append(ch) + return torch.cat(channels, dim=1) + + def _fixed_res(self, channel: int, device: torch.device) -> torch.Tensor: + if self.data_res is None: + return torch.full((3,), self.atlas_res, device=device) + arr = torch.as_tensor(self.data_res, dtype=torch.float32, device=device) + if arr.dim() == 2: + return arr[channel] + return arr + + def _fixed_thickness(self, channel: int, device: torch.device, res: torch.Tensor) -> torch.Tensor: + if self.thickness is None: + return res + arr = torch.as_tensor(self.thickness, dtype=torch.float32, device=device) + if arr.dim() == 2: + return arr[channel] + return arr + + +if __name__ == "__main__": + # Minimal self-contained smoke test on a synthetic label map (CPU-friendly). + torch.manual_seed(0) + device = "cuda" if torch.cuda.is_available() else "cpu" + + B, D, H, W = 2, 48, 56, 52 + zz, yy, xx = torch.meshgrid( + torch.arange(D), torch.arange(H), torch.arange(W), indexing="ij" + ) + centre = torch.tensor([D / 2, H / 2, W / 2]) + r = ((zz - centre[0]) ** 2 + (yy - centre[1]) ** 2 + (xx - centre[2]) ** 2).sqrt() + vol = torch.zeros(D, H, W, dtype=torch.long) + vol[r < 18] = 1 # "tissue A" + vol[r < 10] = 2 # "tissue B" + vol[(xx > W // 2) & (r < 18)] = 3 # right-side structure + labels = vol.view(1, 1, D, H, W).repeat(B, 1, 1, 1, 1) + + gen = SynthSegGenerator(generation_labels=[0, 1, 2, 3], n_channels=1).to(device) + image, out_labels = gen(labels.to(device)) + + print("input labels:", tuple(labels.shape), "values", torch.unique(labels).tolist()) + print("output image :", tuple(image.shape), "range", + (round(float(image.min()), 3), round(float(image.max()), 3))) + print("output labels:", tuple(out_labels.shape), "values", torch.unique(out_labels).tolist()) + assert image.shape[0] == B and image.shape[1] == 1 + assert out_labels.shape[2:] == image.shape[2:] + assert not torch.isnan(image).any(), "NaNs in generated image" + print("OK") diff --git a/auglab/transforms/synthseg/transforms.py b/auglab/transforms/synthseg/transforms.py new file mode 100644 index 0000000..9c838d2 --- /dev/null +++ b/auglab/transforms/synthseg/transforms.py @@ -0,0 +1,195 @@ +"""AugLab-style wrappers around the SynthSeg generative model. + +Two entry points are provided: + +* :class:`RandomSynthSegGPU` -- an :class:`ImageOnlyTransform` that *replaces* + the image with a GMM-synthesised one derived from ``params['seg']``. It is + intensity-only (no internal spatial deformation), so it composes with AugLab's + existing geometric transforms (``RandomAffine3DCustom``, ``RandomFlipTransformGPU``, + ...) inside an :class:`AugmentationSequentialCustom`: place those *before* it so + the mask is deformed first and SynthSeg generates from the deformed labels, + keeping image and label aligned. Drop it into a ``transform_params_gpu.json`` + pipeline like any other transform. + +* :class:`SynthSegTransformsGPU` -- a config-driven top-level module mirroring + :class:`auglab.transforms.gpu.transforms.AugTransformsGPU`. It runs the *full* + SynthSeg pipeline (spatial deform + flip + GMM + bias + intensity + resolution) + and returns ``(image, label)`` from ``forward(data, target)`` -- the calling + convention used by the nnUNet trainer and ``train_monai.py``. This is the + faithful end-to-end SynthSeg generator. +""" + +from __future__ import annotations + +import json +import os +from typing import Any, Dict, List, Optional + +import torch +from torch import nn +from kornia.core import Tensor + +from auglab.transforms.gpu.base import ImageOnlyTransform +from auglab.transforms.synthseg.generator import SynthSegGenerator + +# Keys understood from the JSON config / kwargs, forwarded to SynthSegGenerator. +_GENERATOR_KEYS = { + "generation_labels", "output_labels", "n_neutral_labels", "generation_classes", + "n_channels", "prior_distributions", "prior_means", "prior_stds", + "flipping", "flip_axis", "scaling_bounds", "rotation_bounds", "shearing_bounds", + "translation_bounds", "nonlin_std", "nonlin_scale", "svf_integration_steps", + "bias_field_std", "bias_scale", "gamma_std", "clip", "normalise", + "randomise_res", "max_res_iso", "max_res_aniso", "data_res", "thickness", + "blur_range", "atlas_res", "output_shape", + "em_label_completion", "em_n_foreground_clusters", "em_background_clusters_range", + "em_background_label", "em_n_iters", "em_max_fit_voxels", "em_same_on_batch", + "apply_affine", "apply_nonlinear", "apply_bias_field", + "apply_intensity_augmentation", "apply_resolution", +} + + +def _filter_generator_kwargs(params: Dict[str, Any]) -> Dict[str, Any]: + return {k: v for k, v in params.items() if k in _GENERATOR_KEYS} + + +class RandomSynthSegGPU(ImageOnlyTransform): + """Replace the image with a SynthSeg GMM synthesis of ``params['seg']``. + + Intensity-only: GMM sampling -> bias field -> intensity augmentation -> + resolution randomisation. Spatial deformation / flipping are disabled so that + geometry stays consistent with the segmentation propagated by the surrounding + :class:`AugmentationSequentialCustom` (use AugLab's geometric transforms for + that, placed before this one). + + Args: + apply_to_channel: image channels to overwrite with the synthesis. + p: application probability (Kornia convention). + Remaining kwargs are forwarded to :class:`SynthSegGenerator` (e.g. + ``n_channels``, ``prior_means``, ``bias_field_std``, ``gamma_std``, + ``randomise_res``, ``max_res_iso`` ...). + """ + + def __init__( + self, + apply_to_channel: Optional[List[int]] = None, + same_on_batch: bool = False, + p: float = 0.5, + keepdim: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(p=p, same_on_batch=same_on_batch, keepdim=keepdim) + self.apply_to_channel = apply_to_channel if apply_to_channel is not None else [0] + gen_kwargs = _filter_generator_kwargs(kwargs) + # Intensity-only: never deform/flip internally (geometry comes from the + # surrounding sequential, which also transports the mask). + gen_kwargs.update(apply_affine=False, apply_nonlinear=False, flipping=False, output_shape=None) + self.generator = SynthSegGenerator(**gen_kwargs) + + @torch.no_grad() + def apply_transform( + self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None + ) -> Tensor: + seg = params.get("seg", None) + if seg is None: + return input + + # Pass the real image so EM label completion (if enabled) can cluster it. + synth, _ = self.generator(seg, image=input) # (B, n_channels, D, H, W) + synth = synth.to(device=input.device, dtype=input.dtype) + + for n, c in enumerate(self.apply_to_channel): + if c < 0 or c >= input.shape[1]: + continue + src = n if n < synth.shape[1] else synth.shape[1] - 1 + x = synth[:, src] + if torch.isnan(x).any() or torch.isinf(x).any(): + print(f"Warning nan: {self.__class__.__name__}", flush=True) + continue + input[:, c] = x + return input + + +class SynthSegTransformsGPU(nn.Module): + """Config-driven full SynthSeg brain generator. + + Mirrors :class:`AugTransformsGPU`: construct from a JSON path, move ``.to(device)``, + then call ``transforms(data, target)`` to obtain ``(synthetic_image, label)``. + The image input is ignored (SynthSeg synthesises purely from labels); ``target`` + supplies the label map (single-channel integer, one-hot, or ``(B, D, H, W)``). + + The JSON may either be a flat dict of :class:`SynthSegGenerator` parameters or + wrap them under a ``"SynthSeg"`` key, optionally with a top-level + ``"probability"`` controlling how often synthesis is applied (otherwise the + original ``(data, target)`` is returned unchanged). ``probability`` defaults + to ``1.0`` (always synthesise, as in the paper). + """ + + def __init__(self, json_path: Optional[str] = None, params: Optional[Dict[str, Any]] = None): + super().__init__() + if params is None: + if json_path is None: + raise ValueError("Provide either json_path or params.") + with open(os.path.join(json_path), "r") as f: + config = json.load(f) + else: + config = params + + if "SynthSeg" in config: + config = config["SynthSeg"] + + self.probability = float(config.get("probability", 1.0)) + self.return_onehot = bool(config.get("return_onehot", False)) + self.generator = SynthSegGenerator(**_filter_generator_kwargs(config)) + + @torch.no_grad() + def forward(self, data: Tensor, target: Tensor): + if self.probability < 1.0 and float(torch.rand((), device=data.device)) >= self.probability: + return data, target + + # ``data`` (the real image) is only consumed when em_label_completion is on. + image, labels = self.generator(target, image=data) + image = image.to(device=data.device, dtype=data.dtype) + + if self.return_onehot and target.dim() == 5 and target.shape[1] > 1: + labels = self._to_onehot(labels, target.shape[1]).to(dtype=target.dtype) + else: + labels = labels.to(dtype=target.dtype) + return image, labels + + @staticmethod + def _to_onehot(labels: Tensor, n_channels: int) -> Tensor: + # labels: (B, 1, D, H, W) with values in 1..n_channels for foreground. + B, _, D, H, W = labels.shape + onehot = torch.zeros(B, n_channels, D, H, W, device=labels.device) + for c in range(n_channels): + onehot[:, c] = (labels[:, 0] == (c + 1)).float() + return onehot + + +if __name__ == "__main__": + # Smoke test mirroring the generator's, exercising both wrappers. + torch.manual_seed(0) + device = "cuda" if torch.cuda.is_available() else "cpu" + + B, D, H, W = 2, 40, 44, 42 + zz, yy, xx = torch.meshgrid(torch.arange(D), torch.arange(H), torch.arange(W), indexing="ij") + r = ((zz - D / 2) ** 2 + (yy - H / 2) ** 2 + (xx - W / 2) ** 2).sqrt() + vol = torch.zeros(D, H, W, dtype=torch.long) + vol[r < 15] = 1 + vol[r < 8] = 2 + labels = vol.view(1, 1, D, H, W).repeat(B, 1, 1, 1, 1) + img = torch.randn(B, 1, D, H, W) + + # Full end-to-end driver + driver = SynthSegTransformsGPU(params={"generation_labels": [0, 1, 2], "n_channels": 1}).to(device) + out_img, out_lab = driver(img.to(device), labels.to(device)) + print("driver image", tuple(out_img.shape), "labels", tuple(out_lab.shape), + torch.unique(out_lab).tolist()) + assert out_img.shape[0] == B and not torch.isnan(out_img).any() + + # Intensity-only ImageOnlyTransform + t = RandomSynthSegGPU(generation_labels=[0, 1, 2], n_channels=1, p=1.0).to(device) + out = t.apply_transform(img.clone().to(device), {"seg": labels.to(device)}, {}) + print("imageonly", tuple(out.shape), "range", (round(float(out.min()), 3), round(float(out.max()), 3))) + assert out.shape == img.shape and not torch.isnan(out).any() + print("OK")