@@ -223,7 +223,7 @@ pub fn main() -> Result<(), Error> {
223
223
references. get_sketch_slice ( j, k) ,
224
224
references. sketch_size ,
225
225
) ;
226
- dist = if * ani { ani_pois ( dist, k_f32) } else { dist } ;
226
+ dist = if * ani { ani_pois ( dist, k_f32) } else { 1.0_f32 - dist } ;
227
227
dist_slice[ dist_idx] = dist;
228
228
} else {
229
229
let dist =
@@ -275,7 +275,7 @@ pub fn main() -> Result<(), Error> {
275
275
references. sketch_size ,
276
276
) ;
277
277
dist =
278
- if * ani { ani_pois ( dist, k_f32) } else { dist } ;
278
+ if * ani { ani_pois ( dist, k_f32) } else { 1.0_f32 - dist } ;
279
279
let dist_item = SparseJaccard ( j, dist) ;
280
280
if heap. len ( ) < nn
281
281
|| dist_item < * heap. peek ( ) . unwrap ( )
@@ -333,6 +333,7 @@ pub fn main() -> Result<(), Error> {
333
333
let mut distances =
334
334
DistanceMatrix :: new ( & references, Some ( & query_db) , dist_type) ;
335
335
let par_chunk = CHUNK_SIZE * distances. n_dist_cols ( ) ;
336
+ let nq = query_db. number_samples_loaded ( ) ;
336
337
distances
337
338
. dists_mut ( )
338
339
. par_chunks_mut ( par_chunk)
@@ -341,15 +342,15 @@ pub fn main() -> Result<(), Error> {
341
342
. for_each ( |( chunk_idx, dist_slice) | {
342
343
// Get first i, j index for the chunk
343
344
let start_dist_idx = chunk_idx * CHUNK_SIZE ;
344
- let ( mut i, mut j) = calc_query_indices ( start_dist_idx, n ) ;
345
+ let ( mut i, mut j) = calc_query_indices ( start_dist_idx, nq ) ;
345
346
for dist_idx in 0 ..CHUNK_SIZE {
346
347
if let Some ( k) = k_idx {
347
348
let mut dist = jaccard_dist (
348
349
references. get_sketch_slice ( i, k) ,
349
350
query_db. get_sketch_slice ( j, k) ,
350
351
references. sketch_size ,
351
352
) ;
352
- dist = if * ani { ani_pois ( dist, k_f32) } else { dist } ;
353
+ dist = if * ani { ani_pois ( dist, k_f32) } else { 1.0_f32 - dist } ;
353
354
dist_slice[ dist_idx] = dist;
354
355
} else {
355
356
let dist = core_acc_dist ( & references, & query_db, i, j) ;
@@ -359,7 +360,7 @@ pub fn main() -> Result<(), Error> {
359
360
360
361
// Move to next index
361
362
j += 1 ;
362
- if j >= n {
363
+ if j >= nq {
363
364
i += 1 ;
364
365
j = 0 ;
365
366
// End of all dists reached (final chunk)
0 commit comments