Add explicit spatial_ndim tracking to MetaTensor#8765
Add explicit spatial_ndim tracking to MetaTensor#8765aymuos15 wants to merge 3 commits intoProject-MONAI:devfrom
Conversation
…#6397) Fixes dimension-mismatch crashes when einops.rearrange() or other reshape operations change tensor ndim by decoupling spatial rank from tensor shape. - Add _spatial_ndim attribute to MetaObj, derived from affine in MetaTensor - Expose spatial_ndim property with getter/setter and validation - Sync spatial_ndim on affine assignment and propagate through collation - Update transforms to use spatial_ndim instead of ndim-1 heuristic - Add 18 new tests for spatial_ndim behavior Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
📝 WalkthroughWalkthroughAdds a spatial_ndim attribute and property to MetaTensor (and initializes _spatial_ndim on MetaObj), extends MetaTensor.new/init to accept spatial_ndim, and introduces get_spatial_ndim exported from monai.data. Affine handling, pixdim, pending-affine/rank logic, and lazy apply_pending are updated to respect spatial_ndim. Many transforms and utilities (collation, spatial/intensity/post/utility modules) now derive spatial dimensionality via get_spatial_ndim instead of tensor shape. Collate/decollate preserve spatial_ndim. Tests updated and a new test module verifies spatial_ndim behavior and propagation. Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/transforms/spatial/array.py (1)
3034-3039:⚠️ Potential issue | 🟠 MajorUse one dimension source in
GridDistortionto avoid index errors.Line [3038] now sizes
num_cellsfromget_spatial_ndim(img), but the loop still indexes overimg.shape[1:]. If those diverge,num_cells[dim_idx]can go out of range.💡 Proposed fix
- if len(img.shape) != len(distort_steps) + 1: + spatial_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] + if len(spatial_shape) != len(distort_steps): raise ValueError("the spatial size of `img` does not match with the length of `distort_steps`") all_ranges = [] - num_cells = ensure_tuple_rep(self.num_cells, get_spatial_ndim(img)) + num_cells = ensure_tuple_rep(self.num_cells, len(spatial_shape)) if isinstance(img, MetaTensor) and img.pending_operations: warnings.warn("MetaTensor img has pending operations, transform may return incorrect results.") - for dim_idx, dim_size in enumerate(img.shape[1:]): + for dim_idx, dim_size in enumerate(spatial_shape): dim_distort_steps = distort_steps[dim_idx]Also applies to: 3041-3045
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/transforms/spatial/array.py` around lines 3034 - 3039, The loop indexing in GridDistortion uses img.shape[1:] but num_cells is created from get_spatial_ndim(img), which can differ and cause out-of-range access; fix by deriving a single spatial_ndim = len(img.shape) - 1 (or compute once via get_spatial_ndim(img) and use that same value) and use this spatial_ndim both when building num_cells (ensure_tuple_rep(self.num_cells, spatial_ndim)) and when iterating over dimensions (iterate over range(spatial_ndim) or img.shape[1:spatial_ndim+1]) so the dimension counts always match; apply the same single-source change to the later loop handling lines 3041-3045 (the code that constructs all_ranges and indexes num_cells[dim_idx]).
🧹 Nitpick comments (1)
tests/transforms/utility/test_splitdim.py (1)
51-83: Add a 2D split case to cover clampedspatial_ndimbehavior.Please add a case like
(C, H, W)with default affine andSplitDim(dim=1, keepdim=False)expectingspatial_ndim == 1.As per coding guidelines, "Ensure new or modified definitions will be covered by existing or new unit tests."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/transforms/utility/test_splitdim.py` around lines 51 - 83, Add a new unit test in tests/transforms/utility/test_splitdim.py that constructs a 2D MetaTensor with shape (C, H, W) (e.g., MetaTensor(torch.randn(3, 8, 7))) using the default affine, then calls SplitDim(dim=1, keepdim=False) and asserts that for each returned MetaTensor item its spatial_ndim equals 1; name the test something like test_spatial_ndim_2d_split_clamped and use the same pattern as the existing tests (import torch, create MetaTensor, assert initial spatial_ndim, run SplitDim, loop items and assert item.spatial_ndim == 1).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@monai/data/utils.py`:
- Line 435: Validate that all items in batch share the same spatial_ndim before
setting collated.spatial_ndim: compute the set of values via something like
{getattr(item, "spatial_ndim", 3) for item in batch}, if the set has more than
one entry raise a ValueError (or TypeError) describing the differing values (and
optionally indices), otherwise assign the single value to collated.spatial_ndim;
reference the collated.spatial_ndim assignment and the batch items when
implementing the check.
In `@monai/transforms/lazy/functional.py`:
- Around line 259-263: The cumulative affine matrix produced by
affine_from_pending can have mismatched rank and must be deterministically
normalized to the exact (_rank + 1) square shape before resample; update the
code around symbols _rank, cumulative_xform, affine_from_pending and
to_affine_nd so that instead of only calling to_affine_nd when
cumulative_xform.shape[0] < _rank + 1 you always normalize/bind the affine to
the target rank (call to_affine_nd(_rank, cumulative_xform) whenever
cumulative_xform.shape != (_rank + 1, _rank + 1)), and apply the same change at
the other occurrence (the block around lines 288-289) so resample always
receives a correctly sized affine. Ensure the normalized matrix preserves the
original transformation semantics (cropping extra rows/columns or padding with
identity as to_affine_nd implements).
In `@monai/transforms/spatial/functional.py`:
- Around line 102-105: The computed spatial_rank can exceed the actual tensor
spatial dimensions causing affine/resample rank mismatches; after computing
spatial_rank (from get_spatial_ndim(img) and possibly overrode by spatial_size),
clamp it to the tensor's current spatial dims by setting spatial_rank =
min(spatial_rank, max(0, img.ndim - 1)) before calling to_affine_nd(src_affine);
update the block where spatial_rank is set (the code around spatial_rank,
get_spatial_ndim, ensure_tuple, spatial_size, and to_affine_nd) to apply this
clamp.
In `@monai/transforms/utility/array.py`:
- Around line 317-318: The code currently normalizes self.dim into dim and then
indexes img.shape[dim], which silently accepts out-of-range negative dims (e.g.,
-5 for ndim=4); update the validation before using dim: verify that self.dim is
within [-img.ndim, img.ndim-1] (or check normalized dim in [0, img.ndim-1]) and
raise a clear ValueError/IndexError if not, then compute dim = self.dim if
self.dim >= 0 else self.dim + img.ndim and continue to use n_out =
img.shape[dim]; reference the variables self.dim, dim, img.ndim, and n_out when
making the change.
- Around line 333-338: In SplitDim, the affine setter call re-syncs spatial_ndim
from the affine rank which can overwrite the clamped value that you intend to
adjust; before updating out.affine capture the current out.spatial_ndim (e.g.,
orig_spatial = out.spatial_ndim), perform the affine update (out.affine =
out.affine @ shift), then restore or adjust out.spatial_ndim based on
orig_spatial and self.keepdim (instead of relying on the affine setter to set
spatial_ndim), ensuring the decrement at the current code (when not
self.keepdim) is applied relative to the preserved baseline.
In `@tests/data/meta_tensor/test_spatial_ndim.py`:
- Around line 37-43: The "squeeze" entry in PRESERVATION_CASES currently ignores
the test input `t` by constructing a new MetaTensor and squeezing that; change
the case to operate on the provided input (e.g. replace the current lambda with
one that calls t.squeeze(1)) so the preservation check actually exercises the
operation on the test input. Update the tuple with the same name "squeeze" and
arity 2 but use lambda t: t.squeeze(1) (or an equivalent call that uses `t`) to
ensure the test validates preservation through the operation on the input under
test.
In `@tests/transforms/utility/test_splitdim.py`:
- Around line 58-60: The tests currently iterate over out and call
self.assertEqual(item.spatial_ndim, 2) guarded by isinstance(item, MetaTensor)
which can lead to vacuous passes if no MetaTensor is present; update each test
block (the loops over out at the three locations) to first assert that a
MetaTensor is present (e.g., self.assertTrue(any(isinstance(i, MetaTensor) for i
in out))) and then either filter meta_items = [i for i in out if isinstance(i,
MetaTensor)] and assert on each meta_item.spatial_ndim (e.g., for m in
meta_items: self.assertEqual(m.spatial_ndim, 2)) so the test fails when no
MetaTensor exists and only then checks spatial_ndim.
---
Outside diff comments:
In `@monai/transforms/spatial/array.py`:
- Around line 3034-3039: The loop indexing in GridDistortion uses img.shape[1:]
but num_cells is created from get_spatial_ndim(img), which can differ and cause
out-of-range access; fix by deriving a single spatial_ndim = len(img.shape) - 1
(or compute once via get_spatial_ndim(img) and use that same value) and use this
spatial_ndim both when building num_cells (ensure_tuple_rep(self.num_cells,
spatial_ndim)) and when iterating over dimensions (iterate over
range(spatial_ndim) or img.shape[1:spatial_ndim+1]) so the dimension counts
always match; apply the same single-source change to the later loop handling
lines 3041-3045 (the code that constructs all_ranges and indexes
num_cells[dim_idx]).
---
Nitpick comments:
In `@tests/transforms/utility/test_splitdim.py`:
- Around line 51-83: Add a new unit test in
tests/transforms/utility/test_splitdim.py that constructs a 2D MetaTensor with
shape (C, H, W) (e.g., MetaTensor(torch.randn(3, 8, 7))) using the default
affine, then calls SplitDim(dim=1, keepdim=False) and asserts that for each
returned MetaTensor item its spatial_ndim equals 1; name the test something like
test_spatial_ndim_2d_split_clamped and use the same pattern as the existing
tests (import torch, create MetaTensor, assert initial spatial_ndim, run
SplitDim, loop items and assert item.spatial_ndim == 1).
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (16)
monai/data/__init__.pymonai/data/meta_obj.pymonai/data/meta_tensor.pymonai/data/utils.pymonai/transforms/croppad/functional.pymonai/transforms/intensity/array.pymonai/transforms/inverse.pymonai/transforms/lazy/functional.pymonai/transforms/post/array.pymonai/transforms/spatial/array.pymonai/transforms/spatial/functional.pymonai/transforms/utility/array.pytests/data/meta_tensor/test_meta_tensor.pytests/data/meta_tensor/test_spatial_ndim.pytests/transforms/test_squeezedim.pytests/transforms/utility/test_splitdim.py
| collated.meta = default_collate(meta_dicts) | ||
| collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch] | ||
| collated.is_batch = True | ||
| collated.spatial_ndim = getattr(batch[0], "spatial_ndim", 3) # assumes uniform spatial_ndim |
There was a problem hiding this comment.
Validate spatial_ndim consistency across the batch before assigning.
Using batch[0] silently can propagate wrong rank metadata if items disagree, which can break later spatial ops.
Proposed fix
- collated.spatial_ndim = getattr(batch[0], "spatial_ndim", 3) # assumes uniform spatial_ndim
+ spatial_ndims = {getattr(i, "spatial_ndim", 3) for i in batch}
+ if len(spatial_ndims) != 1:
+ raise ValueError(f"Inconsistent spatial_ndim values in batch: {sorted(spatial_ndims)}.")
+ collated.spatial_ndim = spatial_ndims.pop()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| collated.spatial_ndim = getattr(batch[0], "spatial_ndim", 3) # assumes uniform spatial_ndim | |
| spatial_ndims = {getattr(i, "spatial_ndim", 3) for i in batch} | |
| if len(spatial_ndims) != 1: | |
| raise ValueError(f"Inconsistent spatial_ndim values in batch: {sorted(spatial_ndims)}.") | |
| collated.spatial_ndim = spatial_ndims.pop() |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@monai/data/utils.py` at line 435, Validate that all items in batch share the
same spatial_ndim before setting collated.spatial_ndim: compute the set of
values via something like {getattr(item, "spatial_ndim", 3) for item in batch},
if the set has more than one entry raise a ValueError (or TypeError) describing
the differing values (and optionally indices), otherwise assign the single value
to collated.spatial_ndim; reference the collated.spatial_ndim assignment and the
batch items when implementing the check.
There was a problem hiding this comment.
This would crash much before if that was the case, Unnecessary I feel
There was a problem hiding this comment.
@aymuos15, understood. You're right that collate_tensor_fn(batch) would fail earlier if there were shape or metadata incompatibilities. The validation would be redundant in practice.
| dim = self.dim if self.dim >= 0 else self.dim + img.ndim | ||
| n_out = img.shape[dim] |
There was a problem hiding this comment.
Validate normalized split-dimension bounds.
Line 317 can silently remap invalid negative dims (e.g., -5 with ndim=4) to a valid axis, producing wrong splits instead of failing fast.
Suggested fix
- dim = self.dim if self.dim >= 0 else self.dim + img.ndim
+ dim = self.dim
+ if dim < 0:
+ dim += img.ndim
+ if dim < 0 or dim >= img.ndim:
+ raise ValueError(f"dim must be within [-{img.ndim}, {img.ndim - 1}], got {self.dim}.")
n_out = img.shape[dim]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@monai/transforms/utility/array.py` around lines 317 - 318, The code currently
normalizes self.dim into dim and then indexes img.shape[dim], which silently
accepts out-of-range negative dims (e.g., -5 for ndim=4); update the validation
before using dim: verify that self.dim is within [-img.ndim, img.ndim-1] (or
check normalized dim in [0, img.ndim-1]) and raise a clear ValueError/IndexError
if not, then compute dim = self.dim if self.dim >= 0 else self.dim + img.ndim
and continue to use n_out = img.shape[dim]; reference the variables self.dim,
dim, img.ndim, and n_out when making the change.
There was a problem hiding this comment.
torch.split already raises a good error
There was a problem hiding this comment.
@aymuos15, you're absolutely right! The existing error handling is sufficient:
img.shape[dim]at line 318 will raise anIndexErrorifdimis out of bounds after normalizationtorch.splitandnp.splitalso validate the dimension parameter
The negative index normalization (self.dim + img.ndim) follows standard Python indexing conventions. Any truly invalid dimension will trigger an error at the shape access or split operation.
My suggested validation was unnecessary. Thanks for the clarification!
| ndim = len(out.affine) | ||
| shift = torch.eye(ndim, device=out.affine.device, dtype=out.affine.dtype) | ||
| shift[dim - 1, -1] = idx | ||
| out.affine = out.affine @ shift | ||
| if not self.keepdim: | ||
| out.spatial_ndim = max(1, out.spatial_ndim - 1) |
There was a problem hiding this comment.
Preserve spatial_ndim across affine updates in SplitDim.
Line 336 calls affine setter, which re-syncs spatial_ndim from affine rank; this can overwrite clamped values and make Line 338 decrement from the wrong baseline.
Suggested fix
if dim == 0: # don't update affine if channel dim
continue
+ prev_spatial_ndim = out.spatial_ndim
ndim = len(out.affine)
shift = torch.eye(ndim, device=out.affine.device, dtype=out.affine.dtype)
shift[dim - 1, -1] = idx
out.affine = out.affine @ shift
- if not self.keepdim:
- out.spatial_ndim = max(1, out.spatial_ndim - 1)
+ if not self.keepdim and 0 < dim <= prev_spatial_ndim:
+ out.spatial_ndim = max(1, prev_spatial_ndim - 1)
+ else:
+ out.spatial_ndim = prev_spatial_ndim📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| ndim = len(out.affine) | |
| shift = torch.eye(ndim, device=out.affine.device, dtype=out.affine.dtype) | |
| shift[dim - 1, -1] = idx | |
| out.affine = out.affine @ shift | |
| if not self.keepdim: | |
| out.spatial_ndim = max(1, out.spatial_ndim - 1) | |
| prev_spatial_ndim = out.spatial_ndim | |
| ndim = len(out.affine) | |
| shift = torch.eye(ndim, device=out.affine.device, dtype=out.affine.dtype) | |
| shift[dim - 1, -1] = idx | |
| out.affine = out.affine @ shift | |
| if not self.keepdim and 0 < dim <= prev_spatial_ndim: | |
| out.spatial_ndim = max(1, prev_spatial_ndim - 1) | |
| else: | |
| out.spatial_ndim = prev_spatial_ndim |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@monai/transforms/utility/array.py` around lines 333 - 338, In SplitDim, the
affine setter call re-syncs spatial_ndim from the affine rank which can
overwrite the clamped value that you intend to adjust; before updating
out.affine capture the current out.spatial_ndim (e.g., orig_spatial =
out.spatial_ndim), perform the affine update (out.affine = out.affine @ shift),
then restore or adjust out.spatial_ndim based on orig_spatial and self.keepdim
(instead of relying on the affine setter to set spatial_ndim), ensuring the
decrement at the current code (when not self.keepdim) is applied relative to the
preserved baseline.
There was a problem hiding this comment.
That scenario will not come up unless a major logic change in the affine updates
There was a problem hiding this comment.
@aymuos15, understood. Since the affine shape doesn't change during the matrix multiplication, the setter won't alter spatial_ndim before line 338 decrements it. The logic is sound.
| for item in out: | ||
| if isinstance(item, MetaTensor): | ||
| self.assertEqual(item.spatial_ndim, 2) |
There was a problem hiding this comment.
Avoid vacuous passes in new spatial_ndim tests.
These checks should assert MetaTensor explicitly before checking spatial_ndim.
Suggested fix
out = SplitDim(dim=1, keepdim=False)(arr)
for item in out:
- if isinstance(item, MetaTensor):
- self.assertEqual(item.spatial_ndim, 2)
+ self.assertIsInstance(item, MetaTensor)
+ self.assertEqual(item.spatial_ndim, 2)
@@
out = SplitDim(dim=-1, keepdim=False)(arr)
for item in out:
- if isinstance(item, MetaTensor):
- self.assertEqual(item.spatial_ndim, 2)
+ self.assertIsInstance(item, MetaTensor)
+ self.assertEqual(item.spatial_ndim, 2)
@@
out = SplitDim(dim=0, keepdim=False)(arr)
for item in out:
- if isinstance(item, MetaTensor):
- self.assertEqual(item.spatial_ndim, 2)
+ self.assertIsInstance(item, MetaTensor)
+ self.assertEqual(item.spatial_ndim, 2)Also applies to: 69-71, 80-82
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/transforms/utility/test_splitdim.py` around lines 58 - 60, The tests
currently iterate over out and call self.assertEqual(item.spatial_ndim, 2)
guarded by isinstance(item, MetaTensor) which can lead to vacuous passes if no
MetaTensor is present; update each test block (the loops over out at the three
locations) to first assert that a MetaTensor is present (e.g.,
self.assertTrue(any(isinstance(i, MetaTensor) for i in out))) and then either
filter meta_items = [i for i in out if isinstance(i, MetaTensor)] and assert on
each meta_item.spatial_ndim (e.g., for m in meta_items:
self.assertEqual(m.spatial_ndim, 2)) so the test fails when no MetaTensor exists
and only then checks spatial_ndim.
…/metatensor-einops
dc94d16 to
545184b
Compare
Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
545184b to
36cda68
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/data/meta_tensor/test_spatial_ndim.py (1)
47-57: Consider adding docstrings to parameterized test methods.Coding guidelines require docstrings for all definitions. Brief descriptions of what each test validates would improve maintainability.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/data/meta_tensor/test_spatial_ndim.py` around lines 47 - 57, Add brief docstrings to the two parameterized test methods so they comply with coding guidelines: add a one-line description to test_construction explaining it validates MetaTensor.spatial_ndim for various construction cases using CONSTRUCTION_CASES, and add a one-line description to test_preserved_through_op explaining it verifies spatial_ndim is preserved after applying ops from PRESERVATION_CASES (reference the test function names test_construction and test_preserved_through_op and the MetaTensor usage in the body to locate where to insert the docstrings).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/data/meta_tensor/test_spatial_ndim.py`:
- Around line 47-57: Add brief docstrings to the two parameterized test methods
so they comply with coding guidelines: add a one-line description to
test_construction explaining it validates MetaTensor.spatial_ndim for various
construction cases using CONSTRUCTION_CASES, and add a one-line description to
test_preserved_through_op explaining it verifies spatial_ndim is preserved after
applying ops from PRESERVATION_CASES (reference the test function names
test_construction and test_preserved_through_op and the MetaTensor usage in the
body to locate where to insert the docstrings).
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (3)
monai/transforms/spatial/functional.pytests/data/meta_tensor/test_spatial_ndim.pytests/transforms/utility/test_splitdim.py
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/transforms/utility/test_splitdim.py
- monai/transforms/spatial/functional.py
Summary
Fixes #6397
_spatial_ndim: intattribute toMetaTensorthat explicitly tracks the number of spatial dimensions, preventing dimension-mismatch crashes wheneinops.rearrange()or other reshape operations changendimcopy_meta_from(via__dict__copy) and is preserved through arbitrary torch operationsResize,Rotate,Zoom,Flip,Affine,SplitDim,AddCoordinateChannels, etc.) and lazy resampling to usespatial_ndiminstead of hardcoded 3Key design decisions
spatial_ndim = min(affine.shape[-1] - 1, ndim - 1)— clamped by actual tensor dimsspatial_ndim = affine.shape[-1] - 1— no clamping (user is explicit)peek_pending_affine: uses affine's inner matrix shape (fixes batched(1,4,4)case)spatial_resample:min(spatial_ndim, ndim - 1, 3)— adds ndim-1 constraint as safety netFiles changed (16)
monai/data/meta_obj.py,meta_tensor.py,utils.py,__init__.py— core MetaTensor changesmonai/transforms/— spatial, croppad, intensity, inverse, lazy, post, utility transforms updatedtests/data/meta_tensor/test_spatial_ndim.py— 18 new testsspatial_ndimassertionsTest plan
spatial_ndimproperty (construction, affine sync, propagation, einops reshape, transforms)