-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update Dice Metric/Loss Documentation #8385
Comments
+1. I would also like to know the formula used in DiceMetric. Say, if y_true = ops.one_hot(
ops.squeeze(ops.cast(y_true, 'int32'), axis=-1),
num_classes=self.num_classes
)
y_true_reshaped = ops.reshape(y_true, [-1, self.num_classes])
y_pred = ops.cast(y_pred, y_true.dtype)
y_pred = ops.nn.softmax(y_pred)
y_pred_reshaped = ops.reshape(y_pred, [-1, self.num_classes])
intersection = ops.sum(y_true_reshaped * y_pred_reshaped, axis=0)
union = ops.sum(y_true_reshaped, axis=0) + ops.sum(y_pred_reshaped, axis=0)
Dice = (2 × intersection + smooth) / (union + smooth) I tried to compare results between DiceMetric and above approach, didn't match. |
I'm not totally sure how to demonstrate based off your code snippet, but I will be updating the comments with something like the following: import torch
from monai.metrics import DiceMetric
from monai.networks.utils import one_hot
batch_size, n_classes, h, w = 7, 5, 128, 128
y_pred = torch.rand(batch_size, n_classes, h, w) # network predictions
y_pred = torch.argmax(y_pred, 1, True) # convert to label map
y_pred = one_hot(y_pred, 5) # and then to one-hot
# ground truth as one_hot
y = torch.randint(0, 2, size=(batch_size, n_classes, h, w))
dm = DiceMetric(reduction="mean_batch")
raw_scores = dm(y_pred, y) # dice scores with shape (batch_size, n_classes)
print(raw_scores)
# compute the equivalent metric value
equivalent = torch.sum(y * y_pred, axis=(2, 3)).mul(2) / torch.sum(y + y_pred, axis=(2, 3))
print(equivalent) # same as raw_scores This shows what the computation is for a specific case with one-hot prediction and ground truth. The actual computation is done a per-item and per-channel basis here. |
Is your feature request related to a problem? Please describe.
The docstring descriptions for
DiceLoss
,DiceMetric
,DiceHelper
, andcompute_dice
should include more information about the expectations of inputs, implications of various arguments, and better examples. There's been some confusion about why results from the loss don't match expectations given the value of metrics, this is related to the extra features in the loss such as the smoothing values as well as misunderstanding about activation or thresholding/one-hot formatting.Describe the solution you'd like
A few things should be changed or added:
DiceMetric
should make it clear the ground truth and prediction can be either single-channel label maps or multi-channel one-hot tensors.softmax
parameter ofDiceHelper
is confusing in that it appears to indicatesoftmax
will be applied if it's true, rather than the expectation that this was already done to the prediction. This should be renamed toargmax
and the old name deprecated. It should also be noted that this parameter becomes the opposite value tosigmoid
if not given a value.compute_dice
probably should be discouraged in favour of theDiceMetric
class, and it should be mentioned that they do the same computation. It's possible this function isn't used by anyone and can be deprecated entirely.The text was updated successfully, but these errors were encountered: