Skip to content

Commit 8459efa

Browse files
authored
split term collection count and sub_agg (#1921)
use unrolled ColumnValues::get_vals
1 parent 61cfd8d commit 8459efa

File tree

9 files changed

+124
-44
lines changed

9 files changed

+124
-44
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@ benchmark
1313
.idea
1414
trace.dat
1515
cargo-timing*
16+
control
17+
variable

columnar/src/column_values/mod.rs

+15-4
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,21 @@ pub trait ColumnValues<T: PartialOrd = u64>: Send + Sync {
5858
/// # Panics
5959
///
6060
/// May panic if `idx` is greater than the column length.
61-
fn get_vals(&self, idxs: &[u32], output: &mut [T]) {
62-
assert!(idxs.len() == output.len());
63-
for (out, &idx) in output.iter_mut().zip(idxs.iter()) {
64-
*out = self.get_val(idx);
61+
fn get_vals(&self, indexes: &[u32], output: &mut [T]) {
62+
assert!(indexes.len() == output.len());
63+
let out_and_idx_chunks = output.chunks_exact_mut(4).zip(indexes.chunks_exact(4));
64+
for (out_x4, idx_x4) in out_and_idx_chunks {
65+
out_x4[0] = self.get_val(idx_x4[0]);
66+
out_x4[1] = self.get_val(idx_x4[1]);
67+
out_x4[2] = self.get_val(idx_x4[2]);
68+
out_x4[3] = self.get_val(idx_x4[3]);
69+
}
70+
71+
let step_size = 4;
72+
let cutoff = indexes.len() - indexes.len() % step_size;
73+
74+
for idx in cutoff..indexes.len() {
75+
output[idx] = self.get_val(indexes[idx] as u32);
6576
}
6677
}
6778

columnar/src/column_values/monotonic_column.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ where
5050
Input: PartialOrd + Send + Debug + Sync + Clone,
5151
Output: PartialOrd + Send + Debug + Sync + Clone,
5252
{
53-
#[inline]
53+
#[inline(always)]
5454
fn get_val(&self, idx: u32) -> Output {
5555
let from_val = self.from_column.get_val(idx);
5656
self.monotonic_mapping.mapping(from_val)

columnar/src/column_values/u64_based/tests.rs

+14
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,28 @@ pub(crate) fn create_and_validate<TColumnCodec: ColumnCodec>(
9999

100100
let reader = TColumnCodec::load(OwnedBytes::new(buffer)).unwrap();
101101
assert_eq!(reader.num_vals(), vals.len() as u32);
102+
let mut buffer = Vec::new();
102103
for (doc, orig_val) in vals.iter().copied().enumerate() {
103104
let val = reader.get_val(doc as u32);
104105
assert_eq!(
105106
val, orig_val,
106107
"val `{val}` does not match orig_val {orig_val:?}, in data set {name}, data `{vals:?}`",
107108
);
109+
110+
buffer.resize(1, 0);
111+
reader.get_vals(&[doc as u32], &mut buffer);
112+
let val = buffer[0];
113+
assert_eq!(
114+
val, orig_val,
115+
"val `{val}` does not match orig_val {orig_val:?}, in data set {name}, data `{vals:?}`",
116+
);
108117
}
109118

119+
let all_docs: Vec<u32> = (0..vals.len() as u32).collect();
120+
buffer.resize(all_docs.len(), 0);
121+
reader.get_vals(&all_docs, &mut buffer);
122+
assert_eq!(vals, buffer);
123+
110124
if !vals.is_empty() {
111125
let test_rand_idx = rand::thread_rng().gen_range(0..=vals.len() - 1);
112126
let expected_positions: Vec<u32> = vals

src/aggregation/bucket/histogram/histogram.rs

+2
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
230230
})
231231
}
232232

233+
#[inline]
233234
fn collect(
234235
&mut self,
235236
doc: crate::DocId,
@@ -238,6 +239,7 @@ impl SegmentAggregationCollector for SegmentHistogramCollector {
238239
self.collect_block(&[doc], agg_with_accessor)
239240
}
240241

242+
#[inline]
241243
fn collect_block(
242244
&mut self,
243245
docs: &[crate::DocId],

src/aggregation/bucket/range.rs

+2
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
208208
})
209209
}
210210

211+
#[inline]
211212
fn collect(
212213
&mut self,
213214
doc: crate::DocId,
@@ -216,6 +217,7 @@ impl SegmentAggregationCollector for SegmentRangeCollector {
216217
self.collect_block(&[doc], agg_with_accessor)
217218
}
218219

220+
#[inline]
219221
fn collect_block(
220222
&mut self,
221223
docs: &[crate::DocId],

src/aggregation/bucket/term_agg.rs

+76-36
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,8 @@ impl TermsAggregationInternal {
205205
#[derive(Clone, Debug, Default)]
206206
/// Container to store term_ids/or u64 values and their buckets.
207207
struct TermBuckets {
208-
pub(crate) entries: FxHashMap<u64, TermBucketEntry>,
208+
pub(crate) entries: FxHashMap<u64, u64>,
209+
pub(crate) sub_aggs: FxHashMap<u64, Box<dyn SegmentAggregationCollector>>,
209210
}
210211

211212
#[derive(Clone, Default)]
@@ -249,10 +250,8 @@ impl TermBucketEntry {
249250

250251
impl TermBuckets {
251252
fn force_flush(&mut self, agg_with_accessor: &AggregationsWithAccessor) -> crate::Result<()> {
252-
for entry in &mut self.entries.values_mut() {
253-
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() {
254-
sub_aggregations.flush(agg_with_accessor)?;
255-
}
253+
for sub_aggregations in &mut self.sub_aggs.values_mut() {
254+
sub_aggregations.as_mut().flush(agg_with_accessor)?;
256255
}
257256
Ok(())
258257
}
@@ -268,6 +267,7 @@ pub struct SegmentTermCollector {
268267
blueprint: Option<Box<dyn SegmentAggregationCollector>>,
269268
field_type: ColumnType,
270269
accessor_idx: usize,
270+
val_cache: Vec<u64>,
271271
}
272272

273273
pub(crate) fn get_agg_name_and_property(name: &str) -> (&str, &str) {
@@ -292,6 +292,7 @@ impl SegmentAggregationCollector for SegmentTermCollector {
292292
})
293293
}
294294

295+
#[inline]
295296
fn collect(
296297
&mut self,
297298
doc: crate::DocId,
@@ -300,6 +301,7 @@ impl SegmentAggregationCollector for SegmentTermCollector {
300301
self.collect_block(&[doc], agg_with_accessor)
301302
}
302303

304+
#[inline]
303305
fn collect_block(
304306
&mut self,
305307
docs: &[crate::DocId],
@@ -310,28 +312,35 @@ impl SegmentAggregationCollector for SegmentTermCollector {
310312
&agg_with_accessor.buckets.values[self.accessor_idx].sub_aggregation;
311313

312314
if accessor.get_cardinality() == Cardinality::Full {
313-
for doc in docs {
314-
let term_id = accessor.values.get_val(*doc);
315-
let entry = self
316-
.term_buckets
317-
.entries
318-
.entry(term_id)
319-
.or_insert_with(|| TermBucketEntry::from_blueprint(&self.blueprint));
320-
entry.doc_count += 1;
321-
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() {
315+
self.val_cache.resize(docs.len(), 0);
316+
accessor.values.get_vals(docs, &mut self.val_cache);
317+
for (doc, term_id) in docs.iter().zip(self.val_cache.iter().cloned()) {
318+
let entry = self.term_buckets.entries.entry(term_id).or_default();
319+
*entry += 1;
320+
}
321+
// has subagg
322+
if let Some(blueprint) = self.blueprint.as_ref() {
323+
for (doc, term_id) in docs.iter().zip(self.val_cache.iter().cloned()) {
324+
let sub_aggregations = self
325+
.term_buckets
326+
.sub_aggs
327+
.entry(term_id)
328+
.or_insert_with(|| blueprint.clone());
322329
sub_aggregations.collect(*doc, sub_aggregation_accessor)?;
323330
}
324331
}
325332
} else {
326333
for doc in docs {
327334
for term_id in accessor.values_for_doc(*doc) {
328-
let entry = self
329-
.term_buckets
330-
.entries
331-
.entry(term_id)
332-
.or_insert_with(|| TermBucketEntry::from_blueprint(&self.blueprint));
333-
entry.doc_count += 1;
334-
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() {
335+
let entry = self.term_buckets.entries.entry(term_id).or_default();
336+
*entry += 1;
337+
// TODO: check if seperate loop is faster (may depend on the codec)
338+
if let Some(blueprint) = self.blueprint.as_ref() {
339+
let sub_aggregations = self
340+
.term_buckets
341+
.sub_aggs
342+
.entry(term_id)
343+
.or_insert_with(|| blueprint.clone());
335344
sub_aggregations.collect(*doc, sub_aggregation_accessor)?;
336345
}
337346
}
@@ -386,15 +395,16 @@ impl SegmentTermCollector {
386395
blueprint,
387396
field_type,
388397
accessor_idx,
398+
val_cache: Default::default(),
389399
})
390400
}
391401

402+
#[inline]
392403
pub(crate) fn into_intermediate_bucket_result(
393-
self,
404+
mut self,
394405
agg_with_accessor: &BucketAggregationWithAccessor,
395406
) -> crate::Result<IntermediateBucketResult> {
396-
let mut entries: Vec<(u64, TermBucketEntry)> =
397-
self.term_buckets.entries.into_iter().collect();
407+
let mut entries: Vec<(u64, u64)> = self.term_buckets.entries.into_iter().collect();
398408

399409
let order_by_sub_aggregation =
400410
matches!(self.req.order.target, OrderTarget::SubAggregation(_));
@@ -417,9 +427,9 @@ impl SegmentTermCollector {
417427
}
418428
OrderTarget::Count => {
419429
if self.req.order.order == Order::Desc {
420-
entries.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.doc_count()));
430+
entries.sort_unstable_by_key(|bucket| std::cmp::Reverse(bucket.1));
421431
} else {
422-
entries.sort_unstable_by_key(|bucket| bucket.doc_count());
432+
entries.sort_unstable_by_key(|bucket| bucket.1);
423433
}
424434
}
425435
}
@@ -432,24 +442,51 @@ impl SegmentTermCollector {
432442

433443
let mut dict: FxHashMap<Key, IntermediateTermBucketEntry> = Default::default();
434444
dict.reserve(entries.len());
445+
446+
let mut into_intermediate_bucket_entry =
447+
|id, doc_count| -> crate::Result<IntermediateTermBucketEntry> {
448+
let intermediate_entry = if let Some(blueprint) = self.blueprint.as_ref() {
449+
IntermediateTermBucketEntry {
450+
doc_count,
451+
sub_aggregation: self
452+
.term_buckets
453+
.sub_aggs
454+
.remove(&id)
455+
.expect(&format!(
456+
"Internal Error: could not find subaggregation for id {}",
457+
id
458+
))
459+
.into_intermediate_aggregations_result(
460+
&agg_with_accessor.sub_aggregation,
461+
)?,
462+
}
463+
} else {
464+
IntermediateTermBucketEntry {
465+
doc_count,
466+
sub_aggregation: Default::default(),
467+
}
468+
};
469+
Ok(intermediate_entry)
470+
};
471+
435472
if self.field_type == ColumnType::Str {
436473
let term_dict = agg_with_accessor
437474
.str_dict_column
438475
.as_ref()
439476
.expect("internal error: term dictionary not found for term aggregation");
440477

441478
let mut buffer = String::new();
442-
for (term_id, entry) in entries {
479+
for (term_id, doc_count) in entries {
443480
if !term_dict.ord_to_str(term_id, &mut buffer)? {
444481
return Err(TantivyError::InternalError(format!(
445482
"Couldn't find term_id {} in dict",
446483
term_id
447484
)));
448485
}
449-
dict.insert(
450-
Key::Str(buffer.to_string()),
451-
entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation)?,
452-
);
486+
487+
let intermediate_entry = into_intermediate_bucket_entry(term_id, doc_count)?;
488+
489+
dict.insert(Key::Str(buffer.to_string()), intermediate_entry);
453490
}
454491
if self.req.min_doc_count == 0 {
455492
// TODO: Handle rev streaming for descending sorting by keys
@@ -468,12 +505,10 @@ impl SegmentTermCollector {
468505
}
469506
}
470507
} else {
471-
for (val, entry) in entries {
508+
for (val, doc_count) in entries {
509+
let intermediate_entry = into_intermediate_bucket_entry(val, doc_count)?;
472510
let val = f64_from_fastfield_u64(val, &self.field_type);
473-
dict.insert(
474-
Key::F64(val),
475-
entry.into_intermediate_bucket_entry(&agg_with_accessor.sub_aggregation)?,
476-
);
511+
dict.insert(Key::F64(val), intermediate_entry);
477512
}
478513
};
479514

@@ -495,6 +530,11 @@ impl GetDocCount for (u32, TermBucketEntry) {
495530
self.1.doc_count
496531
}
497532
}
533+
impl GetDocCount for (u64, u64) {
534+
fn doc_count(&self) -> u64 {
535+
self.1
536+
}
537+
}
498538
impl GetDocCount for (u64, TermBucketEntry) {
499539
fn doc_count(&self) -> u64 {
500540
self.1.doc_count

src/aggregation/buf_collector.rs

+4
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,15 @@ impl BufAggregationCollector {
3434
}
3535

3636
impl SegmentAggregationCollector for BufAggregationCollector {
37+
#[inline]
3738
fn into_intermediate_aggregations_result(
3839
self: Box<Self>,
3940
agg_with_accessor: &AggregationsWithAccessor,
4041
) -> crate::Result<IntermediateAggregationResults> {
4142
Box::new(self.collector).into_intermediate_aggregations_result(agg_with_accessor)
4243
}
4344

45+
#[inline]
4446
fn collect(
4547
&mut self,
4648
doc: crate::DocId,
@@ -56,6 +58,7 @@ impl SegmentAggregationCollector for BufAggregationCollector {
5658
Ok(())
5759
}
5860

61+
#[inline]
5962
fn collect_block(
6063
&mut self,
6164
docs: &[crate::DocId],
@@ -67,6 +70,7 @@ impl SegmentAggregationCollector for BufAggregationCollector {
6770
Ok(())
6871
}
6972

73+
#[inline]
7074
fn flush(&mut self, agg_with_accessor: &AggregationsWithAccessor) -> crate::Result<()> {
7175
self.collector
7276
.collect_block(&self.staged_docs[..self.num_staged_docs], agg_with_accessor)?;

src/aggregation/metric/stats.rs

+8-3
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ pub(crate) struct SegmentStatsCollector {
156156
pub(crate) collecting_for: SegmentStatsType,
157157
pub(crate) stats: IntermediateStats,
158158
pub(crate) accessor_idx: usize,
159+
val_cache: Vec<u64>,
159160
}
160161

161162
impl SegmentStatsCollector {
@@ -169,14 +170,16 @@ impl SegmentStatsCollector {
169170
collecting_for,
170171
stats: IntermediateStats::default(),
171172
accessor_idx,
173+
val_cache: Default::default(),
172174
}
173175
}
174176
#[inline]
175177
pub(crate) fn collect_block_with_field(&mut self, docs: &[DocId], field: &Column<u64>) {
176178
if field.get_cardinality() == Cardinality::Full {
177-
for doc in docs {
178-
let val = field.values.get_val(*doc);
179-
let val1 = f64_from_fastfield_u64(val, &self.field_type);
179+
self.val_cache.resize(docs.len(), 0);
180+
field.values.get_vals(docs, &mut self.val_cache);
181+
for val in self.val_cache.iter() {
182+
let val1 = f64_from_fastfield_u64(*val, &self.field_type);
180183
self.stats.collect(val1);
181184
}
182185
} else {
@@ -191,6 +194,7 @@ impl SegmentStatsCollector {
191194
}
192195

193196
impl SegmentAggregationCollector for SegmentStatsCollector {
197+
#[inline]
194198
fn into_intermediate_aggregations_result(
195199
self: Box<Self>,
196200
agg_with_accessor: &AggregationsWithAccessor,
@@ -227,6 +231,7 @@ impl SegmentAggregationCollector for SegmentStatsCollector {
227231
})
228232
}
229233

234+
#[inline]
230235
fn collect(
231236
&mut self,
232237
doc: crate::DocId,

0 commit comments

Comments
 (0)