From c9436fa7787c394f207bcac6032091df560b0a96 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Tue, 11 Mar 2025 18:12:15 +0000 Subject: [PATCH 1/3] Updating docstrings Signed-off-by: Eric Kerfoot --- monai/metrics/meandice.py | 124 +++++++++++++++++++------------------- 1 file changed, 62 insertions(+), 62 deletions(-) diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index f21040d58e..a8270215d1 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -23,35 +23,39 @@ class DiceMetric(CumulativeIterationMetric): """ - Compute average Dice score for a set of pairs of prediction-groundtruth segmentations. + Computes Dice score for a set of pairs of prediction-groundtruth labels. It supports single-channel label maps + or multi-channel images with class segmentations per channel. This allows the computation for both multi-class + and multi-label tasks. - It supports both multi-classes and multi-labels tasks. - Input `y_pred` is compared with ground truth `y`. - `y_pred` is expected to have binarized predictions and `y` can be single-channel class indices or in the - one-hot format. The `include_background` parameter can be set to ``False`` to exclude - the first category (channel index 0) which is by convention assumed to be background. If the non-background - segmentations are small compared to the total image size they can get overwhelmed by the signal from the - background. `y_preds` and `y` can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]), - `y` can also be in the format of `B1HW[D]`. + If either prediction ``y_pred`` or ground truth ``y`` have shape BCHW[D], it is expected that these represent one- + hot segmentations for C number of classes. If either shape is B1HW[D], it is expected that these are label maps + and the number of classes must be specified by the ``num_classes`` parameter. In either case for either inputs, + this metric applies no activations and so non-binary values will produce unexpected results if this metric is used + for binary overlap measurement. Soft labels are thus permitted by this metric. - Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. + The ``include_background`` parameter can be set to `False` to exclude the first category (channel index 0) which + is by convention assumed to be background. If the non-background segmentations are small compared to the total + image size they can get overwhelmed by the signal from the background. This assumes the shape of both prediction + and ground truth is BCHW[D]. + + An example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. Args: - include_background: whether to include Dice computation on the first channel of - the predicted output. Defaults to ``True``. - reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, - available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. - get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). - Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric. - ignore_empty: whether to ignore empty ground truth cases during calculation. - If `True`, NaN value will be set for empty ground truth cases. - If `False`, 1 will be set if the predictions of empty ground truth cases are also empty. - num_classes: number of input channels (always including the background). When this is None, + include_background: whether to include Dice computation on the first channel/category of the prediction and + ground truth. Defaults to ``True``, use ``False`` to exclude the background class. + reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The + available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If "none", is + selected, the metric will not do reduction. + get_not_nans: whether to return the `not_nans` count. If True, aggregate() returns `(metric, not_nans)` where + `not_nans` counts the number of valid values in the result, and will have the same shape. + ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be + set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases + are also empty. + num_classes: number of input channels (always including the background). When this is ``None``, ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are single-channel class indices and the number of classes is not automatically inferred from data. return_with_label: whether to return the metrics with label, only works when reduction is "mean_batch". - If `True`, use "label_{index}" as the key corresponding to C channels; if 'include_background' is True, + If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True, the index begins at "0", otherwise at "1". It can also take a list of label names. The outcome will then be returned as a dictionary. @@ -84,15 +88,14 @@ def __init__( def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ + Compute the dice value using ``DiceHelper``. + Args: - y_pred: input data to compute, typical segmentation model output. - It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values - should be binarized. - y: ground truth to compute mean Dice metric. `y` can be single-channel class indices or - in the one-hot format. + y_pred: prediction value, see class docstring for format definition. + y: ground truth label. Raises: - ValueError: when `y_pred` has less than three dimensions. + ValueError: when `y_pred` has fewer than three dimensions. """ dims = y_pred.ndimension() if dims < 3: @@ -107,10 +110,8 @@ def aggregate( Execute reduction and aggregation logic for the output of `compute_dice`. Args: - reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, - available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction. - + reduction: defines mode of reduction as enumerated in :py:class:`monai.utils.enums.MetricReduction`. + By default this will do no reduction. """ data = self.get_buffer() if not isinstance(data, torch.Tensor): @@ -138,18 +139,19 @@ def compute_dice( ignore_empty: bool = True, num_classes: int | None = None, ) -> torch.Tensor: - """Computes Dice score metric for a batch of predictions. + """ + Computes Dice score metric for a batch of predictions. This performs the same computation as + :py:class:`monai.metrics.DiceMetric`, see the documentation for that class for input formats. Args: y_pred: input data to compute, typical segmentation model output. - `y_pred` can be single-channel class indices or in the one-hot format. - y: ground truth to compute mean dice metric. `y` can be single-channel class indices or in the one-hot format. - include_background: whether to include Dice computation on the first channel of - the predicted output. Defaults to True. - ignore_empty: whether to ignore empty ground truth cases during calculation. - If `True`, NaN value will be set for empty ground truth cases. - If `False`, 1 will be set if the predictions of empty ground truth cases are also empty. - num_classes: number of input channels (always including the background). When this is None, + y: ground truth to compute mean dice metric. + include_background: whether to include Dice computation on the first channel/category of the prediction and + ground truth. Defaults to ``True``, use ``False`` to exclude the background class. + ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be + set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases + are also empty. + num_classes: number of input channels (always including the background). When this is ``None``, ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are single-channel class indices and the number of classes is not automatically inferred from data. @@ -169,8 +171,8 @@ def compute_dice( class DiceHelper: """ - Compute Dice score between two tensors `y_pred` and `y`. - `y_pred` and `y` can be single-channel class indices or in the one-hot format. + Compute Dice score between two tensors ``y_pred`` and ``y``. This is used by :py:class:`monai.metrics.DiceMetric`, + see the documentation for that class for input formats. Example: @@ -188,6 +190,23 @@ class DiceHelper: score, not_nans = DiceHelper(include_background=False, sigmoid=True, softmax=True)(y_pred, y) print(score, not_nans) + Args: + include_background: whether to include Dice computation on the first channel/category of the prediction and + ground truth. Defaults to ``True``, use ``False`` to exclude the background class. + sigmoid: if ``True`, ``y_pred`` will be thresholded at a value of 0.5. Defaults to False. + softmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to + get the discrete prediction. Defaults to the value of ``not sigmoid``. + activate: if this and ``sigmoid` are ``True``, sigmoid activation is applied to ``y_pred``. Defaults to False. + get_not_nans: whether to return the number of not-nan values. + reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The + available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If "none", is + selected, the metric will not do reduction. + ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be + set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases + are also empty. + num_classes: number of input channels (always including the background). When this is ``None``, + ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are + single-channel class indices and the number of classes is not automatically inferred from data. """ def __init__( @@ -201,25 +220,6 @@ def __init__( ignore_empty: bool = True, num_classes: int | None = None, ) -> None: - """ - - Args: - include_background: whether to include the score on the first channel - (default to the value of `sigmoid`, False). - sigmoid: whether ``y_pred`` are/will be sigmoid activated outputs. If True, thresholding at 0.5 - will be performed to get the discrete prediction. Defaults to False. - softmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to - get the discrete prediction. Defaults to the value of ``not sigmoid``. - activate: whether to apply sigmoid to ``y_pred`` if ``sigmoid`` is True. Defaults to False. - This option is only valid when ``sigmoid`` is True. - get_not_nans: whether to return the number of not-nan values. - reduction: define mode of reduction to the metrics - ignore_empty: if `True`, NaN value will be set for empty ground truth cases. - If `False`, 1 will be set if the Union of ``y_pred`` and ``y`` is empty. - num_classes: number of input channels (always including the background). When this is None, - ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are - single-channel class indices and the number of classes is not automatically inferred from data. - """ self.sigmoid = sigmoid self.reduction = reduction self.get_not_nans = get_not_nans From b11540b204b3faf666acde5a7000ea397ca07b75 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 12 Mar 2025 16:49:08 +0000 Subject: [PATCH 2/3] Updating docstrings Signed-off-by: Eric Kerfoot --- monai/metrics/meandice.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index a8270215d1..0bf54f6aed 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -40,6 +40,29 @@ class DiceMetric(CumulativeIterationMetric): An example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. + Example: + + .. code-block:: python + + import torch + from monai.metrics import DiceMetric + + batch_size, n_classes, h, w = 7, 5, 128, 128 + + y_pred = torch.rand(batch_size, n_classes, h, w) # network predictions + y_pred = torch.argmax(y_pred, 1, True) # convert to label map + + # ground truth as label map + y = torch.randint(0, n_classes, size=(batch_size, 1, h, w)) + + dm = DiceMetric( + reduction="mean_batch", return_with_label=True, num_classes=n_classes + ) + + raw_scores = dm(y_pred, y) + print(dm.aggregate()) + + Args: include_background: whether to include Dice computation on the first channel/category of the prediction and ground truth. Defaults to ``True``, use ``False`` to exclude the background class. From 5e066e28d3d527226597419b9b931c8ffa0709ae Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 12 Mar 2025 21:23:42 +0000 Subject: [PATCH 3/3] Amending documentation and tests Signed-off-by: Eric Kerfoot --- monai/metrics/meandice.py | 117 ++++++++++++++++--------- tests/metrics/test_compute_meandice.py | 6 +- 2 files changed, 78 insertions(+), 45 deletions(-) diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index 0bf54f6aed..0802cc3364 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -14,7 +14,7 @@ import torch from monai.metrics.utils import do_metric_reduction -from monai.utils import MetricReduction +from monai.utils import MetricReduction, deprecated_arg from .metric import CumulativeIterationMetric @@ -24,21 +24,27 @@ class DiceMetric(CumulativeIterationMetric): """ Computes Dice score for a set of pairs of prediction-groundtruth labels. It supports single-channel label maps - or multi-channel images with class segmentations per channel. This allows the computation for both multi-class - and multi-label tasks. + or multi-channel images with class segmentations per channel. This allows the computation for both multi-class + and multi-label tasks. If either prediction ``y_pred`` or ground truth ``y`` have shape BCHW[D], it is expected that these represent one- hot segmentations for C number of classes. If either shape is B1HW[D], it is expected that these are label maps - and the number of classes must be specified by the ``num_classes`` parameter. In either case for either inputs, - this metric applies no activations and so non-binary values will produce unexpected results if this metric is used - for binary overlap measurement. Soft labels are thus permitted by this metric. - - The ``include_background`` parameter can be set to `False` to exclude the first category (channel index 0) which - is by convention assumed to be background. If the non-background segmentations are small compared to the total + and the number of classes must be specified by the ``num_classes`` parameter. In either case for either inputs, + this metric applies no activations and so non-binary values will produce unexpected results if this metric is used + for binary overlap measurement (ie. either was expected to be one-hot formatted). Soft labels are thus permitted by + this metric. Typically this implies that raw predictions from a network must first be activated and possibly made + into label maps, eg. for a multi-class prediction tensor softmax and then argmax should be applied over the channel + dimensions to produce a label map. + + The ``include_background`` parameter can be set to `False` to exclude the first category (channel index 0) which + is by convention assumed to be background. If the non-background segmentations are small compared to the total image size they can get overwhelmed by the signal from the background. This assumes the shape of both prediction and ground truth is BCHW[D]. - An example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. + The typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. + + Further information can be found in the official + `MONAI Dice Overview `. Example: @@ -46,6 +52,8 @@ class DiceMetric(CumulativeIterationMetric): import torch from monai.metrics import DiceMetric + from monai.losses import DiceLoss + from monai.networks import one_hot batch_size, n_classes, h, w = 7, 5, 128, 128 @@ -53,7 +61,7 @@ class DiceMetric(CumulativeIterationMetric): y_pred = torch.argmax(y_pred, 1, True) # convert to label map # ground truth as label map - y = torch.randint(0, n_classes, size=(batch_size, 1, h, w)) + y = torch.randint(0, n_classes, size=(batch_size, 1, h, w)) dm = DiceMetric( reduction="mean_batch", return_with_label=True, num_classes=n_classes @@ -62,16 +70,22 @@ class DiceMetric(CumulativeIterationMetric): raw_scores = dm(y_pred, y) print(dm.aggregate()) + # now compute the Dice loss which should be the same as 1 - raw_scores + dl = DiceLoss(to_onehot_y=True, reduction="none") + loss = dl(one_hot(y_pred, n_classes), y).squeeze() + + print(1.0 - loss) # same as raw_scores + Args: include_background: whether to include Dice computation on the first channel/category of the prediction and - ground truth. Defaults to ``True``, use ``False`` to exclude the background class. + ground truth. Defaults to ``True``, use ``False`` to exclude the background class. reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If "none", is selected, the metric will not do reduction. get_not_nans: whether to return the `not_nans` count. If True, aggregate() returns `(metric, not_nans)` where `not_nans` counts the number of valid values in the result, and will have the same shape. - ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be + ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases are also empty. num_classes: number of input channels (always including the background). When this is ``None``, @@ -104,14 +118,14 @@ def __init__( include_background=self.include_background, reduction=MetricReduction.NONE, get_not_nans=False, - softmax=False, + apply_argmax=False, ignore_empty=self.ignore_empty, num_classes=self.num_classes, ) def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ - Compute the dice value using ``DiceHelper``. + Compute the dice value using ``DiceHelper``. Args: y_pred: prediction value, see class docstring for format definition. @@ -133,7 +147,7 @@ def aggregate( Execute reduction and aggregation logic for the output of `compute_dice`. Args: - reduction: defines mode of reduction as enumerated in :py:class:`monai.utils.enums.MetricReduction`. + reduction: defines mode of reduction as enumerated in :py:class:`monai.utils.enums.MetricReduction`. By default this will do no reduction. """ data = self.get_buffer() @@ -163,15 +177,16 @@ def compute_dice( num_classes: int | None = None, ) -> torch.Tensor: """ - Computes Dice score metric for a batch of predictions. This performs the same computation as - :py:class:`monai.metrics.DiceMetric`, see the documentation for that class for input formats. + Computes Dice score metric for a batch of predictions. This performs the same computation as + :py:class:`monai.metrics.DiceMetric`, which is preferrable to use over this function. For input formats, see the + documentation for that class . Args: y_pred: input data to compute, typical segmentation model output. - y: ground truth to compute mean dice metric. + y: ground truth to compute mean dice metric. include_background: whether to include Dice computation on the first channel/category of the prediction and - ground truth. Defaults to ``True``, use ``False`` to exclude the background class. - ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be + ground truth. Defaults to ``True``, use ``False`` to exclude the background class. + ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases are also empty. num_classes: number of input channels (always including the background). When this is ``None``, @@ -186,7 +201,7 @@ def compute_dice( include_background=include_background, reduction=MetricReduction.NONE, get_not_nans=False, - softmax=False, + apply_argmax=False, ignore_empty=ignore_empty, num_classes=num_classes, )(y_pred=y_pred, y=y) @@ -194,8 +209,8 @@ def compute_dice( class DiceHelper: """ - Compute Dice score between two tensors ``y_pred`` and ``y``. This is used by :py:class:`monai.metrics.DiceMetric`, - see the documentation for that class for input formats. + Compute Dice score between two tensors ``y_pred`` and ``y``. This is used by :py:class:`monai.metrics.DiceMetric`, + see the documentation for that class for input formats. Example: @@ -215,16 +230,17 @@ class DiceHelper: Args: include_background: whether to include Dice computation on the first channel/category of the prediction and - ground truth. Defaults to ``True``, use ``False`` to exclude the background class. - sigmoid: if ``True`, ``y_pred`` will be thresholded at a value of 0.5. Defaults to False. - softmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to - get the discrete prediction. Defaults to the value of ``not sigmoid``. - activate: if this and ``sigmoid` are ``True``, sigmoid activation is applied to ``y_pred``. Defaults to False. + ground truth. Defaults to ``True``, use ``False`` to exclude the background class. + threshold: if ``True`, ``y_pred`` will be thresholded at a value of 0.5. Defaults to False. + apply_argmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to + get the discrete prediction. Defaults to the value of ``not threshold``. + activate: if this and ``threshold` are ``True``, sigmoid activation is applied to ``y_pred`` before + thresholding. Defaults to False. get_not_nans: whether to return the number of not-nan values. reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If "none", is selected, the metric will not do reduction. - ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be + ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases are also empty. num_classes: number of input channels (always including the background). When this is ``None``, @@ -232,28 +248,45 @@ class DiceHelper: single-channel class indices and the number of classes is not automatically inferred from data. """ + @deprecated_arg("softmax", "1.5", "1.7", "Use `apply_argmax` instead.", new_name="apply_argmax") + @deprecated_arg("sigmoid", "1.5", "1.7", "Use `threshold` instead.", new_name="threshold") def __init__( self, include_background: bool | None = None, - sigmoid: bool = False, - softmax: bool | None = None, + threshold: bool = False, + apply_argmax: bool | None = None, activate: bool = False, get_not_nans: bool = True, reduction: MetricReduction | str = MetricReduction.MEAN_BATCH, ignore_empty: bool = True, num_classes: int | None = None, + sigmoid: bool | None = None, + softmax: bool | None = None, ) -> None: - self.sigmoid = sigmoid + # handling deprecated arguments + if sigmoid is not None: + threshold = sigmoid + if softmax is not None: + apply_argmax = softmax + + self.threshold = threshold self.reduction = reduction self.get_not_nans = get_not_nans - self.include_background = sigmoid if include_background is None else include_background - self.softmax = not sigmoid if softmax is None else softmax + self.include_background = threshold if include_background is None else include_background + self.apply_argmax = not threshold if apply_argmax is None else apply_argmax self.activate = activate self.ignore_empty = ignore_empty self.num_classes = num_classes def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """""" + """ + Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately + for each batch item and for each channel of those items. + + Args: + y_pred: input predictions with shape HW[D]. + y: ground truth with shape HW[D]. + """ y_o = torch.sum(y) if y_o > 0: return (2.0 * torch.sum(torch.masked_select(y, y_pred))) / (y_o + torch.sum(y_pred)) @@ -266,25 +299,25 @@ def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ + Compute the metric for the given prediction and ground truth. Args: y_pred: input predictions with shape (batch_size, num_classes or 1, spatial_dims...). the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``. y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...). """ - _softmax, _sigmoid = self.softmax, self.sigmoid + _apply_argmax, _threshold = self.apply_argmax, self.threshold if self.num_classes is None: n_pred_ch = y_pred.shape[1] # y_pred is in one-hot format or multi-channel scores else: n_pred_ch = self.num_classes if y_pred.shape[1] == 1 and self.num_classes > 1: # y_pred is single-channel class indices - _softmax = _sigmoid = False + _apply_argmax = _threshold = False - if _softmax: - if n_pred_ch > 1: - y_pred = torch.argmax(y_pred, dim=1, keepdim=True) + if _apply_argmax and n_pred_ch > 1: + y_pred = torch.argmax(y_pred, dim=1, keepdim=True) - elif _sigmoid: + elif _threshold: if self.activate: y_pred = torch.sigmoid(y_pred) y_pred = y_pred > 0.5 diff --git a/tests/metrics/test_compute_meandice.py b/tests/metrics/test_compute_meandice.py index aae15483b5..04c81ff9a7 100644 --- a/tests/metrics/test_compute_meandice.py +++ b/tests/metrics/test_compute_meandice.py @@ -267,15 +267,15 @@ def test_nans(self, input_data, expected_value): @parameterized.expand([TEST_CASE_3]) def test_helper(self, input_data, _unused): vals = {"y_pred": dict(input_data).pop("y_pred"), "y": dict(input_data).pop("y")} - result = DiceHelper(sigmoid=True)(**vals) + result = DiceHelper(threshold=True)(**vals) np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0, 0.0], atol=1e-4) np.testing.assert_allclose(sorted(result[1].cpu().numpy()), [0.0, 1.0, 2.0], atol=1e-4) - result = DiceHelper(softmax=True, get_not_nans=False)(**vals) + result = DiceHelper(apply_argmax=True, get_not_nans=False)(**vals) np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0], atol=1e-4) num_classes = vals["y_pred"].shape[1] vals["y_pred"] = torch.argmax(vals["y_pred"], dim=1, keepdim=True) - result = DiceHelper(sigmoid=True, num_classes=num_classes)(**vals) + result = DiceHelper(threshold=True, num_classes=num_classes)(**vals) np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0, 0.0], atol=1e-4) # DiceMetric class tests