We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 6df2018 commit c14cd67Copy full SHA for c14cd67
pytorch_lightning/metrics/functional/classification.py
@@ -171,11 +171,11 @@ def stat_scores_multiple_classes(
171
>>> sups
172
tensor([1., 0., 1., 1.])
173
"""
174
- num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes)
175
-
176
if pred.ndim == target.ndim + 1:
177
pred = to_categorical(pred, argmax_dim=argmax_dim)
178
+ num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes)
+
179
pred = pred.view((-1, )).long()
180
target = target.view((-1, )).long()
181
0 commit comments