diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 72ee1f27ee..025dc30e61 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -68,6 +68,7 @@ def __init__( batch: bool = False, weight: Sequence[float] | float | int | torch.Tensor | None = None, soft_label: bool = False, + ignore_index: int | None = None, ) -> None: """ Args: @@ -101,7 +102,8 @@ def __init__( The value/values should be no less than 0. Defaults to None. soft_label: whether the target contains non-binary values (soft labels) or not. If True a soft label formulation of the loss will be used. - + ignore_index: if not None, specifies a target index that is ignored and does not contribute to + the input gradient. Defaults to None. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. @@ -123,6 +125,7 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch + self.ignore_index = ignore_index weight = torch.as_tensor(weight) if weight is not None else None self.register_buffer("class_weight", weight) self.class_weight: None | torch.Tensor @@ -140,7 +143,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. Example: - >>> from monai.losses.dice import * # NOQA + >>> from monai.losses.dice import * # NOQA >>> import torch >>> from monai.losses.dice import DiceLoss >>> B, C, H, W = 7, 5, 3, 2 @@ -164,6 +167,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self.other_act is not None: input = self.other_act(input) + # mask the ignore_index if specified, must be done before one_hot + mask: torch.Tensor | None = None + if self.ignore_index is not None: + mask = (target != self.ignore_index).float() + if self.to_onehot_y: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") @@ -181,6 +189,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape != input.shape: raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") + if mask is not None: + input = input * mask + target = target * mask + # reducing only spatial dimensions (not batch nor channels) reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: @@ -204,11 +216,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes) else: if self.class_weight.shape[0] != num_of_classes: - raise ValueError( - """the length of the `weight` sequence should be the same as the number of classes. + raise ValueError("""the length of the `weight` sequence should be the same as the number of classes. If `include_background=False`, the weight should not include - the background category class 0.""" - ) + the background category class 0.""") if self.class_weight.min() < 0: raise ValueError("the value/values of the `weight` should be no less than 0.") # apply class_weight to loss @@ -280,6 +290,7 @@ def __init__( smooth_dr: float = 1e-5, batch: bool = False, soft_label: bool = False, + ignore_index: int | None = None, ) -> None: """ Args: @@ -305,6 +316,8 @@ def __init__( If True, the class-weighted intersection and union areas are first summed across the batches. soft_label: whether the target contains non-binary values (soft labels) or not. If True a soft label formulation of the loss will be used. + ignore_index: if not None, specifies a target index that is ignored and does not contribute to + the input gradient. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -330,6 +343,7 @@ def __init__( self.smooth_dr = float(smooth_dr) self.batch = batch self.soft_label = soft_label + self.ignore_index = ignore_index def w_func(self, grnd): if self.w_type == str(Weight.SIMPLE): @@ -360,6 +374,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self.other_act is not None: input = self.other_act(input) + # Prepare mask before potential one-hot conversion + mask: torch.Tensor | None = None + if self.ignore_index is not None: + mask = (target != self.ignore_index).float() + if self.to_onehot_y: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") @@ -370,14 +389,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if n_pred_ch == 1: warnings.warn("single channel prediction, `include_background=False` ignored.") else: - # if skipping background, removing first channel target = target[:, 1:] input = input[:, 1:] if target.shape != input.shape: raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") - # reducing only spatial dimensions (not batch nor channels) + # Exclude ignored regions from calculations + if mask is not None: + input = input * mask + target = target * mask + reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: reduce_axis = [0] + reduce_axis @@ -404,12 +426,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: f: torch.Tensor = 1.0 - (numer / denom) if self.reduction == LossReduction.MEAN.value: - f = torch.mean(f) # the batch and channel average + f = torch.mean(f) elif self.reduction == LossReduction.SUM.value: - f = torch.sum(f) # sum over the batch and channel dims + f = torch.sum(f) elif self.reduction == LossReduction.NONE.value: - # If we are not computing voxelwise loss components at least - # make sure a none reduction maintains a broadcastable shape broadcast_shape = list(f.shape[0:2]) + [1] * (len(input.shape) - 2) f = f.view(broadcast_shape) else: @@ -442,11 +462,12 @@ def __init__( reduction: LossReduction | str = LossReduction.MEAN, smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, + ignore_index: int | None = None, ) -> None: """ Args: dist_matrix: 2d tensor or 2d numpy array; matrix of distances between the classes. - It must have dimension C x C where C is the number of classes. + It must have dimension C x C where C is the number of classes. weighting_mode: {``"default"``, ``"GDL"``} Specifies how to weight the class-specific sum of errors. Default to ``"default"``. @@ -466,27 +487,11 @@ def __init__( - ``"sum"``: the output will be summed. smooth_nr: a small constant added to the numerator to avoid zero. smooth_dr: a small constant added to the denominator to avoid nan. + ignore_index: if not None, specifies a target index that is ignored and does not contribute to + the input gradient. Raises: ValueError: When ``dist_matrix`` is not a square matrix. - - Example: - .. code-block:: python - - import torch - import numpy as np - from monai.losses import GeneralizedWassersteinDiceLoss - - # Example with 3 classes (including the background: label 0). - # The distance between the background class (label 0) and the other classes is the maximum, equal to 1. - # The distance between class 1 and class 2 is 0.5. - dist_mat = np.array([[0.0, 1.0, 1.0], [1.0, 0.0, 0.5], [1.0, 0.5, 0.0]], dtype=np.float32) - wass_loss = GeneralizedWassersteinDiceLoss(dist_matrix=dist_mat) - - pred_score = torch.tensor([[1000, 0, 0], [0, 1000, 0], [0, 0, 1000]], dtype=torch.float32) - grnd = torch.tensor([0, 1, 2], dtype=torch.int64) - wass_loss(pred_score, grnd) # 0 - """ super().__init__(reduction=LossReduction(reduction).value) @@ -505,13 +510,13 @@ def __init__( self.num_classes = self.m.size(0) self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) + self.ignore_index = ignore_index def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: input: the shape should be BNH[WD]. target: the shape should be BNH[WD]. - """ # Aggregate spatial dimensions flat_input = input.reshape(input.size(0), input.size(1), -1) @@ -523,18 +528,20 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # Compute the Wasserstein distance map wass_dist_map = self.wasserstein_distance_map(probs, flat_target) + # Apply masking for ignore_index + if self.ignore_index is not None: + mask = (flat_target != self.ignore_index).float() + wass_dist_map = wass_dist_map * mask + # Compute the values of alpha to use alpha = self._compute_alpha_generalized_true_positives(flat_target) # Compute the numerator and denominator of the generalized Wasserstein Dice loss if self.alpha_mode == "GDL": # use GDL-style alpha weights (i.e. normalize by the volume of each class) - # contrary to the original definition we also use alpha in the "generalized all error". true_pos = self._compute_generalized_true_positive(alpha, flat_target, wass_dist_map) denom = self._compute_denominator(alpha, flat_target, wass_dist_map) else: # default: as in the original paper - # (i.e. alpha=1 for all foreground classes and 0 for the background). - # Compute the generalised number of true positives true_pos = self._compute_generalized_true_positive(alpha, flat_target, wass_dist_map) all_error = torch.sum(wass_dist_map, dim=1) denom = 2 * true_pos + all_error @@ -544,12 +551,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: wass_dice_loss: torch.Tensor = 1.0 - wass_dice if self.reduction == LossReduction.MEAN.value: - wass_dice_loss = torch.mean(wass_dice_loss) # the batch and channel average + wass_dice_loss = torch.mean(wass_dice_loss) elif self.reduction == LossReduction.SUM.value: - wass_dice_loss = torch.sum(wass_dice_loss) # sum over the batch and channel dims + wass_dice_loss = torch.sum(wass_dice_loss) elif self.reduction == LossReduction.NONE.value: - # GWDL aggregates over classes internally, so wass_dice_loss has shape (B,) - pass + broadcast_shape = input.shape[0:2] + (1,) * (len(input.shape) - 2) + wass_dice_loss = wass_dice_loss.view(broadcast_shape) else: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') @@ -674,6 +681,7 @@ def __init__( lambda_dice: float = 1.0, lambda_ce: float = 1.0, label_smoothing: float = 0.0, + ignore_index: int | None = None, ) -> None: """ Args: @@ -715,6 +723,8 @@ def __init__( label_smoothing: a value in [0, 1] range. If > 0, the labels are smoothed by the given factor to reduce overfitting. Defaults to 0.0. + ignore_index: if not None, specifies a target index that is ignored and does not contribute to + the input gradient. """ super().__init__() @@ -737,8 +747,14 @@ def __init__( smooth_dr=smooth_dr, batch=batch, weight=dice_weight, + ignore_index=ignore_index, + ) + self.cross_entropy = nn.CrossEntropyLoss( + weight=weight, + reduction=reduction, + label_smoothing=label_smoothing, + ignore_index=ignore_index if ignore_index is not None else -100, ) - self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction, label_smoothing=label_smoothing) self.binary_cross_entropy = nn.BCEWithLogitsLoss(pos_weight=weight, reduction=reduction) if lambda_dice < 0.0: raise ValueError("lambda_dice should be no less than 0.0.") @@ -746,6 +762,7 @@ def __init__( raise ValueError("lambda_ce should be no less than 0.0.") self.lambda_dice = lambda_dice self.lambda_ce = lambda_ce + self.ignore_index = ignore_index def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -801,7 +818,21 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ) dice_loss = self.dice(input, target) - ce_loss = self.ce(input, target) if input.shape[1] != 1 else self.bce(input, target) + + if input.shape[1] != 1: + # CrossEntropyLoss handles ignore_index natively + ce_loss = self.ce(input, target) + else: + # BCEWithLogitsLoss does not support ignore_index, handle manually + ce_loss = self.bce(input, target) + if self.ignore_index is not None: + mask = (target != self.ignore_index).float() + ce_loss = ce_loss * mask + if self.dice.reduction == "mean": + ce_loss = torch.mean(ce_loss) + elif self.dice.reduction == "sum": + ce_loss = torch.sum(ce_loss) + total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_ce * ce_loss return total_loss diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index caa237fca8..e90911fec3 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -73,6 +73,7 @@ def __init__( weight: Sequence[float] | float | int | torch.Tensor | None = None, reduction: LossReduction | str = LossReduction.MEAN, use_softmax: bool = False, + ignore_index: int | None = None, ) -> None: """ Args: @@ -92,13 +93,12 @@ def __init__( The value/values should be no less than 0. Defaults to None. reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. - - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. - use_softmax: whether to use softmax to transform the original logits into probabilities. If True, softmax is used. If False, sigmoid is used. Defaults to False. + ignore_index: index of the class to ignore during calculation. Defaults to None. Example: >>> import torch @@ -114,6 +114,9 @@ def __init__( self.gamma = gamma self.weight = weight self.use_softmax = use_softmax + self.use_softmax = use_softmax + self.ignore_index = ignore_index + self.alpha: float | torch.Tensor | None if alpha is None: self.alpha = None @@ -126,37 +129,36 @@ def __init__( self.class_weight: None | torch.Tensor def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Args: - input: the shape should be BNH[WD], where N is the number of classes. - The input should be the original logits since it will be transformed by - a sigmoid/softmax in the forward function. - target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. - - Raises: - ValueError: When input and target (after one hot transform if set) - have different shapes. - ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. - ValueError: When ``self.weight`` is a sequence and the length is not equal to the - number of classes. - ValueError: When ``self.weight`` is/contains a value that is less than 0. - - """ n_pred_ch = input.shape[1] if self.to_onehot_y: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: + original_target = target target = one_hot(target, num_classes=n_pred_ch) + mask = None + if self.ignore_index is not None: + if self.to_onehot_y: + # spatial mask: (B, 1, H, W) + mask = (original_target != self.ignore_index).to(input.dtype).unsqueeze(1) + elif target.shape[1] == 1: + mask = (target != self.ignore_index).to(input.dtype) + else: + # multi-class one-hot target + mask = (1.0 - target[:, self.ignore_index : self.ignore_index + 1]).to(input.dtype) + if not self.include_background: if n_pred_ch == 1: warnings.warn("single channel prediction, `include_background=False` ignored.") else: - # if skipping background, removing first channel target = target[:, 1:] input = input[:, 1:] + if mask is not None: + _mask: torch.Tensor = mask + if _mask.shape[1] > 1: + mask = _mask[:, 1:] if target.shape != input.shape: raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") @@ -165,10 +167,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: input = input.float() target = target.float() alpha_arg = self.alpha + if self.use_softmax: if not self.include_background and self.alpha is not None: if isinstance(self.alpha, (float, int)): alpha_arg = None + # Move the warning INSIDE this block warnings.warn( "`include_background=False`, scalar `alpha` ignored when using softmax.", stacklevel=2 ) @@ -176,41 +180,47 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: else: loss = sigmoid_focal_loss(input, target, self.gamma, alpha_arg) - num_of_classes = target.shape[1] - if self.class_weight is not None and num_of_classes != 1: - # make sure the lengths of weights are equal to the number of classes - if self.class_weight.ndim == 0: - self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes) - else: - if self.class_weight.shape[0] != num_of_classes: + if mask is not None: + loss = loss * mask + + if self.class_weight is not None: + cw = torch.as_tensor(self.class_weight, device=loss.device, dtype=loss.dtype) + num_classes = loss.shape[1] + + if cw.ndim > 0: + if num_classes == 1: + raise ValueError("Per-class class_weight is not supported for single-channel outputs.") + if cw.numel() != num_classes: raise ValueError( - """the length of the `weight` sequence should be the same as the number of classes. - If `include_background=False`, the weight should not include - the background category class 0.""" + f"The number of class_weight ({cw.numel()}) must match the number of " + f"output channels ({num_classes})." ) - if self.class_weight.min() < 0: - raise ValueError("the value/values of the `weight` should be no less than 0.") - # apply class_weight to loss - self.class_weight = self.class_weight.to(loss) - broadcast_dims = [-1] + [1] * len(target.shape[2:]) - self.class_weight = self.class_weight.view(broadcast_dims) - loss = self.class_weight * loss + if (cw < 0).any(): + raise ValueError("class_weight values must be non-negative.") + else: + if cw < 0: + raise ValueError("class_weight values must be non-negative.") + + if cw.ndim == 0: + loss = loss * cw + else: + broadcast_shape = [1, num_classes] + [1] * (loss.ndim - 2) + loss = loss * cw.view(broadcast_shape) if self.reduction == LossReduction.SUM.value: - # Previously there was a mean over the last dimension, which did not - # return a compatible BCE loss. To maintain backwards compatible - # behavior we have a flag that performs this extra step, disable or - # parameterize if necessary. (Or justify why the mean should be there) - average_spatial_dims = True - if average_spatial_dims: - loss = loss.mean(dim=list(range(2, len(target.shape)))) loss = loss.sum() + elif self.reduction == LossReduction.MEAN.value: - loss = loss.mean() + if mask is not None: + # Ensure we only sum the loss where the mask is 1 + # Then divide by the actual number of 1s in the mask + loss = (loss * mask).sum() / mask.sum().clamp(min=1e-5) + else: + loss = loss.mean() + elif self.reduction == LossReduction.NONE.value: pass - else: - raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + return loss @@ -287,14 +297,12 @@ def sigmoid_focal_loss( raise ValueError( f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})." ) - # Reshape alpha for broadcasting: (1, C, 1, 1...) - broadcast_dims = [-1] + [1] * len(target.shape[2:]) + broadcast_dims = [1, -1] + [1] * len(target.shape[2:]) alpha_t = alpha_t.view(broadcast_dims) + # Apply per-class weight only to positive samples - # For positive samples (target==1): multiply by alpha[c] - # For negative samples (target==0): keep weight as 1.0 alpha_factor = torch.where(target == 1, alpha_t, torch.ones_like(alpha_t)) + # This multiplication now works for both Scalar and Tensor cases loss = alpha_factor * loss - return loss diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index 154f34c526..f2c15954c0 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -51,6 +51,7 @@ def __init__( smooth_dr: float = 1e-5, batch: bool = False, soft_label: bool = False, + ignore_index: int | None = None, ) -> None: """ Args: @@ -77,6 +78,7 @@ def __init__( before any `reduction`. soft_label: whether the target contains non-binary values (soft labels) or not. If True a soft label formulation of the loss will be used. + ignore_index: index of the class to ignore during calculation. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -101,6 +103,7 @@ def __init__( self.smooth_dr = float(smooth_dr) self.batch = batch self.soft_label = soft_label + self.ignore_index = ignore_index def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -129,8 +132,21 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: + original_target = target target = one_hot(target, num_classes=n_pred_ch) + if self.ignore_index is not None: + mask_src = original_target if self.to_onehot_y and n_pred_ch > 1 else target + + if mask_src.shape[1] == 1: + mask = (mask_src != self.ignore_index).to(input.dtype) + else: + # Fallback for cases where target is already one-hot + mask = (1.0 - mask_src[:, self.ignore_index : self.ignore_index + 1]).to(input.dtype) + + input = input * mask + target = target * mask + if not self.include_background: if n_pred_ch == 1: warnings.warn("single channel prediction, `include_background=False` ignored.") diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 745513fec0..8ba2bce69c 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -39,6 +39,7 @@ def __init__( gamma: float = 0.75, epsilon: float = 1e-7, reduction: LossReduction | str = LossReduction.MEAN, + ignore_index: int | None = None, ) -> None: """ Args: @@ -46,12 +47,14 @@ def __init__( delta : weight of the background. Defaults to 0.7. gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. + ignore_index: class index to ignore from the loss computation. """ super().__init__(reduction=LossReduction(reduction).value) self.to_onehot_y = to_onehot_y self.delta = delta self.gamma = gamma self.epsilon = epsilon + self.ignore_index = ignore_index def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: n_pred_ch = y_pred.shape[1] @@ -65,22 +68,33 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if y_true.shape != y_pred.shape: raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") - # clip the prediction to avoid NaN + # Handle ignore_index: + mask = torch.ones_like(y_true) + if self.ignore_index is not None: + # Identify valid pixels: where at least one channel is 1 + spatial_mask = (torch.sum(y_true, dim=1, keepdim=True) > 0).float() + mask = spatial_mask.expand_as(y_true) + y_pred = y_pred * mask + y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) axis = list(range(2, len(y_pred.shape))) # Calculate true positives (tp), false negatives (fn) and false positives (fp) tp = torch.sum(y_true * y_pred, dim=axis) - fn = torch.sum(y_true * (1 - y_pred), dim=axis) - fp = torch.sum((1 - y_true) * y_pred, dim=axis) + fn = torch.sum(y_true * (1 - y_pred) * mask, dim=axis) + fp = torch.sum((1 - y_true) * y_pred * mask, dim=axis) dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon) # Calculate losses separately for each class, enhancing both classes back_dice = 1 - dice_class[:, 0] - fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma) + fore_dice = torch.pow(1 - dice_class[:, 1], 1 - self.gamma) # Average class scores - loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) + loss = torch.stack([back_dice, fore_dice], dim=-1) + if self.reduction == LossReduction.MEAN.value: + return torch.mean(loss) + if self.reduction == LossReduction.SUM.value: + return torch.sum(loss) return loss @@ -103,6 +117,7 @@ def __init__( gamma: float = 2, epsilon: float = 1e-7, reduction: LossReduction | str = LossReduction.MEAN, + ignore_index: int | None = None, ): """ Args: @@ -110,12 +125,14 @@ def __init__( delta : weight of the background. Defaults to 0.7. gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2. epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7. + ignore_index: class index to ignore from the loss computation. """ super().__init__(reduction=LossReduction(reduction).value) self.to_onehot_y = to_onehot_y self.delta = delta self.gamma = gamma self.epsilon = epsilon + self.ignore_index = ignore_index def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: n_pred_ch = y_pred.shape[1] @@ -123,6 +140,11 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if self.to_onehot_y: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + elif self.ignore_index is not None: + mask = (y_true != self.ignore_index).float() + y_true_clean = torch.where(y_true == self.ignore_index, 0, y_true) + y_true = one_hot(y_true_clean, num_classes=n_pred_ch) + y_true = y_true * mask else: y_true = one_hot(y_true, num_classes=n_pred_ch) @@ -132,13 +154,24 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) cross_entropy = -y_true * torch.log(y_pred) + if self.ignore_index is not None: + spatial_mask = (torch.sum(y_true, dim=1, keepdim=True) > 0).float() + cross_entropy = cross_entropy * spatial_mask + back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] back_ce = (1 - self.delta) * back_ce fore_ce = cross_entropy[:, 1] fore_ce = self.delta * fore_ce - loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1)) + loss = torch.stack([back_ce, fore_ce], dim=1) # [B, 2, H, W] + if self.reduction == LossReduction.MEAN.value: + if self.ignore_index is not None: + # Normalize by the number of non-ignored pixels + return loss.sum() / spatial_mask.sum().clamp(min=1e-5) + return loss.mean() + if self.reduction == LossReduction.SUM.value: + return loss.sum() return loss @@ -162,6 +195,7 @@ def __init__( gamma: float = 0.5, delta: float = 0.7, reduction: LossReduction | str = LossReduction.MEAN, + ignore_index: int | None = None, ): """ Args: @@ -170,8 +204,7 @@ def __init__( weight : weight for each loss function. Defaults to 0.5. gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5. delta : weight of the background. Defaults to 0.7. - - + ignore_index: class index to ignore from the loss computation. Example: >>> import torch @@ -187,10 +220,12 @@ def __init__( self.gamma = gamma self.delta = delta self.weight: float = weight - self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta) - self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta) + self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta, ignore_index=ignore_index) + self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss( + gamma=self.gamma, delta=self.delta, ignore_index=ignore_index + ) + self.ignore_index = ignore_index - # TODO: Implement this function to support multiple classes segmentation def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: """ Args: @@ -207,25 +242,32 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: ValueError: When num_classes ValueError: When the number of classes entered does not match the expected number """ - if y_pred.shape != y_true.shape: - raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") - if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}") + # Transform binary inputs to 2-channel space if y_pred.shape[1] == 1: - y_pred = one_hot(y_pred, num_classes=self.num_classes) - y_true = one_hot(y_true, num_classes=self.num_classes) - - if torch.max(y_true) != self.num_classes - 1: - raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}") + y_pred = torch.cat([1 - y_pred, y_pred], dim=1) - n_pred_ch = y_pred.shape[1] + # Move one_hot conversion OUTSIDE the if y_pred.shape[1] == 1 block if self.to_onehot_y: - if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + if self.ignore_index is not None: + mask = (y_true != self.ignore_index).float() + y_true_clean = torch.where(y_true == self.ignore_index, 0, y_true) + y_true = one_hot(y_true_clean, num_classes=self.num_classes) + # Keep the channel-wise mask + y_true = y_true * mask else: - y_true = one_hot(y_true, num_classes=n_pred_ch) + y_true = one_hot(y_true, num_classes=self.num_classes) + + # Check if shapes match + if y_true.shape[1] == 1 and y_pred.shape[1] == 2: + y_true = torch.cat([1 - y_true, y_true], dim=1) + if y_true.shape != y_pred.shape: + raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") + + if torch.max(y_true) != self.num_classes - 1: + raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}") asy_focal_loss = self.asy_focal_loss(y_pred, y_true) asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true) diff --git a/monai/metrics/confusion_matrix.py b/monai/metrics/confusion_matrix.py index 26ec823081..51c671c9d3 100644 --- a/monai/metrics/confusion_matrix.py +++ b/monai/metrics/confusion_matrix.py @@ -69,6 +69,7 @@ def __init__( compute_sample: bool = False, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, + ignore_index: int | None = None, ) -> None: super().__init__() self.include_background = include_background @@ -76,6 +77,7 @@ def __init__( self.compute_sample = compute_sample self.reduction = reduction self.get_not_nans = get_not_nans + self.ignore_index = ignore_index def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ @@ -96,7 +98,9 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor warnings.warn("As for classification task, compute_sample should be False.") self.compute_sample = False - return get_confusion_matrix(y_pred=y_pred, y=y, include_background=self.include_background) + return get_confusion_matrix( + y_pred=y_pred, y=y, include_background=self.include_background, ignore_index=self.ignore_index + ) def aggregate( self, compute_sample: bool = False, reduction: MetricReduction | str | None = None @@ -131,7 +135,9 @@ def aggregate( return results -def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True) -> torch.Tensor: +def get_confusion_matrix( + y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, ignore_index: int | None = None +) -> torch.Tensor: """ Compute confusion matrix. A tensor with the shape [BC4] will be returned. Where, the third dimension represents the number of true positive, false positive, true negative and false negative values for @@ -145,6 +151,9 @@ def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_backgrou The values should be binarized. include_background: whether to include metric computation on the first channel of the predicted output. Defaults to True. + ignore_index: index of the class to ignore during calculation. + If ignore_index < number of classes, that class channel is excluded + else ignored regions are inferred from spatial locations where all label channels are zero. Raises: ValueError: when `y_pred` and `y` have different shapes. @@ -158,17 +167,42 @@ def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_backgrou # get confusion matrix related metric batch_size, n_class = y_pred.shape[:2] + + # Create spatial mask if ignore_index is provided + mask = None + if ignore_index is not None: + if ignore_index >= n_class: + # If ignore_index is outside channel range (e.g. 255), we assume it's a spatial mask + mask = y.sum(dim=1, keepdim=True) > 0 + else: + # If ignore_index is a valid channel, exclude that specific channel + mask = 1.0 - y[:, ignore_index : ignore_index + 1] + # convert to [BNS], where S is the number of pixels for one sample. - # As for classification tasks, S equals to 1. y_pred = y_pred.reshape(batch_size, n_class, -1) y = y.reshape(batch_size, n_class, -1) + + if mask is not None: + mask = mask.reshape(batch_size, 1, -1) + y_pred = y_pred * mask + y = y * mask + tp = (y_pred + y) == 2 tn = (y_pred + y) == 0 + if mask is not None: + # When masking, TN must only count locations where the mask is 1 + tn = tn * mask.bool() + tp = tp.sum(dim=[2]).float() tn = tn.sum(dim=[2]).float() p = y.sum(dim=[2]).float() - n = y.shape[-1] - p + + if mask is not None: + # n is total valid pixels (per sample) minus the positives for that class + n = mask.reshape(batch_size, -1).sum(dim=1, keepdim=True) - p + else: + n = y.shape[-1] - p fn = p - tp fp = n - tn diff --git a/monai/metrics/generalized_dice.py b/monai/metrics/generalized_dice.py index 05eb94af48..1c3e72d4e6 100644 --- a/monai/metrics/generalized_dice.py +++ b/monai/metrics/generalized_dice.py @@ -13,7 +13,7 @@ import torch -from monai.metrics.utils import do_metric_reduction, ignore_background +from monai.metrics.utils import do_metric_reduction from monai.utils import MetricReduction, Weight, deprecated_arg, look_up_option from .metric import CumulativeIterationMetric @@ -41,6 +41,7 @@ class GeneralizedDiceScore(CumulativeIterationMetric): Old versions computed `mean` when `mean_batch` was provided due to bug in reduction. weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform ground truth volume into a weight factor. Defaults to ``"square"``. + ignore_index: class index to ignore from the metric computation. Raises: ValueError: When the `reduction` is not one of MetricReduction enum. @@ -51,11 +52,13 @@ def __init__( include_background: bool = True, reduction: MetricReduction | str = MetricReduction.MEAN, weight_type: Weight | str = Weight.SQUARE, + ignore_index: int | None = None, ) -> None: super().__init__() self.include_background = include_background self.reduction = look_up_option(reduction, MetricReduction) self.weight_type = look_up_option(weight_type, Weight) + self.ignore_index = ignore_index self.sum_over_classes = self.reduction in { MetricReduction.SUM, MetricReduction.MEAN, @@ -71,6 +74,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor y_pred (torch.Tensor): Binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions. y (torch.Tensor): Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`. + ignore_index: class index to ignore from the metric computation. Returns: torch.Tensor: Generalized Dice Score averaged across batch and class @@ -84,6 +88,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor include_background=self.include_background, weight_type=self.weight_type, sum_over_classes=self.sum_over_classes, + ignore_index=self.ignore_index, ) @deprecated_arg( @@ -118,6 +123,7 @@ def compute_generalized_dice( include_background: bool = True, weight_type: Weight | str = Weight.SQUARE, sum_over_classes: bool = False, + ignore_index: int | None = None, ) -> torch.Tensor: """ Computes the Generalized Dice Score and returns a tensor with its per image values. @@ -132,6 +138,7 @@ def compute_generalized_dice( weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform ground truth volume into a weight factor. Defaults to ``"square"``. sum_over_labels (bool): Whether to sum the numerator and denominator across all labels before the final computation. + ignore_index: class index to ignore from the metric computation. Returns: torch.Tensor: Per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes]. @@ -147,52 +154,73 @@ def compute_generalized_dice( if y.shape != y_pred.shape: raise ValueError(f"y_pred - {y_pred.shape} - and y - {y.shape} - should have the same shapes.") - # Ignore background, if needed + n_channels = y_pred.shape[1] + channels_to_use = list(range(n_channels)) + if not include_background: - y_pred, y = ignore_background(y_pred=y_pred, y=y) + channels_to_use.pop(0) + + if ignore_index is not None: + # If background was 0 and we ignore class 2, we need the correct absolute index + if ignore_index in channels_to_use: + channels_to_use.remove(ignore_index) + + if not channels_to_use: + return torch.zeros(y_pred.shape[0], 1, device=y_pred.device) # Reducing only spatial dimensions (not batch nor channels), compute the intersection and non-weighted denominator reduce_axis = list(range(2, y_pred.dim())) - intersection = torch.sum(y * y_pred, dim=reduce_axis) - y_o = torch.sum(y, dim=reduce_axis) - y_pred_o = torch.sum(y_pred, dim=reduce_axis) + y_o_full = torch.sum(y, dim=reduce_axis) # shape: (B, C) + intersection = torch.sum(y[:, channels_to_use, ...] * y_pred[:, channels_to_use, ...], dim=reduce_axis) + y_o = torch.sum(y[:, channels_to_use, ...], dim=reduce_axis) + y_pred_o = torch.sum(y_pred[:, channels_to_use, ...], dim=reduce_axis) + denominator = y_o + y_pred_o # Set the class weights weight_type = look_up_option(weight_type, Weight) + y_o_float = y_o_full.float() + if weight_type == Weight.SIMPLE: - w = torch.reciprocal(y_o.float()) + w_full = torch.reciprocal(y_o_float) elif weight_type == Weight.SQUARE: - w = torch.reciprocal(y_o.float() * y_o.float()) + w_full = torch.reciprocal(y_o_float * y_o_float) else: - w = torch.ones_like(y_o.float()) + w_full = torch.ones_like(y_o_float) + + w = w_full[:, channels_to_use] # Replace infinite values for non-appearing classes by the maximum weight - for b in w: - infs = torch.isinf(b) - b[infs] = 0 - b[infs] = torch.max(b) + for b_idx in range(w.shape[0]): + batch_w = w[b_idx] + infs = torch.isinf(batch_w) + if infs.any(): + batch_w[infs] = 0 + max_w = torch.max(batch_w) + batch_w[infs] = max_w if max_w > 0 else 1.0 - # Compute the weighted numerator and denominator, summing along the class axis when sum_over_classes is True if sum_over_classes: - numer = 2.0 * (intersection * w).sum(dim=1, keepdim=True) - denom = (denominator * w).sum(dim=1, keepdim=True) - y_pred_o = y_pred_o.sum(dim=-1, keepdim=True) + intersection = (intersection * w).sum(dim=1, keepdim=True) + denominator = (denominator * w).sum(dim=1, keepdim=True) + numer = 2.0 * intersection + denom = denominator else: numer = 2.0 * (intersection * w) denom = denominator * w - y_pred_o = y_pred_o # Compute the score - generalized_dice_score = numer / denom + generalized_dice_score = numer / (denom + 1e-6) - # Handle zero division. Where denom == 0 and the prediction volume is 0, score is 1. - # Where denom == 0 but the prediction volume is not 0, score is 0 + # Handle zero division denom_zeros = denom == 0 - generalized_dice_score[denom_zeros] = torch.where( - (y_pred_o == 0)[denom_zeros], - torch.tensor(1.0, device=generalized_dice_score.device), - torch.tensor(0.0, device=generalized_dice_score.device), - ) + if denom_zeros.any(): + if sum_over_classes: + generalized_dice_score[denom_zeros] = 1.0 + else: + generalized_dice_score[denom_zeros] = torch.where( + (y_pred_o * w)[denom_zeros] == 0, + torch.ones_like(generalized_dice_score[denom_zeros]), + torch.zeros_like(generalized_dice_score[denom_zeros]), + ) return generalized_dice_score diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index 1b83c93e5b..85cd589f03 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -51,6 +51,7 @@ class HausdorffDistanceMetric(CumulativeIterationMetric): ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric. + ignore_index: index of the class to ignore during calculation. Defaults to ``None``. """ @@ -62,6 +63,7 @@ def __init__( directed: bool = False, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, + ignore_index: int | None = None, ) -> None: super().__init__() self.include_background = include_background @@ -70,6 +72,7 @@ def __init__( self.directed = directed self.reduction = reduction self.get_not_nans = get_not_nans + self.ignore_index = ignore_index def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override] """ @@ -97,6 +100,12 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) if dims < 3: raise ValueError("y_pred should have at least three dimensions.") + mask = None + if self.ignore_index is not None: + mask = (y != self.ignore_index).all(dim=1, keepdim=True).float() + y_pred = y_pred * mask + y = y * mask + # compute (BxC) for each channel for each batch return compute_hausdorff_distance( y_pred=y_pred, @@ -106,6 +115,8 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) percentile=self.percentile, directed=self.directed, spacing=kwargs.get("spacing"), + ignore_index=self.ignore_index, + mask=mask, ) def aggregate( @@ -137,6 +148,8 @@ def compute_hausdorff_distance( percentile: float | None = None, directed: bool = False, spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None, + mask: torch.Tensor | None = None, + ignore_index: int | None = None, ) -> torch.Tensor: """ Compute the Hausdorff distance. @@ -162,6 +175,7 @@ def compute_hausdorff_distance( If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch, else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used for all images in batch. Defaults to ``None``. + ignore_index: index of the class to ignore during calculation. Defaults to ``None``. """ if not include_background: @@ -179,17 +193,35 @@ def compute_hausdorff_distance( spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim) for b, c in np.ndindex(batch_size, n_class): + yp = y_pred[b, c] + yt = y[b, c] + + if ignore_index is not None: + valid_mask = y[b].sum(dim=0) > 0 + yp = yp * valid_mask + yt = yt * valid_mask + + # if everything is ignored, define distance as 0 + if not valid_mask.any(): + hd[b, c] = torch.tensor(0.0, device=y_pred.device) + continue + _, distances, _ = get_edge_surface_distance( - y_pred[b, c], - y[b, c], + yp, + yt, distance_metric=distance_metric, spacing=spacing_list[b], symmetric=not directed, - class_index=c, + mask=mask[b, 0] if mask is not None else None, ) + + if len(distances) == 0: + hd[b, c] = torch.tensor(0.0, device=y_pred.device) + continue + percentile_distances = [_compute_percentile_hausdorff_distance(d, percentile) for d in distances] - max_distance = torch.max(torch.stack(percentile_distances)) - hd[b, c] = max_distance + + hd[b, c] = torch.max(torch.stack(percentile_distances)) return hd diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index fedd94fb93..d3553a2002 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -106,6 +106,7 @@ def __init__( ignore_empty: bool = True, num_classes: int | None = None, return_with_label: bool | list[str] = False, + ignore_index: int | None = None, ) -> None: super().__init__() self.include_background = include_background @@ -114,6 +115,7 @@ def __init__( self.ignore_empty = ignore_empty self.num_classes = num_classes self.return_with_label = return_with_label + self.ignore_index = ignore_index self.dice_helper = DiceHelper( include_background=self.include_background, reduction=MetricReduction.NONE, @@ -121,6 +123,7 @@ def __init__( apply_argmax=False, ignore_empty=self.ignore_empty, num_classes=self.num_classes, + ignore_index=self.ignore_index, ) def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] @@ -175,6 +178,7 @@ def compute_dice( include_background: bool = True, ignore_empty: bool = True, num_classes: int | None = None, + ignore_index: int | None = None, ) -> torch.Tensor: """ Computes Dice score metric for a batch of predictions. This performs the same computation as @@ -192,6 +196,7 @@ def compute_dice( num_classes: number of input channels (always including the background). When this is ``None``, ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are single-channel class indices and the number of classes is not automatically inferred from data. + ignore_index: index of the class to ignore during calculation. Returns: Dice scores per batch and per class, (shape: [batch_size, num_classes]). @@ -204,6 +209,7 @@ def compute_dice( apply_argmax=False, ignore_empty=ignore_empty, num_classes=num_classes, + ignore_index=ignore_index, )(y_pred=y_pred, y=y) @@ -262,6 +268,7 @@ def __init__( num_classes: int | None = None, sigmoid: bool | None = None, softmax: bool | None = None, + ignore_index: int | None = None, ) -> None: # handling deprecated arguments if sigmoid is not None: @@ -277,8 +284,9 @@ def __init__( self.activate = activate self.ignore_empty = ignore_empty self.num_classes = num_classes + self.ignore_index = ignore_index - def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: """ Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately for each batch item and for each channel of those items. @@ -286,7 +294,12 @@ def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor Args: y_pred: input predictions with shape HW[D]. y: ground truth with shape HW[D]. + mask: binary mask where 0 indicates voxels to ignore. """ + if mask is not None: + y_pred = y_pred * mask + y = y * mask + y_o = torch.sum(y) if y_o > 0: return (2.0 * torch.sum(torch.masked_select(y, y_pred))) / (y_o + torch.sum(y_pred)) @@ -322,6 +335,11 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl y_pred = torch.sigmoid(y_pred) y_pred = y_pred > 0.5 + # Create global mask for ignored voxels if ignore_index is set + mask = None + if self.ignore_index is not None: + mask = y != self.ignore_index + first_ch = 0 if self.include_background else 1 data = [] for b in range(y_pred.shape[0]): @@ -329,7 +347,11 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]: x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool() x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c] - c_list.append(self.compute_channel(x_pred, x)) + + # Extract the spatial mask for the current batch item + b_mask = mask[b, 0] if mask is not None else None + + c_list.append(self.compute_channel(x_pred, x, mask=b_mask)) data.append(torch.stack(c_list)) data = torch.stack(data, dim=0).contiguous() # type: ignore diff --git a/monai/metrics/meaniou.py b/monai/metrics/meaniou.py index 65c53f7aa5..069a8a3845 100644 --- a/monai/metrics/meaniou.py +++ b/monai/metrics/meaniou.py @@ -54,12 +54,14 @@ def __init__( reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, ignore_empty: bool = True, + ignore_index: int | None = None, ) -> None: super().__init__() self.include_background = include_background self.reduction = reduction self.get_not_nans = get_not_nans self.ignore_empty = ignore_empty + self.ignore_index = ignore_index def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ @@ -78,7 +80,11 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor raise ValueError(f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.") # compute IoU (BxC) for each channel for each batch return compute_iou( - y_pred=y_pred, y=y, include_background=self.include_background, ignore_empty=self.ignore_empty + y_pred=y_pred, + y=y, + include_background=self.include_background, + ignore_empty=self.ignore_empty, + ignore_index=self.ignore_index, ) def aggregate( @@ -103,7 +109,11 @@ def aggregate( def compute_iou( - y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, ignore_empty: bool = True + y_pred: torch.Tensor, + y: torch.Tensor, + include_background: bool = True, + ignore_empty: bool = True, + ignore_index: int | None = None, ) -> torch.Tensor: """Computes Intersection over Union (IoU) score metric from a batch of predictions. @@ -133,6 +143,13 @@ def compute_iou( if y.shape != y_pred.shape: raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") + if ignore_index is not None: + mask = (y != ignore_index).float() + if mask.shape != y_pred.shape: + mask = mask.expand_as(y_pred) + y_pred = y_pred * mask + y = torch.where(y == ignore_index, torch.tensor(0, device=y.device), y) + # reducing only spatial dimensions (not batch nor channels) n_len = len(y_pred.shape) reduce_axis = list(range(2, n_len)) diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py index b20b47a1a5..88712f8474 100644 --- a/monai/metrics/surface_dice.py +++ b/monai/metrics/surface_dice.py @@ -57,6 +57,7 @@ class SurfaceDiceMetric(CumulativeIterationMetric): If set to ``True``, the function `aggregate` will return both the aggregated NSD and the `not_nans` count. If set to ``False``, `aggregate` will only return the aggregated NSD. use_subvoxels: Whether to use subvoxel distances. Defaults to ``False``. + ignore_index: class index to ignore from the metric computation. """ def __init__( @@ -67,6 +68,7 @@ def __init__( reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, use_subvoxels: bool = False, + ignore_index: int | None = None, ) -> None: super().__init__() self.class_thresholds = class_thresholds @@ -75,6 +77,7 @@ def __init__( self.reduction = reduction self.get_not_nans = get_not_nans self.use_subvoxels = use_subvoxels + self.ignore_index = ignore_index def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override] r""" @@ -94,6 +97,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used for all images in batch. Defaults to ``None``. use_subvoxels: Whether to use subvoxel distances. Defaults to ``False``. + ignore_index: class index to ignore from the metric computation. Returns: @@ -108,6 +112,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) distance_metric=self.distance_metric, spacing=kwargs.get("spacing"), use_subvoxels=self.use_subvoxels, + ignore_index=self.ignore_index, ) def aggregate( @@ -142,6 +147,7 @@ def compute_surface_dice( distance_metric: str = "euclidean", spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None, use_subvoxels: bool = False, + ignore_index: int | None = None, ) -> torch.Tensor: r""" This function computes the (Normalized) Surface Dice (NSD) between the two tensors `y_pred` (referred to as @@ -199,6 +205,7 @@ def compute_surface_dice( else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used for all images in batch. Defaults to ``None``. use_subvoxels: Whether to use subvoxel distances. Defaults to ``False``. + ignore_index: class index to ignore from the metric computation. Raises: ValueError: If `y_pred` and/or `y` are not PyTorch tensors. @@ -213,6 +220,11 @@ def compute_surface_dice( Pytorch Tensor of shape [B,C], containing the NSD values :math:`\operatorname {NSD}_{b,c}` for each batch index :math:`b` and class :math:`c`. """ + if ignore_index is not None: + mask = (y != ignore_index).all(dim=1, keepdim=True).float() + + y_pred = y_pred * mask + y = y * mask if not include_background: y_pred, y = ignore_background(y_pred=y_pred, y=y) @@ -255,6 +267,7 @@ def compute_surface_dice( use_subvoxels=use_subvoxels, symmetric=True, class_index=c, + mask=mask[b, 0] if ignore_index is not None else None, ) boundary_correct: int | torch.Tensor | float boundary_complete: int | torch.Tensor | float diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index 3cb336d6a0..ef68c5c2c5 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -46,6 +46,7 @@ class SurfaceDistanceMetric(CumulativeIterationMetric): ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric. + ignore_index: class index to ignore from the metric computation. """ @@ -56,6 +57,7 @@ def __init__( distance_metric: str = "euclidean", reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, + ignore_index: int | None = None, ) -> None: super().__init__() self.include_background = include_background @@ -63,6 +65,7 @@ def __init__( self.symmetric = symmetric self.reduction = reduction self.get_not_nans = get_not_nans + self.ignore_index = ignore_index def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override] """ @@ -89,6 +92,13 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) if y_pred.dim() < 3: raise ValueError("y_pred should have at least three dimensions.") + mask = None + + if self.ignore_index is not None: + mask = (y != self.ignore_index).all(dim=1, keepdim=True).float() + y_pred = y_pred * mask + y = y * mask + # compute (BxC) for each channel for each batch return compute_average_surface_distance( y_pred=y_pred, @@ -97,6 +107,8 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) symmetric=self.symmetric, distance_metric=self.distance_metric, spacing=kwargs.get("spacing"), + mask=mask, + ignore_index=self.ignore_index, ) def aggregate( @@ -127,6 +139,8 @@ def compute_average_surface_distance( symmetric: bool = False, distance_metric: str = "euclidean", spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None, + mask: torch.Tensor | None = None, + ignore_index: int | None = None, ) -> torch.Tensor: """ This function is used to compute the Average Surface Distance from `y_pred` to `y` @@ -154,10 +168,12 @@ def compute_average_surface_distance( If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch, else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used for all images in batch. Defaults to ``None``. + ignore_index: class index to ignore from the metric computation. """ if not include_background: - y_pred, y = ignore_background(y_pred=y_pred, y=y) + if ignore_index != 0: + y_pred, y = ignore_background(y_pred=y_pred, y=y) y_pred = convert_data_type(y_pred, output_type=torch.Tensor, dtype=torch.float)[0] y = convert_data_type(y, output_type=torch.Tensor, dtype=torch.float)[0] @@ -172,15 +188,27 @@ def compute_average_surface_distance( spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim) for b, c in np.ndindex(batch_size, n_class): + yp = y_pred[b, c] + yt = y[b, c] + + if ignore_index is not None: + valid_mask = y[b].sum(dim=0) > 0 + yp = yp * valid_mask + yt = yt * valid_mask + _, distances, _ = get_edge_surface_distance( - y_pred[b, c], - y[b, c], + yp, + yt, distance_metric=distance_metric, spacing=spacing_list[b], symmetric=symmetric, class_index=c, + mask=mask[b, 0] if mask is not None else None, ) + surface_distance = torch.cat(distances) - asd[b, c] = torch.tensor(np.nan) if surface_distance.shape == (0,) else surface_distance.mean() + asd[b, c] = ( + torch.tensor(float("nan"), device=asd.device) if surface_distance.numel() == 0 else surface_distance.mean() + ) return convert_data_type(asd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0] diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index a451b1a770..abae52c1e8 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -41,6 +41,7 @@ __all__ = [ "ignore_background", + "ignore_index_mask", "do_metric_reduction", "get_mask_edges", "get_surface_distance", @@ -68,6 +69,27 @@ def ignore_background(y_pred: NdarrayTensor, y: NdarrayTensor) -> tuple[NdarrayT return y_pred, y +def ignore_index_mask( + y_pred: torch.Tensor, y: torch.Tensor, ignore_index: int | None = None +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Masks out the specified ignore_index from both predictions and ground truth. + This is a helper for #8667 to allow 'Ignore Class' functionality in metrics. + """ + if ignore_index is None: + return y_pred, y + + # Create a spatial mask (B, 1, H, W, [D]) + # Elements are 0 where target == ignore_index, else 1 + mask = (y != ignore_index).float() + + # Apply mask to zero out the ignored regions + y_pred = y_pred * mask + y = y * mask + + return y_pred, y + + def do_metric_reduction( f: torch.Tensor, reduction: MetricReduction | str = MetricReduction.MEAN ) -> tuple[torch.Tensor | Any, torch.Tensor]: @@ -143,6 +165,7 @@ def get_mask_edges( crop: bool = True, spacing: Sequence | None = None, always_return_as_numpy: bool = False, + ignore_index: int | None = None, ) -> tuple[NdarrayTensor, NdarrayTensor]: """ Compute edges from binary segmentation masks. This @@ -244,6 +267,7 @@ def get_surface_distance( seg_gt: NdarrayOrTensor, distance_metric: str = "euclidean", spacing: int | float | np.ndarray | Sequence[int | float] | None = None, + mask: NdarrayOrTensor | None = None, ) -> NdarrayOrTensor: """ This function is used to compute the surface distances from `seg_pred` to `seg_gt`. @@ -262,6 +286,7 @@ def get_surface_distance( (1) If a single number, isotropic spacing with that value is used. (2) If a sequence of numbers, the length of the sequence must be equal to the image dimensions. (3) If ``None``, spacing of unity is used. Defaults to ``None``. + mask: optional boolean mask. Pixels where mask is False will be ignored in the distance computation. Note: If seg_pred or seg_gt is all 0, may result in nan/inf distance. @@ -275,14 +300,17 @@ def get_surface_distance( dis = np.inf * lib.ones_like(seg_gt, dtype=lib.float32) dis = dis[seg_gt] return convert_to_dst_type(dis, seg_pred, dtype=dis.dtype)[0] + if distance_metric == "euclidean": dis = monai_distance_transform_edt((~seg_gt)[None, ...], sampling=spacing)[0] # type: ignore elif distance_metric in {"chessboard", "taxicab"}: dis = distance_transform_cdt(convert_to_numpy(~seg_gt), metric=distance_metric) else: raise ValueError(f"distance_metric {distance_metric} is not implemented.") + dis = convert_to_dst_type(dis, seg_pred, dtype=lib.float32)[0] - return dis[seg_pred] # type: ignore + out = dis[seg_pred.bool()] + return out if out is not None else dis.new_empty((0,)) def get_edge_surface_distance( @@ -293,6 +321,8 @@ def get_edge_surface_distance( use_subvoxels: bool = False, symmetric: bool = False, class_index: int = -1, + mask: torch.Tensor | None = None, + ignore_index: int | None = None, ) -> tuple[ tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor], @@ -312,6 +342,7 @@ def get_edge_surface_distance( This will return the areas of the edges. symmetric: whether to compute the surface distance from `y_pred` to `y` and from `y` to `y_pred`. class_index: The class-index used for context when warning about empty ground truth or prediction. + mask: optional boolean mask indicating valid pixels. Returns: (edges_pred, edges_gt), (distances_pred_to_gt, [distances_gt_to_pred]), (areas_pred, areas_gt) | tuple() @@ -320,19 +351,18 @@ def get_edge_surface_distance( edges_spacing = None if use_subvoxels: edges_spacing = spacing if spacing is not None else ([1] * len(y_pred.shape)) - (edges_pred, edges_gt, *areas) = get_mask_edges( - y_pred, y, crop=True, spacing=edges_spacing, always_return_as_numpy=False - ) - if not edges_gt.any(): - warnings.warn( - f"the ground truth of class {class_index if class_index != -1 else 'Unknown'} is all 0," - " this may result in nan/inf distance." - ) - if not edges_pred.any(): - warnings.warn( - f"the prediction of class {class_index if class_index != -1 else 'Unknown'} is all 0," - " this may result in nan/inf distance." - ) + + edge_results = get_mask_edges(y_pred, y, crop=True, spacing=edges_spacing, always_return_as_numpy=False) + edges_pred, edges_gt = edge_results[0], edge_results[1] + + if mask is not None: + if len(edge_results) > 2 and isinstance(edge_results[2], tuple): + slices = edge_results[2] + mask = mask[slices] + mask = mask.to(edges_pred.device).bool() + edges_pred = edges_pred & mask + edges_gt = edges_gt & mask + distances: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor] if symmetric: distances = ( @@ -341,7 +371,17 @@ def get_edge_surface_distance( ) # type: ignore else: distances = (get_surface_distance(edges_pred, edges_gt, distance_metric, spacing),) # type: ignore - return convert_to_tensor(((edges_pred, edges_gt), distances, tuple(areas)), device=y_pred.device) # type: ignore[no-any-return] + + distances = tuple(d if d is not None else edges_pred.new_empty((0,)) for d in distances) + + areas = edge_results[3:] if use_subvoxels else () + + out = convert_to_tensor(((edges_pred, edges_gt), distances, tuple(areas)), device=y_pred.device) # type: ignore[no-any-return] + + if out is None: + out = torch.empty((0,), device=y_pred.device) + + return out def is_binary_tensor(input: torch.Tensor, name: str) -> None: diff --git a/tests/losses/test_ignore_index_losses.py b/tests/losses/test_ignore_index_losses.py new file mode 100644 index 0000000000..b07ba5c98d --- /dev/null +++ b/tests/losses/test_ignore_index_losses.py @@ -0,0 +1,68 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.losses import AsymmetricUnifiedFocalLoss, DiceLoss, FocalLoss, TverskyLoss + +# Defining test cases: (LossClass, args) +TEST_CASES = [ + (DiceLoss, {"sigmoid": True}), + (FocalLoss, {"use_softmax": False}), + (TverskyLoss, {"sigmoid": True}), + (AsymmetricUnifiedFocalLoss, {}), +] + + +class TestIgnoreIndexLosses(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_loss_ignore_consistency(self, loss_class, kwargs): + ignore_index = 255 + loss_func = loss_class(ignore_index=ignore_index, **kwargs) + + # Create two inputs that are identical EXCEPT in the area designated as 'ignored' + # Input shape: [Batch, Channel, H, W] + input_base = torch.randn(1, 1, 4, 4) + input_alt = input_base.clone() + input_alt[0, 0, 2:, :] += 5.0 # Significant difference in the bottom half + + # Target: Top half is valid (0,1), Bottom half is ignored (255) + target = torch.tensor( + [[[[1, 0, 1, 0], [0, 1, 0, 1], [255, 255, 255, 255], [255, 255, 255, 255]]]], dtype=torch.float + ) + + # Execute + loss_base = loss_func(input_base, target) + loss_alt = loss_func(input_alt, target) + + # ASSERTION: The losses must be identical because the difference + # occurred only in the ignored region. + torch.testing.assert_close(loss_base, loss_alt, atol=1e-5, rtol=1e-5) + + @parameterized.expand(TEST_CASES) + def test_no_ignore_behavior(self, loss_class, kwargs): + # Ensure that when ignore_index is None, the loss functions normally + loss_func = loss_class(ignore_index=None, **kwargs) + input_data = torch.randn(1, 1, 4, 4) + target = torch.randint(0, 2, (1, 1, 4, 4)).float() + + output = loss_func(input_data, target) + self.assertFalse(torch.isnan(output)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/metrics/test_ignore_index_metrics.py b/tests/metrics/test_ignore_index_metrics.py new file mode 100644 index 0000000000..01144b71f7 --- /dev/null +++ b/tests/metrics/test_ignore_index_metrics.py @@ -0,0 +1,88 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.metrics import ( + ConfusionMatrixMetric, + DiceMetric, + GeneralizedDiceScore, + HausdorffDistanceMetric, + MeanIoU, + SurfaceDiceMetric, + SurfaceDistanceMetric, +) +from monai.utils import optional_import + +scipy, has_scipy = optional_import("scipy") + +# Test cases for metrics with their specific required arguments +TEST_METRICS = [ + (DiceMetric, {"include_background": True, "reduction": "mean"}), + (MeanIoU, {"include_background": True, "reduction": "mean"}), + (GeneralizedDiceScore, {"include_background": True}), + (ConfusionMatrixMetric, {"metric_name": "accuracy"}), +] + +# Metrics that require SciPy (Hausdorff and Surface metrics) +SCIPY_METRICS = [ + (HausdorffDistanceMetric, {"include_background": True}), + (SurfaceDistanceMetric, {"include_background": True}), + (SurfaceDiceMetric, {"class_thresholds": [0.5, 0.5], "include_background": True}), +] + + +@unittest.skipUnless(has_scipy, "Scipy required for surface metrics") +class TestIgnoreIndexMetrics(unittest.TestCase): + + @parameterized.expand(TEST_METRICS + SCIPY_METRICS) + def test_metric_ignore_consistency(self, metric_class, kwargs): + # Initialize metric with ignore_index + metric = metric_class(ignore_index=255, **kwargs) + + # Batch size 1, 2 Classes, 4x4 Image + # y_pred1 and y_pred2 differ ONLY in the bottom half (the ignore zone) + y_pred1 = torch.zeros((1, 2, 4, 4)) + y_pred1[:, 1, 0:2, :] = 1.0 # Top half prediction + + y_pred2 = y_pred1.clone() + y_pred2[:, 1, 2:4, :] = 1.0 # Bottom half prediction (different!) + + # Target: Top half is valid (0/1), Bottom half is 255 + y = torch.zeros((1, 2, 4, 4)) + y[:, 1, 0:2, 0:2] = 1.0 + y[:, :, 2:4, :] = 255 + + # Run metric for both predictions + metric.reset() + metric(y_pred=y_pred1, y=y) + res1 = metric.aggregate() + if isinstance(res1, list): + res1 = res1[0] + + metric.reset() + metric(y_pred=y_pred2, y=y) + res2 = metric.aggregate() + if isinstance(res2, list): + res2 = res2[0] + + # The result must be identical because the spatial difference + # is hidden by the ignore_index + torch.testing.assert_close(res1, res2, msg=f"Failed for {metric_class.__name__}") + + +if __name__ == "__main__": + unittest.main()