@@ -13,30 +13,28 @@ def __init__(self, uri, user, password):
13
13
def close (self ):
14
14
self ._driver .close ()
15
15
16
- def compute_and_store_distances (self , k , exact ):
16
+ def compute_and_store_distances (self , k , exact , distance_function , relationship_name ):
17
17
start = time .time ()
18
18
data , data_labels = self .get_transaction_vectors ()
19
19
print ("Time to get vectors:" , time .time () - start )
20
20
start = time .time ()
21
- #selected_feature = np.loadtxt("array.txt")
22
- #new_data = [np.multiply(vector, selected_feature).tolist() for vector in data]
21
+
23
22
if exact :
24
- ann_labels , ann_distances = self .compute_knn (data , data_labels , k )
25
- label = "DISTANT_FROM_EXACT"
23
+ ann_labels , ann_distances = self .compute_knn (data , data_labels , k , distance_function )
26
24
else :
27
- ann_labels , ann_distances = self .compute_ann (data , data_labels , k )
28
- label = "DISTANT_FROM"
25
+ ann_labels , ann_distances = self .compute_ann (data , data_labels , k , distance_function )
26
+
29
27
print ("Time to compute nearest neighbors:" , time .time () - start )
30
28
start = time .time ()
31
- self .store_ann (data_labels , ann_labels , ann_distances , label )
29
+ self .store_ann (data_labels , ann_labels , ann_distances , relationship_name )
32
30
print ("Time to store nearest neighbors:" , time .time () - start )
33
31
print ("done" )
34
32
35
- def compute_ann (self , data , data_labels , k ):
33
+ def compute_ann (self , data , data_labels , k , distance_function ):
36
34
dim = len (data [0 ])
37
35
num_elements = len (data_labels )
38
36
# Declaring index
39
- p = hnswlib .Index (space = 'l2' , dim = dim ) # possible options are l2, cosine or ip
37
+ p = hnswlib .Index (space = distance_function , dim = dim ) # possible options for ditance_formula are l2, cosine or ip
40
38
# Initing index - the maximum number of elements should be known beforehand
41
39
p .init_index (max_elements = num_elements , ef_construction = 800 , M = 200 )
42
40
# Element insertion (can be called several times):
@@ -47,9 +45,9 @@ def compute_ann(self, data, data_labels, k):
47
45
labels , distances = p .knn_query (data , k = k )
48
46
return labels , distances
49
47
50
- def compute_knn (self , data , data_labels , k ):
48
+ def compute_knn (self , data , data_labels , k , distance_function ):
51
49
pre_processed_data = [np .array (item ) for item in data ]
52
- nbrs = NearestNeighbors (n_neighbors = k , algorithm = 'brute' , metric = 'mahalanobis' , n_jobs = - 1 ).fit (pre_processed_data )
50
+ nbrs = NearestNeighbors (n_neighbors = k , algorithm = 'brute' , metric = distance_function , n_jobs = - 1 ).fit (pre_processed_data )
53
51
knn_distances , knn_labels = nbrs .kneighbors (pre_processed_data )
54
52
distances = knn_distances
55
53
labels = [[data_labels [element ] for element in item ] for item in knn_labels ]
@@ -114,8 +112,11 @@ def store_ann(self, data_labels, ann_labels, ann_distances, label): #ADD the opp
114
112
115
113
if __name__ == '__main__' :
116
114
uri = "bolt://localhost:7687"
115
+ distance_formula_value = "l2" #'mahalanobis' for exact
116
+ #relationship_name_value = "DISTANT_FROM_EXACT"
117
+ relationship_name_value = "DISTANT_FROM"
117
118
analyzer = DistanceBasedAnalysis (uri = uri , user = "neo4j" , password = "q1" )
118
- analyzer .compute_and_store_distances (25 , False );
119
+ analyzer .compute_and_store_distances (25 , False , distance_formula_value , relationship_name_value );
119
120
# Uncomment this line to calculate exact NNs, but it will take a lot of time!
120
121
# analyzer.compute_and_store_distances(25, True);
121
122
analyzer .close ()
0 commit comments