Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid zeros in dice and iou #2567

Merged
merged 13 commits into from
Jul 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from collections import Sequence
from functools import wraps
from typing import Optional, Tuple, Callable
Expand All @@ -6,7 +7,7 @@
from torch.nn import functional as F

from pytorch_lightning.metrics.functional.reduction import reduce
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities import rank_zero_warn, FLOAT16_EPSILON


def to_onehot(
Expand Down Expand Up @@ -893,8 +894,8 @@ def dice_score(
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> average_precision(pred, target)
tensor(0.2500)
>>> dice_score(pred, target)
tensor(0.3333)

"""
num_classes = pred.shape[1]
Expand All @@ -907,14 +908,9 @@ def dice_score(
continue

tp, fp, tn, fn, sup = stat_scores(pred=pred, target=target, class_index=i)

denom = (2 * tp + fp + fn + 1e-15).to(torch.float)

if torch.isclose(denom, torch.zeros_like(denom)).any():
# nan result
score_cls = nan_score
else:
score_cls = (2 * tp).to(torch.float) / denom
denom = (2 * tp + fp + fn).to(torch.float)
# nan result
score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(denom) else nan_score

scores[i - bg] += score_cls
return reduce(scores, reduction=reduction)
Expand Down Expand Up @@ -963,5 +959,7 @@ def iou(
tps = tps[1:]
fps = fps[1:]
fns = fns[1:]
iou = tps / (fps + fns + tps + 1e-15)
denom = fps + fns + tps
denom[denom == 0] = torch.tensor(FLOAT16_EPSILON).type_as(denom)
iou = tps / denom
return reduce(iou, reduction=reduction)
5 changes: 5 additions & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""General utilities"""

import numpy
import torch

from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info
Expand All @@ -14,3 +15,7 @@
APEX_AVAILABLE = True

NATIVE_AMP_AVALAIBLE = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")

FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps
FLOAT64_EPSILON = numpy.finfo(numpy.float64).eps
91 changes: 19 additions & 72 deletions tests/metrics/functional/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,57 +47,23 @@ def test_against_sklearn(sklearn_metric, torch_metric):
"""Compare PL metrics to sklearn version."""
device = 'cuda' if torch.cuda.is_available() else 'cpu'

pred = torch.randint(10, (500,), device=device)
target = torch.randint(10, (500,), device=device)
# iterate over different label counts in predictions and target
for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)]:
pred = torch.randint(n_cls_pred, (300,), device=device)
target = torch.randint(n_cls_target, (300,), device=device)

assert torch.allclose(
torch.tensor(sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy()), dtype=torch.float, device=device),
torch_metric(pred, target))

pred = torch.randint(10, (200,), device=device)
target = torch.randint(5, (200,), device=device)

assert torch.allclose(
torch.tensor(sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy()), dtype=torch.float, device=device),
torch_metric(pred, target))

pred = torch.randint(5, (200,), device=device)
target = torch.randint(10, (200,), device=device)

assert torch.allclose(
torch.tensor(sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy()), dtype=torch.float, device=device),
torch_metric(pred, target))
sk_score = sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy())
sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
pl_score = torch_metric(pred, target)
assert torch.allclose(sk_score, pl_score)


def test_onehot():
test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
expected = torch.tensor([
[
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]
], [
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1]
]
expected = torch.stack([
torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]),
torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)])
])

assert test_tensor.shape == (2, 5)
Expand All @@ -116,30 +82,9 @@ def test_onehot():


def test_to_categorical():
test_tensor = torch.tensor([
[
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]
], [
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1]
]
test_tensor = torch.stack([
torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]),
torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)])
]).to(torch.float)

expected = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
Expand Down Expand Up @@ -260,7 +205,9 @@ def test_fbeta_score(pred, target, beta, exp_score):


@pytest.mark.parametrize(['pred', 'target', 'exp_score'], [
pytest.param([0., 0., 0., 0.], [1., 1., 1., 1.], [0.0, 0.0]),
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], [0.5, 0.5]),
pytest.param([1., 0., 1., 0.], [1., 0., 1., 0.], [1.0, 1.0]),
])
def test_f1_score(pred, target, exp_score):
score = f1_score(torch.tensor(pred), torch.tensor(target), reduction='none')
Expand Down Expand Up @@ -324,7 +271,7 @@ def test_roc_curve(pred, target, expected_tpr, expected_fpr):


@pytest.mark.parametrize(['pred', 'target', 'expected'], [
pytest.param([0, 0, 1, 1], [0, 0, 1, 1], 1.),
pytest.param([0, 1, 0, 1], [0, 1, 0, 1], 1.),
pytest.param([1, 1, 0, 0], [0, 0, 1, 1], 0.),
pytest.param([1, 1, 1, 1], [1, 1, 0, 0], 0.5),
pytest.param([1, 1, 0, 0], [1, 1, 0, 0], 1.),
Expand Down Expand Up @@ -355,7 +302,7 @@ def test_auc(x, y, expected):
# The precision is then the fraction of positive whatever the recall
# is, as there is only one threshold:
pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25),
# With treshold .8 : 1 TP and 2 TN and one FN
# With threshold 0.8 : 1 TP and 2 TN and one FN
pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75),
])
def test_average_precision(scores, target, expected_score):
Expand Down
16 changes: 10 additions & 6 deletions tests/metrics/functional/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,47 @@

@pytest.mark.parametrize(['pred', 'target', 'expected'], [
pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.25),
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 3.0)
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 3.0),
])
def test_mse(pred, target, expected):
score = mse(torch.tensor(pred), torch.tensor(target))
assert score.item() == expected


@pytest.mark.parametrize(['pred', 'target', 'expected'], [
pytest.param([0., 1, 2, 3], [0., 1, 2, 3], 0.0),
pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.5),
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 1.7321)
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 1.7321),
])
def test_rmse(pred, target, expected):
score = rmse(torch.tensor(pred), torch.tensor(target))
assert torch.allclose(score, torch.tensor(expected), atol=1e-3)


@pytest.mark.parametrize(['pred', 'target', 'expected'], [
pytest.param([0., 1, 2, 3], [0., 1, 2, 3], 0.0),
pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.25),
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 1.5)
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 1.5),
])
def test_mae(pred, target, expected):
score = mae(torch.tensor(pred), torch.tensor(target))
assert score.item() == expected


@pytest.mark.parametrize(['pred', 'target', 'expected'], [
pytest.param([0., 1, 2, 3], [0., 1, 2, 3], 0.0),
pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.0207),
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 0.2841)
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 0.2841),
])
def test_rmsle(pred, target, expected):
score = rmsle(torch.tensor(pred), torch.tensor(target))
assert torch.allclose(score, torch.tensor(expected), atol=1e-3)


@pytest.mark.parametrize(['pred', 'target'], [
pytest.param([0., 1., 2., 3.], [0., 1., 2., 3.]),
pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.]),
pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.])
pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.]),
])
def test_psnr_with_skimage(pred, target):
score = psnr(pred=torch.tensor(pred),
Expand All @@ -61,7 +65,7 @@ def test_psnr_with_skimage(pred, target):

@pytest.mark.parametrize(['pred', 'target'], [
pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.]),
pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.])
pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.]),
])
def test_psnr_base_e_wider_range(pred, target):
score = psnr(pred=torch.tensor(pred),
Expand Down
30 changes: 20 additions & 10 deletions tests/metrics/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,38 +43,48 @@ def new_func(*args, **kwargs):

@pytest.mark.parametrize(['metric_class', 'sklearn_func', 'inputs'], [
pytest.param(Accuracy(), sk_accuracy,
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
{'y_pred': torch.randint(10, size=(128,)),
'y_true': torch.randint(10, size=(128,))},
id='Accuracy'),
pytest.param(AUC(), sk_auc,
{'x': torch.arange(10, dtype=torch.float) / 10,
'y': torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.3, 0.5, 0.6, 0.7])},
id='AUC'),
pytest.param(AveragePrecision(), sk_average_precision,
{'y_score': torch.randint(2, size=(128,)), 'y_true': torch.randint(2, size=(128,))},
{'y_score': torch.randint(2, size=(128,)),
'y_true': torch.randint(2, size=(128,))},
id='AveragePrecision'),
pytest.param(ConfusionMatrix(), sk_confusion_matrix,
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
{'y_pred': torch.randint(10, size=(128,)),
'y_true': torch.randint(10, size=(128,))},
id='ConfusionMatrix'),
pytest.param(F1(average='macro'), partial(sk_f1_score, average='macro'),
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
{'y_pred': torch.randint(10, size=(128,)),
'y_true': torch.randint(10, size=(128,))},
id='F1'),
pytest.param(FBeta(beta=0.5, average='macro'), partial(sk_fbeta_score, beta=0.5, average='macro'),
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
{'y_pred': torch.randint(10, size=(128,)),
'y_true': torch.randint(10, size=(128,))},
id='FBeta'),
pytest.param(Precision(average='macro'), partial(sk_precision, average='macro'),
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
{'y_pred': torch.randint(10, size=(128,)),
'y_true': torch.randint(10, size=(128,))},
id='Precision'),
pytest.param(Recall(average='macro'), partial(sk_recall, average='macro'),
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
{'y_pred': torch.randint(10, size=(128,)),
'y_true': torch.randint(10, size=(128,))},
id='Recall'),
pytest.param(PrecisionRecallCurve(), _xy_only(sk_precision_recall_curve),
{'probas_pred': torch.rand(size=(128,)), 'y_true': torch.randint(2, size=(128,))},
{'probas_pred': torch.rand(size=(128,)),
'y_true': torch.randint(2, size=(128,))},
id='PrecisionRecallCurve'),
pytest.param(ROC(), _xy_only(sk_roc_curve),
{'y_score': torch.rand(size=(128,)), 'y_true': torch.randint(2, size=(128,))},
{'y_score': torch.rand(size=(128,)),
'y_true': torch.randint(2, size=(128,))},
id='ROC'),
pytest.param(AUROC(), sk_roc_auc_score,
{'y_score': torch.rand(size=(128,)), 'y_true': torch.randint(2, size=(128,))},
{'y_score': torch.rand(size=(128,)),
'y_true': torch.randint(2, size=(128,))},
id='AUROC'),
])
def test_sklearn_metric(metric_class, sklearn_func, inputs):
Expand Down