Skip to content

Commit 92f122e

Browse files
elias-ramziBorda
andauthored
Fix average_precision metric (#2319)
* Fixed average_precision metric, parenthesis were missing. Added test test that failed with the old implementation * Modified CHANGELOG.md * Update CHANGELOG.md Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 63bd058 commit 92f122e

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2020

2121
- Fixed an issue with forward hooks not being removed after model summary ([#2298](https://github.com/PyTorchLightning/pytorch-lightning/pull/2298))
2222

23+
- Fixed `average_precision` metric ([#2319](https://github.com/PyTorchLightning/pytorch-lightning/pull/2319))
24+
2325

2426
## [0.8.1] - 2020-06-19
2527

pytorch_lightning/metrics/functional/classification.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,7 @@ def average_precision(
844844
# Return the step function integral
845845
# The following works because the last entry of precision is
846846
# guaranteed to be 1, as returned by precision_recall_curve
847-
return -torch.sum(recall[1:] - recall[:-1] * precision[:-1])
847+
return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1])
848848

849849

850850
def dice_score(

tests/metrics/functional/test_classification.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -342,18 +342,19 @@ def test_auc(x, y, expected):
342342
assert auc(torch.tensor(x), torch.tensor(y)) == expected
343343

344344

345-
def test_average_precision_constant_values():
345+
@pytest.mark.parametrize(['scores', 'target', 'expected_score'], [
346346
# Check the average_precision_score of a constant predictor is
347347
# the TPR
348-
349348
# Generate a dataset with 25% of positives
350-
target = torch.zeros(100, dtype=torch.float)
351-
target[::4] = 1
352349
# And a constant score
353-
pred = torch.ones(100)
354350
# The precision is then the fraction of positive whatever the recall
355351
# is, as there is only one threshold:
356-
assert average_precision(pred, target).item() == .25
352+
pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25),
353+
# With treshold .8 : 1 TP and 2 TN and one FN
354+
pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75),
355+
])
356+
def test_average_precision(scores, target, expected_score):
357+
assert average_precision(scores, target) == expected_score
357358

358359

359360
@pytest.mark.parametrize(['pred', 'target', 'expected'], [

0 commit comments

Comments
 (0)