Skip to content

Commit da0f78e

Browse files
authored
Merge pull request #1424 from k-yomo/support-keyed-parameter-in-aggregation
Add support for keyed parameter in range and histgram aggregations
2 parents 931bab8 + 9b6b60c commit da0f78e

File tree

9 files changed

+159
-12
lines changed

9 files changed

+159
-12
lines changed

examples/aggregation.rs

+1
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ fn main() -> tantivy::Result<()> {
110110
(9f64..14f64).into(),
111111
(14f64..20f64).into(),
112112
],
113+
..Default::default()
113114
}),
114115
sub_aggregation: sub_agg_req_1.clone(),
115116
}),

src/aggregation/agg_req.rs

+12-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
//! bucket_agg: BucketAggregationType::Range(RangeAggregation{
2121
//! field: "score".to_string(),
2222
//! ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
23+
//! keyed: false,
2324
//! }),
2425
//! sub_aggregation: Default::default(),
2526
//! }),
@@ -100,6 +101,12 @@ pub(crate) struct BucketAggregationInternal {
100101
}
101102

102103
impl BucketAggregationInternal {
104+
pub(crate) fn as_range(&self) -> Option<&RangeAggregation> {
105+
match &self.bucket_agg {
106+
BucketAggregationType::Range(range) => Some(range),
107+
_ => None,
108+
}
109+
}
103110
pub(crate) fn as_histogram(&self) -> Option<&HistogramAggregation> {
104111
match &self.bucket_agg {
105112
BucketAggregationType::Histogram(histogram) => Some(histogram),
@@ -264,6 +271,7 @@ mod tests {
264271
(7f64..20f64).into(),
265272
(20f64..f64::MAX).into(),
266273
],
274+
keyed: true,
267275
}),
268276
sub_aggregation: Default::default(),
269277
}),
@@ -290,7 +298,8 @@ mod tests {
290298
{
291299
"from": 20.0
292300
}
293-
]
301+
],
302+
"keyed": true
294303
}
295304
}
296305
}"#;
@@ -312,6 +321,7 @@ mod tests {
312321
(7f64..20f64).into(),
313322
(20f64..f64::MAX).into(),
314323
],
324+
..Default::default()
315325
}),
316326
sub_aggregation: Default::default(),
317327
}),
@@ -337,6 +347,7 @@ mod tests {
337347
(7f64..20f64).into(),
338348
(20f64..f64::MAX).into(),
339349
],
350+
..Default::default()
340351
}),
341352
sub_aggregation: agg_req2,
342353
}),

src/aggregation/agg_req_with_accessor.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ impl BucketAggregationWithAccessor {
7777
let mut inverted_index = None;
7878
let (accessor, field_type) = match &bucket {
7979
BucketAggregationType::Range(RangeAggregation {
80-
field: field_name,
81-
ranges: _,
80+
field: field_name, ..
8281
}) => get_ff_reader_and_validate(reader, field_name, Cardinality::SingleValue)?,
8382
BucketAggregationType::Histogram(HistogramAggregation {
8483
field: field_name, ..

src/aggregation/agg_result.rs

+14-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
77
use std::collections::HashMap;
88

9+
use fnv::FnvHashMap;
910
use serde::{Deserialize, Serialize};
1011

1112
use super::agg_req::BucketAggregationInternal;
@@ -104,7 +105,7 @@ pub enum BucketResult {
104105
/// sub_aggregations.
105106
Range {
106107
/// The range buckets sorted by range.
107-
buckets: Vec<RangeBucketEntry>,
108+
buckets: BucketEntries<RangeBucketEntry>,
108109
},
109110
/// This is the histogram entry for a bucket, which contains a key, count, and optionally
110111
/// sub_aggregations.
@@ -114,7 +115,7 @@ pub enum BucketResult {
114115
/// If there are holes depends on the request, if min_doc_count is 0, then there are no
115116
/// holes between the first and last bucket.
116117
/// See [HistogramAggregation](super::bucket::HistogramAggregation)
117-
buckets: Vec<BucketEntry>,
118+
buckets: BucketEntries<BucketEntry>,
118119
},
119120
/// This is the term result
120121
Terms {
@@ -137,6 +138,17 @@ impl BucketResult {
137138
}
138139
}
139140

141+
/// This is the wrapper of buckets entries, which can be vector or hashmap
142+
/// depending on if it's keyed or not.
143+
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
144+
#[serde(untagged)]
145+
pub enum BucketEntries<T> {
146+
/// Vector format bucket entries
147+
Vec(Vec<T>),
148+
/// HashMap format bucket entries
149+
HashMap(FnvHashMap<String, T>),
150+
}
151+
140152
/// This is the default entry for a bucket, which contains a key, count, and optionally
141153
/// sub_aggregations.
142154
///

src/aggregation/bucket/histogram/histogram.rs

+45-2
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@ use crate::{DocId, TantivyError};
4848
///
4949
/// # Limitations/Compatibility
5050
///
51-
/// The keyed parameter (elasticsearch) is not yet supported.
52-
///
5351
/// # JSON Format
5452
/// ```json
5553
/// {
@@ -117,6 +115,9 @@ pub struct HistogramAggregation {
117115
/// Cannot be set in conjunction with min_doc_count > 0, since the empty buckets from extended
118116
/// bounds would not be returned.
119117
pub extended_bounds: Option<HistogramBounds>,
118+
/// Whether to return the buckets as a hash map
119+
#[serde(default)]
120+
pub keyed: bool,
120121
}
121122

122123
impl HistogramAggregation {
@@ -1395,4 +1396,46 @@ mod tests {
13951396

13961397
Ok(())
13971398
}
1399+
1400+
#[test]
1401+
fn histogram_keyed_buckets_test() -> crate::Result<()> {
1402+
let index = get_test_index_with_num_docs(false, 100)?;
1403+
1404+
let agg_req: Aggregations = vec![(
1405+
"histogram".to_string(),
1406+
Aggregation::Bucket(BucketAggregation {
1407+
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
1408+
field: "score_f64".to_string(),
1409+
interval: 50.0,
1410+
keyed: true,
1411+
..Default::default()
1412+
}),
1413+
sub_aggregation: Default::default(),
1414+
}),
1415+
)]
1416+
.into_iter()
1417+
.collect();
1418+
1419+
let res = exec_request(agg_req, &index)?;
1420+
1421+
assert_eq!(
1422+
res,
1423+
json!({
1424+
"histogram": {
1425+
"buckets": {
1426+
"0": {
1427+
"key": 0.0,
1428+
"doc_count": 50
1429+
},
1430+
"50": {
1431+
"key": 50.0,
1432+
"doc_count": 50
1433+
}
1434+
}
1435+
}
1436+
})
1437+
);
1438+
1439+
Ok(())
1440+
}
13981441
}

src/aggregation/bucket/range.rs

+49-3
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ use crate::{DocId, TantivyError};
3535
/// # Limitations/Compatibility
3636
/// Overlapping ranges are not yet supported.
3737
///
38-
/// The keyed parameter (elasticsearch) is not yet supported.
39-
///
4038
/// # Request JSON Format
4139
/// ```json
4240
/// {
@@ -51,13 +49,16 @@ use crate::{DocId, TantivyError};
5149
/// }
5250
/// }
5351
/// ```
54-
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
52+
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
5553
pub struct RangeAggregation {
5654
/// The field to aggregate on.
5755
pub field: String,
5856
/// Note that this aggregation includes the from value and excludes the to value for each
5957
/// range. Extra buckets will be created until the first to, and last from, if necessary.
6058
pub ranges: Vec<RangeAggregationRange>,
59+
/// Whether to return the buckets as a hash map
60+
#[serde(default)]
61+
pub keyed: bool,
6162
}
6263

6364
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
@@ -406,6 +407,7 @@ mod tests {
406407
let req = RangeAggregation {
407408
field: "dummy".to_string(),
408409
ranges,
410+
..Default::default()
409411
};
410412

411413
SegmentRangeCollector::from_req_and_validate(
@@ -427,6 +429,7 @@ mod tests {
427429
bucket_agg: BucketAggregationType::Range(RangeAggregation {
428430
field: "fraction_f64".to_string(),
429431
ranges: vec![(0f64..0.1f64).into(), (0.1f64..0.2f64).into()],
432+
..Default::default()
430433
}),
431434
sub_aggregation: Default::default(),
432435
}),
@@ -454,6 +457,49 @@ mod tests {
454457
Ok(())
455458
}
456459

460+
#[test]
461+
fn range_keyed_buckets_test() -> crate::Result<()> {
462+
let index = get_test_index_with_num_docs(false, 100)?;
463+
464+
let agg_req: Aggregations = vec![(
465+
"range".to_string(),
466+
Aggregation::Bucket(BucketAggregation {
467+
bucket_agg: BucketAggregationType::Range(RangeAggregation {
468+
field: "fraction_f64".to_string(),
469+
ranges: vec![(0f64..0.1f64).into(), (0.1f64..0.2f64).into()],
470+
keyed: true,
471+
}),
472+
sub_aggregation: Default::default(),
473+
}),
474+
)]
475+
.into_iter()
476+
.collect();
477+
478+
let collector = AggregationCollector::from_aggs(agg_req, None);
479+
480+
let reader = index.reader()?;
481+
let searcher = reader.searcher();
482+
let agg_res = searcher.search(&AllQuery, &collector).unwrap();
483+
484+
let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;
485+
486+
assert_eq!(
487+
res,
488+
json!({
489+
"range": {
490+
"buckets": {
491+
"*-0": { "key": "*-0", "doc_count": 0, "to": 0.0},
492+
"0-0.1": {"key": "0-0.1", "doc_count": 10, "from": 0.0, "to": 0.1},
493+
"0.1-0.2": {"key": "0.1-0.2", "doc_count": 10, "from": 0.1, "to": 0.2},
494+
"0.2-*": {"key": "0.2-*", "doc_count": 80, "from": 0.2},
495+
}
496+
}
497+
})
498+
);
499+
500+
Ok(())
501+
}
502+
457503
#[test]
458504
fn bucket_test_extend_range_hole() {
459505
let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()];

src/aggregation/intermediate_agg_result.rs

+26-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use super::bucket::{
2121
use super::metric::{IntermediateAverage, IntermediateStats};
2222
use super::segment_agg_result::SegmentMetricResultCollector;
2323
use super::{Key, SerializedKey, VecWithNames};
24-
use crate::aggregation::agg_result::{AggregationResults, BucketEntry};
24+
use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry};
2525
use crate::aggregation::bucket::TermsAggregationInternal;
2626

2727
/// Contains the intermediate aggregation result, which is optimized to be merged with other
@@ -281,6 +281,21 @@ impl IntermediateBucketResult {
281281
.unwrap_or(f64::MIN)
282282
.total_cmp(&right.from.unwrap_or(f64::MIN))
283283
});
284+
285+
let is_keyed = req
286+
.as_range()
287+
.expect("unexpected aggregation, expected range aggregation")
288+
.keyed;
289+
let buckets = if is_keyed {
290+
let mut bucket_map =
291+
FnvHashMap::with_capacity_and_hasher(buckets.len(), Default::default());
292+
for bucket in buckets {
293+
bucket_map.insert(bucket.key.to_string(), bucket);
294+
}
295+
BucketEntries::HashMap(bucket_map)
296+
} else {
297+
BucketEntries::Vec(buckets)
298+
};
284299
Ok(BucketResult::Range { buckets })
285300
}
286301
IntermediateBucketResult::Histogram { buckets } => {
@@ -291,6 +306,16 @@ impl IntermediateBucketResult {
291306
&req.sub_aggregation,
292307
)?;
293308

309+
let buckets = if req.as_histogram().unwrap().keyed {
310+
let mut bucket_map =
311+
FnvHashMap::with_capacity_and_hasher(buckets.len(), Default::default());
312+
for bucket in buckets {
313+
bucket_map.insert(bucket.key.to_string(), bucket);
314+
}
315+
BucketEntries::HashMap(bucket_map)
316+
} else {
317+
BucketEntries::Vec(buckets)
318+
};
294319
Ok(BucketResult::Histogram { buckets })
295320
}
296321
IntermediateBucketResult::Terms(terms) => terms.into_final_result(

src/aggregation/metric/stats.rs

+1
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ mod tests {
285285
(7f64..19f64).into(),
286286
(19f64..20f64).into(),
287287
],
288+
..Default::default()
288289
}),
289290
sub_aggregation: iter::once((
290291
"stats".to_string(),

0 commit comments

Comments
 (0)