Skip to content
16 changes: 9 additions & 7 deletions monai/auto3dseg/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,21 +468,23 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
"""
d: dict[Hashable, MetaTensor] = dict(data)
start = time.time()
if isinstance(d[self.image_key], (torch.Tensor, MetaTensor)) and d[self.image_key].device.type == "cuda":
using_cuda = True
else:
using_cuda = False
image_tensor = d[self.image_key]
label_tensor = d[self.label_key]
using_cuda = any(
isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda"
for t in (image_tensor, label_tensor)
)
restore_grad_state = torch.is_grad_enabled()
torch.set_grad_enabled(False)

ndas: list[MetaTensor] = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] # type: ignore
ndas_label: MetaTensor = d[self.label_key].astype(torch.int16) # (H,W,D)
ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore
ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D)
Comment on lines 471 to 481
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

rg -n "LabelStats|FgImageStats|using_cuda|empty_cache|get_label_ccp" -g "*test*.py" -i

Repository: Project-MONAI/MONAI

Length of output: 1005


🏁 Script executed:

# Also check the actual code at the specified lines
head -500 monai/auto3dseg/analyzer.py | tail -50

Repository: Project-MONAI/MONAI

Length of output: 2358


🏁 Script executed:

# Find the LabelStats class and its __call__ method
rg -n "class LabelStats|def __call__" monai/auto3dseg/analyzer.py

Repository: Project-MONAI/MONAI

Length of output: 662


🏁 Script executed:

sed -n '319,360p' tests/apps/test_auto3dseg.py

Repository: Project-MONAI/MONAI

Length of output: 2553


🏁 Script executed:

sed -n '341,370p' tests/apps/test_auto3dseg.py

Repository: Project-MONAI/MONAI

Length of output: 1890


🏁 Script executed:

sed -n '420,500p' tests/apps/test_auto3dseg.py

Repository: Project-MONAI/MONAI

Length of output: 3971


🏁 Script executed:

head -50 tests/apps/test_auto3dseg.py | grep -E "device|import"

Repository: Project-MONAI/MONAI

Length of output: 606


🏁 Script executed:

grep -n "^device\s*=" tests/apps/test_auto3dseg.py

Repository: Project-MONAI/MONAI

Length of output: 80


🏁 Script executed:

# Check the beginning of the test file for device setup
head -100 tests/apps/test_auto3dseg.py

Repository: Project-MONAI/MONAI

Length of output: 3483


🏁 Script executed:

grep -n "SIM_GPU_TEST_CASES\|@parameterized\|skip_if_no_cuda" tests/apps/test_auto3dseg.py

Repository: Project-MONAI/MONAI

Length of output: 359


🏁 Script executed:

# Check if there are GPU parametrized test methods
grep -B 5 -A 10 "test_label_stats.*gpu\|test_label_stats.*cuda" tests/apps/test_auto3dseg.py

Repository: Project-MONAI/MONAI

Length of output: 45


🏁 Script executed:

sed -n '214,280p' tests/apps/test_auto3dseg.py

Repository: Project-MONAI/MONAI

Length of output: 3116


🏁 Script executed:

# Check what the GPU test actually tests
grep -A 40 "@skip_if_no_cuda" tests/apps/test_auto3dseg.py | head -60

Repository: Project-MONAI/MONAI

Length of output: 1974


Add tests for mixed-device LabelStats handling.

Tests currently cover CPU-only and GPU-only scenarios but not mixed (image on CPU + label on CUDA, or vice versa). Add parameterized test cases to verify LabelStats handles these mixed-device scenarios correctly per the coding guidelines.

🤖 Prompt for AI Agents
In `@monai/auto3dseg/analyzer.py` around lines 471 - 478, Add parameterized unit
tests that cover mixed-device scenarios for LabelStats by creating cases where
image_tensor is on CPU and label_tensor is on CUDA and vice versa; instantiate
batches similar to how ndas and ndas_label are derived (use image_tensor,
label_tensor and MetaTensor/torch.Tensor on specific torch.device settings),
call the LabelStats code paths that consume ndas/ndas_label, and assert expected
statistics and that no device-related errors occur. Ensure tests toggle CUDA
availability with torch.cuda.is_available() guards, use explicit .to(device) on
tensors, and verify behavior matches the existing CPU-only and GPU-only
assertions so mixed-device handling is validated.


if ndas_label.shape != ndas[0].shape:
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")

nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas]
nda_foregrounds = [nda if nda.numel() > 0 else torch.Tensor([0]) for nda in nda_foregrounds]
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]

unique_label = unique(ndas_label)
if isinstance(ndas_label, (MetaTensor, torch.Tensor)):
Expand Down
Loading