Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions monai/losses/cldice.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / (
torch.sum(skel_true[:, 1:, ...]) + self.smooth
)
cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens)
cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens + 1e-7)
return cl_dice


Expand Down Expand Up @@ -179,6 +179,6 @@ def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / (
torch.sum(skel_true[:, 1:, ...]) + self.smooth
)
cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens)
cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens + 1e-7)
total_loss: torch.Tensor = (1.0 - self.alpha) * dice + self.alpha * cl_dice
return total_loss
25 changes: 25 additions & 0 deletions tests/losses/test_cldice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,31 @@ def test_with_cuda(self):
np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(output_dice.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4)

def test_zero_input_no_nan(self):
"""Test that zero-valued inputs do not produce NaN loss (division by zero guard)."""
loss = SoftclDiceLoss(smooth=0.0)
loss_dice = SoftDiceclDiceLoss(smooth=0.0)
y_pred = torch.zeros((2, 3, 8, 8))
y_true = torch.zeros((2, 3, 8, 8))
result = loss(y_true, y_pred)
result_dice = loss_dice(y_true, y_pred)
self.assertFalse(torch.isnan(result).any(), "SoftclDiceLoss produced NaN for zero inputs with smooth=0")
self.assertFalse(torch.isnan(result_dice).any(), "SoftDiceclDiceLoss produced NaN for zero inputs with smooth=0")

def test_no_overlap_no_nan(self):
"""Test that non-overlapping pred/target do not produce NaN loss."""
loss = SoftclDiceLoss(smooth=0.0)
loss_dice = SoftDiceclDiceLoss(smooth=0.0)
# Create non-overlapping predictions and ground truth
y_pred = torch.zeros((2, 3, 16, 16))
y_true = torch.zeros((2, 3, 16, 16))
y_pred[:, 1:, :8, :] = 1.0 # prediction in left half
y_true[:, 1:, 8:, :] = 1.0 # ground truth in right half
result = loss(y_true, y_pred)
result_dice = loss_dice(y_true, y_pred)
self.assertFalse(torch.isnan(result).any(), "SoftclDiceLoss produced NaN for non-overlapping inputs")
self.assertFalse(torch.isnan(result_dice).any(), "SoftDiceclDiceLoss produced NaN for non-overlapping inputs")


if __name__ == "__main__":
unittest.main()
Loading