Skip to content

Commit ebc7812

Browse files
authored
Add BytesFilterCollector to support filtering based on a bytes fast field (#2075)
* Do some Clippy- and Cargo-related boy-scouting. * Add BytesFilterCollector to support filtering based on a bytes fast field This is basically a copy of the existing FilterCollector but modified and specialised to work on a bytes fast field. * Changed semantics of filter collectors to consider multi-valued fields
1 parent 8199aa7 commit ebc7812

File tree

8 files changed

+223
-47
lines changed

8 files changed

+223
-47
lines changed

bitpacker/src/filter_vec/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::ops::RangeInclusive;
22

3-
#[cfg(any(target_arch = "x86_64"))]
3+
#[cfg(target_arch = "x86_64")]
44
mod avx2;
55

66
mod scalar;

columnar/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ edition = "2021"
55
license = "MIT"
66
homepage = "https://github.com/quickwit-oss/tantivy"
77
repository = "https://github.com/quickwit-oss/tantivy"
8-
desciption = "column oriented storage for tantivy"
8+
description = "column oriented storage for tantivy"
99
categories = ["database-implementations", "data-structures", "compression"]
1010

1111
[dependencies]

src/collector/filter_collector_wrapper.rs

+214-40
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,35 @@
66
//
77
// Of course, you can have a look at the tantivy's built-in collectors
88
// such as the `CountCollector` for more examples.
9-
10-
// ---
11-
// Importing tantivy...
9+
use std::fmt::Debug;
1210
use std::marker::PhantomData;
13-
use std::sync::Arc;
1411

15-
use columnar::{ColumnValues, DynamicColumn, HasAssociatedColumnType};
12+
use columnar::{BytesColumn, Column, DynamicColumn, HasAssociatedColumnType};
1613

1714
use crate::collector::{Collector, SegmentCollector};
1815
use crate::schema::Field;
19-
use crate::{Score, SegmentReader, TantivyError};
16+
use crate::{DocId, Score, SegmentReader, TantivyError};
2017

2118
/// The `FilterCollector` filters docs using a fast field value and a predicate.
22-
/// Only the documents for which the predicate returned "true" will be passed on to the next
23-
/// collector.
19+
///
20+
/// Only the documents containing at least one value for which the predicate returns `true`
21+
/// will be passed on to the next collector.
22+
///
23+
/// In other words,
24+
/// - documents with no values are filtered out.
25+
/// - documents with several values are accepted if at least one value matches the predicate.
26+
///
2427
///
2528
/// ```rust
2629
/// use tantivy::collector::{TopDocs, FilterCollector};
2730
/// use tantivy::query::QueryParser;
28-
/// use tantivy::schema::{Schema, TEXT, INDEXED, FAST};
31+
/// use tantivy::schema::{Schema, TEXT, FAST};
2932
/// use tantivy::{doc, DocAddress, Index};
3033
///
3134
/// # fn main() -> tantivy::Result<()> {
3235
/// let mut schema_builder = Schema::builder();
3336
/// let title = schema_builder.add_text_field("title", TEXT);
34-
/// let price = schema_builder.add_u64_field("price", INDEXED | FAST);
37+
/// let price = schema_builder.add_u64_field("price", FAST);
3538
/// let schema = schema_builder.build();
3639
/// let index = Index::create_in_ram(schema);
3740
///
@@ -47,20 +50,24 @@ use crate::{Score, SegmentReader, TantivyError};
4750
///
4851
/// let query_parser = QueryParser::for_index(&index, vec![title]);
4952
/// let query = query_parser.parse_query("diary")?;
50-
/// let no_filter_collector = FilterCollector::new(price, &|value: u64| value > 20_120u64, TopDocs::with_limit(2));
53+
/// let no_filter_collector = FilterCollector::new(price, |value: u64| value > 20_120u64, TopDocs::with_limit(2));
5154
/// let top_docs = searcher.search(&query, &no_filter_collector)?;
5255
///
5356
/// assert_eq!(top_docs.len(), 1);
5457
/// assert_eq!(top_docs[0].1, DocAddress::new(0, 1));
5558
///
56-
/// let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new(price, &|value| value < 5u64, TopDocs::with_limit(2));
59+
/// let filter_all_collector: FilterCollector<_, _, u64> = FilterCollector::new(price, |value| value < 5u64, TopDocs::with_limit(2));
5760
/// let filtered_top_docs = searcher.search(&query, &filter_all_collector)?;
5861
///
5962
/// assert_eq!(filtered_top_docs.len(), 0);
6063
/// # Ok(())
6164
/// # }
6265
/// ```
63-
pub struct FilterCollector<TCollector, TPredicate, TPredicateValue: Default>
66+
///
67+
/// Note that this is limited to fast fields which implement the [`FastValue`] trait,
68+
/// e.g. `u64` but not `&[u8]`. To filter based on a bytes fast field,
69+
/// use a [`BytesFilterCollector`] instead.
70+
pub struct FilterCollector<TCollector, TPredicate, TPredicateValue>
6471
where TPredicate: 'static + Clone
6572
{
6673
field: Field,
@@ -69,19 +76,15 @@ where TPredicate: 'static + Clone
6976
t_predicate_value: PhantomData<TPredicateValue>,
7077
}
7178

72-
impl<TCollector, TPredicate, TPredicateValue: Default>
79+
impl<TCollector, TPredicate, TPredicateValue>
7380
FilterCollector<TCollector, TPredicate, TPredicateValue>
7481
where
7582
TCollector: Collector + Send + Sync,
7683
TPredicate: Fn(TPredicateValue) -> bool + Send + Sync + Clone,
7784
{
78-
/// Create a new FilterCollector.
79-
pub fn new(
80-
field: Field,
81-
predicate: TPredicate,
82-
collector: TCollector,
83-
) -> FilterCollector<TCollector, TPredicate, TPredicateValue> {
84-
FilterCollector {
85+
/// Create a new `FilterCollector`.
86+
pub fn new(field: Field, predicate: TPredicate, collector: TCollector) -> Self {
87+
Self {
8588
field,
8689
predicate,
8790
collector,
@@ -90,16 +93,14 @@ where
9093
}
9194
}
9295

93-
impl<TCollector, TPredicate, TPredicateValue: Default> Collector
96+
impl<TCollector, TPredicate, TPredicateValue> Collector
9497
for FilterCollector<TCollector, TPredicate, TPredicateValue>
9598
where
9699
TCollector: Collector + Send + Sync,
97100
TPredicate: 'static + Fn(TPredicateValue) -> bool + Send + Sync + Clone,
98101
TPredicateValue: HasAssociatedColumnType,
99102
DynamicColumn: Into<Option<columnar::Column<TPredicateValue>>>,
100103
{
101-
// That's the type of our result.
102-
// Our standard deviation will be a float.
103104
type Fruit = TCollector::Fruit;
104105

105106
type Child = FilterSegmentCollector<TCollector::Child, TPredicate, TPredicateValue>;
@@ -108,7 +109,7 @@ where
108109
&self,
109110
segment_local_id: u32,
110111
segment_reader: &SegmentReader,
111-
) -> crate::Result<FilterSegmentCollector<TCollector::Child, TPredicate, TPredicateValue>> {
112+
) -> crate::Result<Self::Child> {
112113
let schema = segment_reader.schema();
113114
let field_entry = schema.get_field_entry(self.field);
114115
if !field_entry.is_fast() {
@@ -118,16 +119,16 @@ where
118119
)));
119120
}
120121

121-
let fast_field_reader = segment_reader
122+
let column_opt = segment_reader
122123
.fast_fields()
123-
.column_first_or_default(schema.get_field_name(self.field))?;
124+
.column_opt(field_entry.name())?;
124125

125126
let segment_collector = self
126127
.collector
127128
.for_segment(segment_local_id, segment_reader)?;
128129

129130
Ok(FilterSegmentCollector {
130-
fast_field_reader,
131+
column_opt,
131132
segment_collector,
132133
predicate: self.predicate.clone(),
133134
t_predicate_value: PhantomData,
@@ -146,35 +147,208 @@ where
146147
}
147148
}
148149

149-
pub struct FilterSegmentCollector<TSegmentCollector, TPredicate, TPredicateValue>
150-
where
151-
TPredicate: 'static,
152-
DynamicColumn: Into<Option<columnar::Column<TPredicateValue>>>,
153-
{
154-
fast_field_reader: Arc<dyn ColumnValues<TPredicateValue>>,
150+
pub struct FilterSegmentCollector<TSegmentCollector, TPredicate, TPredicateValue> {
151+
column_opt: Option<Column<TPredicateValue>>,
155152
segment_collector: TSegmentCollector,
156153
predicate: TPredicate,
157154
t_predicate_value: PhantomData<TPredicateValue>,
158155
}
159156

157+
impl<TSegmentCollector, TPredicate, TPredicateValue>
158+
FilterSegmentCollector<TSegmentCollector, TPredicate, TPredicateValue>
159+
where
160+
TPredicateValue: PartialOrd + Copy + Debug + Send + Sync + 'static,
161+
TPredicate: 'static + Fn(TPredicateValue) -> bool + Send + Sync,
162+
{
163+
#[inline]
164+
fn accept_document(&self, doc_id: DocId) -> bool {
165+
if let Some(column) = &self.column_opt {
166+
for val in column.values_for_doc(doc_id) {
167+
if (self.predicate)(val) {
168+
return true;
169+
}
170+
}
171+
}
172+
false
173+
}
174+
}
175+
160176
impl<TSegmentCollector, TPredicate, TPredicateValue> SegmentCollector
161177
for FilterSegmentCollector<TSegmentCollector, TPredicate, TPredicateValue>
162178
where
163179
TSegmentCollector: SegmentCollector,
164180
TPredicateValue: HasAssociatedColumnType,
165-
TPredicate: 'static + Fn(TPredicateValue) -> bool + Send + Sync,
166-
DynamicColumn: Into<Option<columnar::Column<TPredicateValue>>>,
181+
TPredicate: 'static + Fn(TPredicateValue) -> bool + Send + Sync, /* DynamicColumn: Into<Option<columnar::Column<TPredicateValue>>> */
182+
{
183+
type Fruit = TSegmentCollector::Fruit;
184+
185+
fn collect(&mut self, doc: u32, score: Score) {
186+
if self.accept_document(doc) {
187+
self.segment_collector.collect(doc, score);
188+
}
189+
}
190+
191+
fn harvest(self) -> TSegmentCollector::Fruit {
192+
self.segment_collector.harvest()
193+
}
194+
}
195+
196+
/// A variant of the [`FilterCollector`] specialized for bytes fast fields, i.e.
197+
/// it transparently wraps an inner [`Collector`] but filters documents
198+
/// based on the result of applying the predicate to the bytes fast field.
199+
///
200+
/// A document is accepted if and only if the predicate returns `true` for at least one value.
201+
///
202+
/// In other words,
203+
/// - documents with no values are filtered out.
204+
/// - documents with several values are accepted if at least one value matches the predicate.
205+
///
206+
/// ```rust
207+
/// use tantivy::collector::{TopDocs, BytesFilterCollector};
208+
/// use tantivy::query::QueryParser;
209+
/// use tantivy::schema::{Schema, TEXT, FAST};
210+
/// use tantivy::{doc, DocAddress, Index};
211+
///
212+
/// # fn main() -> tantivy::Result<()> {
213+
/// let mut schema_builder = Schema::builder();
214+
/// let title = schema_builder.add_text_field("title", TEXT);
215+
/// let barcode = schema_builder.add_bytes_field("barcode", FAST);
216+
/// let schema = schema_builder.build();
217+
/// let index = Index::create_in_ram(schema);
218+
///
219+
/// let mut index_writer = index.writer_with_num_threads(1, 10_000_000)?;
220+
/// index_writer.add_document(doc!(title => "The Name of the Wind", barcode => &b"010101"[..]))?;
221+
/// index_writer.add_document(doc!(title => "The Diary of Muadib", barcode => &b"110011"[..]))?;
222+
/// index_writer.add_document(doc!(title => "A Dairy Cow", barcode => &b"110111"[..]))?;
223+
/// index_writer.add_document(doc!(title => "The Diary of a Young Girl", barcode => &b"011101"[..]))?;
224+
/// index_writer.add_document(doc!(title => "Bridget Jones's Diary"))?;
225+
/// index_writer.commit()?;
226+
///
227+
/// let reader = index.reader()?;
228+
/// let searcher = reader.searcher();
229+
///
230+
/// let query_parser = QueryParser::for_index(&index, vec![title]);
231+
/// let query = query_parser.parse_query("diary")?;
232+
/// let filter_collector = BytesFilterCollector::new(barcode, |bytes: &[u8]| bytes.starts_with(b"01"), TopDocs::with_limit(2));
233+
/// let top_docs = searcher.search(&query, &filter_collector)?;
234+
///
235+
/// assert_eq!(top_docs.len(), 1);
236+
/// assert_eq!(top_docs[0].1, DocAddress::new(0, 3));
237+
/// # Ok(())
238+
/// # }
239+
/// ```
240+
pub struct BytesFilterCollector<TCollector, TPredicate>
241+
where TPredicate: 'static + Clone
242+
{
243+
field: Field,
244+
collector: TCollector,
245+
predicate: TPredicate,
246+
}
247+
248+
impl<TCollector, TPredicate> BytesFilterCollector<TCollector, TPredicate>
249+
where
250+
TCollector: Collector + Send + Sync,
251+
TPredicate: Fn(&[u8]) -> bool + Send + Sync + Clone,
252+
{
253+
/// Create a new `BytesFilterCollector`.
254+
pub fn new(field: Field, predicate: TPredicate, collector: TCollector) -> Self {
255+
Self {
256+
field,
257+
predicate,
258+
collector,
259+
}
260+
}
261+
}
262+
263+
impl<TCollector, TPredicate> Collector for BytesFilterCollector<TCollector, TPredicate>
264+
where
265+
TCollector: Collector + Send + Sync,
266+
TPredicate: 'static + Fn(&[u8]) -> bool + Send + Sync + Clone,
267+
{
268+
type Fruit = TCollector::Fruit;
269+
270+
type Child = BytesFilterSegmentCollector<TCollector::Child, TPredicate>;
271+
272+
fn for_segment(
273+
&self,
274+
segment_local_id: u32,
275+
segment_reader: &SegmentReader,
276+
) -> crate::Result<Self::Child> {
277+
let schema = segment_reader.schema();
278+
let field_name = schema.get_field_name(self.field);
279+
280+
let column_opt = segment_reader.fast_fields().bytes(field_name)?;
281+
282+
let segment_collector = self
283+
.collector
284+
.for_segment(segment_local_id, segment_reader)?;
285+
286+
Ok(BytesFilterSegmentCollector {
287+
column_opt,
288+
segment_collector,
289+
predicate: self.predicate.clone(),
290+
buffer: Vec::new(),
291+
})
292+
}
293+
294+
fn requires_scoring(&self) -> bool {
295+
self.collector.requires_scoring()
296+
}
297+
298+
fn merge_fruits(
299+
&self,
300+
segment_fruits: Vec<<TCollector::Child as SegmentCollector>::Fruit>,
301+
) -> crate::Result<TCollector::Fruit> {
302+
self.collector.merge_fruits(segment_fruits)
303+
}
304+
}
305+
306+
pub struct BytesFilterSegmentCollector<TSegmentCollector, TPredicate>
307+
where TPredicate: 'static
308+
{
309+
column_opt: Option<BytesColumn>,
310+
segment_collector: TSegmentCollector,
311+
predicate: TPredicate,
312+
buffer: Vec<u8>,
313+
}
314+
315+
impl<TSegmentCollector, TPredicate> BytesFilterSegmentCollector<TSegmentCollector, TPredicate>
316+
where
317+
TSegmentCollector: SegmentCollector,
318+
TPredicate: 'static + Fn(&[u8]) -> bool + Send + Sync,
319+
{
320+
#[inline]
321+
fn accept_document(&mut self, doc_id: DocId) -> bool {
322+
if let Some(column) = &self.column_opt {
323+
for ord in column.term_ords(doc_id) {
324+
self.buffer.clear();
325+
326+
let found = column.ord_to_bytes(ord, &mut self.buffer).unwrap_or(false);
327+
328+
if found && (self.predicate)(&self.buffer) {
329+
return true;
330+
}
331+
}
332+
}
333+
false
334+
}
335+
}
336+
337+
impl<TSegmentCollector, TPredicate> SegmentCollector
338+
for BytesFilterSegmentCollector<TSegmentCollector, TPredicate>
339+
where
340+
TSegmentCollector: SegmentCollector,
341+
TPredicate: 'static + Fn(&[u8]) -> bool + Send + Sync,
167342
{
168343
type Fruit = TSegmentCollector::Fruit;
169344

170345
fn collect(&mut self, doc: u32, score: Score) {
171-
let value = self.fast_field_reader.get_val(doc);
172-
if (self.predicate)(value) {
173-
self.segment_collector.collect(doc, score)
346+
if self.accept_document(doc) {
347+
self.segment_collector.collect(doc, score);
174348
}
175349
}
176350

177-
fn harvest(self) -> <TSegmentCollector as SegmentCollector>::Fruit {
351+
fn harvest(self) -> TSegmentCollector::Fruit {
178352
self.segment_collector.harvest()
179353
}
180354
}

src/collector/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ mod docset_collector;
112112
pub use self::docset_collector::DocSetCollector;
113113

114114
mod filter_collector_wrapper;
115-
pub use self::filter_collector_wrapper::FilterCollector;
115+
pub use self::filter_collector_wrapper::{BytesFilterCollector, FilterCollector};
116116

117117
/// `Fruit` is the type for the result of our collection.
118118
/// e.g. `usize` for the `Count` collector.

0 commit comments

Comments
 (0)