Skip to content

Commit 3f623df

Browse files
committed
Add support for keyed parameter in range and histgram aggregations
1 parent 931bab8 commit 3f623df

File tree

9 files changed

+64
-8
lines changed

9 files changed

+64
-8
lines changed

examples/aggregation.rs

+1
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ fn main() -> tantivy::Result<()> {
110110
(9f64..14f64).into(),
111111
(14f64..20f64).into(),
112112
],
113+
..Default::default()
113114
}),
114115
sub_aggregation: sub_agg_req_1.clone(),
115116
}),

src/aggregation/agg_req.rs

+10
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
//! Aggregation::Bucket(BucketAggregation {
2020
//! bucket_agg: BucketAggregationType::Range(RangeAggregation{
2121
//! field: "score".to_string(),
22+
//! keyed: false,
2223
//! ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
2324
//! }),
2425
//! sub_aggregation: Default::default(),
@@ -100,6 +101,12 @@ pub(crate) struct BucketAggregationInternal {
100101
}
101102

102103
impl BucketAggregationInternal {
104+
pub(crate) fn as_range(&self) -> Option<&RangeAggregation> {
105+
match &self.bucket_agg {
106+
BucketAggregationType::Range(range) => Some(range),
107+
_ => None,
108+
}
109+
}
103110
pub(crate) fn as_histogram(&self) -> Option<&HistogramAggregation> {
104111
match &self.bucket_agg {
105112
BucketAggregationType::Histogram(histogram) => Some(histogram),
@@ -264,6 +271,7 @@ mod tests {
264271
(7f64..20f64).into(),
265272
(20f64..f64::MAX).into(),
266273
],
274+
..Default::default()
267275
}),
268276
sub_aggregation: Default::default(),
269277
}),
@@ -312,6 +320,7 @@ mod tests {
312320
(7f64..20f64).into(),
313321
(20f64..f64::MAX).into(),
314322
],
323+
..Default::default()
315324
}),
316325
sub_aggregation: Default::default(),
317326
}),
@@ -337,6 +346,7 @@ mod tests {
337346
(7f64..20f64).into(),
338347
(20f64..f64::MAX).into(),
339348
],
349+
..Default::default()
340350
}),
341351
sub_aggregation: agg_req2,
342352
}),

src/aggregation/agg_req_with_accessor.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ impl BucketAggregationWithAccessor {
7777
let mut inverted_index = None;
7878
let (accessor, field_type) = match &bucket {
7979
BucketAggregationType::Range(RangeAggregation {
80-
field: field_name,
81-
ranges: _,
80+
field: field_name, ..
8281
}) => get_ff_reader_and_validate(reader, field_name, Cardinality::SingleValue)?,
8382
BucketAggregationType::Histogram(HistogramAggregation {
8483
field: field_name, ..

src/aggregation/agg_result.rs

+13-2
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ pub enum BucketResult {
104104
/// sub_aggregations.
105105
Range {
106106
/// The range buckets sorted by range.
107-
buckets: Vec<RangeBucketEntry>,
107+
buckets: BucketEntries<RangeBucketEntry>,
108108
},
109109
/// This is the histogram entry for a bucket, which contains a key, count, and optionally
110110
/// sub_aggregations.
@@ -114,7 +114,7 @@ pub enum BucketResult {
114114
/// If there are holes depends on the request, if min_doc_count is 0, then there are no
115115
/// holes between the first and last bucket.
116116
/// See [HistogramAggregation](super::bucket::HistogramAggregation)
117-
buckets: Vec<BucketEntry>,
117+
buckets: BucketEntries<BucketEntry>,
118118
},
119119
/// This is the term result
120120
Terms {
@@ -137,6 +137,17 @@ impl BucketResult {
137137
}
138138
}
139139

140+
/// This is the wrapper of buckets entries, which can be vector or hashmap
141+
/// depending on if it's keyed or not.
142+
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
143+
#[serde(untagged)]
144+
pub enum BucketEntries<T> {
145+
/// Vector format bucket entries
146+
Vec(Vec<T>),
147+
/// HashMap format bucket entries
148+
HashMap(HashMap<String, T>)
149+
}
150+
140151
/// This is the default entry for a bucket, which contains a key, count, and optionally
141152
/// sub_aggregations.
142153
///

src/aggregation/bucket/histogram/histogram.rs

+3
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ pub struct HistogramAggregation {
117117
/// Cannot be set in conjunction with min_doc_count > 0, since the empty buckets from extended
118118
/// bounds would not be returned.
119119
pub extended_bounds: Option<HistogramBounds>,
120+
/// Whether to return the buckets as a hash map
121+
#[serde(skip_serializing_if = "Option::is_none")]
122+
pub keyed: Option<bool>,
120123
}
121124

122125
impl HistogramAggregation {

src/aggregation/bucket/range.rs

+6-3
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ use crate::{DocId, TantivyError};
3535
/// # Limitations/Compatibility
3636
/// Overlapping ranges are not yet supported.
3737
///
38-
/// The keyed parameter (elasticsearch) is not yet supported.
39-
///
4038
/// # Request JSON Format
4139
/// ```json
4240
/// {
@@ -51,13 +49,16 @@ use crate::{DocId, TantivyError};
5149
/// }
5250
/// }
5351
/// ```
54-
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
52+
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
5553
pub struct RangeAggregation {
5654
/// The field to aggregate on.
5755
pub field: String,
5856
/// Note that this aggregation includes the from value and excludes the to value for each
5957
/// range. Extra buckets will be created until the first to, and last from, if necessary.
6058
pub ranges: Vec<RangeAggregationRange>,
59+
/// Whether to return the buckets as a hash map
60+
#[serde(skip_serializing_if = "Option::is_none")]
61+
pub keyed: Option<bool>,
6162
}
6263

6364
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
@@ -406,6 +407,7 @@ mod tests {
406407
let req = RangeAggregation {
407408
field: "dummy".to_string(),
408409
ranges,
410+
..Default::default()
409411
};
410412

411413
SegmentRangeCollector::from_req_and_validate(
@@ -427,6 +429,7 @@ mod tests {
427429
bucket_agg: BucketAggregationType::Range(RangeAggregation {
428430
field: "fraction_f64".to_string(),
429431
ranges: vec![(0f64..0.1f64).into(), (0.1f64..0.2f64).into()],
432+
..Default::default()
430433
}),
431434
sub_aggregation: Default::default(),
432435
}),

src/aggregation/intermediate_agg_result.rs

+20-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use super::bucket::{
2121
use super::metric::{IntermediateAverage, IntermediateStats};
2222
use super::segment_agg_result::SegmentMetricResultCollector;
2323
use super::{Key, SerializedKey, VecWithNames};
24-
use crate::aggregation::agg_result::{AggregationResults, BucketEntry};
24+
use crate::aggregation::agg_result::{AggregationResults, BucketEntry, BucketEntries};
2525
use crate::aggregation::bucket::TermsAggregationInternal;
2626

2727
/// Contains the intermediate aggregation result, which is optimized to be merged with other
@@ -281,6 +281,16 @@ impl IntermediateBucketResult {
281281
.unwrap_or(f64::MIN)
282282
.total_cmp(&right.from.unwrap_or(f64::MIN))
283283
});
284+
285+
let buckets = if req.as_range().unwrap().keyed.is_some() {
286+
let mut bucket_map = HashMap::new();
287+
for bucket in buckets {
288+
bucket_map.insert(bucket.key.to_string(), bucket);
289+
}
290+
BucketEntries::HashMap(bucket_map)
291+
} else {
292+
BucketEntries::Vec(buckets)
293+
};
284294
Ok(BucketResult::Range { buckets })
285295
}
286296
IntermediateBucketResult::Histogram { buckets } => {
@@ -291,6 +301,15 @@ impl IntermediateBucketResult {
291301
&req.sub_aggregation,
292302
)?;
293303

304+
let buckets = if req.as_histogram().unwrap().keyed.is_some() {
305+
let mut bucket_map = HashMap::new();
306+
for bucket in buckets {
307+
bucket_map.insert(bucket.key.to_string(), bucket);
308+
}
309+
BucketEntries::HashMap(bucket_map)
310+
} else {
311+
BucketEntries::Vec(buckets)
312+
};
294313
Ok(BucketResult::Histogram { buckets })
295314
}
296315
IntermediateBucketResult::Terms(terms) => terms.into_final_result(

src/aggregation/metric/stats.rs

+1
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ mod tests {
285285
(7f64..19f64).into(),
286286
(19f64..20f64).into(),
287287
],
288+
..Default::default()
288289
}),
289290
sub_aggregation: iter::once((
290291
"stats".to_string(),

src/aggregation/mod.rs

+9
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@
132132
//! bucket_agg: BucketAggregationType::Range(RangeAggregation{
133133
//! field: "score".to_string(),
134134
//! ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
135+
//! keyed: None,
135136
//! }),
136137
//! sub_aggregation: sub_agg_req_1.clone(),
137138
//! }),
@@ -765,6 +766,7 @@ mod tests {
765766
bucket_agg: BucketAggregationType::Range(RangeAggregation {
766767
field: "score".to_string(),
767768
ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
769+
..Default::default()
768770
}),
769771
sub_aggregation: Default::default(),
770772
}),
@@ -775,6 +777,7 @@ mod tests {
775777
bucket_agg: BucketAggregationType::Range(RangeAggregation {
776778
field: "score_f64".to_string(),
777779
ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
780+
..Default::default()
778781
}),
779782
sub_aggregation: Default::default(),
780783
}),
@@ -785,6 +788,7 @@ mod tests {
785788
bucket_agg: BucketAggregationType::Range(RangeAggregation {
786789
field: "score_i64".to_string(),
787790
ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
791+
..Default::default()
788792
}),
789793
sub_aggregation: Default::default(),
790794
}),
@@ -941,6 +945,7 @@ mod tests {
941945
(7f64..19f64).into(),
942946
(19f64..20f64).into(),
943947
],
948+
..Default::default()
944949
}),
945950
sub_aggregation: sub_agg_req.clone(),
946951
}),
@@ -955,6 +960,7 @@ mod tests {
955960
(7f64..19f64).into(),
956961
(19f64..20f64).into(),
957962
],
963+
..Default::default()
958964
}),
959965
sub_aggregation: sub_agg_req.clone(),
960966
}),
@@ -969,6 +975,7 @@ mod tests {
969975
(7f64..19f64).into(),
970976
(19f64..20f64).into(),
971977
],
978+
..Default::default()
972979
}),
973980
sub_aggregation: sub_agg_req,
974981
}),
@@ -1416,6 +1423,7 @@ mod tests {
14161423
(40000f64..50000f64).into(),
14171424
(50000f64..60000f64).into(),
14181425
],
1426+
..Default::default()
14191427
}),
14201428
sub_aggregation: Default::default(),
14211429
}),
@@ -1575,6 +1583,7 @@ mod tests {
15751583
(7000f64..20000f64).into(),
15761584
(20000f64..60000f64).into(),
15771585
],
1586+
..Default::default()
15781587
}),
15791588
sub_aggregation: sub_agg_req_1.clone(),
15801589
}),

0 commit comments

Comments
 (0)