Skip to content

Commit b32267f

Browse files
committed
Auto merge of #45595 - scottmcm:iter-try-fold, r=dtolnay
Short-circuiting internal iteration with Iterator::try_fold & try_rfold These are the core methods in terms of which the other methods (`fold`, `all`, `any`, `find`, `position`, `nth`, ...) can be implemented, allowing Iterator implementors to get the full goodness of internal iteration by only overriding one method (per direction). Based off the `Try` trait, so works with both `Result` and `Option` (:tada: #42526). The `try_fold` rustdoc examples use `Option` and the `try_rfold` ones use `Result`. AKA continuing in the vein of PRs #44682 & #44856 for more of `Iterator`. New bench following the pattern from the latter of those: ``` test iter::bench_take_while_chain_ref_sum ... bench: 1,130,843 ns/iter (+/- 25,110) test iter::bench_take_while_chain_sum ... bench: 362,530 ns/iter (+/- 391) ``` I also ran the benches without the `fold` & `rfold` overrides to test their new default impls, with basically no change. I left them there, though, to take advantage of existing overrides and because `AlwaysOk` has some sub-optimality due to #43278 (which 45225 should fix). If you're wondering why there are three type parameters, see issue #45462 Thanks for @bluss for the [original IRLO thread](https://internals.rust-lang.org/t/pre-rfc-fold-ok-is-composable-internal-iteration/4434) and the rfold PR and to @cuviper for adding so many folds, [encouraging me](#45379 (comment)) to make this PR, and finding a catastrophic bug in a pre-review.
2 parents 3bcb00d + b5dba91 commit b32267f

File tree

8 files changed

+929
-182
lines changed

8 files changed

+929
-182
lines changed

src/libcore/benches/iter.rs

+6
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,9 @@ bench_sums! {
275275
bench_skip_while_chain_ref_sum,
276276
(0i64..1000000).chain(0..1000000).skip_while(|&x| x < 1000)
277277
}
278+
279+
bench_sums! {
280+
bench_take_while_chain_sum,
281+
bench_take_while_chain_ref_sum,
282+
(0i64..1000000).chain(1000000..).take_while(|&x| x < 1111111)
283+
}

src/libcore/iter/iterator.rs

+138-56
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
// except according to those terms.
1010

1111
use cmp::Ordering;
12+
use ops::Try;
1213

14+
use super::{AlwaysOk, LoopState};
1315
use super::{Chain, Cycle, Cloned, Enumerate, Filter, FilterMap, FlatMap, Fuse};
1416
use super::{Inspect, Map, Peekable, Scan, Skip, SkipWhile, StepBy, Take, TakeWhile, Rev};
1517
use super::{Zip, Sum, Product};
@@ -251,12 +253,8 @@ pub trait Iterator {
251253
/// ```
252254
#[inline]
253255
#[stable(feature = "rust1", since = "1.0.0")]
254-
fn nth(&mut self, mut n: usize) -> Option<Self::Item> {
255-
for x in self {
256-
if n == 0 { return Some(x) }
257-
n -= 1;
258-
}
259-
None
256+
fn nth(&mut self, n: usize) -> Option<Self::Item> {
257+
self.spec_nth(n)
260258
}
261259

262260
/// Creates an iterator starting at the same point, but stepping by
@@ -1337,6 +1335,78 @@ pub trait Iterator {
13371335
(left, right)
13381336
}
13391337

1338+
/// An iterator method that applies a function as long as it returns
1339+
/// successfully, producing a single, final value.
1340+
///
1341+
/// `try_fold()` takes two arguments: an initial value, and a closure with
1342+
/// two arguments: an 'accumulator', and an element. The closure either
1343+
/// returns successfully, with the value that the accumulator should have
1344+
/// for the next iteration, or it returns failure, with an error value that
1345+
/// is propagated back to the caller immediately (short-circuiting).
1346+
///
1347+
/// The initial value is the value the accumulator will have on the first
1348+
/// call. If applying the closure succeeded against every element of the
1349+
/// iterator, `try_fold()` returns the final accumulator as success.
1350+
///
1351+
/// Folding is useful whenever you have a collection of something, and want
1352+
/// to produce a single value from it.
1353+
///
1354+
/// # Note to Implementors
1355+
///
1356+
/// Most of the other (forward) methods have default implementations in
1357+
/// terms of this one, so try to implement this explicitly if it can
1358+
/// do something better than the default `for` loop implementation.
1359+
///
1360+
/// In particular, try to have this call `try_fold()` on the internal parts
1361+
/// from which this iterator is composed. If multiple calls are needed,
1362+
/// the `?` operator be convenient for chaining the accumulator value along,
1363+
/// but beware any invariants that need to be upheld before those early
1364+
/// returns. This is a `&mut self` method, so iteration needs to be
1365+
/// resumable after hitting an error here.
1366+
///
1367+
/// # Examples
1368+
///
1369+
/// Basic usage:
1370+
///
1371+
/// ```
1372+
/// #![feature(iterator_try_fold)]
1373+
/// let a = [1, 2, 3];
1374+
///
1375+
/// // the checked sum of all of the elements of the array
1376+
/// let sum = a.iter()
1377+
/// .try_fold(0i8, |acc, &x| acc.checked_add(x));
1378+
///
1379+
/// assert_eq!(sum, Some(6));
1380+
/// ```
1381+
///
1382+
/// Short-circuiting:
1383+
///
1384+
/// ```
1385+
/// #![feature(iterator_try_fold)]
1386+
/// let a = [10, 20, 30, 100, 40, 50];
1387+
/// let mut it = a.iter();
1388+
///
1389+
/// // This sum overflows when adding the 100 element
1390+
/// let sum = it.try_fold(0i8, |acc, &x| acc.checked_add(x));
1391+
/// assert_eq!(sum, None);
1392+
///
1393+
/// // Because it short-circuited, the remaining elements are still
1394+
/// // available through the iterator.
1395+
/// assert_eq!(it.len(), 2);
1396+
/// assert_eq!(it.next(), Some(&40));
1397+
/// ```
1398+
#[inline]
1399+
#[unstable(feature = "iterator_try_fold", issue = "45594")]
1400+
fn try_fold<B, F, R>(&mut self, init: B, mut f: F) -> R where
1401+
Self: Sized, F: FnMut(B, Self::Item) -> R, R: Try<Ok=B>
1402+
{
1403+
let mut accum = init;
1404+
while let Some(x) = self.next() {
1405+
accum = f(accum, x)?;
1406+
}
1407+
Try::from_ok(accum)
1408+
}
1409+
13401410
/// An iterator method that applies a function, producing a single, final value.
13411411
///
13421412
/// `fold()` takes two arguments: an initial value, and a closure with two
@@ -1361,7 +1431,7 @@ pub trait Iterator {
13611431
/// ```
13621432
/// let a = [1, 2, 3];
13631433
///
1364-
/// // the sum of all of the elements of a
1434+
/// // the sum of all of the elements of the array
13651435
/// let sum = a.iter()
13661436
/// .fold(0, |acc, &x| acc + x);
13671437
///
@@ -1403,14 +1473,10 @@ pub trait Iterator {
14031473
/// ```
14041474
#[inline]
14051475
#[stable(feature = "rust1", since = "1.0.0")]
1406-
fn fold<B, F>(self, init: B, mut f: F) -> B where
1476+
fn fold<B, F>(mut self, init: B, mut f: F) -> B where
14071477
Self: Sized, F: FnMut(B, Self::Item) -> B,
14081478
{
1409-
let mut accum = init;
1410-
for x in self {
1411-
accum = f(accum, x);
1412-
}
1413-
accum
1479+
self.try_fold(init, move |acc, x| AlwaysOk(f(acc, x))).0
14141480
}
14151481

14161482
/// Tests if every element of the iterator matches a predicate.
@@ -1455,12 +1521,10 @@ pub trait Iterator {
14551521
fn all<F>(&mut self, mut f: F) -> bool where
14561522
Self: Sized, F: FnMut(Self::Item) -> bool
14571523
{
1458-
for x in self {
1459-
if !f(x) {
1460-
return false;
1461-
}
1462-
}
1463-
true
1524+
self.try_fold((), move |(), x| {
1525+
if f(x) { LoopState::Continue(()) }
1526+
else { LoopState::Break(()) }
1527+
}) == LoopState::Continue(())
14641528
}
14651529

14661530
/// Tests if any element of the iterator matches a predicate.
@@ -1506,12 +1570,10 @@ pub trait Iterator {
15061570
Self: Sized,
15071571
F: FnMut(Self::Item) -> bool
15081572
{
1509-
for x in self {
1510-
if f(x) {
1511-
return true;
1512-
}
1513-
}
1514-
false
1573+
self.try_fold((), move |(), x| {
1574+
if f(x) { LoopState::Break(()) }
1575+
else { LoopState::Continue(()) }
1576+
}) == LoopState::Break(())
15151577
}
15161578

15171579
/// Searches for an element of an iterator that satisfies a predicate.
@@ -1562,10 +1624,10 @@ pub trait Iterator {
15621624
Self: Sized,
15631625
P: FnMut(&Self::Item) -> bool,
15641626
{
1565-
for x in self {
1566-
if predicate(&x) { return Some(x) }
1567-
}
1568-
None
1627+
self.try_fold((), move |(), x| {
1628+
if predicate(&x) { LoopState::Break(x) }
1629+
else { LoopState::Continue(()) }
1630+
}).break_value()
15691631
}
15701632

15711633
/// Searches for an element in an iterator, returning its index.
@@ -1623,18 +1685,17 @@ pub trait Iterator {
16231685
///
16241686
/// ```
16251687
#[inline]
1688+
#[rustc_inherit_overflow_checks]
16261689
#[stable(feature = "rust1", since = "1.0.0")]
16271690
fn position<P>(&mut self, mut predicate: P) -> Option<usize> where
16281691
Self: Sized,
16291692
P: FnMut(Self::Item) -> bool,
16301693
{
1631-
// `enumerate` might overflow.
1632-
for (i, x) in self.enumerate() {
1633-
if predicate(x) {
1634-
return Some(i);
1635-
}
1636-
}
1637-
None
1694+
// The addition might panic on overflow
1695+
self.try_fold(0, move |i, x| {
1696+
if predicate(x) { LoopState::Break(i) }
1697+
else { LoopState::Continue(i + 1) }
1698+
}).break_value()
16381699
}
16391700

16401701
/// Searches for an element in an iterator from the right, returning its
@@ -1681,17 +1742,14 @@ pub trait Iterator {
16811742
P: FnMut(Self::Item) -> bool,
16821743
Self: Sized + ExactSizeIterator + DoubleEndedIterator
16831744
{
1684-
let mut i = self.len();
1685-
1686-
while let Some(v) = self.next_back() {
1687-
// No need for an overflow check here, because `ExactSizeIterator`
1688-
// implies that the number of elements fits into a `usize`.
1689-
i -= 1;
1690-
if predicate(v) {
1691-
return Some(i);
1692-
}
1693-
}
1694-
None
1745+
// No need for an overflow check here, because `ExactSizeIterator`
1746+
// implies that the number of elements fits into a `usize`.
1747+
let n = self.len();
1748+
self.try_rfold(n, move |i, x| {
1749+
let i = i - 1;
1750+
if predicate(x) { LoopState::Break(i) }
1751+
else { LoopState::Continue(i) }
1752+
}).break_value()
16951753
}
16961754

16971755
/// Returns the maximum element of an iterator.
@@ -1922,10 +1980,10 @@ pub trait Iterator {
19221980
let mut ts: FromA = Default::default();
19231981
let mut us: FromB = Default::default();
19241982

1925-
for (t, u) in self {
1983+
self.for_each(|(t, u)| {
19261984
ts.extend(Some(t));
19271985
us.extend(Some(u));
1928-
}
1986+
});
19291987

19301988
(ts, us)
19311989
}
@@ -2300,17 +2358,17 @@ fn select_fold1<I, B, FProj, FCmp>(mut it: I,
23002358
// start with the first element as our selection. This avoids
23012359
// having to use `Option`s inside the loop, translating to a
23022360
// sizeable performance gain (6x in one case).
2303-
it.next().map(|mut sel| {
2304-
let mut sel_p = f_proj(&sel);
2361+
it.next().map(|first| {
2362+
let first_p = f_proj(&first);
23052363

2306-
for x in it {
2364+
it.fold((first_p, first), |(sel_p, sel), x| {
23072365
let x_p = f_proj(&x);
23082366
if f_cmp(&sel_p, &sel, &x_p, &x) {
2309-
sel = x;
2310-
sel_p = x_p;
2367+
(x_p, x)
2368+
} else {
2369+
(sel_p, sel)
23112370
}
2312-
}
2313-
(sel_p, sel)
2371+
})
23142372
})
23152373
}
23162374

@@ -2323,3 +2381,27 @@ impl<'a, I: Iterator + ?Sized> Iterator for &'a mut I {
23232381
(**self).nth(n)
23242382
}
23252383
}
2384+
2385+
2386+
trait SpecIterator : Iterator {
2387+
fn spec_nth(&mut self, n: usize) -> Option<Self::Item>;
2388+
}
2389+
2390+
impl<I: Iterator + ?Sized> SpecIterator for I {
2391+
default fn spec_nth(&mut self, mut n: usize) -> Option<Self::Item> {
2392+
for x in self {
2393+
if n == 0 { return Some(x) }
2394+
n -= 1;
2395+
}
2396+
None
2397+
}
2398+
}
2399+
2400+
impl<I: Iterator + Sized> SpecIterator for I {
2401+
fn spec_nth(&mut self, n: usize) -> Option<Self::Item> {
2402+
self.try_fold(n, move |i, x| {
2403+
if i == 0 { LoopState::Break(x) }
2404+
else { LoopState::Continue(i - 1) }
2405+
}).break_value()
2406+
}
2407+
}

0 commit comments

Comments
 (0)