Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Dice Metric/Loss Documentation #8385

Open
ericspod opened this issue Mar 11, 2025 · 2 comments · May be fixed by #8388
Open

Update Dice Metric/Loss Documentation #8385

ericspod opened this issue Mar 11, 2025 · 2 comments · May be fixed by #8388
Assignees

Comments

@ericspod
Copy link
Member

Is your feature request related to a problem? Please describe.
The docstring descriptions for DiceLoss, DiceMetric, DiceHelper, and compute_dice should include more information about the expectations of inputs, implications of various arguments, and better examples. There's been some confusion about why results from the loss don't match expectations given the value of metrics, this is related to the extra features in the loss such as the smoothing values as well as misunderstanding about activation or thresholding/one-hot formatting.

Describe the solution you'd like
A few things should be changed or added:

  • Docstring for DiceMetric should make it clear the ground truth and prediction can be either single-channel label maps or multi-channel one-hot tensors.
  • The values of the inputs do matter since no thresholding is done in the metrics themselves.
  • The softmax parameter of DiceHelper is confusing in that it appears to indicate softmax will be applied if it's true, rather than the expectation that this was already done to the prediction. This should be renamed to argmax and the old name deprecated. It should also be noted that this parameter becomes the opposite value to sigmoid if not given a value.
  • Using compute_dice probably should be discouraged in favour of the DiceMetric class, and it should be mentioned that they do the same computation. It's possible this function isn't used by anyone and can be deprecated entirely.
  • An example in the docstring of how to use the metric and the loss to get expected results should be included.
@innat
Copy link

innat commented Mar 12, 2025

+1.

I would also like to know the formula used in DiceMetric. Say, if y_true contains sparse vector (1, 2, 3) and model gives logits, then:

y_true = ops.one_hot(
   ops.squeeze(ops.cast(y_true, 'int32'), axis=-1), 
   num_classes=self.num_classes
)
y_true_reshaped = ops.reshape(y_true, [-1, self.num_classes])

y_pred = ops.cast(y_pred, y_true.dtype)
y_pred = ops.nn.softmax(y_pred)
y_pred_reshaped = ops.reshape(y_pred, [-1, self.num_classes])
        
intersection = ops.sum(y_true_reshaped * y_pred_reshaped, axis=0)
union = ops.sum(y_true_reshaped, axis=0) + ops.sum(y_pred_reshaped, axis=0)
Dice = (2 × intersection + smooth) / (union + smooth)

I tried to compare results between DiceMetric and above approach, didn't match.

@ericspod
Copy link
Member Author

+1.

I would also like to know the formula used in DiceMetric.

I'm not totally sure how to demonstrate based off your code snippet, but I will be updating the comments with something like the following:

import torch
from monai.metrics import DiceMetric
from monai.networks.utils import one_hot

batch_size, n_classes, h, w = 7, 5, 128, 128

y_pred = torch.rand(batch_size, n_classes, h, w)  # network predictions
y_pred = torch.argmax(y_pred, 1, True)  # convert to label map
y_pred = one_hot(y_pred, 5)  # and then to one-hot

# ground truth as one_hot
y = torch.randint(0, 2, size=(batch_size, n_classes, h, w))

dm = DiceMetric(reduction="mean_batch")

raw_scores = dm(y_pred, y)  # dice scores with shape (batch_size, n_classes)
print(raw_scores)

# compute the equivalent metric value
equivalent = torch.sum(y * y_pred, axis=(2, 3)).mul(2) / torch.sum(y + y_pred, axis=(2, 3))
print(equivalent)  # same as raw_scores

This shows what the computation is for a specific case with one-hot prediction and ground truth. The actual computation is done a per-item and per-channel basis here.

@ericspod ericspod linked a pull request Mar 12, 2025 that will close this issue
7 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants