Skip to content

Commit ac4a215

Browse files
DiuvenYounghun RohBorda
authored
Faster Accuracy metric (#2775)
* Faster classfication stats * Faster accuracy metric * minor change on cls metric * Add out-of-bound class clamping * Add more tests and minor fixes * Resolve code style warning * Update for #2781 * hotfix * Update pytorch_lightning/metrics/functional/classification.py Co-authored-by: Jirka Borovec <[email protected]> * Update about conversation * Add docstring on stat_scores_multiple_classes Co-authored-by: Younghun Roh <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent dd78be5 commit ac4a215

File tree

2 files changed

+73
-24
lines changed

2 files changed

+73
-24
lines changed

pytorch_lightning/metrics/functional/classification.py

+63-18
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,10 @@ def stat_scores_multiple_classes(
138138
target: torch.Tensor,
139139
num_classes: Optional[int] = None,
140140
argmax_dim: int = 1,
141+
reduction: str = 'none',
141142
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
142143
"""
143-
Calls the stat_scores function iteratively for all classes, thus
144-
calculating the number of true postive, false postive, true negative
144+
Calculates the number of true postive, false postive, true negative
145145
and false negative for each class
146146
147147
Args:
@@ -150,6 +150,12 @@ def stat_scores_multiple_classes(
150150
num_classes: number of classes if known
151151
argmax_dim: if pred is a tensor of probabilities, this indicates the
152152
axis the argmax transformation will be applied over
153+
reduction: method for reducing result values (default: none)
154+
Available reduction methods:
155+
156+
- elementwise_mean: takes the mean
157+
- none: pass array
158+
- sum: add elements
153159
154160
Return:
155161
True Positive, False Positive, True Negative, False Negative, Support
@@ -173,16 +179,58 @@ def stat_scores_multiple_classes(
173179
if pred.ndim == target.ndim + 1:
174180
pred = to_categorical(pred, argmax_dim=argmax_dim)
175181

176-
num_classes = get_num_classes(pred=pred, target=target,
177-
num_classes=num_classes)
182+
num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes)
178183

179-
tps = torch.zeros((num_classes,), device=pred.device)
180-
fps = torch.zeros((num_classes,), device=pred.device)
181-
tns = torch.zeros((num_classes,), device=pred.device)
182-
fns = torch.zeros((num_classes,), device=pred.device)
183-
sups = torch.zeros((num_classes,), device=pred.device)
184-
for c in range(num_classes):
185-
tps[c], fps[c], tns[c], fns[c], sups[c] = stat_scores(pred=pred, target=target, class_index=c)
184+
if pred.dtype != torch.bool:
185+
pred.clamp_max_(max=num_classes)
186+
if target.dtype != torch.bool:
187+
target.clamp_max_(max=num_classes)
188+
189+
possible_reductions = ('none', 'sum', 'elementwise_mean')
190+
if reduction not in possible_reductions:
191+
raise ValueError("reduction type %s not supported" % reduction)
192+
193+
if reduction == 'none':
194+
pred = pred.view((-1, )).long()
195+
target = target.view((-1, )).long()
196+
197+
tps = torch.zeros((num_classes + 1,), device=pred.device)
198+
fps = torch.zeros((num_classes + 1,), device=pred.device)
199+
tns = torch.zeros((num_classes + 1,), device=pred.device)
200+
fns = torch.zeros((num_classes + 1,), device=pred.device)
201+
sups = torch.zeros((num_classes + 1,), device=pred.device)
202+
203+
match_true = (pred == target).float()
204+
match_false = 1 - match_true
205+
206+
tps.scatter_add_(0, pred, match_true)
207+
fps.scatter_add_(0, pred, match_false)
208+
fns.scatter_add_(0, target, match_false)
209+
tns = pred.size(0) - (tps + fps + fns)
210+
sups.scatter_add_(0, target, torch.ones_like(match_true))
211+
212+
tps = tps[:num_classes]
213+
fps = fps[:num_classes]
214+
tns = tns[:num_classes]
215+
fns = fns[:num_classes]
216+
sups = sups[:num_classes]
217+
218+
elif reduction == 'sum' or reduction == 'elementwise_mean':
219+
count_match_true = (pred == target).sum().float()
220+
oob_tp, oob_fp, oob_tn, oob_fn, oob_sup = stat_scores(pred, target, num_classes, argmax_dim)
221+
222+
tps = count_match_true - oob_tp
223+
fps = pred.nelement() - count_match_true - oob_fp
224+
fns = pred.nelement() - count_match_true - oob_fn
225+
tns = pred.nelement() * (num_classes + 1) - (tps + fps + fns + oob_tn)
226+
sups = pred.nelement() - oob_sup.float()
227+
228+
if reduction == 'elementwise_mean':
229+
tps /= num_classes
230+
fps /= num_classes
231+
fns /= num_classes
232+
tns /= num_classes
233+
sups /= num_classes
186234

187235
return tps, fps, tns, fns, sups
188236

@@ -218,16 +266,13 @@ def accuracy(
218266
tensor(0.7500)
219267
220268
"""
221-
tps, fps, tns, fns, sups = stat_scores_multiple_classes(
222-
pred=pred, target=target, num_classes=num_classes)
223-
224269
if not (target > 0).any() and num_classes is None:
225270
raise RuntimeError("cannot infer num_classes when target is all zero")
226271

227-
if reduction in ('elementwise_mean', 'sum'):
228-
return reduce(sum(tps) / sum(sups), reduction=reduction)
229-
if reduction == 'none':
230-
return reduce(tps / sups, reduction=reduction)
272+
tps, fps, tns, fns, sups = stat_scores_multiple_classes(
273+
pred=pred, target=target, num_classes=num_classes, reduction=reduction)
274+
275+
return tps / sups
231276

232277

233278
def confusion_matrix(

tests/metrics/functional/test_classification.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,19 @@ def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expect
121121
assert sup.item() == expected_support
122122

123123

124-
@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp',
124+
@pytest.mark.parametrize(['pred', 'target', 'reduction', 'expected_tp', 'expected_fp',
125125
'expected_tn', 'expected_fn', 'expected_support'], [
126-
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]),
126+
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 'none',
127+
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]),
128+
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'none',
127129
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]),
128-
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]),
129-
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2])
130+
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'sum',
131+
torch.tensor(2), torch.tensor(2), torch.tensor(14), torch.tensor(2), torch.tensor(4)),
132+
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'elementwise_mean',
133+
torch.tensor(0.4), torch.tensor(0.4), torch.tensor(2.8), torch.tensor(0.4), torch.tensor(0.8))
130134
])
131-
def test_stat_scores_multiclass(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
132-
tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target)
135+
def test_stat_scores_multiclass(pred, target, reduction, expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
136+
tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target, reduction=reduction)
133137

134138
assert torch.allclose(torch.tensor(expected_tp).to(tp), tp)
135139
assert torch.allclose(torch.tensor(expected_fp).to(fp), fp)

0 commit comments

Comments
 (0)