Skip to content

MalisLoss: pass eroded gt_seg through data pipeline (skip per-step CC, fix crop topology)#211

Merged
donglaiw merged 4 commits into
masterfrom
feat/malis-gt-passthrough
Jun 7, 2026
Merged

MalisLoss: pass eroded gt_seg through data pipeline (skip per-step CC, fix crop topology)#211
donglaiw merged 4 commits into
masterfrom
feat/malis-gt-passthrough

Conversation

@donglaiw

Copy link
Copy Markdown
Collaborator

Summary

Passes the eroded GT segmentation from the data pipeline through to MalisLoss so it can:

  1. Skip its internal connected_components_affgraph(gt_affs, nhood) call every training step (~5–10 % of MALIS step cost; small).
  2. Preserve global instance IDs under malis_crop_size (the primary motivation; fixes a correctness artifact in the prior crop PR).

Opt-in via label_transform.emit_gt_seg: true on the YAML side; default off → bit-for-bit identical to current behaviour for configs without it.

Why

With malis_crop_size enabled (#crop PR), the current MALIS path runs CC on the cropped gt_affs. When a single GT instance spans the crop boundary so that two of its pieces appear inside the crop but are not connected within that window, CC labels them as distinct components. MALIS then injects spurious negative-constraint edges inside one true instance.

Passing the eroded gt_seg from upstream (where global instance IDs are preserved) and cropping it with the same origin fixes this.

Speedup story (measured + documented)

Production config: MedNeXt-L, batch 2, 128³ patch on L40S.

Configuration it/s sec/step h/epoch (5000 steps) Speedup vs original
BCE only (no MALIS) ~0.71 ~1.4 ~1.95 — (reference)
Full-volume MALIS (original) ~0.17 ~5.9 ~7.3 1.0×
MALIS + malis_crop_size: 64 ~0.78 ~1.3 ~1.78 ~4.6×
MALIS + crop + emit_gt_seg: true ~0.78+ ~1.3 ~1.78 ~4.6× plus a few %
  • Crop alone gives ~4.6× speedup vs the original full-volume MALIS (slurm 2505814 vs 2487040).
  • emit_gt_seg adds a small additional speedup AND is primarily a correctness fix for the cropped case.

See docs/source/notes/malis.rst (new) and the MalisLoss class docstring for the full table.

Implementation

Data-pipeline boundary

  • LabelTransformConfig.emit_gt_seg: bool = False (strict dataclass field).
  • CopyItemsd(keys="label", names="gt_seg") inserted in connectomics/data/augmentation/build.py immediately after SegErosionInstanced (both train and val), gated by label_cfg.emit_gt_seg. Makes "post-augment, post-erode" the canonical snapshot point.
  • MultiTaskLabelTransformd is untouched.

Loss orchestrator

  • LossMetadata.gt_seg_arg: Optional[str]. Set to "gt_seg" for MalisLoss; all other losses unaffected.
  • compute_standard_loss(..., gt_seg=None) plumbs the batch's gt_seg to any term whose metadata declares the arg, via the existing extra_loss_kwargs extension point.
  • compute_deep_supervision_loss(..., gt_seg=None) forces gt_seg = None for every head — DS lower heads work on downsampled targets that gt_seg can't match label-correctly. MalisLoss + DS falls back to the legacy CC-recompute path. Test-pinned.

MalisLoss

  • forward(pred, target, mask=None, gt_seg=None) accepts the optional kwarg.
  • _prepare_gt_seg normalizes shape (accepts [B, Z, Y, X] and [B, 1, Z, Y, X]), validates against pred spatial dims.
  • _apply_crop_if_configured now returns a 4-tuple and crops gt_seg at the same origin as pred/target/mask.
  • _compute_malis_weights(..., *, gt_seg=None) uses the supplied seg per sample when provided; falls back to connected_components_affgraph otherwise.

YAML (tutorials/neuron_nisb/base_banis+_malis.yaml)

default:
  data:
    label_transform:
      emit_gt_seg: true   # opt-in; pairs with MalisLoss
  ...
            malis_crop_size: 64

Tests

python -m pytest tests/unit/test_malis_loss.py tests/unit/test_data_factory.py tests/unit/test_loss_orchestrator.py -q76 passed, 1 skipped.

  • test_malis_loss.py — 6 new tests:
    • Metadata declares gt_seg_arg.
    • Uncropped CC equivalence (gt_seg supplied vs reconstructed match within rtol=1e-5).
    • gt_seg=None strict-equality preservation of the legacy path.
    • Cropped-instance-fragmentation bug fix: Path A spies on connected_components_affgraph and asserts the CC labels fragment; Path B asserts CC was NOT called and the supplied gt_seg retains a single instance label; losses differ.
    • Shape validation (ValueError on mismatch).
    • No grad flows through gt_seg.
  • test_data_factory.py — 2 new tests for the CopyItemsd insertion (train + val symmetry, plus a runtime end-to-end equivalence to a manual erosion).
  • test_loss_orchestrator.py — 2 new tests pinning the standard-loss gt_seg plumbing and the DS=legacy fallback.

Docs

  • New docs/source/notes/malis.rst covering MALIS performance and correctness knobs (linked from docs/source/index.rst "Get Started" toctree).
  • MalisLoss class docstring extended with a Performance section citing the measured speedup.

Backward compatibility

  • Default behaviour (no emit_gt_seg) is bit-for-bit identical to current MalisLoss.
  • Configs without MALIS are completely unaffected.
  • Deep-supervision configs are unaffected (orchestrator forces legacy CC fallback when DS is on).

CCC design history

Under .agent/features/malis_gt_passthrough/ in the worktree (gitignored). Plan rounds: 2 (plan_v0 NEEDS_CHANGES → plan_v1 APPROVE_WITH_MINOR_COMMENTS). Code rounds: 1 (code_v0 → review_v0 APPROVE_WITH_MINOR_COMMENTS). 3 minor review_v0 findings were applied as small follow-up commits or kept as observational (see review_v0 artifact).

🤖 Generated with Claude Code

Donglai Wei and others added 4 commits May 27, 2026 23:16
…, fix crop topology)

Adds an opt-in `gt_seg` plumbing path that pipes the eroded GT
segmentation from the data pipeline through to MalisLoss, so MALIS
can skip its internal `connected_components_affgraph(gt_affs, nhood)`
call and use the supplied seg directly.

Primary motivation (correctness): with `malis_crop_size` enabled
(landed in commit 5b0451f), the current MALIS path runs CC on the
*cropped* gt_affs. When a single GT instance spans the crop boundary
so that two of its pieces appear inside the crop but are not
connected within that window, CC labels them as distinct components,
and MALIS then injects spurious negative-constraint edges inside one
true instance. Passing the eroded gt_seg from upstream (where global
instance IDs are preserved) and cropping it with the same origin
fixes this.

Secondary motivation (small speedup): removes the per-step CC inside
`_compute_malis_weights`. Roughly 5-10% of the remaining MALIS step
cost on top of cropping.

Pipeline boundary
- `LabelTransformConfig.emit_gt_seg: bool = False` (strict dataclass
  field, defaults to off).
- `MultiTaskLabelTransformd` is left untouched. Instead, a
  `CopyItemsd(keys="label", names="gt_seg")` MONAI transform is
  inserted in `data/augmentation/build.py` immediately after the
  existing `SegErosionInstanced` step (both train and val), gated by
  `label_cfg.emit_gt_seg`. This makes "post-augment, post-erode" the
  canonical snapshot point.
- Default behavior (no `emit_gt_seg`) is bit-for-bit identical to
  pre-PR `MalisLoss`.

Loss orchestrator
- New `LossMetadata.gt_seg_arg: Optional[str]`. Set to `"gt_seg"`
  for MalisLoss; all other losses unaffected.
- `compute_standard_loss(..., gt_seg=None)` plumbs the batch's
  `gt_seg` to any term whose metadata declares the arg, via the
  existing `extra_loss_kwargs` extension point.
- `compute_deep_supervision_loss(..., gt_seg=None)` forces
  `gt_seg=None` for every head (DS lower heads work on downsampled
  targets that gt_seg can't match label-correctly). MalisLoss + DS
  falls back to the legacy CC-recompute path. Pinned by a unit
  test.

MalisLoss
- `forward(pred, target, mask=None, gt_seg=None)` accepts the
  optional kwarg.
- `_prepare_gt_seg` normalizes shape to `[B, Z, Y, X]` (accepts
  both `[B, Z, Y, X]` and `[B, 1, Z, Y, X]`), validates against
  pred spatial dims.
- `_apply_crop_if_configured` now returns a 4-tuple and crops
  gt_seg at the same origin as pred/target/mask.
- `_compute_malis_weights(..., *, gt_seg=None)` uses the supplied
  seg per sample when provided; falls back to
  `connected_components_affgraph` otherwise.

Speedup story (measured + documented)

Production config: MedNeXt-L, batch 2, 128^3 patch on L40S.

| Config                              | it/s | hours/epoch (5000 steps) |
|-------------------------------------|------|--------------------------|
| BCE only (no MALIS)                 | ~0.71| ~1.95                    |
| Full-volume MALIS (original)        | ~0.17| ~7.3                     |
| MALIS + malis_crop_size=64          | ~0.78| ~1.78                    |
| MALIS + crop=64 + emit_gt_seg=true  | ~0.78+| ~1.78                   |

Crop alone gives ~4.6x speedup vs the original full-volume MALIS
(slurm 2505814 vs 2487040). emit_gt_seg adds a small additional
speedup on top, and is primarily a correctness fix for the cropped
case. See `docs/source/notes/malis.rst` and `MalisLoss` class
docstring.

YAML opt-in (`tutorials/neuron_nisb/base_banis+_malis.yaml`)
- `default.data.label_transform.emit_gt_seg: true` enables the
  passthrough path.
- `malis_crop_size: 64` uncommented to enable the crop in
  production.

Tests
- `tests/unit/test_malis_loss.py` — 6 new tests covering metadata,
  uncropped CC equivalence, `gt_seg=None` legacy preservation, the
  cropped-instance-fragmentation bug fix (Path A spy asserts CC was
  called and labels fragment; Path B asserts CC was NOT called and
  the supplied gt_seg retains a single instance label), shape
  validation, and that no grad flows through gt_seg.
- `tests/unit/test_data_factory.py` — 2 new tests for the
  CopyItemsd insertion (train + val symmetry, and a runtime
  end-to-end equivalence to a manual erosion).
- `tests/unit/test_loss_orchestrator.py` — 2 new tests pinning
  the standard-loss gt_seg plumbing and the DS=legacy fallback.
- All existing test_malis_loss cases unchanged.
- `python -m pytest tests/unit/test_malis_loss.py tests/unit/test_data_factory.py tests/unit/test_loss_orchestrator.py -q`
  -> 76 passed, 1 skipped.

Docs
- New `docs/source/notes/malis.rst` covering MALIS speedup and
  correctness knobs (linked from `docs/source/index.rst` Get
  Started toctree).
- `MalisLoss` class docstring extended with a Performance section
  citing the measured speedup.

CCC design history under `.agent/features/malis_gt_passthrough/`
(gitignored).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
CI lint runs mypy (no_implicit_optional) on all changed connectomics
files. This PR touches build.py, pulling its latent
`keys: list[str] = None` signatures into mypy scope. Annotate them
`list[str] | None` to make the lint job green.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The macOS test job (the only platform that finishes before fail-fast
cancels the others) failed on three env-independent test failures.
cc3d cupy tests skip on CI (no cupy), so these were the real blockers:

1. test_training_step_uses_deep_supervision_branch — this PR adds a
   `gt_seg` kwarg to compute_deep_supervision_loss; update the test's
   fake to accept it.
2. test_connectomics_inference_public_api_snapshot — branch was behind
   master; restore `is_external_chunk_sharding_enabled` (present in
   inference.__all__ and in master's snapshot).
3. test_decode_waterz_reuses_agglomeration_graph_for_dust — stale on
   master: decode_waterz now calls dust_merge_from_region_graph, not
   waterz.merge_segments. Mock and assert on the actual call.

Verified: full suite 633 passed, 2 skipped (cupy file ignored — skips
on CI).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…solve split json)

PR #211's macOS / minimal-install CI lane is red on pre-existing failures
unrelated to the malis change: several optional deps (numba, zarr, optuna,
em_erl, waterz) and the ABISS `ws` binary are absent there, and config_io
did not resolve per-split `json` paths.

- config_io.resolve_data_paths: resolve per-split `json` like image/label/mask
  (fixes test_resolve_data_paths_resolves_test_json). Also drops an unused
  inference_cfg local and wraps a long message (flake8 F841/E501) so the file
  stays lint-clean now that it is in the diff.
- guard optional-dependency tests with importorskip/skipif so they skip rather
  than error when the dep is absent, matching the existing zarr/tifffile idiom:
  waterz (both decoder test modules), optuna trial-timeout tests, zarr aux-cache
  test, em_erl nerl/skeleton tests, numba threshold-sensitivity, and the ABISS
  relative-script test (guarded on lib/abiss/build/ws).

Verified: under a simulated minimal install the previously-failing tests now
skip; with the deps installed they still pass.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@donglaiw donglaiw merged commit 799afd5 into master Jun 7, 2026
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant