diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 3dfbc2032c..de53108d1d 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -243,14 +243,19 @@ def sliding_window_inference( for idx in slice_range ] if sw_batch_size > 1: - win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device) + win_data = torch.cat([inputs[ensure_tuple(win_slice)] for win_slice in unravel_slice]).to(sw_device) if condition is not None: - win_condition = torch.cat([condition[win_slice] for win_slice in unravel_slice]).to(sw_device) + win_condition = torch.cat([condition[ensure_tuple(win_slice)] for win_slice in unravel_slice]).to( + sw_device + ) kwargs["condition"] = win_condition else: - win_data = inputs[unravel_slice[0]].to(sw_device) + s0 = unravel_slice[0] + s0_idx = ensure_tuple(s0) + + win_data = inputs[s0_idx].to(sw_device) if condition is not None: - win_condition = condition[unravel_slice[0]].to(sw_device) + win_condition = condition[s0_idx].to(sw_device) kwargs["condition"] = win_condition if with_coord: @@ -277,7 +282,7 @@ def sliding_window_inference( offset = s[buffer_dim + 2].start - c_start s[buffer_dim + 2] = slice(offset, offset + roi_size[buffer_dim]) s[0] = slice(0, 1) - sw_device_buffer[0][s] += p * w_t + sw_device_buffer[0][ensure_tuple(s)] += p * w_t b_i += len(unravel_slice) if b_i < b_slices[b_s][0]: continue @@ -308,10 +313,11 @@ def sliding_window_inference( o_slice[buffer_dim + 2] = slice(c_start, c_end) img_b = b_s // n_per_batch # image batch index o_slice[0] = slice(img_b, img_b + 1) + o_slice_idx = ensure_tuple(o_slice) if non_blocking: - output_image_list[0][o_slice].copy_(sw_device_buffer[0], non_blocking=non_blocking) + output_image_list[0][o_slice_idx].copy_(sw_device_buffer[0], non_blocking=non_blocking) else: - output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device) + output_image_list[0][o_slice_idx] += sw_device_buffer[0].to(device=device) else: sw_device_buffer[ss] *= w_t sw_device_buffer[ss] = sw_device_buffer[ss].to(device) @@ -387,7 +393,7 @@ def _compute_coords(coords, z_scale, out, patch): idx_zm[axis] = slice( int(original_idx[axis].start * z_scale[axis - 2]), int(original_idx[axis].stop * z_scale[axis - 2]) ) - out[idx_zm] += p + out[ensure_tuple(idx_zm)] += p def _get_scan_interval( diff --git a/tests/inferers/test_sliding_window_inference.py b/tests/inferers/test_sliding_window_inference.py index 8700c4fcd0..5a624c787f 100644 --- a/tests/inferers/test_sliding_window_inference.py +++ b/tests/inferers/test_sliding_window_inference.py @@ -20,6 +20,7 @@ from monai.data.utils import list_data_collate from monai.inferers import SlidingWindowInferer, SlidingWindowInfererAdapt, sliding_window_inference +from monai.inferers.utils import _compute_coords from monai.utils import optional_import from tests.test_utils import TEST_TORCH_AND_META_TENSORS, skip_if_no_cuda, test_is_quick @@ -704,6 +705,99 @@ def compute_dict(data, condition): for rr, _ in zip(result_dict, expected_dict): np.testing.assert_allclose(result_dict[rr].cpu().numpy(), expected_dict[rr], rtol=1e-4) + @parameterized.expand([(1,), (4,)]) + def test_conditioned_branches_and_buffered_parity(self, sw_batch_size): + """Validate conditioned parity between buffered and non-buffered flows. + + Args: + sw_batch_size (int): Sliding-window batch size. + + Returns: + None. + + Raises: + AssertionError: If device, conditioning alignment, or output parity checks fail. + """ + inputs = torch.arange(1 * 1 * 10 * 8, dtype=torch.float).reshape(1, 1, 10, 8) + condition = inputs + 100.0 + roi_shape = (4, 4) + + def compute(data, condition): + """Compute output for a conditioned patch. + + Args: + data (torch.Tensor): Input patch tensor. + condition (torch.Tensor): Conditioning patch tensor aligned to ``data``. + + Returns: + torch.Tensor: Element-wise ``data + condition``. + + Raises: + AssertionError: If device placement or conditioning alignment checks fail. + """ + self.assertEqual(data.device.type, "cpu") + self.assertEqual(condition.device.type, "cpu") + torch.testing.assert_close(condition - data, torch.full_like(data, 100.0)) + return data + condition + + # Non-buffered flow. + result_non_buffered = sliding_window_inference( + inputs, roi_shape, sw_batch_size, compute, overlap=0.5, mode="constant", condition=condition + ) + # Buffered flow; should match the non-buffered output. + result_buffered = sliding_window_inference( + inputs, + roi_shape, + sw_batch_size, + compute, + overlap=0.5, + mode="constant", + condition=condition, + buffer_steps=2, + buffer_dim=0, + ) + + expected = inputs + condition + torch.testing.assert_close(result_non_buffered, expected) + torch.testing.assert_close(result_buffered, expected) + torch.testing.assert_close(result_buffered, result_non_buffered) + + +class TestSlidingWindowUtils(unittest.TestCase): + """Tests for low-level sliding-window utility helpers. + + Args: + None. + + Returns: + None. + + Raises: + None. + """ + + def test_compute_coords_accepts_list_indices(self): + """Ensure ``_compute_coords`` handles list-based index containers. + + Args: + None. + + Returns: + None. + + Raises: + AssertionError: If computed output placement differs from expected placement. + """ + out = torch.zeros((1, 1, 12, 12), dtype=torch.float) + patch = torch.arange(16, dtype=torch.float).reshape(1, 1, 4, 4) + coords = [[slice(0, 1), slice(None), slice(1, 3), slice(2, 4)]] + + _compute_coords(coords=coords, z_scale=[2.0, 2.0], out=out, patch=patch) + + expected = torch.zeros_like(out) + expected[0, 0, 2:6, 4:8] = patch[0, 0] + torch.testing.assert_close(out, expected) + if __name__ == "__main__": unittest.main()