Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions monai/inferers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,14 +243,19 @@ def sliding_window_inference(
for idx in slice_range
]
if sw_batch_size > 1:
win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
win_data = torch.cat([inputs[ensure_tuple(win_slice)] for win_slice in unravel_slice]).to(sw_device)
if condition is not None:
win_condition = torch.cat([condition[win_slice] for win_slice in unravel_slice]).to(sw_device)
win_condition = torch.cat([condition[ensure_tuple(win_slice)] for win_slice in unravel_slice]).to(
sw_device
)
kwargs["condition"] = win_condition
else:
win_data = inputs[unravel_slice[0]].to(sw_device)
s0 = unravel_slice[0]
s0_idx = ensure_tuple(s0)

win_data = inputs[s0_idx].to(sw_device)
if condition is not None:
win_condition = condition[unravel_slice[0]].to(sw_device)
win_condition = condition[s0_idx].to(sw_device)
kwargs["condition"] = win_condition

if with_coord:
Expand All @@ -277,7 +282,7 @@ def sliding_window_inference(
offset = s[buffer_dim + 2].start - c_start
s[buffer_dim + 2] = slice(offset, offset + roi_size[buffer_dim])
s[0] = slice(0, 1)
sw_device_buffer[0][s] += p * w_t
sw_device_buffer[0][ensure_tuple(s)] += p * w_t
b_i += len(unravel_slice)
if b_i < b_slices[b_s][0]:
continue
Expand Down Expand Up @@ -308,10 +313,11 @@ def sliding_window_inference(
o_slice[buffer_dim + 2] = slice(c_start, c_end)
img_b = b_s // n_per_batch # image batch index
o_slice[0] = slice(img_b, img_b + 1)
o_slice_idx = ensure_tuple(o_slice)
if non_blocking:
output_image_list[0][o_slice].copy_(sw_device_buffer[0], non_blocking=non_blocking)
output_image_list[0][o_slice_idx].copy_(sw_device_buffer[0], non_blocking=non_blocking)
else:
output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device)
output_image_list[0][o_slice_idx] += sw_device_buffer[0].to(device=device)
else:
sw_device_buffer[ss] *= w_t
sw_device_buffer[ss] = sw_device_buffer[ss].to(device)
Expand Down Expand Up @@ -387,7 +393,7 @@ def _compute_coords(coords, z_scale, out, patch):
idx_zm[axis] = slice(
int(original_idx[axis].start * z_scale[axis - 2]), int(original_idx[axis].stop * z_scale[axis - 2])
)
out[idx_zm] += p
out[ensure_tuple(idx_zm)] += p


def _get_scan_interval(
Expand Down
94 changes: 94 additions & 0 deletions tests/inferers/test_sliding_window_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from monai.data.utils import list_data_collate
from monai.inferers import SlidingWindowInferer, SlidingWindowInfererAdapt, sliding_window_inference
from monai.inferers.utils import _compute_coords
from monai.utils import optional_import
from tests.test_utils import TEST_TORCH_AND_META_TENSORS, skip_if_no_cuda, test_is_quick

Expand Down Expand Up @@ -704,6 +705,99 @@ def compute_dict(data, condition):
for rr, _ in zip(result_dict, expected_dict):
np.testing.assert_allclose(result_dict[rr].cpu().numpy(), expected_dict[rr], rtol=1e-4)

@parameterized.expand([(1,), (4,)])
def test_conditioned_branches_and_buffered_parity(self, sw_batch_size):
"""Validate conditioned parity between buffered and non-buffered flows.

Args:
sw_batch_size (int): Sliding-window batch size.

Returns:
None.

Raises:
AssertionError: If device, conditioning alignment, or output parity checks fail.
"""
inputs = torch.arange(1 * 1 * 10 * 8, dtype=torch.float).reshape(1, 1, 10, 8)
condition = inputs + 100.0
roi_shape = (4, 4)

def compute(data, condition):
"""Compute output for a conditioned patch.

Args:
data (torch.Tensor): Input patch tensor.
condition (torch.Tensor): Conditioning patch tensor aligned to ``data``.

Returns:
torch.Tensor: Element-wise ``data + condition``.

Raises:
AssertionError: If device placement or conditioning alignment checks fail.
"""
self.assertEqual(data.device.type, "cpu")
self.assertEqual(condition.device.type, "cpu")
torch.testing.assert_close(condition - data, torch.full_like(data, 100.0))
return data + condition

# Non-buffered flow.
result_non_buffered = sliding_window_inference(
inputs, roi_shape, sw_batch_size, compute, overlap=0.5, mode="constant", condition=condition
)
# Buffered flow; should match the non-buffered output.
result_buffered = sliding_window_inference(
inputs,
roi_shape,
sw_batch_size,
compute,
overlap=0.5,
mode="constant",
condition=condition,
buffer_steps=2,
buffer_dim=0,
)

expected = inputs + condition
torch.testing.assert_close(result_non_buffered, expected)
torch.testing.assert_close(result_buffered, expected)
torch.testing.assert_close(result_buffered, result_non_buffered)


class TestSlidingWindowUtils(unittest.TestCase):
"""Tests for low-level sliding-window utility helpers.

Args:
None.

Returns:
None.

Raises:
None.
"""

def test_compute_coords_accepts_list_indices(self):
"""Ensure ``_compute_coords`` handles list-based index containers.

Args:
None.

Returns:
None.

Raises:
AssertionError: If computed output placement differs from expected placement.
"""
out = torch.zeros((1, 1, 12, 12), dtype=torch.float)
patch = torch.arange(16, dtype=torch.float).reshape(1, 1, 4, 4)
coords = [[slice(0, 1), slice(None), slice(1, 3), slice(2, 4)]]

_compute_coords(coords=coords, z_scale=[2.0, 2.0], out=out, patch=patch)

expected = torch.zeros_like(out)
expected[0, 0, 2:6, 4:8] = patch[0, 0]
torch.testing.assert_close(out, expected)


if __name__ == "__main__":
unittest.main()
Loading