From 4a9889b64074bc6db8bfe7b846f122330a5412f6 Mon Sep 17 00:00:00 2001 From: Oleksandr Sanin Date: Wed, 17 Jun 2026 09:15:38 +0000 Subject: [PATCH] fix(RandGridDistortiond): only convert transform keys when skipping transform MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When `_do_transform` is False, `RandGridDistortiond.__call__` was calling `convert_to_tensor(d, ...)` on the entire input dict. This recursively converts *all* values — including non-image entries such as integers and scalars — into PyTorch tensors. The converted dict is then returned to the DataLoader which hands it to MONAI's `collate_meta_tensor_fn`. That collate path expects non-image entries to remain as their original Python types; receiving 0-d tensors instead triggers an `AttributeError: 'int' object has no attribute 'numel'` when the collate function iterates over what it believes to be a batch of tensors. Fix: iterate over `self.key_iterator(d)` and convert only those values, exactly as the transform loop further down in the same method already does. This matches the per-key pattern used in sibling transforms such as `RandAffined` and leaves unrelated dict entries unchanged. Also adds a regression test that verifies integer and string entries are preserved when the transform is skipped (prob=0.0). Closes #8604 Signed-off-by: Oleksandr Sanin --- monai/transforms/spatial/dictionary.py | 7 ++++--- tests/transforms/test_rand_grid_distortiond.py | 9 +++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 51ad0435fc..543d875c4e 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -2305,12 +2305,13 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc d = dict(data) self.randomize(None) if not self._do_transform: - out: dict[Hashable, torch.Tensor] = convert_to_tensor(d, track_meta=get_track_meta()) - return out + for key in self.key_iterator(d): + d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) + return d first_key: Hashable = self.first_key(d) if first_key == (): - out = convert_to_tensor(d, track_meta=get_track_meta()) + out: dict[Hashable, torch.Tensor] = convert_to_tensor(d, track_meta=get_track_meta()) return out if isinstance(d[first_key], MetaTensor) and d[first_key].pending_operations: # type: ignore warnings.warn(f"data['{first_key}'] has pending operations, transform may return incorrect results.") diff --git a/tests/transforms/test_rand_grid_distortiond.py b/tests/transforms/test_rand_grid_distortiond.py index 8f8de144f6..d1de93ea95 100644 --- a/tests/transforms/test_rand_grid_distortiond.py +++ b/tests/transforms/test_rand_grid_distortiond.py @@ -85,6 +85,15 @@ def test_rand_grid_distortiond(self, input_param, seed, input_data, expected_val assert_allclose(result["img"], expected_val_img, type_test=False, rtol=1e-4, atol=1e-4) assert_allclose(result["mask"], expected_val_mask, type_test=False, rtol=1e-4, atol=1e-4) + def test_no_transform_preserves_non_image_keys(self): + """Non-image dict entries must not be coerced to tensors when the transform is skipped.""" + img = np.indices([6, 6]).astype(np.float32) + data = {"img": img, "label": 42, "filename": "scan.nii"} + g = RandGridDistortiond(keys=["img"], prob=0.0) + result = g(data) + self.assertIsInstance(result["label"], int) + self.assertIsInstance(result["filename"], str) + if __name__ == "__main__": unittest.main()