@@ -205,7 +205,8 @@ impl TermsAggregationInternal {
205
205
#[ derive( Clone , Debug , Default ) ]
206
206
/// Container to store term_ids/or u64 values and their buckets.
207
207
struct TermBuckets {
208
- pub ( crate ) entries : FxHashMap < u64 , TermBucketEntry > ,
208
+ pub ( crate ) entries : FxHashMap < u64 , u64 > ,
209
+ pub ( crate ) sub_aggs : FxHashMap < u64 , Box < dyn SegmentAggregationCollector > > ,
209
210
}
210
211
211
212
#[ derive( Clone , Default ) ]
@@ -249,10 +250,8 @@ impl TermBucketEntry {
249
250
250
251
impl TermBuckets {
251
252
fn force_flush ( & mut self , agg_with_accessor : & AggregationsWithAccessor ) -> crate :: Result < ( ) > {
252
- for entry in & mut self . entries . values_mut ( ) {
253
- if let Some ( sub_aggregations) = entry. sub_aggregations . as_mut ( ) {
254
- sub_aggregations. flush ( agg_with_accessor) ?;
255
- }
253
+ for sub_aggregations in & mut self . sub_aggs . values_mut ( ) {
254
+ sub_aggregations. as_mut ( ) . flush ( agg_with_accessor) ?;
256
255
}
257
256
Ok ( ( ) )
258
257
}
@@ -268,6 +267,7 @@ pub struct SegmentTermCollector {
268
267
blueprint : Option < Box < dyn SegmentAggregationCollector > > ,
269
268
field_type : ColumnType ,
270
269
accessor_idx : usize ,
270
+ val_cache : Vec < u64 > ,
271
271
}
272
272
273
273
pub ( crate ) fn get_agg_name_and_property ( name : & str ) -> ( & str , & str ) {
@@ -292,6 +292,7 @@ impl SegmentAggregationCollector for SegmentTermCollector {
292
292
} )
293
293
}
294
294
295
+ #[ inline]
295
296
fn collect (
296
297
& mut self ,
297
298
doc : crate :: DocId ,
@@ -300,6 +301,7 @@ impl SegmentAggregationCollector for SegmentTermCollector {
300
301
self . collect_block ( & [ doc] , agg_with_accessor)
301
302
}
302
303
304
+ #[ inline]
303
305
fn collect_block (
304
306
& mut self ,
305
307
docs : & [ crate :: DocId ] ,
@@ -310,28 +312,35 @@ impl SegmentAggregationCollector for SegmentTermCollector {
310
312
& agg_with_accessor. buckets . values [ self . accessor_idx ] . sub_aggregation ;
311
313
312
314
if accessor. get_cardinality ( ) == Cardinality :: Full {
313
- for doc in docs {
314
- let term_id = accessor. values . get_val ( * doc) ;
315
- let entry = self
316
- . term_buckets
317
- . entries
318
- . entry ( term_id)
319
- . or_insert_with ( || TermBucketEntry :: from_blueprint ( & self . blueprint ) ) ;
320
- entry. doc_count += 1 ;
321
- if let Some ( sub_aggregations) = entry. sub_aggregations . as_mut ( ) {
315
+ self . val_cache . resize ( docs. len ( ) , 0 ) ;
316
+ accessor. values . get_vals ( docs, & mut self . val_cache ) ;
317
+ for ( doc, term_id) in docs. iter ( ) . zip ( self . val_cache . iter ( ) . cloned ( ) ) {
318
+ let entry = self . term_buckets . entries . entry ( term_id) . or_default ( ) ;
319
+ * entry += 1 ;
320
+ }
321
+ // has subagg
322
+ if let Some ( blueprint) = self . blueprint . as_ref ( ) {
323
+ for ( doc, term_id) in docs. iter ( ) . zip ( self . val_cache . iter ( ) . cloned ( ) ) {
324
+ let sub_aggregations = self
325
+ . term_buckets
326
+ . sub_aggs
327
+ . entry ( term_id)
328
+ . or_insert_with ( || blueprint. clone ( ) ) ;
322
329
sub_aggregations. collect ( * doc, sub_aggregation_accessor) ?;
323
330
}
324
331
}
325
332
} else {
326
333
for doc in docs {
327
334
for term_id in accessor. values_for_doc ( * doc) {
328
- let entry = self
329
- . term_buckets
330
- . entries
331
- . entry ( term_id)
332
- . or_insert_with ( || TermBucketEntry :: from_blueprint ( & self . blueprint ) ) ;
333
- entry. doc_count += 1 ;
334
- if let Some ( sub_aggregations) = entry. sub_aggregations . as_mut ( ) {
335
+ let entry = self . term_buckets . entries . entry ( term_id) . or_default ( ) ;
336
+ * entry += 1 ;
337
+ // TODO: check if seperate loop is faster (may depend on the codec)
338
+ if let Some ( blueprint) = self . blueprint . as_ref ( ) {
339
+ let sub_aggregations = self
340
+ . term_buckets
341
+ . sub_aggs
342
+ . entry ( term_id)
343
+ . or_insert_with ( || blueprint. clone ( ) ) ;
335
344
sub_aggregations. collect ( * doc, sub_aggregation_accessor) ?;
336
345
}
337
346
}
@@ -386,15 +395,16 @@ impl SegmentTermCollector {
386
395
blueprint,
387
396
field_type,
388
397
accessor_idx,
398
+ val_cache : Default :: default ( ) ,
389
399
} )
390
400
}
391
401
402
+ #[ inline]
392
403
pub ( crate ) fn into_intermediate_bucket_result (
393
- self ,
404
+ mut self ,
394
405
agg_with_accessor : & BucketAggregationWithAccessor ,
395
406
) -> crate :: Result < IntermediateBucketResult > {
396
- let mut entries: Vec < ( u64 , TermBucketEntry ) > =
397
- self . term_buckets . entries . into_iter ( ) . collect ( ) ;
407
+ let mut entries: Vec < ( u64 , u64 ) > = self . term_buckets . entries . into_iter ( ) . collect ( ) ;
398
408
399
409
let order_by_sub_aggregation =
400
410
matches ! ( self . req. order. target, OrderTarget :: SubAggregation ( _) ) ;
@@ -417,9 +427,9 @@ impl SegmentTermCollector {
417
427
}
418
428
OrderTarget :: Count => {
419
429
if self . req . order . order == Order :: Desc {
420
- entries. sort_unstable_by_key ( |bucket| std:: cmp:: Reverse ( bucket. doc_count ( ) ) ) ;
430
+ entries. sort_unstable_by_key ( |bucket| std:: cmp:: Reverse ( bucket. 1 ) ) ;
421
431
} else {
422
- entries. sort_unstable_by_key ( |bucket| bucket. doc_count ( ) ) ;
432
+ entries. sort_unstable_by_key ( |bucket| bucket. 1 ) ;
423
433
}
424
434
}
425
435
}
@@ -432,24 +442,51 @@ impl SegmentTermCollector {
432
442
433
443
let mut dict: FxHashMap < Key , IntermediateTermBucketEntry > = Default :: default ( ) ;
434
444
dict. reserve ( entries. len ( ) ) ;
445
+
446
+ let mut into_intermediate_bucket_entry =
447
+ |id, doc_count| -> crate :: Result < IntermediateTermBucketEntry > {
448
+ let intermediate_entry = if let Some ( blueprint) = self . blueprint . as_ref ( ) {
449
+ IntermediateTermBucketEntry {
450
+ doc_count,
451
+ sub_aggregation : self
452
+ . term_buckets
453
+ . sub_aggs
454
+ . remove ( & id)
455
+ . expect ( & format ! (
456
+ "Internal Error: could not find subaggregation for id {}" ,
457
+ id
458
+ ) )
459
+ . into_intermediate_aggregations_result (
460
+ & agg_with_accessor. sub_aggregation ,
461
+ ) ?,
462
+ }
463
+ } else {
464
+ IntermediateTermBucketEntry {
465
+ doc_count,
466
+ sub_aggregation : Default :: default ( ) ,
467
+ }
468
+ } ;
469
+ Ok ( intermediate_entry)
470
+ } ;
471
+
435
472
if self . field_type == ColumnType :: Str {
436
473
let term_dict = agg_with_accessor
437
474
. str_dict_column
438
475
. as_ref ( )
439
476
. expect ( "internal error: term dictionary not found for term aggregation" ) ;
440
477
441
478
let mut buffer = String :: new ( ) ;
442
- for ( term_id, entry ) in entries {
479
+ for ( term_id, doc_count ) in entries {
443
480
if !term_dict. ord_to_str ( term_id, & mut buffer) ? {
444
481
return Err ( TantivyError :: InternalError ( format ! (
445
482
"Couldn't find term_id {} in dict" ,
446
483
term_id
447
484
) ) ) ;
448
485
}
449
- dict . insert (
450
- Key :: Str ( buffer . to_string ( ) ) ,
451
- entry . into_intermediate_bucket_entry ( & agg_with_accessor . sub_aggregation ) ? ,
452
- ) ;
486
+
487
+ let intermediate_entry = into_intermediate_bucket_entry ( term_id , doc_count ) ? ;
488
+
489
+ dict . insert ( Key :: Str ( buffer . to_string ( ) ) , intermediate_entry ) ;
453
490
}
454
491
if self . req . min_doc_count == 0 {
455
492
// TODO: Handle rev streaming for descending sorting by keys
@@ -468,12 +505,10 @@ impl SegmentTermCollector {
468
505
}
469
506
}
470
507
} else {
471
- for ( val, entry) in entries {
508
+ for ( val, doc_count) in entries {
509
+ let intermediate_entry = into_intermediate_bucket_entry ( val, doc_count) ?;
472
510
let val = f64_from_fastfield_u64 ( val, & self . field_type ) ;
473
- dict. insert (
474
- Key :: F64 ( val) ,
475
- entry. into_intermediate_bucket_entry ( & agg_with_accessor. sub_aggregation ) ?,
476
- ) ;
511
+ dict. insert ( Key :: F64 ( val) , intermediate_entry) ;
477
512
}
478
513
} ;
479
514
@@ -495,6 +530,11 @@ impl GetDocCount for (u32, TermBucketEntry) {
495
530
self . 1 . doc_count
496
531
}
497
532
}
533
+ impl GetDocCount for ( u64 , u64 ) {
534
+ fn doc_count ( & self ) -> u64 {
535
+ self . 1
536
+ }
537
+ }
498
538
impl GetDocCount for ( u64 , TermBucketEntry ) {
499
539
fn doc_count ( & self ) -> u64 {
500
540
self . 1 . doc_count
0 commit comments