1
+ import snap
2
+ import numpy as np
3
+ import sys
4
+ from sklearn .ensemble import RandomForestClassifier
5
+ from sklearn .linear_model import LogisticRegression
6
+ from sklearn .neighbors import KNeighborsClassifier
7
+ from sklearn .neural_network import MLPClassifier
8
+ from weight_evolution import EvolModel
9
+ from train_graph import Train_Graph
10
+ from test_graph import Test_Graph
11
+ from feature_extractors import *
12
+
13
+ def get_all_features (feature_funcs , max_scc , train_examples , test_examples ):
14
+ all_train_features = []
15
+ all_test_features = []
16
+ for func in feature_funcs :
17
+ print 'Extracting features with' , func
18
+ all_train_features .append (get_features (max_scc , train_examples , func ))
19
+ all_test_features .append (get_features (max_scc , test_examples , func ))
20
+ # Transpose since sklearn takes nsamples, nfeatures shape
21
+ all_train_features = np .array (all_train_features ).T
22
+ all_test_features = np .array (all_test_features ).T
23
+ return all_train_features , all_test_features
24
+
25
+ def get_features (max_scc , examples , func ):
26
+ all_features = []
27
+ for ex in examples :
28
+ result = func (max_scc , ex [0 ], ex [1 ])
29
+ all_features .append (result )
30
+ return all_features
31
+
32
+ def print_metrics (gt , pred ):
33
+ print 'Accuracy:' , sklearn .metrics .accuracy_score (gt , pred )
34
+ print 'Precision:' , sklearn .metrics .precision_score (gt , pred )
35
+ print 'Recall:' , sklearn .metrics .recall_score (gt , pred )
36
+ print 'F1 Score:' , sklearn .metrics .f1_score (gt , pred )
37
+
38
+ def test_classifiers (train_examples , train_labels , test_examples , test_labels ):
39
+ knn = KNeighborsClassifier ()
40
+ logistic = LogisticRegression ()
41
+ rf = RandomForestClassifier (n_estimators = 100 )
42
+ my_nn = MLPClassifier (hidden_layer_sizes = (100 , 50 , 50 ))
43
+ bliss_model = EvolModel ()
44
+ models = [bliss_model , knn , logistic , rf , my_nn ]
45
+ for model in models :
46
+ print 'Training model' , model
47
+ model .fit (train_examples , train_labels )
48
+ preds = model .predict (test_examples )
49
+ gt = [elem for elem in test_labels ]
50
+ print ''
51
+ print 'Evaluating Testing Set:'
52
+ print_metrics (gt , preds )
53
+
54
+ print ''
55
+ print 'Evaluating Training Set:'
56
+ preds_train = model .predict (train_examples )
57
+ gt_train = [elem for elem in train_labels ]
58
+ print_metrics (gt_train , preds_train )
59
+
60
+ def main (temp_train_feats , temp_train_ex , temp_test_feats , temp_test_ex , graph_file ):
61
+ train_graph_obj = Train_Graph (graph_file_root = graph_file )
62
+ graph = train_graph_obj .pgraph
63
+ train_examples = temp_train_ex [:, 0 ].tolist ()
64
+ train_labels = temp_train_ex [:, 1 ].tolist ()
65
+ test_examples = temp_test_ex [:, 0 ].tolist ()
66
+ test_labels = temp_test_ex [:, 1 ].tolist ()
67
+ feature_funcs = [preferential_attachment ]
68
+ # feature_funcs = [get_graph_distance, get_ev_centr_sum, get_page_rank_sum, \
69
+ # preferential_attachment, get_2_hops, get_degree_sum, \
70
+ # std_nbr_degree_sum, mean_nbr_deg_sum, adamic_adar_2, \
71
+ # common_neighbors_2]
72
+ print 'Extracting features'
73
+ norm_train_features , norm_test_features = get_all_features (feature_funcs , graph , train_examples , test_examples )
74
+ all_train_feats = np .hstack ([norm_train_features , temp_train_feats ])
75
+ all_test_feats = np .hstack ([norm_test_features , temp_test_feats ])
76
+ all_train_feats = sklearn .preprocessing .scale (all_train_feats )
77
+ all_test_feats = sklearn .preprocessing .scale (all_test_feats )
78
+ print 'Testing Classifiers'
79
+ test_classifiers (all_train_features , train_labels , all_test_features , test_labels )
80
+
81
+ if __name__ == '__main__' :
82
+ temp_train_feats = np .load (sys .argv [1 ])
83
+ temp_train_ex = np .load (sys .argv [2 ])
84
+ temp_test_feats = np .load (sys .argv [3 ])
85
+ temp_test_ex = np .load (sys .argv [4 ])
86
+ graph_file = np .load (sys .argv [5 ])
87
+ main (temp_train_feats , temp_train_ex , temp_test_feats , temp_test_ex , graph_file )
0 commit comments