Skip to content

Commit 5bc328f

Browse files
committed
Allow canonicalizing the array::map loop in trusted cases
1 parent 52df055 commit 5bc328f

File tree

14 files changed

+237
-142
lines changed

14 files changed

+237
-142
lines changed

library/core/src/array/drain.rs

+29-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
1-
use crate::iter::TrustedLen;
1+
use crate::iter::{TrustedLen, UncheckedIterator};
22
use crate::mem::ManuallyDrop;
33
use crate::ptr::drop_in_place;
44
use crate::slice;
55

6-
// INVARIANT: It's ok to drop the remainder of the inner iterator.
7-
pub(crate) struct Drain<'a, T>(slice::IterMut<'a, T>);
8-
6+
/// A situationally-optimized version of `array.into_iter().for_each(func)`.
7+
///
8+
/// [`crate::array::IntoIter`]s are great when you need an owned iterator, but
9+
/// storing the entire array *inside* the iterator like that can sometimes
10+
/// pessimize code. Notable, it can be more bytes than you really want to move
11+
/// around, and because the array accesses index into it SRoA has a harder time
12+
/// optimizing away the type than it does iterators that just hold a couple pointers.
13+
///
14+
/// Thus this function exists, which gives a way to get *moved* access to the
15+
/// elements of an array using a small iterator -- no bigger than a slice iterator.
16+
///
17+
/// The function-taking-a-closure structure makes it safe, as it keeps callers
18+
/// from looking at already-dropped elements.
919
pub(crate) fn drain_array_with<T, R, const N: usize>(
1020
array: [T; N],
1121
func: impl for<'a> FnOnce(Drain<'a, T>) -> R,
@@ -16,6 +26,11 @@ pub(crate) fn drain_array_with<T, R, const N: usize>(
1626
func(drain)
1727
}
1828

29+
/// See [`drain_array_with`] -- this is `pub(crate)` only so it's allowed to be
30+
/// mentioned in the signature of that method. (Otherwise it hits `E0446`.)
31+
// INVARIANT: It's ok to drop the remainder of the inner iterator.
32+
pub(crate) struct Drain<'a, T>(slice::IterMut<'a, T>);
33+
1934
impl<T> Drop for Drain<'_, T> {
2035
fn drop(&mut self) {
2136
// SAFETY: By the type invariant, we're allowed to drop all these.
@@ -49,3 +64,13 @@ impl<T> ExactSizeIterator for Drain<'_, T> {
4964

5065
// SAFETY: This is a 1:1 wrapper for a slice iterator, which is also `TrustedLen`.
5166
unsafe impl<T> TrustedLen for Drain<'_, T> {}
67+
68+
impl<T> UncheckedIterator for Drain<'_, T> {
69+
unsafe fn next_unchecked(&mut self) -> T {
70+
// SAFETY: `Drain` is 1:1 with the inner iterator, so if the caller promised
71+
// that there's an element left, the inner iterator has one too.
72+
let p: *const T = unsafe { self.0.next_unchecked() };
73+
// SAFETY: The iterator was already advanced, so we won't drop this later.
74+
unsafe { p.read() }
75+
}
76+
}

library/core/src/array/mod.rs

+101-123
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::convert::{Infallible, TryFrom};
1010
use crate::error::Error;
1111
use crate::fmt;
1212
use crate::hash::{self, Hash};
13-
use crate::iter::TrustedLen;
13+
use crate::iter::UncheckedIterator;
1414
use crate::mem::{self, MaybeUninit};
1515
use crate::ops::{
1616
ChangeOutputType, ControlFlow, FromResidual, Index, IndexMut, NeverShortCircuit, Residual, Try,
@@ -55,16 +55,11 @@ pub use iter::IntoIter;
5555
/// ```
5656
#[inline]
5757
#[stable(feature = "array_from_fn", since = "1.63.0")]
58-
pub fn from_fn<T, const N: usize, F>(mut cb: F) -> [T; N]
58+
pub fn from_fn<T, const N: usize, F>(cb: F) -> [T; N]
5959
where
6060
F: FnMut(usize) -> T,
6161
{
62-
let mut idx = 0;
63-
[(); N].map(|_| {
64-
let res = cb(idx);
65-
idx += 1;
66-
res
67-
})
62+
try_from_fn(NeverShortCircuit::wrap_mut_1(cb)).0
6863
}
6964

7065
/// Creates an array `[T; N]` where each fallible array element `T` is returned by the `cb` call.
@@ -104,9 +99,14 @@ where
10499
R: Try,
105100
R::Residual: Residual<[R::Output; N]>,
106101
{
107-
// SAFETY: we know for certain that this iterator will yield exactly `N`
108-
// items.
109-
unsafe { try_collect_into_array_unchecked(&mut (0..N).map(cb)) }
102+
let mut array = MaybeUninit::uninit_array::<N>();
103+
match try_from_fn_erased(&mut array, cb) {
104+
ControlFlow::Break(r) => FromResidual::from_residual(r),
105+
ControlFlow::Continue(()) => {
106+
// SAFETY: All elements of the array were populated.
107+
try { unsafe { MaybeUninit::array_assume_init(array) } }
108+
}
109+
}
110110
}
111111

112112
/// Converts a reference to `T` into a reference to an array of length 1 (without copying).
@@ -430,9 +430,7 @@ trait SpecArrayClone: Clone {
430430
impl<T: Clone> SpecArrayClone for T {
431431
#[inline]
432432
default fn clone<const N: usize>(array: &[T; N]) -> [T; N] {
433-
// SAFETY: we know for certain that this iterator will yield exactly `N`
434-
// items.
435-
unsafe { collect_into_array_unchecked(&mut array.iter().cloned()) }
433+
from_trusted_iterator(array.iter().cloned())
436434
}
437435
}
438436

@@ -516,12 +514,7 @@ impl<T, const N: usize> [T; N] {
516514
where
517515
F: FnMut(T) -> U,
518516
{
519-
drain_array_with(self, |iter| {
520-
let mut iter = iter.map(f);
521-
// SAFETY: we know for certain that this iterator will yield exactly `N`
522-
// items.
523-
unsafe { collect_into_array_unchecked(&mut iter) }
524-
})
517+
self.try_map(NeverShortCircuit::wrap_mut_1(f)).0
525518
}
526519

527520
/// A fallible function `f` applied to each element on array `self` in order to
@@ -558,12 +551,7 @@ impl<T, const N: usize> [T; N] {
558551
R: Try,
559552
R::Residual: Residual<[R::Output; N]>,
560553
{
561-
drain_array_with(self, |iter| {
562-
let mut iter = iter.map(f);
563-
// SAFETY: we know for certain that this iterator will yield exactly `N`
564-
// items.
565-
unsafe { try_collect_into_array_unchecked(&mut iter) }
566-
})
554+
drain_array_with(self, |iter| try_from_trusted_iterator(iter.map(f)))
567555
}
568556

569557
/// 'Zips up' two arrays into a single array of pairs.
@@ -585,12 +573,7 @@ impl<T, const N: usize> [T; N] {
585573
#[unstable(feature = "array_zip", issue = "80094")]
586574
pub fn zip<U>(self, rhs: [U; N]) -> [(T, U); N] {
587575
drain_array_with(self, |lhs| {
588-
drain_array_with(rhs, |rhs| {
589-
let mut iter = crate::iter::zip(lhs, rhs);
590-
// SAFETY: we know for certain that this iterator will yield exactly `N`
591-
// items.
592-
unsafe { collect_into_array_unchecked(&mut iter) }
593-
})
576+
drain_array_with(rhs, |rhs| from_trusted_iterator(crate::iter::zip(lhs, rhs)))
594577
})
595578
}
596579

@@ -638,9 +621,7 @@ impl<T, const N: usize> [T; N] {
638621
/// ```
639622
#[unstable(feature = "array_methods", issue = "76118")]
640623
pub fn each_ref(&self) -> [&T; N] {
641-
// SAFETY: we know for certain that this iterator will yield exactly `N`
642-
// items.
643-
unsafe { collect_into_array_unchecked(&mut self.iter()) }
624+
from_trusted_iterator(self.iter())
644625
}
645626

646627
/// Borrows each element mutably and returns an array of mutable references
@@ -660,9 +641,7 @@ impl<T, const N: usize> [T; N] {
660641
/// ```
661642
#[unstable(feature = "array_methods", issue = "76118")]
662643
pub fn each_mut(&mut self) -> [&mut T; N] {
663-
// SAFETY: we know for certain that this iterator will yield exactly `N`
664-
// items.
665-
unsafe { collect_into_array_unchecked(&mut self.iter_mut()) }
644+
from_trusted_iterator(self.iter_mut())
666645
}
667646

668647
/// Divides one array reference into two at an index.
@@ -822,99 +801,71 @@ impl<T, const N: usize> [T; N] {
822801
}
823802
}
824803

825-
/// Pulls `N` items from `iter` and returns them as an array. If the iterator
826-
/// yields fewer than `N` items, this function exhibits undefined behavior.
804+
/// Populate an array from the first `N` elements of `iter`
827805
///
828-
/// # Safety
806+
/// # Panics
829807
///
830-
/// It is up to the caller to guarantee that `iter` yields at least `N` items.
831-
/// Violating this condition causes undefined behavior.
832-
unsafe fn try_collect_into_array_unchecked<I, T, R, const N: usize>(
833-
iter: &mut I,
834-
) -> ChangeOutputType<I::Item, [T; N]>
835-
where
836-
// Note: `TrustedLen` here is somewhat of an experiment. This is just an
837-
// internal function, so feel free to remove if this bound turns out to be a
838-
// bad idea. In that case, remember to also remove the lower bound
839-
// `debug_assert!` below!
840-
I: Iterator + TrustedLen,
841-
I::Item: Try<Output = T, Residual = R>,
842-
R: Residual<[T; N]>,
843-
{
844-
debug_assert!(N <= iter.size_hint().1.unwrap_or(usize::MAX));
845-
debug_assert!(N <= iter.size_hint().0);
846-
847-
let mut array = MaybeUninit::uninit_array::<N>();
848-
let cf = try_collect_into_array_erased(iter, &mut array);
849-
match cf {
850-
ControlFlow::Break(r) => FromResidual::from_residual(r),
851-
ControlFlow::Continue(initialized) => {
852-
debug_assert_eq!(initialized, N);
853-
// SAFETY: because of our function contract, all the elements
854-
// must have been initialized.
855-
let output = unsafe { MaybeUninit::array_assume_init(array) };
856-
Try::from_output(output)
857-
}
858-
}
808+
/// If the iterator doesn't actually have enough items.
809+
///
810+
/// By depending on `TrustedLen`, however, we can do that check up-front (where
811+
/// it easily optimizes away) so it doesn't impact the loop that fills the array.
812+
#[inline]
813+
fn from_trusted_iterator<T, const N: usize>(iter: impl UncheckedIterator<Item = T>) -> [T; N] {
814+
try_from_trusted_iterator(iter.map(NeverShortCircuit)).0
859815
}
860816

861-
/// Infallible version of [`try_collect_into_array_unchecked`].
862-
unsafe fn collect_into_array_unchecked<I, const N: usize>(iter: &mut I) -> [I::Item; N]
817+
#[inline]
818+
fn try_from_trusted_iterator<T, R, const N: usize>(
819+
iter: impl UncheckedIterator<Item = R>,
820+
) -> ChangeOutputType<R, [T; N]>
863821
where
864-
I: Iterator + TrustedLen,
822+
R: Try<Output = T>,
823+
R::Residual: Residual<[T; N]>,
865824
{
866-
let mut map = iter.map(NeverShortCircuit);
867-
868-
// SAFETY: The same safety considerations w.r.t. the iterator length
869-
// apply for `try_collect_into_array_unchecked` as for
870-
// `collect_into_array_unchecked`
871-
match unsafe { try_collect_into_array_unchecked(&mut map) } {
872-
NeverShortCircuit(array) => array,
825+
assert!(iter.size_hint().0 >= N);
826+
fn next<T>(mut iter: impl UncheckedIterator<Item = T>) -> impl FnMut(usize) -> T {
827+
move |_| {
828+
// SAFETY: We know that `from_fn` will call this at most N times,
829+
// and we checked to ensure that we have at least that many items.
830+
unsafe { iter.next_unchecked() }
831+
}
873832
}
833+
834+
try_from_fn(next(iter))
874835
}
875836

876-
/// Rather than *returning* the array, this fills in a passed-in buffer.
877-
/// If any of the iterator elements short-circuit, it drops everything in the
878-
/// buffer and return the error. Otherwise it returns the number of items
879-
/// which were initialized in the buffer.
837+
/// Version of [`try_from_fn`] using a passed-in slice in order to avoid
838+
/// needing to monomorphize for every array length.
880839
///
881-
/// (The caller is responsible for dropping those items on success, but not
882-
/// doing that is just a leak, not UB, so this function is itself safe.)
840+
/// This takes a generator rather than an iterator so that *at the type level*
841+
/// it never needs to worry about running out of items. When combined with
842+
/// an infallible `Try` type, that means the loop canonicalizes easily, allowing
843+
/// it to optimize well.
883844
///
884-
/// This means less monomorphization, but more importantly it means that the
885-
/// returned array doesn't need to be copied into the `Result`, since returning
886-
/// the result seemed (2023-01) to cause in an extra `N + 1`-length `alloca`
887-
/// even if it's always `unwrap_unchecked` later.
845+
/// It would be *possible* to unify this and [`iter_next_chunk_erased`] into one
846+
/// function that does the union of both things, but last time it was that way
847+
/// it resulted in poor codegen from the "are there enough source items?" checks
848+
/// not optimizing away. So if you give it a shot, make sure to watch what
849+
/// happens in the codegen tests.
888850
#[inline]
889-
fn try_collect_into_array_erased<I, T, R>(
890-
iter: &mut I,
851+
fn try_from_fn_erased<T, R>(
891852
buffer: &mut [MaybeUninit<T>],
892-
) -> ControlFlow<R, usize>
853+
mut generator: impl FnMut(usize) -> R,
854+
) -> ControlFlow<R::Residual>
893855
where
894-
I: Iterator,
895-
I::Item: Try<Output = T, Residual = R>,
856+
R: Try<Output = T>,
896857
{
897-
let n = buffer.len();
898858
let mut guard = Guard { array_mut: buffer, initialized: 0 };
899859

900-
for _ in 0..n {
901-
match iter.next() {
902-
Some(item_rslt) => {
903-
let item = item_rslt.branch()?;
860+
while guard.initialized < guard.array_mut.len() {
861+
let item = generator(guard.initialized).branch()?;
904862

905-
// SAFETY: `guard.initialized` starts at 0, which means push can be called
906-
// at most `n` times, which this loop does.
907-
unsafe {
908-
guard.push_unchecked(item);
909-
}
910-
}
911-
None => break,
912-
}
863+
// SAFETY: The loop condition ensures we have space to push the item
864+
unsafe { guard.push_unchecked(item) };
913865
}
914866

915-
let initialized = guard.initialized;
916867
mem::forget(guard);
917-
ControlFlow::Continue(initialized)
868+
ControlFlow::Continue(())
918869
}
919870

920871
/// Panic guard for incremental initialization of arrays.
@@ -928,7 +879,7 @@ where
928879
///
929880
/// To minimize indirection fields are still pub but callers should at least use
930881
/// `push_unchecked` to signal that something unsafe is going on.
931-
pub(crate) struct Guard<'a, T> {
882+
struct Guard<'a, T> {
932883
/// The array to be initialized.
933884
pub array_mut: &'a mut [MaybeUninit<T>],
934885
/// The number of items that have been initialized so far.
@@ -960,7 +911,7 @@ impl<T> Drop for Guard<'_, T> {
960911
// SAFETY: this slice will contain only initialized objects.
961912
unsafe {
962913
crate::ptr::drop_in_place(MaybeUninit::slice_assume_init_mut(
963-
&mut self.array_mut.get_unchecked_mut(..self.initialized),
914+
self.array_mut.get_unchecked_mut(..self.initialized),
964915
));
965916
}
966917
}
@@ -982,17 +933,44 @@ impl<T> Drop for Guard<'_, T> {
982933
pub(crate) fn iter_next_chunk<T, const N: usize>(
983934
iter: &mut impl Iterator<Item = T>,
984935
) -> Result<[T; N], IntoIter<T, N>> {
985-
let mut map = iter.map(NeverShortCircuit);
986936
let mut array = MaybeUninit::uninit_array::<N>();
987-
let ControlFlow::Continue(initialized) = try_collect_into_array_erased(&mut map, &mut array);
988-
if initialized == N {
989-
// SAFETY: All elements of the array were populated.
990-
let output = unsafe { MaybeUninit::array_assume_init(array) };
991-
Ok(output)
992-
} else {
993-
let alive = 0..initialized;
994-
// SAFETY: `array` was initialized with exactly `initialized`
995-
// number of elements.
996-
return Err(unsafe { IntoIter::new_unchecked(array, alive) });
937+
let r = iter_next_chunk_erased(&mut array, iter);
938+
match r {
939+
Ok(()) => {
940+
// SAFETY: All elements of `array` were populated.
941+
Ok(unsafe { MaybeUninit::array_assume_init(array) })
942+
}
943+
Err(initialized) => {
944+
// SAFETY: Only the first `initialized` elements were populated
945+
Err(unsafe { IntoIter::new_unchecked(array, 0..initialized) })
946+
}
947+
}
948+
}
949+
950+
/// Version of [`iter_next_chunk`] using a passed-in slice in order to avoid
951+
/// needing to monomorphize for every array length.
952+
///
953+
/// Unfortunately this loop has two exit conditions, the buffer filling up
954+
/// or the iterator running out of items, making it tend to optimize poorly.
955+
#[inline]
956+
fn iter_next_chunk_erased<T>(
957+
buffer: &mut [MaybeUninit<T>],
958+
iter: &mut impl Iterator<Item = T>,
959+
) -> Result<(), usize> {
960+
let mut guard = Guard { array_mut: buffer, initialized: 0 };
961+
while guard.initialized < guard.array_mut.len() {
962+
let Some(item) = iter.next() else {
963+
// Unlike `try_from_fn_erased`, we want to keep the partial results,
964+
// so we need to defuse the guard instead of using `?`.
965+
let initialized = guard.initialized;
966+
mem::forget(guard);
967+
return Err(initialized)
968+
};
969+
970+
// SAFETY: The loop condition ensures we have space to push the item
971+
unsafe { guard.push_unchecked(item) };
997972
}
973+
974+
mem::forget(guard);
975+
Ok(())
998976
}

0 commit comments

Comments
 (0)