diff --git a/monai/losses/cldice.py b/monai/losses/cldice.py index 406cc3825f..64f2cbfb09 100644 --- a/monai/losses/cldice.py +++ b/monai/losses/cldice.py @@ -142,7 +142,7 @@ def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / ( torch.sum(skel_true[:, 1:, ...]) + self.smooth ) - cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) + cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens + 1e-7) return cl_dice @@ -179,6 +179,6 @@ def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / ( torch.sum(skel_true[:, 1:, ...]) + self.smooth ) - cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) + cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens + 1e-7) total_loss: torch.Tensor = (1.0 - self.alpha) * dice + self.alpha * cl_dice return total_loss diff --git a/tests/losses/test_cldice_loss.py b/tests/losses/test_cldice_loss.py index 14d3575e3b..33430d7480 100644 --- a/tests/losses/test_cldice_loss.py +++ b/tests/losses/test_cldice_loss.py @@ -46,6 +46,31 @@ def test_with_cuda(self): np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(output_dice.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4) + def test_zero_input_no_nan(self): + """Test that zero-valued inputs do not produce NaN loss (division by zero guard).""" + loss = SoftclDiceLoss(smooth=0.0) + loss_dice = SoftDiceclDiceLoss(smooth=0.0) + y_pred = torch.zeros((2, 3, 8, 8)) + y_true = torch.zeros((2, 3, 8, 8)) + result = loss(y_true, y_pred) + result_dice = loss_dice(y_true, y_pred) + self.assertFalse(torch.isnan(result).any(), "SoftclDiceLoss produced NaN for zero inputs with smooth=0") + self.assertFalse(torch.isnan(result_dice).any(), "SoftDiceclDiceLoss produced NaN for zero inputs with smooth=0") + + def test_no_overlap_no_nan(self): + """Test that non-overlapping pred/target do not produce NaN loss.""" + loss = SoftclDiceLoss(smooth=0.0) + loss_dice = SoftDiceclDiceLoss(smooth=0.0) + # Create non-overlapping predictions and ground truth + y_pred = torch.zeros((2, 3, 16, 16)) + y_true = torch.zeros((2, 3, 16, 16)) + y_pred[:, 1:, :8, :] = 1.0 # prediction in left half + y_true[:, 1:, 8:, :] = 1.0 # ground truth in right half + result = loss(y_true, y_pred) + result_dice = loss_dice(y_true, y_pred) + self.assertFalse(torch.isnan(result).any(), "SoftclDiceLoss produced NaN for non-overlapping inputs") + self.assertFalse(torch.isnan(result_dice).any(), "SoftDiceclDiceLoss produced NaN for non-overlapping inputs") + if __name__ == "__main__": unittest.main()