diff --git a/monai/losses/spectral_loss.py b/monai/losses/spectral_loss.py index 06714f3993..fcba03f132 100644 --- a/monai/losses/spectral_loss.py +++ b/monai/losses/spectral_loss.py @@ -55,8 +55,8 @@ def __init__( self.fft_norm = fft_norm def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - input_amplitude = self._get_fft_amplitude(target) - target_amplitude = self._get_fft_amplitude(input) + input_amplitude = self._get_fft_amplitude(input) + target_amplitude = self._get_fft_amplitude(target) # Compute distance between amplitude of frequency components # See Section 3.3 from https://arxiv.org/abs/2005.00341 diff --git a/monai/losses/ssim_loss.py b/monai/losses/ssim_loss.py index 8ee1da7267..3fa578da29 100644 --- a/monai/losses/ssim_loss.py +++ b/monai/losses/ssim_loss.py @@ -111,17 +111,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # 2D data x = torch.ones([1,1,10,10])/2 y = torch.ones([1,1,10,10])/2 - print(1-SSIMLoss(spatial_dims=2)(x,y)) + print(SSIMLoss(spatial_dims=2)(x,y)) # pseudo-3D data x = torch.ones([1,5,10,10])/2 # 5 could represent number of slices y = torch.ones([1,5,10,10])/2 - print(1-SSIMLoss(spatial_dims=2)(x,y)) + print(SSIMLoss(spatial_dims=2)(x,y)) # 3D data x = torch.ones([1,1,10,10,10])/2 y = torch.ones([1,1,10,10,10])/2 - print(1-SSIMLoss(spatial_dims=3)(x,y)) + print(SSIMLoss(spatial_dims=3)(x,y)) """ ssim_value = self.ssim_metric._compute_tensor(input, target).view(-1, 1) loss: torch.Tensor = 1 - ssim_value