Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 101 additions & 170 deletions monai/losses/unified_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,221 +20,152 @@
from monai.utils import LossReduction


class AsymmetricFocalTverskyLoss(_Loss):
class AsymmetricUnifiedFocalLoss(_Loss):
"""
AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
AsymmetricUnifiedFocalLoss is a variant of Focal Loss that combines Asymmetric Focal Loss
and Asymmetric Focal Tversky Loss to handle imbalanced medical image segmentation.

Actually, it's only supported for binary image segmentation now.
It supports multi-class segmentation by treating channel 0 as background and
channels 1..N as foreground, applying asymmetric weighting controlled by `delta`.

Reimplementation of the Asymmetric Focal Tversky Loss described in:
Reimplementation of the Asymmetric Unified Focal Loss described in:

- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
Michael Yeung, Computerized Medical Imaging and Graphics

Example:
>>> import torch
>>> from monai.losses import AsymmetricUnifiedFocalLoss
>>> # B, C, H, W = 1, 3, 32, 32
>>> pred_logits = torch.randn(1, 3, 32, 32)
>>> # Ground truth indices (B, 1, H, W)
>>> grnd = torch.randint(0, 3, (1, 1, 32, 32))
>>> # Use softmax=True if input is logits
>>> loss_func = AsymmetricUnifiedFocalLoss(to_onehot_y=True, use_softmax=True)
>>> loss = loss_func(pred_logits, grnd)
"""

def __init__(
self,
weight: float = 0.5,
delta: float = 0.6,
gamma: float = 0.5,
include_background: bool = True,
to_onehot_y: bool = False,
delta: float = 0.7,
gamma: float = 0.75,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
use_softmax: bool = False,
epsilon: float = 1e-7,
) -> None:
"""
Args:
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
weight: The weighting factor between Asymmetric Focal Loss and Asymmetric Focal Tversky Loss.
Final Loss = weight * AFL + (1 - weight) * AFTL. Defaults to 0.5.
delta: The balancing factor controls the weight of background vs foreground classes.
Values > 0.5 give more weight to foreground (False Negatives). Defaults to 0.6.
gamma: The focal exponent. Higher values focus more on hard examples. Defaults to 0.5.
include_background: If False, channel index 0 (background category) is excluded from the loss calculation.
Defaults to True.
to_onehot_y: Whether to convert the label `target` into the one-hot format. Defaults to False.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
use_softmax: Whether to use softmax to transform the original logits into probabilities.
If True, softmax is used. If False, assumes input is already probabilities. Defaults to False.
epsilon: Small value to prevent division by zero or log(0). Defaults to 1e-7.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
if not 0 <= weight <= 1:
raise ValueError(f"weight must be in [0, 1], got {weight}")
if not 0 <= delta <= 1:
raise ValueError(f"delta must be in [0, 1], got {delta}")
if gamma <= 0:
raise ValueError(f"gamma must be > 0, got {gamma}")
self.weight = weight
self.delta = delta
self.gamma = gamma
self.include_background = include_background
self.to_onehot_y = to_onehot_y
self.use_softmax = use_softmax
self.epsilon = epsilon

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
n_pred_ch = y_pred.shape[1]

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)

if y_true.shape != y_pred.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

# clip the prediction to avoid NaN
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
axis = list(range(2, len(y_pred.shape)))

# Calculate true positives (tp), false negatives (fn) and false positives (fp)
tp = torch.sum(y_true * y_pred, dim=axis)
fn = torch.sum(y_true * (1 - y_pred), dim=axis)
fp = torch.sum((1 - y_true) * y_pred, dim=axis)
dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)

# Calculate losses separately for each class, enhancing both classes
back_dice = 1 - dice_class[:, 0]
fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma)

# Average class scores
loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
return loss


class AsymmetricFocalLoss(_Loss):
"""
AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.

Actually, it's only supported for binary image segmentation now.

Reimplementation of the Asymmetric Focal Loss described in:

- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
Michael Yeung, Computerized Medical Imaging and Graphics
"""

def __init__(
self,
to_onehot_y: bool = False,
delta: float = 0.7,
gamma: float = 2,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
):
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
input: the shape should be BNH[WD], where N is the number of classes.
target: the shape should be BNH[WD] or B1H[WD].

Raises:
ValueError: When input and target have incompatible shapes.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.delta = delta
self.gamma = gamma
self.epsilon = epsilon
if self.use_softmax:
input = torch.nn.functional.softmax(input, dim=1)

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
n_pred_ch = y_pred.shape[1]
n_pred_ch = input.shape[1]

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)

if y_true.shape != y_pred.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
cross_entropy = -y_true * torch.log(y_pred)

back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
back_ce = (1 - self.delta) * back_ce

fore_ce = cross_entropy[:, 1]
fore_ce = self.delta * fore_ce

loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1))
return loss
if target.shape[1] == 1:
target = one_hot(target, num_classes=n_pred_ch)

if target.shape != input.shape:
raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")

class AsymmetricUnifiedFocalLoss(_Loss):
"""
AsymmetricUnifiedFocalLoss is a variant of Focal Loss.
# Clip values for numerical stability
input = torch.clamp(input, self.epsilon, 1.0 - self.epsilon)

Actually, it's only supported for binary image segmentation now
# Part A: Asymmetric Focal Loss
# Cross Entropy: -target * log(input)
cross_entropy = -target * torch.log(input)

Reimplementation of the Asymmetric Unified Focal Tversky Loss described in:
# Background (Channel 0): (1 - delta) * (1 - p)^gamma * CE
back_ce = (1 - self.delta) * torch.pow(1 - input[:, 0:1], self.gamma) * cross_entropy[:, 0:1]

- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
Michael Yeung, Computerized Medical Imaging and Graphics
"""
# Foreground (Channel 1..N): delta * CE
fore_ce = self.delta * cross_entropy[:, 1:]

def __init__(
self,
to_onehot_y: bool = False,
num_classes: int = 2,
weight: float = 0.5,
gamma: float = 0.5,
delta: float = 0.7,
reduction: LossReduction | str = LossReduction.MEAN,
):
"""
Args:
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
num_classes : number of classes, it only supports 2 now. Defaults to 2.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
weight : weight for each loss function, if it's none it's 0.5. Defaults to None.

Example:
>>> import torch
>>> from monai.losses import AsymmetricUnifiedFocalLoss
>>> pred = torch.ones((1,1,32,32), dtype=torch.float32)
>>> grnd = torch.ones((1,1,32,32), dtype=torch.int64)
>>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True)
>>> fl(pred, grnd)
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.num_classes = num_classes
self.gamma = gamma
self.delta = delta
self.weight: float = weight
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)
# Combine
if self.include_background:
asy_focal_loss = torch.cat([back_ce, fore_ce], dim=1)
else:
asy_focal_loss = fore_ce

# TODO: Implement this function to support multiple classes segmentation
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Args:
y_pred : the shape should be BNH[WD], where N is the number of classes.
It only supports binary segmentation.
The input should be the original logits since it will be transformed by
a sigmoid in the forward function.
y_true : the shape should be BNH[WD], where N is the number of classes.
It only supports binary segmentation.
# Part B: Asymmetric Focal Tversky Loss
# Sum over spatial dimensions (Batch and Channel dims are preserved)
reduce_axis = list(range(2, input.dim()))

Raises:
ValueError: When input and target are different shape
ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5
ValueError: When num_classes
ValueError: When the number of classes entered does not match the expected number
"""
if y_pred.shape != y_true.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
tp = torch.sum(target * input, dim=reduce_axis)
fn = torch.sum(target * (1 - input), dim=reduce_axis)
fp = torch.sum((1 - target) * input, dim=reduce_axis)

if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")
# Tversky Index
dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)

if y_pred.shape[1] == 1:
y_pred = one_hot(y_pred, num_classes=self.num_classes)
y_true = one_hot(y_true, num_classes=self.num_classes)
# Background: 1 - Dice
back_dice_loss = 1 - dice_class[:, 0:1]

if torch.max(y_true) != self.num_classes - 1:
raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}")
# Foreground: (1 - Dice)^(1 - gamma)
fore_dice_loss = torch.pow(torch.clamp(1 - dice_class[:, 1:], min=self.epsilon), 1 - self.gamma)

n_pred_ch = y_pred.shape[1]
if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)
# Combine
if self.include_background:
asy_focal_tversky_loss = torch.cat([back_dice_loss, fore_dice_loss], dim=1)
else:
asy_focal_tversky_loss = fore_dice_loss

asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)
# Part C: Unified Combination & Reduction
# Aggregate Focal Loss spatial dimensions to match Tversky Loss shape (B, C)
if asy_focal_loss.dim() > 2:
asy_focal_loss = torch.mean(asy_focal_loss, dim=reduce_axis)

loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss
# Weighted sum
total_loss = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss

if self.reduction == LossReduction.SUM.value:
return torch.sum(loss) # sum over the batch and channel dims
return torch.sum(total_loss)
if self.reduction == LossReduction.NONE.value:
return loss # returns [N, num_classes] losses
return total_loss
if self.reduction == LossReduction.MEAN.value:
return torch.mean(loss)
return torch.mean(total_loss)

raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
Loading
Loading