1
1
import numpy as np
2
-
2
+ import sklearn .metrics as sk
3
+ recall_level_default = 0.95
3
4
4
5
def calib_err (confidence , correct , p = '2' , beta = 100 ):
5
6
# beta is target bin size
@@ -90,15 +91,6 @@ def tune_temp(logits, labels, binary_search=True, lower=0.2, upper=5.0, eps=0.00
90
91
return t
91
92
92
93
93
- def get_measures (confidence , correct ):
94
- rms = calib_err (confidence , correct , p = '2' )
95
- aurra_metric = aurra (confidence , correct )
96
- mad = calib_err (confidence , correct , p = '1' ) # secondary metric
97
- sf1 = soft_f1 (confidence , correct ) # secondary metric
98
-
99
- return rms , aurra_metric , mad , sf1
100
-
101
-
102
94
def print_measures (rms , aurra_metric , mad , sf1 , method_name = 'Baseline' ):
103
95
print ('\t \t \t \t \t \t \t ' + method_name )
104
96
print ('RMS Calib Error (%): \t \t {:.2f}' .format (100 * rms ))
@@ -122,3 +114,89 @@ def show_calibration_results(confidence, correct, method_name='Baseline'):
122
114
# print('Soft F1-Score (%): \t\t{:.2f}'.format(
123
115
# 100 * soft_f1(confidence, correct)))
124
116
117
+ def fpr_and_fdr_at_recall (y_true , y_score , recall_level = recall_level_default , pos_label = None ):
118
+ classes = np .unique (y_true )
119
+ if (pos_label is None and
120
+ not (np .array_equal (classes , [0 , 1 ]) or
121
+ np .array_equal (classes , [- 1 , 1 ]) or
122
+ np .array_equal (classes , [0 ]) or
123
+ np .array_equal (classes , [- 1 ]) or
124
+ np .array_equal (classes , [1 ]))):
125
+ raise ValueError ("Data is not binary and pos_label is not specified" )
126
+ elif pos_label is None :
127
+ pos_label = 1.
128
+
129
+ # make y_true a boolean vector
130
+ y_true = (y_true == pos_label )
131
+
132
+ # sort scores and corresponding truth values
133
+ desc_score_indices = np .argsort (y_score , kind = "mergesort" )[::- 1 ]
134
+ y_score = y_score [desc_score_indices ]
135
+ y_true = y_true [desc_score_indices ]
136
+
137
+ # y_score typically has many tied values. Here we extract
138
+ # the indices associated with the distinct values. We also
139
+ # concatenate a value for the end of the curve.
140
+ distinct_value_indices = np .where (np .diff (y_score ))[0 ]
141
+ threshold_idxs = np .r_ [distinct_value_indices , y_true .size - 1 ]
142
+
143
+ # accumulate the true positives with decreasing threshold
144
+ tps = stable_cumsum (y_true )[threshold_idxs ]
145
+ fps = 1 + threshold_idxs - tps # add one because of zero-based indexing
146
+
147
+ thresholds = y_score [threshold_idxs ]
148
+
149
+ recall = tps / tps [- 1 ]
150
+
151
+ last_ind = tps .searchsorted (tps [- 1 ])
152
+ sl = slice (last_ind , None , - 1 ) # [last_ind::-1]
153
+ recall , fps , tps , thresholds = np .r_ [recall [sl ], 1 ], np .r_ [fps [sl ], 0 ], np .r_ [tps [sl ], 0 ], thresholds [sl ]
154
+
155
+ cutoff = np .argmin (np .abs (recall - recall_level ))
156
+
157
+ return fps [cutoff ] / (np .sum (np .logical_not (y_true ))) # , fps[cutoff]/(fps[cutoff] + tps[cutoff])
158
+
159
+ def get_measures (_pos , _neg , recall_level = recall_level_default ):
160
+ pos = np .array (_pos [:]).reshape ((- 1 , 1 ))
161
+ neg = np .array (_neg [:]).reshape ((- 1 , 1 ))
162
+ examples = np .squeeze (np .vstack ((pos , neg )))
163
+ labels = np .zeros (len (examples ), dtype = np .int32 )
164
+ labels [:len (pos )] += 1
165
+
166
+ auroc = sk .roc_auc_score (labels , examples )
167
+ aupr = sk .average_precision_score (labels , examples )
168
+ fpr = fpr_and_fdr_at_recall (labels , examples , recall_level )
169
+
170
+ return auroc , aupr , fpr
171
+
172
+
173
+ def print_measures_old (auroc , aupr , fpr , method_name = 'Ours' , recall_level = recall_level_default ):
174
+ print ('\t \t \t ' + method_name )
175
+ print ('FPR{:d}:\t {:.2f}' .format (int (100 * recall_level ), 100 * fpr ))
176
+ print ('AUROC: \t {:.2f}' .format (100 * auroc ))
177
+ print ('AUPR: \t {:.2f}' .format (100 * aupr ))
178
+
179
+
180
+ def print_measures_with_std (aurocs , auprs , fprs , method_name = 'Ours' , recall_level = recall_level_default ):
181
+ print ('\t \t \t ' + method_name )
182
+ print ('FPR{:d}:\t {:.2f}\t +/- {:.2f}' .format (int (100 * recall_level ), 100 * np .mean (fprs ), 100 * np .std (fprs )))
183
+ print ('AUROC: \t {:.2f}\t +/- {:.2f}' .format (100 * np .mean (aurocs ), 100 * np .std (aurocs )))
184
+ print ('AUPR: \t {:.2f}\t +/- {:.2f}' .format (100 * np .mean (auprs ), 100 * np .std (auprs )))
185
+
186
+
187
+ def get_and_print_results (out_score , in_score , num_to_avg = 1 ):
188
+
189
+ aurocs , auprs , fprs = [], [], []
190
+ #for _ in range(num_to_avg):
191
+ # out_score = get_ood_scores(ood_loader)
192
+ measures = get_measures (out_score , in_score )
193
+ aurocs .append (measures [0 ]); auprs .append (measures [1 ]); fprs .append (measures [2 ])
194
+
195
+ auroc = np .mean (aurocs ); aupr = np .mean (auprs ); fpr = np .mean (fprs )
196
+ #auroc_list.append(auroc); aupr_list.append(aupr); fpr_list.append(fpr)
197
+
198
+ #if num_to_avg >= 5:
199
+ # print_measures_with_std(aurocs, auprs, fprs, method_name='Ours')
200
+ #else:
201
+ # print_measures(auroc, aupr, fpr, method_name='Ours')
202
+ return auroc , aupr , fpr
0 commit comments