Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Dice Metric Docs #8388

Open
wants to merge 7 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 132 additions & 76 deletions monai/metrics/meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch

from monai.metrics.utils import do_metric_reduction
from monai.utils import MetricReduction
from monai.utils import MetricReduction, deprecated_arg

from .metric import CumulativeIterationMetric

Expand All @@ -23,35 +23,76 @@

class DiceMetric(CumulativeIterationMetric):
"""
Compute average Dice score for a set of pairs of prediction-groundtruth segmentations.
Computes Dice score for a set of pairs of prediction-groundtruth labels. It supports single-channel label maps
or multi-channel images with class segmentations per channel. This allows the computation for both multi-class
and multi-label tasks.

It supports both multi-classes and multi-labels tasks.
Input `y_pred` is compared with ground truth `y`.
`y_pred` is expected to have binarized predictions and `y` can be single-channel class indices or in the
one-hot format. The `include_background` parameter can be set to ``False`` to exclude
the first category (channel index 0) which is by convention assumed to be background. If the non-background
segmentations are small compared to the total image size they can get overwhelmed by the signal from the
background. `y_preds` and `y` can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]),
`y` can also be in the format of `B1HW[D]`.
If either prediction ``y_pred`` or ground truth ``y`` have shape BCHW[D], it is expected that these represent one-
hot segmentations for C number of classes. If either shape is B1HW[D], it is expected that these are label maps
and the number of classes must be specified by the ``num_classes`` parameter. In either case for either inputs,
this metric applies no activations and so non-binary values will produce unexpected results if this metric is used
for binary overlap measurement (ie. either was expected to be one-hot formatted). Soft labels are thus permitted by
this metric. Typically this implies that raw predictions from a network must first be activated and possibly made
into label maps, eg. for a multi-class prediction tensor softmax and then argmax should be applied over the channel
dimensions to produce a label map.

The ``include_background`` parameter can be set to `False` to exclude the first category (channel index 0) which
is by convention assumed to be background. If the non-background segmentations are small compared to the total
image size they can get overwhelmed by the signal from the background. This assumes the shape of both prediction
and ground truth is BCHW[D].

The typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

Further information can be found in the official
`MONAI Dice Overview <https://github.com/Project-MONAI/tutorials/blob/main/modules/dice_loss_metric_notes.ipynb>`.

Example:

.. code-block:: python

import torch
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.networks import one_hot

batch_size, n_classes, h, w = 7, 5, 128, 128

y_pred = torch.rand(batch_size, n_classes, h, w) # network predictions
y_pred = torch.argmax(y_pred, 1, True) # convert to label map

# ground truth as label map
y = torch.randint(0, n_classes, size=(batch_size, 1, h, w))

dm = DiceMetric(
reduction="mean_batch", return_with_label=True, num_classes=n_classes
)

raw_scores = dm(y_pred, y)
print(dm.aggregate())

# now compute the Dice loss which should be the same as 1 - raw_scores
dl = DiceLoss(to_onehot_y=True, reduction="none")
loss = dl(one_hot(y_pred, n_classes), y).squeeze()

print(1.0 - loss) # same as raw_scores

Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

Args:
include_background: whether to include Dice computation on the first channel of
the predicted output. Defaults to ``True``.
reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).
Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric.
ignore_empty: whether to ignore empty ground truth cases during calculation.
If `True`, NaN value will be set for empty ground truth cases.
If `False`, 1 will be set if the predictions of empty ground truth cases are also empty.
num_classes: number of input channels (always including the background). When this is None,
include_background: whether to include Dice computation on the first channel/category of the prediction and
ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The
available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If "none", is
selected, the metric will not do reduction.
get_not_nans: whether to return the `not_nans` count. If True, aggregate() returns `(metric, not_nans)` where
`not_nans` counts the number of valid values in the result, and will have the same shape.
ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
are also empty.
num_classes: number of input channels (always including the background). When this is ``None``,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.
return_with_label: whether to return the metrics with label, only works when reduction is "mean_batch".
If `True`, use "label_{index}" as the key corresponding to C channels; if 'include_background' is True,
If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True,
the index begins at "0", otherwise at "1". It can also take a list of label names.
The outcome will then be returned as a dictionary.

Expand All @@ -77,22 +118,21 @@ def __init__(
include_background=self.include_background,
reduction=MetricReduction.NONE,
get_not_nans=False,
softmax=False,
apply_argmax=False,
ignore_empty=self.ignore_empty,
num_classes=self.num_classes,
)

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
"""
Compute the dice value using ``DiceHelper``.

Args:
y_pred: input data to compute, typical segmentation model output.
It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
should be binarized.
y: ground truth to compute mean Dice metric. `y` can be single-channel class indices or
in the one-hot format.
y_pred: prediction value, see class docstring for format definition.
y: ground truth label.

Raises:
ValueError: when `y_pred` has less than three dimensions.
ValueError: when `y_pred` has fewer than three dimensions.
"""
dims = y_pred.ndimension()
if dims < 3:
Expand All @@ -107,10 +147,8 @@ def aggregate(
Execute reduction and aggregation logic for the output of `compute_dice`.

Args:
reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction.

reduction: defines mode of reduction as enumerated in :py:class:`monai.utils.enums.MetricReduction`.
By default this will do no reduction.
"""
data = self.get_buffer()
if not isinstance(data, torch.Tensor):
Expand Down Expand Up @@ -138,18 +176,20 @@ def compute_dice(
ignore_empty: bool = True,
num_classes: int | None = None,
) -> torch.Tensor:
"""Computes Dice score metric for a batch of predictions.
"""
Computes Dice score metric for a batch of predictions. This performs the same computation as
:py:class:`monai.metrics.DiceMetric`, which is preferrable to use over this function. For input formats, see the
documentation for that class .

Args:
y_pred: input data to compute, typical segmentation model output.
`y_pred` can be single-channel class indices or in the one-hot format.
y: ground truth to compute mean dice metric. `y` can be single-channel class indices or in the one-hot format.
include_background: whether to include Dice computation on the first channel of
the predicted output. Defaults to True.
ignore_empty: whether to ignore empty ground truth cases during calculation.
If `True`, NaN value will be set for empty ground truth cases.
If `False`, 1 will be set if the predictions of empty ground truth cases are also empty.
num_classes: number of input channels (always including the background). When this is None,
y: ground truth to compute mean dice metric.
include_background: whether to include Dice computation on the first channel/category of the prediction and
ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
are also empty.
num_classes: number of input channels (always including the background). When this is ``None``,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.

Expand All @@ -161,16 +201,16 @@ def compute_dice(
include_background=include_background,
reduction=MetricReduction.NONE,
get_not_nans=False,
softmax=False,
apply_argmax=False,
ignore_empty=ignore_empty,
num_classes=num_classes,
)(y_pred=y_pred, y=y)


class DiceHelper:
"""
Compute Dice score between two tensors `y_pred` and `y`.
`y_pred` and `y` can be single-channel class indices or in the one-hot format.
Compute Dice score between two tensors ``y_pred`` and ``y``. This is used by :py:class:`monai.metrics.DiceMetric`,
see the documentation for that class for input formats.

Example:

Expand All @@ -188,49 +228,65 @@ class DiceHelper:
score, not_nans = DiceHelper(include_background=False, sigmoid=True, softmax=True)(y_pred, y)
print(score, not_nans)

Args:
include_background: whether to include Dice computation on the first channel/category of the prediction and
ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
threshold: if ``True`, ``y_pred`` will be thresholded at a value of 0.5. Defaults to False.
apply_argmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to
get the discrete prediction. Defaults to the value of ``not threshold``.
activate: if this and ``threshold` are ``True``, sigmoid activation is applied to ``y_pred`` before
thresholding. Defaults to False.
get_not_nans: whether to return the number of not-nan values.
reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The
available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If "none", is
selected, the metric will not do reduction.
ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
are also empty.
num_classes: number of input channels (always including the background). When this is ``None``,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.
"""

@deprecated_arg("softmax", "1.5", "1.7", "Use `apply_argmax` instead.", new_name="apply_argmax")
@deprecated_arg("sigmoid", "1.5", "1.7", "Use `threshold` instead.", new_name="threshold")
def __init__(
self,
include_background: bool | None = None,
sigmoid: bool = False,
softmax: bool | None = None,
threshold: bool = False,
apply_argmax: bool | None = None,
activate: bool = False,
get_not_nans: bool = True,
reduction: MetricReduction | str = MetricReduction.MEAN_BATCH,
ignore_empty: bool = True,
num_classes: int | None = None,
sigmoid: bool | None = None,
softmax: bool | None = None,
) -> None:
"""
# handling deprecated arguments
if sigmoid is not None:
threshold = sigmoid
if softmax is not None:
apply_argmax = softmax

Args:
include_background: whether to include the score on the first channel
(default to the value of `sigmoid`, False).
sigmoid: whether ``y_pred`` are/will be sigmoid activated outputs. If True, thresholding at 0.5
will be performed to get the discrete prediction. Defaults to False.
softmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to
get the discrete prediction. Defaults to the value of ``not sigmoid``.
activate: whether to apply sigmoid to ``y_pred`` if ``sigmoid`` is True. Defaults to False.
This option is only valid when ``sigmoid`` is True.
get_not_nans: whether to return the number of not-nan values.
reduction: define mode of reduction to the metrics
ignore_empty: if `True`, NaN value will be set for empty ground truth cases.
If `False`, 1 will be set if the Union of ``y_pred`` and ``y`` is empty.
num_classes: number of input channels (always including the background). When this is None,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.
"""
self.sigmoid = sigmoid
self.threshold = threshold
self.reduction = reduction
self.get_not_nans = get_not_nans
self.include_background = sigmoid if include_background is None else include_background
self.softmax = not sigmoid if softmax is None else softmax
self.include_background = threshold if include_background is None else include_background
self.apply_argmax = not threshold if apply_argmax is None else apply_argmax
self.activate = activate
self.ignore_empty = ignore_empty
self.num_classes = num_classes

def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
""""""
"""
Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately
for each batch item and for each channel of those items.

Args:
y_pred: input predictions with shape HW[D].
y: ground truth with shape HW[D].
"""
y_o = torch.sum(y)
if y_o > 0:
return (2.0 * torch.sum(torch.masked_select(y, y_pred))) / (y_o + torch.sum(y_pred))
Expand All @@ -243,25 +299,25 @@ def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor

def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Compute the metric for the given prediction and ground truth.

Args:
y_pred: input predictions with shape (batch_size, num_classes or 1, spatial_dims...).
the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``.
y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
"""
_softmax, _sigmoid = self.softmax, self.sigmoid
_apply_argmax, _threshold = self.apply_argmax, self.threshold
if self.num_classes is None:
n_pred_ch = y_pred.shape[1] # y_pred is in one-hot format or multi-channel scores
else:
n_pred_ch = self.num_classes
if y_pred.shape[1] == 1 and self.num_classes > 1: # y_pred is single-channel class indices
_softmax = _sigmoid = False
_apply_argmax = _threshold = False

if _softmax:
if n_pred_ch > 1:
y_pred = torch.argmax(y_pred, dim=1, keepdim=True)
if _apply_argmax and n_pred_ch > 1:
y_pred = torch.argmax(y_pred, dim=1, keepdim=True)

elif _sigmoid:
elif _threshold:
if self.activate:
y_pred = torch.sigmoid(y_pred)
y_pred = y_pred > 0.5
Expand Down
6 changes: 3 additions & 3 deletions tests/metrics/test_compute_meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,15 @@ def test_nans(self, input_data, expected_value):
@parameterized.expand([TEST_CASE_3])
def test_helper(self, input_data, _unused):
vals = {"y_pred": dict(input_data).pop("y_pred"), "y": dict(input_data).pop("y")}
result = DiceHelper(sigmoid=True)(**vals)
result = DiceHelper(threshold=True)(**vals)
np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0, 0.0], atol=1e-4)
np.testing.assert_allclose(sorted(result[1].cpu().numpy()), [0.0, 1.0, 2.0], atol=1e-4)
result = DiceHelper(softmax=True, get_not_nans=False)(**vals)
result = DiceHelper(apply_argmax=True, get_not_nans=False)(**vals)
np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0], atol=1e-4)

num_classes = vals["y_pred"].shape[1]
vals["y_pred"] = torch.argmax(vals["y_pred"], dim=1, keepdim=True)
result = DiceHelper(sigmoid=True, num_classes=num_classes)(**vals)
result = DiceHelper(threshold=True, num_classes=num_classes)(**vals)
np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0, 0.0], atol=1e-4)

# DiceMetric class tests
Expand Down
Loading