@@ -138,10 +138,10 @@ def stat_scores_multiple_classes(
138
138
target : torch .Tensor ,
139
139
num_classes : Optional [int ] = None ,
140
140
argmax_dim : int = 1 ,
141
+ reduction : str = 'none' ,
141
142
) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
142
143
"""
143
- Calls the stat_scores function iteratively for all classes, thus
144
- calculating the number of true postive, false postive, true negative
144
+ Calculates the number of true postive, false postive, true negative
145
145
and false negative for each class
146
146
147
147
Args:
@@ -150,6 +150,12 @@ def stat_scores_multiple_classes(
150
150
num_classes: number of classes if known
151
151
argmax_dim: if pred is a tensor of probabilities, this indicates the
152
152
axis the argmax transformation will be applied over
153
+ reduction: method for reducing result values (default: none)
154
+ Available reduction methods:
155
+
156
+ - elementwise_mean: takes the mean
157
+ - none: pass array
158
+ - sum: add elements
153
159
154
160
Return:
155
161
True Positive, False Positive, True Negative, False Negative, Support
@@ -173,16 +179,58 @@ def stat_scores_multiple_classes(
173
179
if pred .ndim == target .ndim + 1 :
174
180
pred = to_categorical (pred , argmax_dim = argmax_dim )
175
181
176
- num_classes = get_num_classes (pred = pred , target = target ,
177
- num_classes = num_classes )
182
+ num_classes = get_num_classes (pred = pred , target = target , num_classes = num_classes )
178
183
179
- tps = torch .zeros ((num_classes ,), device = pred .device )
180
- fps = torch .zeros ((num_classes ,), device = pred .device )
181
- tns = torch .zeros ((num_classes ,), device = pred .device )
182
- fns = torch .zeros ((num_classes ,), device = pred .device )
183
- sups = torch .zeros ((num_classes ,), device = pred .device )
184
- for c in range (num_classes ):
185
- tps [c ], fps [c ], tns [c ], fns [c ], sups [c ] = stat_scores (pred = pred , target = target , class_index = c )
184
+ if pred .dtype != torch .bool :
185
+ pred .clamp_max_ (max = num_classes )
186
+ if target .dtype != torch .bool :
187
+ target .clamp_max_ (max = num_classes )
188
+
189
+ possible_reductions = ('none' , 'sum' , 'elementwise_mean' )
190
+ if reduction not in possible_reductions :
191
+ raise ValueError ("reduction type %s not supported" % reduction )
192
+
193
+ if reduction == 'none' :
194
+ pred = pred .view ((- 1 , )).long ()
195
+ target = target .view ((- 1 , )).long ()
196
+
197
+ tps = torch .zeros ((num_classes + 1 ,), device = pred .device )
198
+ fps = torch .zeros ((num_classes + 1 ,), device = pred .device )
199
+ tns = torch .zeros ((num_classes + 1 ,), device = pred .device )
200
+ fns = torch .zeros ((num_classes + 1 ,), device = pred .device )
201
+ sups = torch .zeros ((num_classes + 1 ,), device = pred .device )
202
+
203
+ match_true = (pred == target ).float ()
204
+ match_false = 1 - match_true
205
+
206
+ tps .scatter_add_ (0 , pred , match_true )
207
+ fps .scatter_add_ (0 , pred , match_false )
208
+ fns .scatter_add_ (0 , target , match_false )
209
+ tns = pred .size (0 ) - (tps + fps + fns )
210
+ sups .scatter_add_ (0 , target , torch .ones_like (match_true ))
211
+
212
+ tps = tps [:num_classes ]
213
+ fps = fps [:num_classes ]
214
+ tns = tns [:num_classes ]
215
+ fns = fns [:num_classes ]
216
+ sups = sups [:num_classes ]
217
+
218
+ elif reduction == 'sum' or reduction == 'elementwise_mean' :
219
+ count_match_true = (pred == target ).sum ().float ()
220
+ oob_tp , oob_fp , oob_tn , oob_fn , oob_sup = stat_scores (pred , target , num_classes , argmax_dim )
221
+
222
+ tps = count_match_true - oob_tp
223
+ fps = pred .nelement () - count_match_true - oob_fp
224
+ fns = pred .nelement () - count_match_true - oob_fn
225
+ tns = pred .nelement () * (num_classes + 1 ) - (tps + fps + fns + oob_tn )
226
+ sups = pred .nelement () - oob_sup .float ()
227
+
228
+ if reduction == 'elementwise_mean' :
229
+ tps /= num_classes
230
+ fps /= num_classes
231
+ fns /= num_classes
232
+ tns /= num_classes
233
+ sups /= num_classes
186
234
187
235
return tps , fps , tns , fns , sups
188
236
@@ -218,16 +266,13 @@ def accuracy(
218
266
tensor(0.7500)
219
267
220
268
"""
221
- tps , fps , tns , fns , sups = stat_scores_multiple_classes (
222
- pred = pred , target = target , num_classes = num_classes )
223
-
224
269
if not (target > 0 ).any () and num_classes is None :
225
270
raise RuntimeError ("cannot infer num_classes when target is all zero" )
226
271
227
- if reduction in ( 'elementwise_mean' , 'sum' ):
228
- return reduce ( sum ( tps ) / sum ( sups ) , reduction = reduction )
229
- if reduction == 'none' :
230
- return reduce ( tps / sups , reduction = reduction )
272
+ tps , fps , tns , fns , sups = stat_scores_multiple_classes (
273
+ pred = pred , target = target , num_classes = num_classes , reduction = reduction )
274
+
275
+ return tps / sups
231
276
232
277
233
278
def confusion_matrix (
0 commit comments