-
Notifications
You must be signed in to change notification settings - Fork 1.4k
4609: Add AUC-Margin Loss for AUROC optimization #8719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
3ee0c07
f1d38f4
c550c29
2a56f54
448c5df
8ff079c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,150 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Copyright (c) MONAI Consortium | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # You may obtain a copy of the License at | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch.nn as nn | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from torch.nn.modules.loss import _Loss | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from monai.utils import LossReduction | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class AUCMLoss(_Loss): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| AUC-Margin loss with squared-hinge surrogate loss for optimizing AUROC. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| The loss optimizes the Area Under the ROC Curve (AUROC) by using margin-based constraints | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| on positive and negative predictions. It supports two versions: 'v1' includes class prior | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| information, while 'v2' removes this dependency for better generalization. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Reference: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Yuan, Zhuoning, Yan, Yan, Sonka, Milan, and Yang, Tianbao. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "Large-scale robust deep auc maximization: A new surrogate loss and empirical studies on medical image classification." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Proceedings of the IEEE/CVF International Conference on Computer Vision. 2021. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| https://arxiv.org/abs/2012.03173 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Implementation based on: https://github.com/Optimization-AI/LibAUC/blob/1.4.0/libauc/losses/auc.py | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Example: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> from monai.losses import AUCMLoss | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> loss_fn = AUCMLoss() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> input = torch.randn(32, 1, requires_grad=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> target = torch.randint(0, 2, (32, 1)).float() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> loss = loss_fn(input, target) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| margin: float = 1.0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| imratio: float | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| version: str = "v1", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| reduction: LossReduction | str = LossReduction.MEAN, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| margin: margin for squared-hinge surrogate loss (default: ``1.0``). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| imratio: the ratio of the number of positive samples to the number of total samples in the training dataset. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| If this value is not given, it will be automatically calculated with mini-batch samples. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| This value is ignored when ``version`` is set to ``'v2'``. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| version: whether to include prior class information in the objective function (default: ``'v1'``). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 'v1' includes class prior, 'v2' removes this dependency. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| reduction: {``"none"``, ``"mean"``, ``"sum"``} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Specifies the reduction to apply to the output. Defaults to ``"mean"``. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Note: This loss is computed at the batch level and always returns a scalar. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| The reduction parameter is accepted for API consistency but has no effect. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Raises: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ValueError: When ``version`` is not one of ["v1", "v2"]. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ValueError: When ``imratio`` is not in [0, 1]. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Example: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> from monai.losses import AUCMLoss | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> loss_fn = AUCMLoss(version='v2') | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> input = torch.randn(32, 1, requires_grad=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> target = torch.randint(0, 2, (32, 1)).float() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| >>> loss = loss_fn(input, target) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| super().__init__(reduction=LossReduction(reduction).value) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if version not in ["v1", "v2"]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"version should be 'v1' or 'v2', got {version}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if imratio is not None and not (0.0 <= imratio <= 1.0): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"imratio must be in [0, 1], got {imratio}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.margin = margin | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.imratio = imratio | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.version = version | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.a = nn.Parameter(torch.tensor(0.0)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.b = nn.Parameter(torch.tensor(0.0)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.alpha = nn.Parameter(torch.tensor(0.0)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input: the shape should be B1HW[D], where the channel dimension is 1 for binary classification. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| target: the shape should be B1HW[D], with values 0 or 1. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.Tensor: scalar AUCM loss. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Raises: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ValueError: When input or target have incorrect shapes. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ValueError: When input or target have fewer than 2 dimensions. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ValueError: When target contains non-binary values. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if input.ndim < 2 or target.ndim < 2: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("Input and target must have at least 2 dimensions (B, C, ...)") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if input.shape[1] != 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"Input should have 1 channel for binary classification, got {input.shape[1]}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if target.shape[1] != 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"Target should have 1 channel, got {target.shape[1]}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if input.shape != target.shape: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"Input and target shapes do not match: {input.shape} vs {target.shape}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input = input.flatten() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| target = target.flatten() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not torch.all((target == 0) | (target == 1)): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("Target must contain only binary values (0 or 1)") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pos_mask = (target == 1).float() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| neg_mask = (target == 0).float() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.version == "v1": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| p = float(self.imratio) if self.imratio is not None else float(pos_mask.mean().item()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| loss = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (1 - p) * self._safe_mean((input - self.a) ** 2, pos_mask) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| + p * self._safe_mean((input - self.b) ** 2, neg_mask) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| + 2 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * self.alpha | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| p * (1 - p) * self.margin | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| + self._safe_mean(p * input * neg_mask - (1 - p) * input * pos_mask, pos_mask + neg_mask) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - p * (1 - p) * self.alpha**2 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| loss = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._safe_mean((input - self.a) ** 2, pos_mask) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| + self._safe_mean((input - self.b) ** 2, neg_mask) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| + 2 * self.alpha * (self.margin + self._safe_mean(input, neg_mask) - self._safe_mean(input, pos_mask)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - self.alpha**2 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+122
to
+141
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
If I haven't messed up the refactor of the loss equations, this would help simplify the code and reduce redundant calculations. With a bit more commentary this may be easier to understand. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return loss | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _safe_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Compute mean safely over masked elements.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| denom = mask.sum() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if denom == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return (tensor * mask).sum() / denom | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,111 @@ | ||
| # Copyright (c) MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
|
|
||
| from monai.losses import AUCMLoss | ||
| from tests.test_utils import test_script_save | ||
|
|
||
|
|
||
| class TestAUCMLoss(unittest.TestCase): | ||
| """Test cases for AUCMLoss.""" | ||
|
|
||
| def test_v1(self): | ||
| """Test AUCMLoss with version 'v1'.""" | ||
| loss_fn = AUCMLoss(version="v1") | ||
| input = torch.randn(32, 1, requires_grad=True) | ||
| target = torch.randint(0, 2, (32, 1)).float() | ||
| loss = loss_fn(input, target) | ||
| self.assertIsInstance(loss, torch.Tensor) | ||
| self.assertEqual(loss.ndim, 0) | ||
|
|
||
| def test_v2(self): | ||
| """Test AUCMLoss with version 'v2'.""" | ||
| loss_fn = AUCMLoss(version="v2") | ||
| input = torch.randn(32, 1, requires_grad=True) | ||
| target = torch.randint(0, 2, (32, 1)).float() | ||
| loss = loss_fn(input, target) | ||
| self.assertIsInstance(loss, torch.Tensor) | ||
| self.assertEqual(loss.ndim, 0) | ||
|
Comment on lines
+25
to
+41
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These tests are a good start but I think we need a few which tests the output values themselves to ensure the calculation is what it should be. You can precompute some values and store them as globals at the top of this file, look at other test files to see how this is done with |
||
|
|
||
| def test_invalid_version(self): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The following tests are good too but could be condensed using |
||
| """Test that invalid version raises ValueError.""" | ||
| with self.assertRaises(ValueError): | ||
| AUCMLoss(version="invalid") | ||
|
|
||
| def test_invalid_imratio(self): | ||
| """Test that invalid imratio raises ValueError.""" | ||
| with self.assertRaises(ValueError): | ||
| AUCMLoss(imratio=1.5) | ||
| with self.assertRaises(ValueError): | ||
| AUCMLoss(imratio=-0.1) | ||
|
|
||
| def test_invalid_input_shape(self): | ||
| """Test that invalid input shape raises ValueError.""" | ||
| loss_fn = AUCMLoss() | ||
| input = torch.randn(32, 2) # Wrong channel | ||
| target = torch.randint(0, 2, (32, 1)).float() | ||
| with self.assertRaises(ValueError): | ||
| loss_fn(input, target) | ||
|
|
||
| def test_invalid_target_shape(self): | ||
| """Test that invalid target shape raises ValueError.""" | ||
| loss_fn = AUCMLoss() | ||
| input = torch.randn(32, 1) | ||
| target = torch.randint(0, 2, (32, 2)).float() # Wrong channel | ||
| with self.assertRaises(ValueError): | ||
| loss_fn(input, target) | ||
|
|
||
| def test_insufficient_dimensions(self): | ||
| """Test that tensors with insufficient dimensions raise ValueError.""" | ||
| loss_fn = AUCMLoss() | ||
| input = torch.randn(32) # 1D tensor | ||
| target = torch.randint(0, 2, (32, 1)).float() | ||
| with self.assertRaises(ValueError): | ||
| loss_fn(input, target) | ||
|
|
||
| def test_shape_mismatch(self): | ||
| """Test that mismatched shapes raise ValueError.""" | ||
| loss_fn = AUCMLoss() | ||
| input = torch.randn(32, 1) | ||
| target = torch.randint(0, 2, (16, 1)).float() | ||
| with self.assertRaises(ValueError): | ||
| loss_fn(input, target) | ||
|
|
||
| def test_non_binary_target(self): | ||
| """Test that non-binary target values raise ValueError.""" | ||
| loss_fn = AUCMLoss() | ||
| input = torch.randn(32, 1) | ||
| target = torch.tensor([[0.5], [1.0], [2.0], [0.0]] * 8) # 32x1, still non-binary | ||
| with self.assertRaises(ValueError): | ||
| loss_fn(input, target) | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def test_backward(self): | ||
| """Test that gradients can be computed.""" | ||
| loss_fn = AUCMLoss() | ||
| input = torch.randn(32, 1, requires_grad=True) | ||
| target = torch.randint(0, 2, (32, 1)).float() | ||
| loss = loss_fn(input, target) | ||
| loss.backward() | ||
| self.assertIsNotNone(input.grad) | ||
|
|
||
| def test_script_save(self): | ||
| """Test that the loss can be saved as TorchScript.""" | ||
| loss_fn = AUCMLoss() | ||
| test_script_save(loss_fn, torch.randn(32, 1), torch.randint(0, 2, (32, 1)).float()) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
Uh oh!
There was an error while loading. Please reload this page.