Skip to content

Commit 2cd44fb

Browse files
committed
Adding _by, by_key, largest variants of k_smallest
1 parent c68e6b4 commit 2cd44fb

File tree

3 files changed

+226
-31
lines changed

3 files changed

+226
-31
lines changed

src/k_smallest.rs

+89-15
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,96 @@
1-
use alloc::collections::BinaryHeap;
2-
use core::cmp::Ord;
1+
use alloc::vec::Vec;
2+
use core::cmp::Ordering;
3+
4+
/// Consumes a given iterator, returning the minimum elements in **ascending** order.
5+
pub(crate) fn k_smallest_general<I, F>(mut iter: I, k: usize, mut comparator: F) -> Vec<I::Item>
6+
where
7+
I: Iterator,
8+
F: FnMut(&I::Item, &I::Item) -> Ordering,
9+
{
10+
/// Sift the element currently at `origin` away from the root until it is properly ordered.
11+
///
12+
/// This will leave **larger** elements closer to the root of the heap.
13+
fn sift_down<T, F>(heap: &mut [T], is_less_than: &mut F, mut origin: usize)
14+
where
15+
F: FnMut(&T, &T) -> bool,
16+
{
17+
#[inline]
18+
fn children_of(n: usize) -> (usize, usize) {
19+
(2 * n + 1, 2 * n + 2)
20+
}
21+
22+
while origin < heap.len() {
23+
let (left_idx, right_idx) = children_of(origin);
24+
if left_idx >= heap.len() {
25+
return;
26+
}
27+
28+
let replacement_idx =
29+
if right_idx < heap.len() && is_less_than(&heap[left_idx], &heap[right_idx]) {
30+
right_idx
31+
} else {
32+
left_idx
33+
};
34+
35+
if is_less_than(&heap[origin], &heap[replacement_idx]) {
36+
heap.swap(origin, replacement_idx);
37+
origin = replacement_idx;
38+
} else {
39+
return;
40+
}
41+
}
42+
}
343

4-
pub(crate) fn k_smallest<T: Ord, I: Iterator<Item = T>>(mut iter: I, k: usize) -> BinaryHeap<T> {
544
if k == 0 {
6-
return BinaryHeap::new();
45+
return Vec::new();
746
}
47+
let mut storage: Vec<I::Item> = iter.by_ref().take(k).collect();
848

9-
let mut heap = iter.by_ref().take(k).collect::<BinaryHeap<_>>();
49+
let mut is_less_than = move |a: &_, b: &_| comparator(a, b) == Ordering::Less;
1050

11-
iter.for_each(|i| {
12-
debug_assert_eq!(heap.len(), k);
13-
// Equivalent to heap.push(min(i, heap.pop())) but more efficient.
14-
// This should be done with a single `.peek_mut().unwrap()` but
15-
// `PeekMut` sifts-down unconditionally on Rust 1.46.0 and prior.
16-
if *heap.peek().unwrap() > i {
17-
*heap.peek_mut().unwrap() = i;
18-
}
19-
});
51+
// Rearrange the storage into a valid heap by reordering from the second-bottom-most layer up to the root.
52+
// Slightly faster than ordering on each insert, but only by a factor of lg(k).
53+
// The resulting heap has the **largest** item on top.
54+
for i in (0..=(storage.len() / 2)).rev() {
55+
sift_down(&mut storage, &mut is_less_than, i);
56+
}
57+
58+
if k == storage.len() {
59+
// If we fill the storage, there may still be iterator elements left so feed them into the heap.
60+
// Also avoids unexpected behaviour with restartable iterators.
61+
iter.for_each(|val| {
62+
if is_less_than(&val, &storage[0]) {
63+
// Treating this as an push-and-pop saves having to write a sift-up implementation.
64+
// https://en.wikipedia.org/wiki/Binary_heap#Insert_then_extract
65+
storage[0] = val;
66+
// We retain the smallest items we've seen so far, but ordered largest first so we can drop the largest efficiently.
67+
sift_down(&mut storage, &mut is_less_than, 0);
68+
}
69+
});
70+
}
71+
72+
// Ultimately the items need to be in least-first, strict order, but the heap is currently largest-first.
73+
// To achieve this, repeatedly,
74+
// 1) "pop" the largest item off the heap into the tail slot of the underlying storage,
75+
// 2) shrink the logical size of the heap by 1,
76+
// 3) restore the heap property over the remaining items.
77+
let mut heap = &mut storage[..];
78+
while heap.len() > 1 {
79+
let last_idx = heap.len() - 1;
80+
heap.swap(0, last_idx);
81+
// Sifting over a truncated slice means that the sifting will not disturb already popped elements.
82+
heap = &mut heap[..last_idx];
83+
sift_down(heap, &mut is_less_than, 0);
84+
}
85+
86+
storage
87+
}
2088

21-
heap
89+
#[inline]
90+
pub(crate) fn key_to_cmp<T, K, F>(key: F) -> impl Fn(&T, &T) -> Ordering
91+
where
92+
F: Fn(&T) -> K,
93+
K: Ord,
94+
{
95+
move |a, b| key(a).cmp(&key(b))
2296
}

src/lib.rs

+98-4
Original file line numberDiff line numberDiff line change
@@ -2950,14 +2950,108 @@ pub trait Itertools: Iterator {
29502950
/// itertools::assert_equal(five_smallest, 0..5);
29512951
/// ```
29522952
#[cfg(feature = "use_alloc")]
2953-
fn k_smallest(self, k: usize) -> VecIntoIter<Self::Item>
2953+
fn k_smallest(mut self, k: usize) -> VecIntoIter<Self::Item>
29542954
where
29552955
Self: Sized,
29562956
Self::Item: Ord,
29572957
{
2958-
crate::k_smallest::k_smallest(self, k)
2959-
.into_sorted_vec()
2960-
.into_iter()
2958+
// The stdlib heap has optimised handling of "holes", which is not included in our heap implementation in k_smallest_general.
2959+
// While the difference is unlikely to have practical impact unless `Self::Item` is very large, this method uses the stdlib structure
2960+
// to maintain performance compared to previous versions of the crate.
2961+
use alloc::collections::BinaryHeap;
2962+
2963+
if k == 0 {
2964+
return Vec::new().into_iter();
2965+
}
2966+
2967+
let mut heap = self.by_ref().take(k).collect::<BinaryHeap<_>>();
2968+
2969+
self.for_each(|i| {
2970+
debug_assert_eq!(heap.len(), k);
2971+
// Equivalent to heap.push(min(i, heap.pop())) but more efficient.
2972+
// This should be done with a single `.peek_mut().unwrap()` but
2973+
// `PeekMut` sifts-down unconditionally on Rust 1.46.0 and prior.
2974+
if *heap.peek().unwrap() > i {
2975+
*heap.peek_mut().unwrap() = i;
2976+
}
2977+
});
2978+
2979+
heap.into_sorted_vec().into_iter()
2980+
}
2981+
2982+
/// Sort the k smallest elements into a new iterator using the provided comparison.
2983+
///
2984+
/// This corresponds to `self.sorted_by(cmp).take(k)` in the same way that
2985+
/// [Itertools::k_smallest] corresponds to `self.sorted().take(k)`, in both semantics and complexity.
2986+
/// Particularly, a custom heap implementation ensures the comparison is not cloned.
2987+
#[cfg(feature = "use_alloc")]
2988+
fn k_smallest_by<F>(self, k: usize, cmp: F) -> VecIntoIter<Self::Item>
2989+
where
2990+
Self: Sized,
2991+
F: Fn(&Self::Item, &Self::Item) -> Ordering,
2992+
{
2993+
k_smallest::k_smallest_general(self, k, cmp).into_iter()
2994+
}
2995+
2996+
/// Return the elements producing the k smallest outputs of the provided function
2997+
///
2998+
/// This corresponds to `self.sorted_by_key(cmp).take(k)` in the same way that
2999+
/// [Itertools::k_smallest] corresponds to `self.sorted().take(k)`, in both semantics and time complexity.
3000+
#[cfg(feature = "use_alloc")]
3001+
fn k_smallest_by_key<F, K>(self, k: usize, key: F) -> VecIntoIter<Self::Item>
3002+
where
3003+
Self: Sized,
3004+
F: Fn(&Self::Item) -> K,
3005+
K: Ord,
3006+
{
3007+
self.k_smallest_by(k, k_smallest::key_to_cmp(key))
3008+
}
3009+
3010+
/// Sort the k largest elements into a new iterator, in descending order.
3011+
/// Semantically equivalent to `k_smallest` with a reversed `Ord`
3012+
/// However, this is implemented by way of a custom binary heap
3013+
/// which does not have the same performance characteristics for very large `Self::Item`
3014+
/// ```
3015+
/// use itertools::Itertools;
3016+
///
3017+
/// // A random permutation of 0..15
3018+
/// let numbers = vec![6, 9, 1, 14, 0, 4, 8, 7, 11, 2, 10, 3, 13, 12, 5];
3019+
///
3020+
/// let five_largest = numbers
3021+
/// .into_iter()
3022+
/// .k_largest(5);
3023+
///
3024+
/// itertools::assert_equal(five_largest, vec![14,13,12,11,10]);
3025+
/// ```
3026+
#[cfg(feature = "use_alloc")]
3027+
fn k_largest(self, k: usize) -> VecIntoIter<Self::Item>
3028+
where
3029+
Self: Sized,
3030+
Self::Item: Ord,
3031+
{
3032+
self.k_largest_by(k, Self::Item::cmp)
3033+
}
3034+
3035+
/// Sort the k largest elements into a new iterator using the provided comparison.
3036+
/// Functionally equivalent to `k_smallest_by` with a reversed `Ord`
3037+
#[cfg(feature = "use_alloc")]
3038+
fn k_largest_by<F>(self, k: usize, cmp: F) -> VecIntoIter<Self::Item>
3039+
where
3040+
Self: Sized,
3041+
F: Fn(&Self::Item, &Self::Item) -> Ordering,
3042+
{
3043+
self.k_smallest_by(k, move |a, b| cmp(b, a))
3044+
}
3045+
3046+
/// Return the elements producing the k largest outputs of the provided function
3047+
#[cfg(feature = "use_alloc")]
3048+
fn k_largest_by_key<F, K>(self, k: usize, key: F) -> VecIntoIter<Self::Item>
3049+
where
3050+
Self: Sized,
3051+
F: Fn(&Self::Item) -> K,
3052+
K: Ord,
3053+
{
3054+
self.k_largest_by(k, k_smallest::key_to_cmp(key))
29613055
}
29623056

29633057
/// Collect all iterator elements into one of two

tests/test_std.rs

+39-12
Original file line numberDiff line numberDiff line change
@@ -492,23 +492,50 @@ fn sorted_by() {
492492
}
493493

494494
qc::quickcheck! {
495-
fn k_smallest_range(n: u64, m: u16, k: u16) -> () {
495+
fn k_smallest_range(n: i64, m: u16, k: u16) -> () {
496496
// u16 is used to constrain k and m to 0..2¹⁶,
497497
// otherwise the test could use too much memory.
498-
let (k, m) = (k as u64, m as u64);
498+
let (k, m) = (k as usize, m as u64);
499499

500+
let mut v: Vec<_> = (n..n.saturating_add(m as _)).collect();
500501
// Generate a random permutation of n..n+m
501-
let i = {
502-
let mut v: Vec<u64> = (n..n.saturating_add(m)).collect();
503-
v.shuffle(&mut thread_rng());
504-
v.into_iter()
505-
};
502+
v.shuffle(&mut thread_rng());
503+
504+
// Construct the right answers for the top and bottom elements
505+
let mut sorted = v.clone();
506+
sorted.sort();
507+
// how many elements are we checking
508+
let num_elements = min(k, m as _);
509+
510+
// Compute the top and bottom k in various combinations
511+
let smallest = v.iter().cloned().k_smallest(k);
512+
let smallest_by = v.iter().cloned().k_smallest_by(k, Ord::cmp);
513+
let smallest_by_key = v.iter().cloned().k_smallest_by_key(k, |&x| x);
514+
515+
let largest = v.iter().cloned().k_largest(k);
516+
let largest_by = v.iter().cloned().k_largest_by(k, Ord::cmp);
517+
let largest_by_key = v.iter().cloned().k_largest_by_key(k, |&x| x);
518+
519+
// Check the variations produce the same answers and that they're right
520+
for (a,b,c,d) in izip!(
521+
sorted[..num_elements].iter().cloned(),
522+
smallest,
523+
smallest_by,
524+
smallest_by_key) {
525+
assert_eq!(a,b);
526+
assert_eq!(a,c);
527+
assert_eq!(a,d);
528+
}
506529

507-
// Check that taking the k smallest elements yields n..n+min(k, m)
508-
it::assert_equal(
509-
i.k_smallest(k as usize),
510-
n..n.saturating_add(min(k, m))
511-
);
530+
for (a,b,c,d) in izip!(
531+
sorted[sorted.len()-num_elements..].iter().rev().cloned(),
532+
largest,
533+
largest_by,
534+
largest_by_key) {
535+
assert_eq!(a,b);
536+
assert_eq!(a,c);
537+
assert_eq!(a,d);
538+
}
512539
}
513540
}
514541

0 commit comments

Comments
 (0)