This repository was archived by the owner on Aug 20, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathevaluation.py
42 lines (39 loc) · 1.52 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import numpy as np
from sklearn.metrics import (
auc,
precision_recall_curve,
roc_auc_score,
f1_score,
confusion_matrix,
matthews_corrcoef,
)
import warnings
warnings.filterwarnings('ignore', category=RuntimeWarning)
def calc_classification_metrics(pred_scores, pred_labels, labels):
if len(np.unique(labels)) == 2: # binary classification
roc_auc_pred_score = roc_auc_score(labels, pred_scores)
precisions, recalls, thresholds = precision_recall_curve(labels,
pred_scores)
fscore = (2 * precisions * recalls) / (precisions + recalls)
fscore[np.isnan(fscore)] = 0
ix = np.argmax(fscore)
threshold = thresholds[ix].item()
pr_auc = auc(recalls, precisions)
tn, fp, fn, tp = confusion_matrix(labels, pred_labels, labels=[0, 1]).ravel()
result = {'roc_auc': roc_auc_pred_score,
'threshold': threshold,
'pr_auc': pr_auc,
'recall': recalls[ix].item(),
'precision': precisions[ix].item(), 'f1': fscore[ix].item(),
'tn': tn.item(), 'fp': fp.item(), 'fn': fn.item(), 'tp': tp.item()
}
else:
acc = (pred_labels == labels).mean()
f1 = f1_score(y_true=labels, y_pred=pred_labels)
result = {
"acc": acc,
"f1": f1,
"acc_and_f1": (acc + f1) / 2,
"mcc": matthews_corrcoef(labels, pred_labels)
}
return result