From 73b65ef93fd73ce91eaa8bfb60279d0523239c09 Mon Sep 17 00:00:00 2001 From: yang Date: Sun, 7 Jun 2026 22:34:49 -0700 Subject: [PATCH] Use centerline_extraction_3d_cuda in SoftclDiceLoss for both prob and target Signed-off-by: yang --- monai/losses/cldice.py | 53 +++++++++++++++++++++++++-- setup.cfg | 3 ++ tests/losses/test_cldice_loss.py | 62 ++++++++++++++++++++++++++++++++ 3 files changed, 115 insertions(+), 3 deletions(-) diff --git a/monai/losses/cldice.py b/monai/losses/cldice.py index 7d7e447c54..5d54e1a62a 100644 --- a/monai/losses/cldice.py +++ b/monai/losses/cldice.py @@ -20,9 +20,11 @@ from monai.losses.dice import DiceLoss from monai.networks import one_hot -from monai.utils import LossReduction +from monai.utils import LossReduction, optional_import from monai.utils.deprecate_utils import deprecated_arg +centerline_extraction_3d, _has_thinning = optional_import("centerline_extraction_3d_cuda") + def soft_erode(img: torch.Tensor) -> torch.Tensor: # type: ignore """ @@ -129,6 +131,8 @@ def __init__( softmax: bool = False, other_act: Callable | None = None, reduction: LossReduction | str = LossReduction.MEAN, + use_hard_target: bool = False, + use_hard_prob: bool = False, ) -> None: """ Args: @@ -151,6 +155,10 @@ def __init__( - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. + use_hard_target: if True, use the exact CUDA 3D binary thinning for the target skeleton instead of soft skeletonization. + Requires centerline_extraction_3d_cuda package and a CUDA 3D target. Defaults to False. + use_hard_prob: if True, use the CUDA 3D prob map thinning with backward for the prediction skeleton instead of soft skeletonization. + Requires centerline_extraction_3d_cuda package and a CUDA 3D input. Defaults to False. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -181,6 +189,8 @@ def __init__( self.sigmoid = sigmoid self.softmax = softmax self.other_act = other_act + self.use_hard_target = use_hard_target + self.use_hard_prob = use_hard_prob @deprecated_arg("y_pred", since="1.5", removed="1.8", new_name="input", msg_suffix="please use `input` instead.") @deprecated_arg("y_true", since="1.5", removed="1.8", new_name="target", msg_suffix="please use `target` instead.") @@ -193,6 +203,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Raises: AssertionError: When input and target (after one hot transform if set) have different shapes. + ValueError: When `use_hard_prob` or `use_hard_target` is enabled but the tensor is not 5D CUDA + or `centerline_extraction_3d_cuda` is unavailable. """ n_pred_ch = input.shape[1] @@ -225,8 +237,33 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape != input.shape: raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") - skel_pred = soft_skel(input, self.iter) - skel_true = soft_skel(target, self.iter) + if self.use_hard_prob: + if not (input.dim() == 5 and _has_thinning and input.is_cuda): + raise ValueError( + "use_hard_prob=True but conditions not met. " + "Requires 5D CUDA tensor and centerline_extraction_3d_cuda package." + ) + pred_mask = (input >= 0.5).to(torch.uint8).contiguous() + skel_pred = torch.zeros_like(input) + for b in range(input.shape[0]): + for c in range(input.shape[1]): + skel_pred[b, c] = centerline_extraction_3d.extract_centerline(pred_mask[b, c], input[b, c], 0) + else: + skel_pred = soft_skel(input, self.iter) + + if self.use_hard_target: + if not (target.dim() == 5 and _has_thinning and target.is_cuda): + raise ValueError( + "use_hard_target=True but conditions not met. " + "Requires 5D CUDA tensor and centerline_extraction_3d_cuda package." + ) + skel_true = (target > 0).to(torch.uint8).contiguous() + for b in range(target.shape[0]): + for c in range(target.shape[1]): + centerline_extraction_3d.binary_thinning(skel_true[b, c], 0) + skel_true = skel_true.to(target.dtype) + else: + skel_true = soft_skel(target, self.iter) # Compute per-batch clDice by reducing over channel and spatial dimensions # reduce_axis includes all dimensions except batch (dim 0) @@ -279,6 +316,8 @@ def __init__( softmax: bool = False, other_act: Callable | None = None, reduction: LossReduction | str = LossReduction.MEAN, + use_hard_target: bool = False, + use_hard_prob: bool = False, ) -> None: """ Args: @@ -304,6 +343,10 @@ def __init__( - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. + use_hard_target: if True, use the exact CUDA 3D binary thinning for the target skeleton instead of soft skeletonization. + Requires MONAI C++ extensions and a 3D target. Defaults to False. + use_hard_prob: if True, use the CUDA 3D prob map thinning with backward for the prediction skeleton instead of soft skeletonization. + Requires centerline_extraction_3d_cuda package and a CUDA 3D input. Defaults to False. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -336,6 +379,8 @@ def __init__( softmax=softmax, other_act=other_act, reduction=reduction, + use_hard_target=use_hard_target, + use_hard_prob=use_hard_prob, ) self.alpha = alpha self.to_onehot_y = to_onehot_y @@ -351,6 +396,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Raises: ValueError: When number of dimensions for input and target are different. ValueError: When number of channels for target is neither 1 nor the same as input. + ValueError: When `use_hard_prob` or `use_hard_target` is enabled but the tensor is not 5D CUDA + or `centerline_extraction_3d_cuda` is unavailable. """ if input.dim() != target.dim(): diff --git a/setup.cfg b/setup.cfg index d987141d0b..e472fe170f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -90,6 +90,7 @@ all = nvidia-ml-py huggingface_hub pyamg>=5.0.0, <5.3.0 + centerline_extraction_3d_cuda nibabel = nibabel ninja = @@ -179,6 +180,8 @@ huggingface_hub = huggingface_hub pyamg = pyamg>=5.0.0, <5.3.0 +centerline_extraction = + centerline_extraction_3d_cuda # segment-anything = # segment-anything @ git+https://github.com/facebookresearch/segment-anything@6fdee8f2727f4506cfbbe553e23b895e27956588#egg=segment-anything diff --git a/tests/losses/test_cldice_loss.py b/tests/losses/test_cldice_loss.py index cb17cb81ad..f2d07bf36c 100644 --- a/tests/losses/test_cldice_loss.py +++ b/tests/losses/test_cldice_loss.py @@ -85,6 +85,39 @@ def test_cuda(self): result = loss(ONES_2D["input"].cuda(), ONES_2D["target"].cuda()) np.testing.assert_allclose(result.detach().cpu().numpy(), 0.0, atol=1e-4) + @skip_if_no_cuda + def test_hard_target(self): + """Test SoftclDiceLoss with use_hard_target=True using binary thinning on 3D CUDA tensors.""" + # Skip if thinning not available + from monai.losses.cldice import _has_thinning + + if not _has_thinning: + self.skipTest("centerline_extraction_3d_cuda not available") + + loss = SoftclDiceLoss(use_hard_target=True) + # MUST BE 3D for hard target logic to trigger! (shape: B, N, H, W, D) + result = loss(ONES_3D["input"].cuda(), ONES_3D["target"].cuda()) + np.testing.assert_allclose(result.detach().cpu().numpy(), 0.0, atol=1e-4) + + @skip_if_no_cuda + def test_hard_prob(self): + """Test SoftclDiceLoss with use_hard_prob=True using prob thinning on 3D CUDA tensors.""" + # Skip if thinning not available + from monai.losses.cldice import _has_thinning + + if not _has_thinning: + self.skipTest("centerline_extraction_3d_cuda not available") + + loss = SoftclDiceLoss(use_hard_prob=True) + # MUST BE 3D for hard prob logic to trigger! (shape: B, N, H, W, D) + input_tensor = torch.ones_like(ONES_3D["input"]).cuda() + input_tensor.requires_grad = True + target = ONES_3D["target"].cuda() + result = loss(input_tensor, target) + np.testing.assert_allclose(result.detach().cpu().numpy(), 0.0, atol=1e-4) + result.backward() + self.assertIsNotNone(input_tensor.grad) + def test_reduction_shapes(self): input_tensor = torch.ones((4, 2, 8, 8)) target = torch.ones((4, 2, 8, 8)) @@ -128,6 +161,35 @@ def test_cuda(self): result = loss(ONES_2D["input"].cuda(), ONES_2D["target"].cuda()) np.testing.assert_allclose(result.detach().cpu().numpy(), 0.0, atol=1e-4) + @skip_if_no_cuda + def test_hard_target(self): + """Test SoftDiceclDiceLoss with use_hard_target=True.""" + from monai.losses.cldice import _has_thinning + + if not _has_thinning: + self.skipTest("centerline_extraction_3d_cuda not available") + + loss = SoftDiceclDiceLoss(use_hard_target=True) + result = loss(ONES_3D["input"].cuda(), ONES_3D["target"].cuda()) + np.testing.assert_allclose(result.detach().cpu().numpy(), 0.0, atol=1e-4) + + @skip_if_no_cuda + def test_hard_prob(self): + """Test SoftDiceclDiceLoss with use_hard_prob=True.""" + from monai.losses.cldice import _has_thinning + + if not _has_thinning: + self.skipTest("centerline_extraction_3d_cuda not available") + + loss = SoftDiceclDiceLoss(use_hard_prob=True) + input_tensor = torch.ones_like(ONES_3D["input"]).cuda() + input_tensor.requires_grad = True + target = ONES_3D["target"].cuda() + result = loss(input_tensor, target) + np.testing.assert_allclose(result.detach().cpu().numpy(), 0.0, atol=1e-4) + result.backward() + self.assertIsNotNone(input_tensor.grad) + def test_dimension_mismatch(self): loss = SoftDiceclDiceLoss() with self.assertRaises(ValueError):