Skip to content

Commit 0777070

Browse files
authored
Add AUPR measure code for ImageNet-O
1 parent fe092a9 commit 0777070

File tree

1 file changed

+88
-10
lines changed

1 file changed

+88
-10
lines changed

calibration_tools.py

+88-10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
2-
2+
import sklearn.metrics as sk
3+
recall_level_default = 0.95
34

45
def calib_err(confidence, correct, p='2', beta=100):
56
# beta is target bin size
@@ -90,15 +91,6 @@ def tune_temp(logits, labels, binary_search=True, lower=0.2, upper=5.0, eps=0.00
9091
return t
9192

9293

93-
def get_measures(confidence, correct):
94-
rms = calib_err(confidence, correct, p='2')
95-
aurra_metric = aurra(confidence, correct)
96-
mad = calib_err(confidence, correct, p='1') # secondary metric
97-
sf1 = soft_f1(confidence, correct) # secondary metric
98-
99-
return rms, aurra_metric, mad, sf1
100-
101-
10294
def print_measures(rms, aurra_metric, mad, sf1, method_name='Baseline'):
10395
print('\t\t\t\t\t\t\t' + method_name)
10496
print('RMS Calib Error (%): \t\t{:.2f}'.format(100 * rms))
@@ -122,3 +114,89 @@ def show_calibration_results(confidence, correct, method_name='Baseline'):
122114
# print('Soft F1-Score (%): \t\t{:.2f}'.format(
123115
# 100 * soft_f1(confidence, correct)))
124116

117+
def fpr_and_fdr_at_recall(y_true, y_score, recall_level=recall_level_default, pos_label=None):
118+
classes = np.unique(y_true)
119+
if (pos_label is None and
120+
not (np.array_equal(classes, [0, 1]) or
121+
np.array_equal(classes, [-1, 1]) or
122+
np.array_equal(classes, [0]) or
123+
np.array_equal(classes, [-1]) or
124+
np.array_equal(classes, [1]))):
125+
raise ValueError("Data is not binary and pos_label is not specified")
126+
elif pos_label is None:
127+
pos_label = 1.
128+
129+
# make y_true a boolean vector
130+
y_true = (y_true == pos_label)
131+
132+
# sort scores and corresponding truth values
133+
desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1]
134+
y_score = y_score[desc_score_indices]
135+
y_true = y_true[desc_score_indices]
136+
137+
# y_score typically has many tied values. Here we extract
138+
# the indices associated with the distinct values. We also
139+
# concatenate a value for the end of the curve.
140+
distinct_value_indices = np.where(np.diff(y_score))[0]
141+
threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]
142+
143+
# accumulate the true positives with decreasing threshold
144+
tps = stable_cumsum(y_true)[threshold_idxs]
145+
fps = 1 + threshold_idxs - tps # add one because of zero-based indexing
146+
147+
thresholds = y_score[threshold_idxs]
148+
149+
recall = tps / tps[-1]
150+
151+
last_ind = tps.searchsorted(tps[-1])
152+
sl = slice(last_ind, None, -1) # [last_ind::-1]
153+
recall, fps, tps, thresholds = np.r_[recall[sl], 1], np.r_[fps[sl], 0], np.r_[tps[sl], 0], thresholds[sl]
154+
155+
cutoff = np.argmin(np.abs(recall - recall_level))
156+
157+
return fps[cutoff] / (np.sum(np.logical_not(y_true))) # , fps[cutoff]/(fps[cutoff] + tps[cutoff])
158+
159+
def get_measures(_pos, _neg, recall_level=recall_level_default):
160+
pos = np.array(_pos[:]).reshape((-1, 1))
161+
neg = np.array(_neg[:]).reshape((-1, 1))
162+
examples = np.squeeze(np.vstack((pos, neg)))
163+
labels = np.zeros(len(examples), dtype=np.int32)
164+
labels[:len(pos)] += 1
165+
166+
auroc = sk.roc_auc_score(labels, examples)
167+
aupr = sk.average_precision_score(labels, examples)
168+
fpr = fpr_and_fdr_at_recall(labels, examples, recall_level)
169+
170+
return auroc, aupr, fpr
171+
172+
173+
def print_measures_old(auroc, aupr, fpr, method_name='Ours', recall_level=recall_level_default):
174+
print('\t\t\t' + method_name)
175+
print('FPR{:d}:\t{:.2f}'.format(int(100 * recall_level), 100 * fpr))
176+
print('AUROC: \t{:.2f}'.format(100 * auroc))
177+
print('AUPR: \t{:.2f}'.format(100 * aupr))
178+
179+
180+
def print_measures_with_std(aurocs, auprs, fprs, method_name='Ours', recall_level=recall_level_default):
181+
print('\t\t\t' + method_name)
182+
print('FPR{:d}:\t{:.2f}\t+/- {:.2f}'.format(int(100 * recall_level), 100 * np.mean(fprs), 100 * np.std(fprs)))
183+
print('AUROC: \t{:.2f}\t+/- {:.2f}'.format(100 * np.mean(aurocs), 100 * np.std(aurocs)))
184+
print('AUPR: \t{:.2f}\t+/- {:.2f}'.format(100 * np.mean(auprs), 100 * np.std(auprs)))
185+
186+
187+
def get_and_print_results(out_score, in_score, num_to_avg=1):
188+
189+
aurocs, auprs, fprs = [], [], []
190+
#for _ in range(num_to_avg):
191+
# out_score = get_ood_scores(ood_loader)
192+
measures = get_measures(out_score, in_score)
193+
aurocs.append(measures[0]); auprs.append(measures[1]); fprs.append(measures[2])
194+
195+
auroc = np.mean(aurocs); aupr = np.mean(auprs); fpr = np.mean(fprs)
196+
#auroc_list.append(auroc); aupr_list.append(aupr); fpr_list.append(fpr)
197+
198+
#if num_to_avg >= 5:
199+
# print_measures_with_std(aurocs, auprs, fprs, method_name='Ours')
200+
#else:
201+
# print_measures(auroc, aupr, fpr, method_name='Ours')
202+
return auroc, aupr, fpr

0 commit comments

Comments
 (0)