Skip to content

Commit 11ac451

Browse files
committed
abort aggregation when too many buckets are created
Validation happens on different phases depending on the aggregation Term: During segment collection Histogram: At the end when converting in intermediate buckets (we preallocate empty buckets for the range) Revisit after #1370 Range: When validating the request update CHANGELOG
1 parent 6a46322 commit 11ac451

File tree

12 files changed

+118
-38
lines changed

12 files changed

+118
-38
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Unreleased
1111
- Add [histogram](https://github.com/quickwit-oss/tantivy/pull/1306) aggregation (@PSeitz)
1212
- Add support for fastfield on text fields (@PSeitz)
1313
- Add terms aggregation (@PSeitz)
14+
- API Change: `SegmentCollector.collect` changed to return a `Result`.
1415

1516
Tantivy 0.17
1617
================================

examples/custom_collector.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ struct StatsSegmentCollector {
102102
impl SegmentCollector for StatsSegmentCollector {
103103
type Fruit = Option<Stats>;
104104

105-
fn collect(&mut self, doc: u32, _score: Score) -> crate::Result<()> {
105+
fn collect(&mut self, doc: u32, _score: Score) -> tantivy::Result<()> {
106106
let value = self.fast_field_reader.get(doc) as f64;
107107
self.stats.count += 1;
108108
self.stats.sum += value;

src/aggregation/agg_req_with_accessor.rs

+5
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,15 @@ pub struct BucketAggregationWithAccessor {
6262
pub(crate) field_type: Type,
6363
pub(crate) bucket_agg: BucketAggregationType,
6464
pub(crate) sub_aggregation: AggregationsWithAccessor,
65+
pub(crate) bucket_count: Rc<AtomicU32>,
6566
}
6667

6768
impl BucketAggregationWithAccessor {
6869
fn try_from_bucket(
6970
bucket: &BucketAggregationType,
7071
sub_aggregation: &Aggregations,
7172
reader: &SegmentReader,
73+
bucket_count: Rc<AtomicU32>,
7274
) -> crate::Result<BucketAggregationWithAccessor> {
7375
let mut inverted_index = None;
7476
let (accessor, field_type) = match &bucket {
@@ -97,6 +99,7 @@ impl BucketAggregationWithAccessor {
9799
sub_aggregation: get_aggs_with_accessor_and_validate(&sub_aggregation, reader)?,
98100
bucket_agg: bucket.clone(),
99101
inverted_index,
102+
bucket_count,
100103
})
101104
}
102105
}
@@ -137,6 +140,7 @@ pub(crate) fn get_aggs_with_accessor_and_validate(
137140
aggs: &Aggregations,
138141
reader: &SegmentReader,
139142
) -> crate::Result<AggregationsWithAccessor> {
143+
let bucket_count: Rc<AtomicU32> = Default::default();
140144
let mut metrics = vec![];
141145
let mut buckets = vec![];
142146
for (key, agg) in aggs.iter() {
@@ -147,6 +151,7 @@ pub(crate) fn get_aggs_with_accessor_and_validate(
147151
&bucket.bucket_agg,
148152
&bucket.sub_aggregation,
149153
reader,
154+
Rc::clone(&bucket_count),
150155
)?,
151156
)),
152157
Aggregation::Metric(metric) => metrics.push((

src/aggregation/bucket/histogram/histogram.rs

+8-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ use crate::aggregation::f64_from_fastfield_u64;
1313
use crate::aggregation::intermediate_agg_result::{
1414
IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry,
1515
};
16-
use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector;
16+
use crate::aggregation::segment_agg_result::{
17+
validate_bucket_count, SegmentAggregationResultsCollector,
18+
};
1719
use crate::fastfield::{DynamicFastFieldReader, FastFieldReader};
1820
use crate::schema::Type;
1921
use crate::{DocId, TantivyError};
@@ -250,6 +252,11 @@ impl SegmentHistogramCollector {
250252
);
251253
};
252254

255+
agg_with_accessor
256+
.bucket_count
257+
.fetch_add(buckets.len() as u32, std::sync::atomic::Ordering::Relaxed);
258+
validate_bucket_count(&agg_with_accessor.bucket_count)?;
259+
253260
Ok(IntermediateBucketResult::Histogram { buckets })
254261
}
255262

src/aggregation/bucket/range.rs

+20-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use std::fmt::Debug;
22
use std::ops::Range;
3+
use std::rc::Rc;
4+
use std::sync::atomic::AtomicU32;
35

6+
use fnv::FnvHashMap;
47
use serde::{Deserialize, Serialize};
58

69
use crate::aggregation::agg_req_with_accessor::{
@@ -9,8 +12,10 @@ use crate::aggregation::agg_req_with_accessor::{
912
use crate::aggregation::intermediate_agg_result::{
1013
IntermediateBucketResult, IntermediateRangeBucketEntry, IntermediateRangeBucketResult,
1114
};
12-
use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector;
13-
use crate::aggregation::{f64_from_fastfield_u64, f64_to_fastfield_u64, Key};
15+
use crate::aggregation::segment_agg_result::{
16+
validate_bucket_count, SegmentAggregationResultsCollector,
17+
};
18+
use crate::aggregation::{f64_from_fastfield_u64, f64_to_fastfield_u64, Key, SerializedKey};
1419
use crate::fastfield::FastFieldReader;
1520
use crate::schema::Type;
1621
use crate::{DocId, TantivyError};
@@ -153,7 +158,7 @@ impl SegmentRangeCollector {
153158
) -> crate::Result<IntermediateBucketResult> {
154159
let field_type = self.field_type;
155160

156-
let buckets = self
161+
let buckets: FnvHashMap<SerializedKey, IntermediateRangeBucketEntry> = self
157162
.buckets
158163
.into_iter()
159164
.map(move |range_bucket| {
@@ -174,12 +179,13 @@ impl SegmentRangeCollector {
174179
pub(crate) fn from_req_and_validate(
175180
req: &RangeAggregation,
176181
sub_aggregation: &AggregationsWithAccessor,
182+
bucket_count: &Rc<AtomicU32>,
177183
field_type: Type,
178184
) -> crate::Result<Self> {
179185
// The range input on the request is f64.
180186
// We need to convert to u64 ranges, because we read the values as u64.
181187
// The mapping from the conversion is monotonic so ordering is preserved.
182-
let buckets = extend_validate_ranges(&req.ranges, &field_type)?
188+
let buckets: Vec<_> = extend_validate_ranges(&req.ranges, &field_type)?
183189
.iter()
184190
.map(|range| {
185191
let to = if range.end == u64::MAX {
@@ -212,6 +218,9 @@ impl SegmentRangeCollector {
212218
})
213219
.collect::<crate::Result<_>>()?;
214220

221+
bucket_count.fetch_add(buckets.len() as u32, std::sync::atomic::Ordering::Relaxed);
222+
validate_bucket_count(bucket_count)?;
223+
215224
Ok(SegmentRangeCollector {
216225
buckets,
217226
field_type,
@@ -403,8 +412,13 @@ mod tests {
403412
ranges,
404413
};
405414

406-
SegmentRangeCollector::from_req_and_validate(&req, &Default::default(), field_type)
407-
.expect("unexpected error")
415+
SegmentRangeCollector::from_req_and_validate(
416+
&req,
417+
&Default::default(),
418+
&Default::default(),
419+
field_type,
420+
)
421+
.expect("unexpected error")
408422
}
409423

410424
#[test]

src/aggregation/bucket/term_agg.rs

+45-12
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ use crate::aggregation::agg_req_with_accessor::{
1111
use crate::aggregation::intermediate_agg_result::{
1212
IntermediateBucketResult, IntermediateTermBucketEntry, IntermediateTermBucketResult,
1313
};
14-
use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector;
14+
use crate::aggregation::segment_agg_result::{
15+
validate_bucket_count, SegmentAggregationResultsCollector,
16+
};
1517
use crate::error::DataCorruption;
1618
use crate::fastfield::MultiValuedFastFieldReader;
1719
use crate::schema::Type;
@@ -244,19 +246,23 @@ impl TermBuckets {
244246
&mut self,
245247
term_ids: &[u64],
246248
doc: DocId,
247-
bucket_with_accessor: &AggregationsWithAccessor,
249+
bucket_with_accessor: &BucketAggregationWithAccessor,
248250
blueprint: &Option<SegmentAggregationResultsCollector>,
249251
) -> crate::Result<()> {
250252
for &term_id in term_ids {
251-
let entry = self
252-
.entries
253-
.entry(term_id as u32)
254-
.or_insert_with(|| TermBucketEntry::from_blueprint(blueprint));
253+
let entry = self.entries.entry(term_id as u32).or_insert_with(|| {
254+
bucket_with_accessor
255+
.bucket_count
256+
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
257+
258+
TermBucketEntry::from_blueprint(blueprint)
259+
});
255260
entry.doc_count += 1;
256261
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() {
257-
sub_aggregations.collect(doc, bucket_with_accessor)?;
262+
sub_aggregations.collect(doc, &bucket_with_accessor.sub_aggregation)?;
258263
}
259264
}
265+
validate_bucket_count(&bucket_with_accessor.bucket_count)?;
260266
Ok(())
261267
}
262268

@@ -441,25 +447,25 @@ impl SegmentTermCollector {
441447
self.term_buckets.increment_bucket(
442448
&vals1,
443449
docs[0],
444-
&bucket_with_accessor.sub_aggregation,
450+
bucket_with_accessor,
445451
&self.blueprint,
446452
)?;
447453
self.term_buckets.increment_bucket(
448454
&vals2,
449455
docs[1],
450-
&bucket_with_accessor.sub_aggregation,
456+
bucket_with_accessor,
451457
&self.blueprint,
452458
)?;
453459
self.term_buckets.increment_bucket(
454460
&vals3,
455461
docs[2],
456-
&bucket_with_accessor.sub_aggregation,
462+
bucket_with_accessor,
457463
&self.blueprint,
458464
)?;
459465
self.term_buckets.increment_bucket(
460466
&vals4,
461467
docs[3],
462-
&bucket_with_accessor.sub_aggregation,
468+
bucket_with_accessor,
463469
&self.blueprint,
464470
)?;
465471
}
@@ -469,7 +475,7 @@ impl SegmentTermCollector {
469475
self.term_buckets.increment_bucket(
470476
&vals1,
471477
doc,
472-
&bucket_with_accessor.sub_aggregation,
478+
bucket_with_accessor,
473479
&self.blueprint,
474480
)?;
475481
}
@@ -1175,6 +1181,33 @@ mod tests {
11751181
Ok(())
11761182
}
11771183

1184+
#[test]
1185+
fn terms_aggregation_term_bucket_limit() -> crate::Result<()> {
1186+
let terms: Vec<String> = (0..100_000).map(|el| el.to_string()).collect();
1187+
let terms_per_segment = vec![terms.iter().map(|el| el.as_str()).collect()];
1188+
1189+
let index = get_test_index_from_terms(true, &terms_per_segment)?;
1190+
1191+
let agg_req: Aggregations = vec![(
1192+
"my_texts".to_string(),
1193+
Aggregation::Bucket(BucketAggregation {
1194+
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
1195+
field: "string_id".to_string(),
1196+
min_doc_count: Some(0),
1197+
..Default::default()
1198+
}),
1199+
sub_aggregation: Default::default(),
1200+
}),
1201+
)]
1202+
.into_iter()
1203+
.collect();
1204+
1205+
let res = exec_request_with_query(agg_req, &index, None);
1206+
assert!(res.is_err());
1207+
1208+
Ok(())
1209+
}
1210+
11781211
#[test]
11791212
fn test_json_format() -> crate::Result<()> {
11801213
let agg_req: Aggregations = vec![(

src/aggregation/mod.rs

+10-5
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,9 @@ mod tests {
417417
let mut schema_builder = Schema::builder();
418418
let text_fieldtype = crate::schema::TextOptions::default()
419419
.set_indexing_options(
420-
TextFieldIndexing::default().set_index_option(IndexRecordOption::WithFreqs),
420+
TextFieldIndexing::default()
421+
.set_index_option(IndexRecordOption::Basic)
422+
.set_fieldnorms(false),
421423
)
422424
.set_fast()
423425
.set_stored();
@@ -435,7 +437,8 @@ mod tests {
435437
);
436438
let index = Index::create_in_ram(schema_builder.build());
437439
{
438-
let mut index_writer = index.writer_for_tests()?;
440+
// let mut index_writer = index.writer_for_tests()?;
441+
let mut index_writer = index.writer_with_num_threads(1, 30_000_000)?;
439442
for values in segment_and_values {
440443
for (i, term) in values {
441444
let i = *i;
@@ -457,9 +460,11 @@ mod tests {
457460
let segment_ids = index
458461
.searchable_segment_ids()
459462
.expect("Searchable segments failed.");
460-
let mut index_writer = index.writer_for_tests()?;
461-
index_writer.merge(&segment_ids).wait()?;
462-
index_writer.wait_merging_threads()?;
463+
if segment_ids.len() > 1 {
464+
let mut index_writer = index.writer_for_tests()?;
465+
index_writer.merge(&segment_ids).wait()?;
466+
index_writer.wait_merging_threads()?;
467+
}
463468
}
464469

465470
Ok(index)

src/aggregation/segment_agg_result.rs

+13-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
//! merging.
55
66
use std::fmt::Debug;
7+
use std::rc::Rc;
8+
use std::sync::atomic::AtomicU32;
79

810
use super::agg_req::MetricAggregation;
911
use super::agg_req_with_accessor::{
@@ -16,7 +18,7 @@ use super::metric::{
1618
};
1719
use super::VecWithNames;
1820
use crate::aggregation::agg_req::BucketAggregationType;
19-
use crate::DocId;
21+
use crate::{DocId, TantivyError};
2022

2123
pub(crate) const DOC_BLOCK_SIZE: usize = 64;
2224
pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE];
@@ -236,6 +238,7 @@ impl SegmentBucketResultCollector {
236238
Ok(Self::Range(SegmentRangeCollector::from_req_and_validate(
237239
range_req,
238240
&req.sub_aggregation,
241+
&req.bucket_count,
239242
req.field_type,
240243
)?))
241244
}
@@ -273,3 +276,12 @@ impl SegmentBucketResultCollector {
273276
Ok(())
274277
}
275278
}
279+
280+
pub(crate) fn validate_bucket_count(bucket_count: &Rc<AtomicU32>) -> crate::Result<()> {
281+
if bucket_count.load(std::sync::atomic::Ordering::Relaxed) > 65000 {
282+
return Err(TantivyError::InvalidArgument(
283+
"Aborting aggregation because too many buckets were created".to_string(),
284+
));
285+
}
286+
Ok(())
287+
}

src/collector/mod.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,14 @@ pub trait Collector: Sync + Send {
175175
if let Some(alive_bitset) = reader.alive_bitset() {
176176
weight.for_each(reader, &mut |doc, score| {
177177
if alive_bitset.is_alive(doc) {
178-
segment_collector.collect(doc, score).unwrap(); // TODO
178+
segment_collector.collect(doc, score)?;
179179
}
180+
Ok(())
180181
})?;
181182
} else {
182183
weight.for_each(reader, &mut |doc, score| {
183-
segment_collector.collect(doc, score).unwrap(); // TODO
184+
segment_collector.collect(doc, score)?;
185+
Ok(())
184186
})?;
185187
}
186188
Ok(segment_collector.harvest())

src/query/boolean_query/boolean_weight.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -186,17 +186,17 @@ impl Weight for BooleanWeight {
186186
fn for_each(
187187
&self,
188188
reader: &SegmentReader,
189-
callback: &mut dyn FnMut(DocId, Score),
189+
callback: &mut dyn FnMut(DocId, Score) -> crate::Result<()>,
190190
) -> crate::Result<()> {
191191
let scorer = self.complex_scorer::<SumWithCoordsCombiner>(reader, 1.0)?;
192192
match scorer {
193193
SpecializedScorer::TermUnion(term_scorers) => {
194194
let mut union_scorer =
195195
Union::<TermScorer, SumWithCoordsCombiner>::from(term_scorers);
196-
for_each_scorer(&mut union_scorer, callback);
196+
for_each_scorer(&mut union_scorer, callback)?;
197197
}
198198
SpecializedScorer::Other(mut scorer) => {
199-
for_each_scorer(scorer.as_mut(), callback);
199+
for_each_scorer(scorer.as_mut(), callback)?;
200200
}
201201
}
202202
Ok(())

src/query/term_query/term_weight.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ impl Weight for TermWeight {
4949
fn for_each(
5050
&self,
5151
reader: &SegmentReader,
52-
callback: &mut dyn FnMut(DocId, Score),
52+
callback: &mut dyn FnMut(DocId, Score) -> crate::Result<()>,
5353
) -> crate::Result<()> {
5454
let mut scorer = self.specialized_scorer(reader, 1.0)?;
55-
for_each_scorer(&mut scorer, callback);
55+
for_each_scorer(&mut scorer, callback)?;
5656
Ok(())
5757
}
5858

0 commit comments

Comments
 (0)