Skip to content

Commit 7c0186a

Browse files
authored
Rollup merge of rust-lang#65226 - ssomers:master, r=bluss
BTreeSet symmetric_difference & union optimized No scalability changes, but: - Grew the cmp_opt function (shared by symmetric_difference & union) into a MergeIter, with less memory overhead than the pairs of Peekable iterators now, speeding up ~20% on my machine (not so clear on Travis though, I actually switched it off there because it wasn't consistent about identical code). Mainly meant to improve readability by sharing code, though it does end up using more lines of code. Extending and reusing the MergeIter in btree_map might be better, but I'm not sure that's possible or desirable. This MergeIter probably pretends to be more generic than it is, yet doesn't declare to be an iterator because there's no need to, it's only there to help construct genuine iterators SymmetricDifference & Union. - Compact the code of rust-lang#64820 by moving if/else into match guards. r? @bluss
2 parents ef8ac78 + 5697432 commit 7c0186a

File tree

2 files changed

+144
-121
lines changed

2 files changed

+144
-121
lines changed

src/liballoc/collections/btree/set.rs

+119-120
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// to TreeMap
33

44
use core::borrow::Borrow;
5-
use core::cmp::Ordering::{self, Less, Greater, Equal};
5+
use core::cmp::Ordering::{Less, Greater, Equal};
66
use core::cmp::{max, min};
77
use core::fmt::{self, Debug};
88
use core::iter::{Peekable, FromIterator, FusedIterator};
@@ -109,6 +109,77 @@ pub struct Range<'a, T: 'a> {
109109
iter: btree_map::Range<'a, T, ()>,
110110
}
111111

112+
/// Core of SymmetricDifference and Union.
113+
/// More efficient than btree.map.MergeIter,
114+
/// and crucially for SymmetricDifference, nexts() reports on both sides.
115+
#[derive(Clone)]
116+
struct MergeIterInner<I>
117+
where I: Iterator,
118+
I::Item: Copy,
119+
{
120+
a: I,
121+
b: I,
122+
peeked: Option<MergeIterPeeked<I>>,
123+
}
124+
125+
#[derive(Copy, Clone, Debug)]
126+
enum MergeIterPeeked<I: Iterator> {
127+
A(I::Item),
128+
B(I::Item),
129+
}
130+
131+
impl<I> MergeIterInner<I>
132+
where I: ExactSizeIterator + FusedIterator,
133+
I::Item: Copy + Ord,
134+
{
135+
fn new(a: I, b: I) -> Self {
136+
MergeIterInner { a, b, peeked: None }
137+
}
138+
139+
fn nexts(&mut self) -> (Option<I::Item>, Option<I::Item>) {
140+
let mut a_next = match self.peeked {
141+
Some(MergeIterPeeked::A(next)) => Some(next),
142+
_ => self.a.next(),
143+
};
144+
let mut b_next = match self.peeked {
145+
Some(MergeIterPeeked::B(next)) => Some(next),
146+
_ => self.b.next(),
147+
};
148+
let ord = match (a_next, b_next) {
149+
(None, None) => Equal,
150+
(_, None) => Less,
151+
(None, _) => Greater,
152+
(Some(a1), Some(b1)) => a1.cmp(&b1),
153+
};
154+
self.peeked = match ord {
155+
Less => b_next.take().map(MergeIterPeeked::B),
156+
Equal => None,
157+
Greater => a_next.take().map(MergeIterPeeked::A),
158+
};
159+
(a_next, b_next)
160+
}
161+
162+
fn lens(&self) -> (usize, usize) {
163+
match self.peeked {
164+
Some(MergeIterPeeked::A(_)) => (1 + self.a.len(), self.b.len()),
165+
Some(MergeIterPeeked::B(_)) => (self.a.len(), 1 + self.b.len()),
166+
_ => (self.a.len(), self.b.len()),
167+
}
168+
}
169+
}
170+
171+
impl<I> Debug for MergeIterInner<I>
172+
where I: Iterator + Debug,
173+
I::Item: Copy + Debug,
174+
{
175+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176+
f.debug_tuple("MergeIterInner")
177+
.field(&self.a)
178+
.field(&self.b)
179+
.finish()
180+
}
181+
}
182+
112183
/// A lazy iterator producing elements in the difference of `BTreeSet`s.
113184
///
114185
/// This `struct` is created by the [`difference`] method on [`BTreeSet`].
@@ -120,6 +191,7 @@ pub struct Range<'a, T: 'a> {
120191
pub struct Difference<'a, T: 'a> {
121192
inner: DifferenceInner<'a, T>,
122193
}
194+
#[derive(Debug)]
123195
enum DifferenceInner<'a, T: 'a> {
124196
Stitch {
125197
// iterate all of self and some of other, spotting matches along the way
@@ -137,21 +209,7 @@ enum DifferenceInner<'a, T: 'a> {
137209
#[stable(feature = "collection_debug", since = "1.17.0")]
138210
impl<T: fmt::Debug> fmt::Debug for Difference<'_, T> {
139211
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140-
match &self.inner {
141-
DifferenceInner::Stitch {
142-
self_iter,
143-
other_iter,
144-
} => f
145-
.debug_tuple("Difference")
146-
.field(&self_iter)
147-
.field(&other_iter)
148-
.finish(),
149-
DifferenceInner::Search {
150-
self_iter,
151-
other_set: _,
152-
} => f.debug_tuple("Difference").field(&self_iter).finish(),
153-
DifferenceInner::Iterate(iter) => f.debug_tuple("Difference").field(&iter).finish(),
154-
}
212+
f.debug_tuple("Difference").field(&self.inner).finish()
155213
}
156214
}
157215

@@ -163,18 +221,12 @@ impl<T: fmt::Debug> fmt::Debug for Difference<'_, T> {
163221
/// [`BTreeSet`]: struct.BTreeSet.html
164222
/// [`symmetric_difference`]: struct.BTreeSet.html#method.symmetric_difference
165223
#[stable(feature = "rust1", since = "1.0.0")]
166-
pub struct SymmetricDifference<'a, T: 'a> {
167-
a: Peekable<Iter<'a, T>>,
168-
b: Peekable<Iter<'a, T>>,
169-
}
224+
pub struct SymmetricDifference<'a, T: 'a>(MergeIterInner<Iter<'a, T>>);
170225

171226
#[stable(feature = "collection_debug", since = "1.17.0")]
172227
impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
173228
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174-
f.debug_tuple("SymmetricDifference")
175-
.field(&self.a)
176-
.field(&self.b)
177-
.finish()
229+
f.debug_tuple("SymmetricDifference").field(&self.0).finish()
178230
}
179231
}
180232

@@ -189,6 +241,7 @@ impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
189241
pub struct Intersection<'a, T: 'a> {
190242
inner: IntersectionInner<'a, T>,
191243
}
244+
#[derive(Debug)]
192245
enum IntersectionInner<'a, T: 'a> {
193246
Stitch {
194247
// iterate similarly sized sets jointly, spotting matches along the way
@@ -206,23 +259,7 @@ enum IntersectionInner<'a, T: 'a> {
206259
#[stable(feature = "collection_debug", since = "1.17.0")]
207260
impl<T: fmt::Debug> fmt::Debug for Intersection<'_, T> {
208261
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209-
match &self.inner {
210-
IntersectionInner::Stitch {
211-
a,
212-
b,
213-
} => f
214-
.debug_tuple("Intersection")
215-
.field(&a)
216-
.field(&b)
217-
.finish(),
218-
IntersectionInner::Search {
219-
small_iter,
220-
large_set: _,
221-
} => f.debug_tuple("Intersection").field(&small_iter).finish(),
222-
IntersectionInner::Answer(answer) => {
223-
f.debug_tuple("Intersection").field(&answer).finish()
224-
}
225-
}
262+
f.debug_tuple("Intersection").field(&self.inner).finish()
226263
}
227264
}
228265

@@ -234,18 +271,12 @@ impl<T: fmt::Debug> fmt::Debug for Intersection<'_, T> {
234271
/// [`BTreeSet`]: struct.BTreeSet.html
235272
/// [`union`]: struct.BTreeSet.html#method.union
236273
#[stable(feature = "rust1", since = "1.0.0")]
237-
pub struct Union<'a, T: 'a> {
238-
a: Peekable<Iter<'a, T>>,
239-
b: Peekable<Iter<'a, T>>,
240-
}
274+
pub struct Union<'a, T: 'a>(MergeIterInner<Iter<'a, T>>);
241275

242276
#[stable(feature = "collection_debug", since = "1.17.0")]
243277
impl<T: fmt::Debug> fmt::Debug for Union<'_, T> {
244278
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245-
f.debug_tuple("Union")
246-
.field(&self.a)
247-
.field(&self.b)
248-
.finish()
279+
f.debug_tuple("Union").field(&self.0).finish()
249280
}
250281
}
251282

@@ -355,19 +386,16 @@ impl<T: Ord> BTreeSet<T> {
355386
self_iter.next_back();
356387
DifferenceInner::Iterate(self_iter)
357388
}
358-
_ => {
359-
if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
360-
DifferenceInner::Search {
361-
self_iter: self.iter(),
362-
other_set: other,
363-
}
364-
} else {
365-
DifferenceInner::Stitch {
366-
self_iter: self.iter(),
367-
other_iter: other.iter().peekable(),
368-
}
389+
_ if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => {
390+
DifferenceInner::Search {
391+
self_iter: self.iter(),
392+
other_set: other,
369393
}
370394
}
395+
_ => DifferenceInner::Stitch {
396+
self_iter: self.iter(),
397+
other_iter: other.iter().peekable(),
398+
},
371399
},
372400
}
373401
}
@@ -396,10 +424,7 @@ impl<T: Ord> BTreeSet<T> {
396424
pub fn symmetric_difference<'a>(&'a self,
397425
other: &'a BTreeSet<T>)
398426
-> SymmetricDifference<'a, T> {
399-
SymmetricDifference {
400-
a: self.iter().peekable(),
401-
b: other.iter().peekable(),
402-
}
427+
SymmetricDifference(MergeIterInner::new(self.iter(), other.iter()))
403428
}
404429

405430
/// Visits the values representing the intersection,
@@ -447,24 +472,22 @@ impl<T: Ord> BTreeSet<T> {
447472
(Greater, _) | (_, Less) => IntersectionInner::Answer(None),
448473
(Equal, _) => IntersectionInner::Answer(Some(self_min)),
449474
(_, Equal) => IntersectionInner::Answer(Some(self_max)),
450-
_ => {
451-
if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
452-
IntersectionInner::Search {
453-
small_iter: self.iter(),
454-
large_set: other,
455-
}
456-
} else if other.len() <= self.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
457-
IntersectionInner::Search {
458-
small_iter: other.iter(),
459-
large_set: self,
460-
}
461-
} else {
462-
IntersectionInner::Stitch {
463-
a: self.iter(),
464-
b: other.iter(),
465-
}
475+
_ if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => {
476+
IntersectionInner::Search {
477+
small_iter: self.iter(),
478+
large_set: other,
479+
}
480+
}
481+
_ if other.len() <= self.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => {
482+
IntersectionInner::Search {
483+
small_iter: other.iter(),
484+
large_set: self,
466485
}
467486
}
487+
_ => IntersectionInner::Stitch {
488+
a: self.iter(),
489+
b: other.iter(),
490+
},
468491
},
469492
}
470493
}
@@ -489,10 +512,7 @@ impl<T: Ord> BTreeSet<T> {
489512
/// ```
490513
#[stable(feature = "rust1", since = "1.0.0")]
491514
pub fn union<'a>(&'a self, other: &'a BTreeSet<T>) -> Union<'a, T> {
492-
Union {
493-
a: self.iter().peekable(),
494-
b: other.iter().peekable(),
495-
}
515+
Union(MergeIterInner::new(self.iter(), other.iter()))
496516
}
497517

498518
/// Clears the set, removing all values.
@@ -1166,15 +1186,6 @@ impl<'a, T> DoubleEndedIterator for Range<'a, T> {
11661186
#[stable(feature = "fused", since = "1.26.0")]
11671187
impl<T> FusedIterator for Range<'_, T> {}
11681188

1169-
/// Compares `x` and `y`, but return `short` if x is None and `long` if y is None
1170-
fn cmp_opt<T: Ord>(x: Option<&T>, y: Option<&T>, short: Ordering, long: Ordering) -> Ordering {
1171-
match (x, y) {
1172-
(None, _) => short,
1173-
(_, None) => long,
1174-
(Some(x1), Some(y1)) => x1.cmp(y1),
1175-
}
1176-
}
1177-
11781189
#[stable(feature = "rust1", since = "1.0.0")]
11791190
impl<T> Clone for Difference<'_, T> {
11801191
fn clone(&self) -> Self {
@@ -1261,10 +1272,7 @@ impl<T: Ord> FusedIterator for Difference<'_, T> {}
12611272
#[stable(feature = "rust1", since = "1.0.0")]
12621273
impl<T> Clone for SymmetricDifference<'_, T> {
12631274
fn clone(&self) -> Self {
1264-
SymmetricDifference {
1265-
a: self.a.clone(),
1266-
b: self.b.clone(),
1267-
}
1275+
SymmetricDifference(self.0.clone())
12681276
}
12691277
}
12701278
#[stable(feature = "rust1", since = "1.0.0")]
@@ -1273,19 +1281,19 @@ impl<'a, T: Ord> Iterator for SymmetricDifference<'a, T> {
12731281

12741282
fn next(&mut self) -> Option<&'a T> {
12751283
loop {
1276-
match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) {
1277-
Less => return self.a.next(),
1278-
Equal => {
1279-
self.a.next();
1280-
self.b.next();
1281-
}
1282-
Greater => return self.b.next(),
1284+
let (a_next, b_next) = self.0.nexts();
1285+
if a_next.and(b_next).is_none() {
1286+
return a_next.or(b_next);
12831287
}
12841288
}
12851289
}
12861290

12871291
fn size_hint(&self) -> (usize, Option<usize>) {
1288-
(0, Some(self.a.len() + self.b.len()))
1292+
let (a_len, b_len) = self.0.lens();
1293+
// No checked_add, because even if a and b refer to the same set,
1294+
// and T is an empty type, the storage overhead of sets limits
1295+
// the number of elements to less than half the range of usize.
1296+
(0, Some(a_len + b_len))
12891297
}
12901298
}
12911299

@@ -1311,7 +1319,7 @@ impl<T> Clone for Intersection<'_, T> {
13111319
small_iter: small_iter.clone(),
13121320
large_set,
13131321
},
1314-
IntersectionInner::Answer(answer) => IntersectionInner::Answer(answer.clone()),
1322+
IntersectionInner::Answer(answer) => IntersectionInner::Answer(*answer),
13151323
},
13161324
}
13171325
}
@@ -1365,30 +1373,21 @@ impl<T: Ord> FusedIterator for Intersection<'_, T> {}
13651373
#[stable(feature = "rust1", since = "1.0.0")]
13661374
impl<T> Clone for Union<'_, T> {
13671375
fn clone(&self) -> Self {
1368-
Union {
1369-
a: self.a.clone(),
1370-
b: self.b.clone(),
1371-
}
1376+
Union(self.0.clone())
13721377
}
13731378
}
13741379
#[stable(feature = "rust1", since = "1.0.0")]
13751380
impl<'a, T: Ord> Iterator for Union<'a, T> {
13761381
type Item = &'a T;
13771382

13781383
fn next(&mut self) -> Option<&'a T> {
1379-
match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) {
1380-
Less => self.a.next(),
1381-
Equal => {
1382-
self.b.next();
1383-
self.a.next()
1384-
}
1385-
Greater => self.b.next(),
1386-
}
1384+
let (a_next, b_next) = self.0.nexts();
1385+
a_next.or(b_next)
13871386
}
13881387

13891388
fn size_hint(&self) -> (usize, Option<usize>) {
1390-
let a_len = self.a.len();
1391-
let b_len = self.b.len();
1389+
let (a_len, b_len) = self.0.lens();
1390+
// No checked_add - see SymmetricDifference::size_hint.
13921391
(max(a_len, b_len), Some(a_len + b_len))
13931392
}
13941393
}

0 commit comments

Comments
 (0)