Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 72 additions & 41 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
batch: bool = False,
weight: Sequence[float] | float | int | torch.Tensor | None = None,
soft_label: bool = False,
ignore_index: int | None = None,
) -> None:
"""
Args:
Expand Down Expand Up @@ -101,7 +102,8 @@ def __init__(
The value/values should be no less than 0. Defaults to None.
soft_label: whether the target contains non-binary values (soft labels) or not.
If True a soft label formulation of the loss will be used.

ignore_index: if not None, specifies a target index that is ignored and does not contribute to
the input gradient. Defaults to None.
Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
Expand All @@ -123,6 +125,7 @@ def __init__(
self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)
self.batch = batch
self.ignore_index = ignore_index
weight = torch.as_tensor(weight) if weight is not None else None
self.register_buffer("class_weight", weight)
self.class_weight: None | torch.Tensor
Expand All @@ -140,7 +143,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].

Example:
>>> from monai.losses.dice import * # NOQA
>>> from monai.losses.dice import * # NOQA
>>> import torch
>>> from monai.losses.dice import DiceLoss
>>> B, C, H, W = 7, 5, 3, 2
Expand All @@ -164,6 +167,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if self.other_act is not None:
input = self.other_act(input)

# mask the ignore_index if specified, must be done before one_hot
mask: torch.Tensor | None = None
if self.ignore_index is not None:
mask = (target != self.ignore_index).float()

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
Expand All @@ -181,6 +189,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if target.shape != input.shape:
raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")

if mask is not None:
input = input * mask
target = target * mask

# reducing only spatial dimensions (not batch nor channels)
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
if self.batch:
Expand All @@ -204,11 +216,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes)
else:
if self.class_weight.shape[0] != num_of_classes:
raise ValueError(
"""the length of the `weight` sequence should be the same as the number of classes.
raise ValueError("""the length of the `weight` sequence should be the same as the number of classes.
If `include_background=False`, the weight should not include
the background category class 0."""
)
the background category class 0.""")
if self.class_weight.min() < 0:
raise ValueError("the value/values of the `weight` should be no less than 0.")
# apply class_weight to loss
Expand Down Expand Up @@ -280,6 +290,7 @@ def __init__(
smooth_dr: float = 1e-5,
batch: bool = False,
soft_label: bool = False,
ignore_index: int | None = None,
) -> None:
"""
Args:
Expand All @@ -305,6 +316,8 @@ def __init__(
If True, the class-weighted intersection and union areas are first summed across the batches.
soft_label: whether the target contains non-binary values (soft labels) or not.
If True a soft label formulation of the loss will be used.
ignore_index: if not None, specifies a target index that is ignored and does not contribute to
the input gradient.

Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
Expand All @@ -330,6 +343,7 @@ def __init__(
self.smooth_dr = float(smooth_dr)
self.batch = batch
self.soft_label = soft_label
self.ignore_index = ignore_index

def w_func(self, grnd):
if self.w_type == str(Weight.SIMPLE):
Expand Down Expand Up @@ -360,6 +374,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if self.other_act is not None:
input = self.other_act(input)

# Prepare mask before potential one-hot conversion
mask: torch.Tensor | None = None
if self.ignore_index is not None:
mask = (target != self.ignore_index).float()

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
Expand All @@ -370,14 +389,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.")
else:
# if skipping background, removing first channel
target = target[:, 1:]
input = input[:, 1:]

if target.shape != input.shape:
raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")

# reducing only spatial dimensions (not batch nor channels)
# Exclude ignored regions from calculations
if mask is not None:
input = input * mask
target = target * mask

reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
if self.batch:
reduce_axis = [0] + reduce_axis
Expand All @@ -404,12 +426,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
f: torch.Tensor = 1.0 - (numer / denom)

if self.reduction == LossReduction.MEAN.value:
f = torch.mean(f) # the batch and channel average
f = torch.mean(f)
elif self.reduction == LossReduction.SUM.value:
f = torch.sum(f) # sum over the batch and channel dims
f = torch.sum(f)
elif self.reduction == LossReduction.NONE.value:
# If we are not computing voxelwise loss components at least
# make sure a none reduction maintains a broadcastable shape
broadcast_shape = list(f.shape[0:2]) + [1] * (len(input.shape) - 2)
f = f.view(broadcast_shape)
else:
Expand Down Expand Up @@ -442,11 +462,12 @@ def __init__(
reduction: LossReduction | str = LossReduction.MEAN,
smooth_nr: float = 1e-5,
smooth_dr: float = 1e-5,
ignore_index: int | None = None,
) -> None:
"""
Args:
dist_matrix: 2d tensor or 2d numpy array; matrix of distances between the classes.
It must have dimension C x C where C is the number of classes.
It must have dimension C x C where C is the number of classes.
weighting_mode: {``"default"``, ``"GDL"``}
Specifies how to weight the class-specific sum of errors.
Default to ``"default"``.
Expand All @@ -466,27 +487,11 @@ def __init__(
- ``"sum"``: the output will be summed.
smooth_nr: a small constant added to the numerator to avoid zero.
smooth_dr: a small constant added to the denominator to avoid nan.
ignore_index: if not None, specifies a target index that is ignored and does not contribute to
the input gradient.

Raises:
ValueError: When ``dist_matrix`` is not a square matrix.

Example:
.. code-block:: python

import torch
import numpy as np
from monai.losses import GeneralizedWassersteinDiceLoss

# Example with 3 classes (including the background: label 0).
# The distance between the background class (label 0) and the other classes is the maximum, equal to 1.
# The distance between class 1 and class 2 is 0.5.
dist_mat = np.array([[0.0, 1.0, 1.0], [1.0, 0.0, 0.5], [1.0, 0.5, 0.0]], dtype=np.float32)
wass_loss = GeneralizedWassersteinDiceLoss(dist_matrix=dist_mat)

pred_score = torch.tensor([[1000, 0, 0], [0, 1000, 0], [0, 0, 1000]], dtype=torch.float32)
grnd = torch.tensor([0, 1, 2], dtype=torch.int64)
wass_loss(pred_score, grnd) # 0

"""
super().__init__(reduction=LossReduction(reduction).value)

Expand All @@ -505,13 +510,13 @@ def __init__(
self.num_classes = self.m.size(0)
self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)
self.ignore_index = ignore_index

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
input: the shape should be BNH[WD].
target: the shape should be BNH[WD].

"""
# Aggregate spatial dimensions
flat_input = input.reshape(input.size(0), input.size(1), -1)
Expand All @@ -523,18 +528,20 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
# Compute the Wasserstein distance map
wass_dist_map = self.wasserstein_distance_map(probs, flat_target)

# Apply masking for ignore_index
if self.ignore_index is not None:
mask = (flat_target != self.ignore_index).float()
wass_dist_map = wass_dist_map * mask

# Compute the values of alpha to use
alpha = self._compute_alpha_generalized_true_positives(flat_target)

# Compute the numerator and denominator of the generalized Wasserstein Dice loss
if self.alpha_mode == "GDL":
# use GDL-style alpha weights (i.e. normalize by the volume of each class)
# contrary to the original definition we also use alpha in the "generalized all error".
true_pos = self._compute_generalized_true_positive(alpha, flat_target, wass_dist_map)
denom = self._compute_denominator(alpha, flat_target, wass_dist_map)
else: # default: as in the original paper
# (i.e. alpha=1 for all foreground classes and 0 for the background).
# Compute the generalised number of true positives
true_pos = self._compute_generalized_true_positive(alpha, flat_target, wass_dist_map)
all_error = torch.sum(wass_dist_map, dim=1)
denom = 2 * true_pos + all_error
Expand All @@ -544,12 +551,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
wass_dice_loss: torch.Tensor = 1.0 - wass_dice

if self.reduction == LossReduction.MEAN.value:
wass_dice_loss = torch.mean(wass_dice_loss) # the batch and channel average
wass_dice_loss = torch.mean(wass_dice_loss)
elif self.reduction == LossReduction.SUM.value:
wass_dice_loss = torch.sum(wass_dice_loss) # sum over the batch and channel dims
wass_dice_loss = torch.sum(wass_dice_loss)
elif self.reduction == LossReduction.NONE.value:
# GWDL aggregates over classes internally, so wass_dice_loss has shape (B,)
pass
broadcast_shape = input.shape[0:2] + (1,) * (len(input.shape) - 2)
wass_dice_loss = wass_dice_loss.view(broadcast_shape)
else:
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')

Expand Down Expand Up @@ -674,6 +681,7 @@ def __init__(
lambda_dice: float = 1.0,
lambda_ce: float = 1.0,
label_smoothing: float = 0.0,
ignore_index: int | None = None,
) -> None:
"""
Args:
Expand Down Expand Up @@ -715,6 +723,8 @@ def __init__(
label_smoothing: a value in [0, 1] range. If > 0, the labels are smoothed
by the given factor to reduce overfitting.
Defaults to 0.0.
ignore_index: if not None, specifies a target index that is ignored and does not contribute to
the input gradient.

"""
super().__init__()
Expand All @@ -737,15 +747,22 @@ def __init__(
smooth_dr=smooth_dr,
batch=batch,
weight=dice_weight,
ignore_index=ignore_index,
)
self.cross_entropy = nn.CrossEntropyLoss(
weight=weight,
reduction=reduction,
label_smoothing=label_smoothing,
ignore_index=ignore_index if ignore_index is not None else -100,
)
self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction, label_smoothing=label_smoothing)
self.binary_cross_entropy = nn.BCEWithLogitsLoss(pos_weight=weight, reduction=reduction)
if lambda_dice < 0.0:
raise ValueError("lambda_dice should be no less than 0.0.")
if lambda_ce < 0.0:
raise ValueError("lambda_ce should be no less than 0.0.")
self.lambda_dice = lambda_dice
self.lambda_ce = lambda_ce
self.ignore_index = ignore_index

def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -801,7 +818,21 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
)

dice_loss = self.dice(input, target)
ce_loss = self.ce(input, target) if input.shape[1] != 1 else self.bce(input, target)

if input.shape[1] != 1:
# CrossEntropyLoss handles ignore_index natively
ce_loss = self.ce(input, target)
else:
# BCEWithLogitsLoss does not support ignore_index, handle manually
ce_loss = self.bce(input, target)
if self.ignore_index is not None:
mask = (target != self.ignore_index).float()
ce_loss = ce_loss * mask
if self.dice.reduction == "mean":
ce_loss = torch.mean(ce_loss)
elif self.dice.reduction == "sum":
ce_loss = torch.sum(ce_loss)

total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_ce * ce_loss

return total_loss
Expand Down
Loading
Loading