Skip to content

Commit c14cd67

Browse files
author
Younghun Roh
committed
Update for Lightning-AI#2781
1 parent 6df2018 commit c14cd67

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pytorch_lightning/metrics/functional/classification.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -171,11 +171,11 @@ def stat_scores_multiple_classes(
171171
>>> sups
172172
tensor([1., 0., 1., 1.])
173173
"""
174-
num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes)
175-
176174
if pred.ndim == target.ndim + 1:
177175
pred = to_categorical(pred, argmax_dim=argmax_dim)
178176

177+
num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes)
178+
179179
pred = pred.view((-1, )).long()
180180
target = target.view((-1, )).long()
181181

0 commit comments

Comments
 (0)