Skip to content

Commit 44ea731

Browse files
committed
set max bucket size as parameter
1 parent 11ac451 commit 44ea731

10 files changed

+140
-73
lines changed

examples/aggregation.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ fn main() -> tantivy::Result<()> {
117117
.into_iter()
118118
.collect();
119119

120-
let collector = AggregationCollector::from_aggs(agg_req_1);
120+
let collector = AggregationCollector::from_aggs(agg_req_1, None);
121121

122122
let searcher = reader.searcher();
123123
let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap();

src/aggregation/agg_req_with_accessor.rs

+16-4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use std::sync::Arc;
77
use super::agg_req::{Aggregation, Aggregations, BucketAggregationType, MetricAggregation};
88
use super::bucket::{HistogramAggregation, RangeAggregation, TermsAggregation};
99
use super::metric::{AverageAggregation, StatsAggregation};
10+
use super::segment_agg_result::BucketCount;
1011
use super::VecWithNames;
1112
use crate::fastfield::{
1213
type_and_cardinality, DynamicFastFieldReader, FastType, MultiValuedFastFieldReader,
@@ -62,7 +63,7 @@ pub struct BucketAggregationWithAccessor {
6263
pub(crate) field_type: Type,
6364
pub(crate) bucket_agg: BucketAggregationType,
6465
pub(crate) sub_aggregation: AggregationsWithAccessor,
65-
pub(crate) bucket_count: Rc<AtomicU32>,
66+
pub(crate) bucket_count: BucketCount,
6667
}
6768

6869
impl BucketAggregationWithAccessor {
@@ -71,6 +72,7 @@ impl BucketAggregationWithAccessor {
7172
sub_aggregation: &Aggregations,
7273
reader: &SegmentReader,
7374
bucket_count: Rc<AtomicU32>,
75+
max_bucket_count: u32,
7476
) -> crate::Result<BucketAggregationWithAccessor> {
7577
let mut inverted_index = None;
7678
let (accessor, field_type) = match &bucket {
@@ -96,10 +98,18 @@ impl BucketAggregationWithAccessor {
9698
Ok(BucketAggregationWithAccessor {
9799
accessor,
98100
field_type,
99-
sub_aggregation: get_aggs_with_accessor_and_validate(&sub_aggregation, reader)?,
101+
sub_aggregation: get_aggs_with_accessor_and_validate(
102+
&sub_aggregation,
103+
reader,
104+
bucket_count.clone(),
105+
max_bucket_count,
106+
)?,
100107
bucket_agg: bucket.clone(),
101108
inverted_index,
102-
bucket_count,
109+
bucket_count: BucketCount {
110+
bucket_count,
111+
max_bucket_count,
112+
},
103113
})
104114
}
105115
}
@@ -139,8 +149,9 @@ impl MetricAggregationWithAccessor {
139149
pub(crate) fn get_aggs_with_accessor_and_validate(
140150
aggs: &Aggregations,
141151
reader: &SegmentReader,
152+
bucket_count: Rc<AtomicU32>,
153+
max_bucket_count: u32,
142154
) -> crate::Result<AggregationsWithAccessor> {
143-
let bucket_count: Rc<AtomicU32> = Default::default();
144155
let mut metrics = vec![];
145156
let mut buckets = vec![];
146157
for (key, agg) in aggs.iter() {
@@ -152,6 +163,7 @@ pub(crate) fn get_aggs_with_accessor_and_validate(
152163
&bucket.sub_aggregation,
153164
reader,
154165
Rc::clone(&bucket_count),
166+
max_bucket_count,
155167
)?,
156168
)),
157169
Aggregation::Metric(metric) => metrics.push((

src/aggregation/bucket/histogram/histogram.rs

+3-5
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@ use crate::aggregation::f64_from_fastfield_u64;
1313
use crate::aggregation::intermediate_agg_result::{
1414
IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry,
1515
};
16-
use crate::aggregation::segment_agg_result::{
17-
validate_bucket_count, SegmentAggregationResultsCollector,
18-
};
16+
use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector;
1917
use crate::fastfield::{DynamicFastFieldReader, FastFieldReader};
2018
use crate::schema::Type;
2119
use crate::{DocId, TantivyError};
@@ -254,8 +252,8 @@ impl SegmentHistogramCollector {
254252

255253
agg_with_accessor
256254
.bucket_count
257-
.fetch_add(buckets.len() as u32, std::sync::atomic::Ordering::Relaxed);
258-
validate_bucket_count(&agg_with_accessor.bucket_count)?;
255+
.add_count(buckets.len() as u32);
256+
agg_with_accessor.bucket_count.validate_bucket_count()?;
259257

260258
Ok(IntermediateBucketResult::Histogram { buckets })
261259
}

src/aggregation/bucket/range.rs

+5-9
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
use std::fmt::Debug;
22
use std::ops::Range;
3-
use std::rc::Rc;
4-
use std::sync::atomic::AtomicU32;
53

64
use fnv::FnvHashMap;
75
use serde::{Deserialize, Serialize};
@@ -12,9 +10,7 @@ use crate::aggregation::agg_req_with_accessor::{
1210
use crate::aggregation::intermediate_agg_result::{
1311
IntermediateBucketResult, IntermediateRangeBucketEntry, IntermediateRangeBucketResult,
1412
};
15-
use crate::aggregation::segment_agg_result::{
16-
validate_bucket_count, SegmentAggregationResultsCollector,
17-
};
13+
use crate::aggregation::segment_agg_result::{BucketCount, SegmentAggregationResultsCollector};
1814
use crate::aggregation::{f64_from_fastfield_u64, f64_to_fastfield_u64, Key, SerializedKey};
1915
use crate::fastfield::FastFieldReader;
2016
use crate::schema::Type;
@@ -179,7 +175,7 @@ impl SegmentRangeCollector {
179175
pub(crate) fn from_req_and_validate(
180176
req: &RangeAggregation,
181177
sub_aggregation: &AggregationsWithAccessor,
182-
bucket_count: &Rc<AtomicU32>,
178+
bucket_count: &BucketCount,
183179
field_type: Type,
184180
) -> crate::Result<Self> {
185181
// The range input on the request is f64.
@@ -218,8 +214,8 @@ impl SegmentRangeCollector {
218214
})
219215
.collect::<crate::Result<_>>()?;
220216

221-
bucket_count.fetch_add(buckets.len() as u32, std::sync::atomic::Ordering::Relaxed);
222-
validate_bucket_count(bucket_count)?;
217+
bucket_count.add_count(buckets.len() as u32);
218+
bucket_count.validate_bucket_count()?;
223219

224220
Ok(SegmentRangeCollector {
225221
buckets,
@@ -438,7 +434,7 @@ mod tests {
438434
.into_iter()
439435
.collect();
440436

441-
let collector = AggregationCollector::from_aggs(agg_req);
437+
let collector = AggregationCollector::from_aggs(agg_req, None);
442438

443439
let reader = index.reader()?;
444440
let searcher = reader.searcher();

src/aggregation/bucket/term_agg.rs

+24-15
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@ use crate::aggregation::agg_req_with_accessor::{
1111
use crate::aggregation::intermediate_agg_result::{
1212
IntermediateBucketResult, IntermediateTermBucketEntry, IntermediateTermBucketResult,
1313
};
14-
use crate::aggregation::segment_agg_result::{
15-
validate_bucket_count, SegmentAggregationResultsCollector,
16-
};
14+
use crate::aggregation::segment_agg_result::{BucketCount, SegmentAggregationResultsCollector};
1715
use crate::error::DataCorruption;
1816
use crate::fastfield::MultiValuedFastFieldReader;
1917
use crate::schema::Type;
@@ -246,23 +244,23 @@ impl TermBuckets {
246244
&mut self,
247245
term_ids: &[u64],
248246
doc: DocId,
249-
bucket_with_accessor: &BucketAggregationWithAccessor,
247+
sub_aggregation: &AggregationsWithAccessor,
248+
bucket_count: &BucketCount,
250249
blueprint: &Option<SegmentAggregationResultsCollector>,
251250
) -> crate::Result<()> {
252251
for &term_id in term_ids {
253252
let entry = self.entries.entry(term_id as u32).or_insert_with(|| {
254-
bucket_with_accessor
255-
.bucket_count
256-
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
253+
bucket_count.add_count(1);
257254

258255
TermBucketEntry::from_blueprint(blueprint)
259256
});
260257
entry.doc_count += 1;
261258
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() {
262-
sub_aggregations.collect(doc, &bucket_with_accessor.sub_aggregation)?;
259+
sub_aggregations.collect(doc, &sub_aggregation)?;
263260
}
264261
}
265-
validate_bucket_count(&bucket_with_accessor.bucket_count)?;
262+
bucket_count.validate_bucket_count()?;
263+
266264
Ok(())
267265
}
268266

@@ -447,25 +445,29 @@ impl SegmentTermCollector {
447445
self.term_buckets.increment_bucket(
448446
&vals1,
449447
docs[0],
450-
bucket_with_accessor,
448+
&bucket_with_accessor.sub_aggregation,
449+
&bucket_with_accessor.bucket_count,
451450
&self.blueprint,
452451
)?;
453452
self.term_buckets.increment_bucket(
454453
&vals2,
455454
docs[1],
456-
bucket_with_accessor,
455+
&bucket_with_accessor.sub_aggregation,
456+
&bucket_with_accessor.bucket_count,
457457
&self.blueprint,
458458
)?;
459459
self.term_buckets.increment_bucket(
460460
&vals3,
461461
docs[2],
462-
bucket_with_accessor,
462+
&bucket_with_accessor.sub_aggregation,
463+
&bucket_with_accessor.bucket_count,
463464
&self.blueprint,
464465
)?;
465466
self.term_buckets.increment_bucket(
466467
&vals4,
467468
docs[3],
468-
bucket_with_accessor,
469+
&bucket_with_accessor.sub_aggregation,
470+
&bucket_with_accessor.bucket_count,
469471
&self.blueprint,
470472
)?;
471473
}
@@ -475,7 +477,8 @@ impl SegmentTermCollector {
475477
self.term_buckets.increment_bucket(
476478
&vals1,
477479
doc,
478-
bucket_with_accessor,
480+
&bucket_with_accessor.sub_aggregation,
481+
&bucket_with_accessor.bucket_count,
479482
&self.blueprint,
480483
)?;
481484
}
@@ -1326,9 +1329,15 @@ mod bench {
13261329
let mut collector = get_collector_with_buckets(total_terms);
13271330
let vals = get_rand_terms(total_terms, num_terms);
13281331
let aggregations_with_accessor: AggregationsWithAccessor = Default::default();
1332+
let bucket_count: BucketCount = BucketCount {
1333+
bucket_count: Default::default(),
1334+
max_bucket_count: 1_000_001u32,
1335+
};
13291336
b.iter(|| {
13301337
for &val in &vals {
1331-
collector.increment_bucket(&[val], 0, &aggregations_with_accessor, &None);
1338+
collector
1339+
.increment_bucket(&[val], 0, &aggregations_with_accessor, &bucket_count, &None)
1340+
.unwrap();
13321341
}
13331342
})
13341343
}

src/aggregation/collector.rs

+34-8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::rc::Rc;
2+
13
use super::agg_req::Aggregations;
24
use super::agg_req_with_accessor::AggregationsWithAccessor;
35
use super::agg_result::AggregationResults;
@@ -7,17 +9,25 @@ use crate::aggregation::agg_req_with_accessor::get_aggs_with_accessor_and_valida
79
use crate::collector::{Collector, SegmentCollector};
810
use crate::SegmentReader;
911

12+
pub const MAX_BUCKET_COUNT: u32 = 65000;
13+
1014
/// Collector for aggregations.
1115
///
1216
/// The collector collects all aggregations by the underlying aggregation request.
1317
pub struct AggregationCollector {
1418
agg: Aggregations,
19+
max_bucket_count: u32,
1520
}
1621

1722
impl AggregationCollector {
1823
/// Create collector from aggregation request.
19-
pub fn from_aggs(agg: Aggregations) -> Self {
20-
Self { agg }
24+
///
25+
/// max_bucket_count will default to `MAX_BUCKET_COUNT` (65000) when unset
26+
pub fn from_aggs(agg: Aggregations, max_bucket_count: Option<u32>) -> Self {
27+
Self {
28+
agg,
29+
max_bucket_count: max_bucket_count.unwrap_or(MAX_BUCKET_COUNT),
30+
}
2131
}
2232
}
2333

@@ -28,15 +38,21 @@ impl AggregationCollector {
2838
/// # Purpose
2939
/// AggregationCollector returns `IntermediateAggregationResults` and not the final
3040
/// `AggregationResults`, so that results from differenct indices can be merged and then converted
31-
/// into the final `AggregationResults` via the `into()` method.
41+
/// into the final `AggregationResults` via the `into_final_result()` method.
3242
pub struct DistributedAggregationCollector {
3343
agg: Aggregations,
44+
max_bucket_count: u32,
3445
}
3546

3647
impl DistributedAggregationCollector {
3748
/// Create collector from aggregation request.
38-
pub fn from_aggs(agg: Aggregations) -> Self {
39-
Self { agg }
49+
///
50+
/// max_bucket_count will default to `MAX_BUCKET_COUNT` (65000) when unset
51+
pub fn from_aggs(agg: Aggregations, max_bucket_count: Option<u32>) -> Self {
52+
Self {
53+
agg,
54+
max_bucket_count: max_bucket_count.unwrap_or(MAX_BUCKET_COUNT),
55+
}
4056
}
4157
}
4258

@@ -50,7 +66,11 @@ impl Collector for DistributedAggregationCollector {
5066
_segment_local_id: crate::SegmentOrdinal,
5167
reader: &crate::SegmentReader,
5268
) -> crate::Result<Self::Child> {
53-
AggregationSegmentCollector::from_agg_req_and_reader(&self.agg, reader)
69+
AggregationSegmentCollector::from_agg_req_and_reader(
70+
&self.agg,
71+
reader,
72+
self.max_bucket_count,
73+
)
5474
}
5575

5676
fn requires_scoring(&self) -> bool {
@@ -75,7 +95,11 @@ impl Collector for AggregationCollector {
7595
_segment_local_id: crate::SegmentOrdinal,
7696
reader: &crate::SegmentReader,
7797
) -> crate::Result<Self::Child> {
78-
AggregationSegmentCollector::from_agg_req_and_reader(&self.agg, reader)
98+
AggregationSegmentCollector::from_agg_req_and_reader(
99+
&self.agg,
100+
reader,
101+
self.max_bucket_count,
102+
)
79103
}
80104

81105
fn requires_scoring(&self) -> bool {
@@ -117,8 +141,10 @@ impl AggregationSegmentCollector {
117141
pub fn from_agg_req_and_reader(
118142
agg: &Aggregations,
119143
reader: &SegmentReader,
144+
max_bucket_count: u32,
120145
) -> crate::Result<Self> {
121-
let aggs_with_accessor = get_aggs_with_accessor_and_validate(agg, reader)?;
146+
let aggs_with_accessor =
147+
get_aggs_with_accessor_and_validate(agg, reader, Rc::default(), max_bucket_count)?;
122148
let result =
123149
SegmentAggregationResultsCollector::from_req_and_validate(&aggs_with_accessor)?;
124150
Ok(AggregationSegmentCollector {

src/aggregation/intermediate_agg_result.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ pub struct IntermediateAggregationResults {
3535
}
3636

3737
impl IntermediateAggregationResults {
38-
/// Convert and intermediate result and its aggregation request to the final result
38+
/// Convert intermediate result and its aggregation request to the final result.
3939
pub(crate) fn into_final_bucket_result(
4040
self,
4141
req: Aggregations,
4242
) -> crate::Result<AggregationResults> {
4343
self.into_final_bucket_result_internal(&(req.into()))
4444
}
4545

46-
/// Convert and intermediate result and its aggregation request to the final result
46+
/// Convert intermediate result and its aggregation request to the final result.
4747
///
4848
/// Internal function, AggregationsInternal is used instead Aggregations, which is optimized
4949
/// for internal processing, by splitting metric and buckets into seperate groups.

src/aggregation/metric/stats.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ mod tests {
222222
.into_iter()
223223
.collect();
224224

225-
let collector = AggregationCollector::from_aggs(agg_req_1);
225+
let collector = AggregationCollector::from_aggs(agg_req_1, None);
226226

227227
let reader = index.reader()?;
228228
let searcher = reader.searcher();
@@ -299,7 +299,7 @@ mod tests {
299299
.into_iter()
300300
.collect();
301301

302-
let collector = AggregationCollector::from_aggs(agg_req_1);
302+
let collector = AggregationCollector::from_aggs(agg_req_1, None);
303303

304304
let searcher = reader.searcher();
305305
let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap();

0 commit comments

Comments
 (0)