Skip to content

Commit 7fb55b4

Browse files
authored
Rollup merge of rust-lang#94212 - scottmcm:swapper, r=dtolnay
Stop manually SIMDing in `swap_nonoverlapping` Like I previously did for `reverse` (rust-lang#90821), this leaves it to LLVM to pick how to vectorize it, since it can know better the chunk size to use, compared to the "32 bytes always" approach we currently have. A variety of codegen tests are included to confirm that the various cases are still being vectorized. It does still need logic to type-erase in some cases, though, as while LLVM is now smart enough to vectorize over slices of things like `[u8; 4]`, it fails to do so over slices of `[u8; 3]`. As a bonus, this change also means one no longer gets the spurious `memcpy`(s?) at the end up swapping a slice of `__m256`s: <https://rust.godbolt.org/z/joofr4v8Y> <details> <summary>ASM for this example</summary> ## Before (from godbolt) note the `push`/`pop`s and `memcpy` ```x86 swap_m256_slice: push r15 push r14 push r13 push r12 push rbx sub rsp, 32 cmp rsi, rcx jne .LBB0_6 mov r14, rsi shl r14, 5 je .LBB0_6 mov r15, rdx mov rbx, rdi xor eax, eax .LBB0_3: mov rcx, rax vmovaps ymm0, ymmword ptr [rbx + rax] vmovaps ymm1, ymmword ptr [r15 + rax] vmovaps ymmword ptr [rbx + rax], ymm1 vmovaps ymmword ptr [r15 + rax], ymm0 add rax, 32 add rcx, 64 cmp rcx, r14 jbe .LBB0_3 sub r14, rax jbe .LBB0_6 add rbx, rax add r15, rax mov r12, rsp mov r13, qword ptr [rip + memcpy@GOTPCREL] mov rdi, r12 mov rsi, rbx mov rdx, r14 vzeroupper call r13 mov rdi, rbx mov rsi, r15 mov rdx, r14 call r13 mov rdi, r15 mov rsi, r12 mov rdx, r14 call r13 .LBB0_6: add rsp, 32 pop rbx pop r12 pop r13 pop r14 pop r15 vzeroupper ret ``` ## After (from my machine) Note no `rsp` manipulation, sorry for different ASM syntax ```x86 swap_m256_slice: cmpq %r9, %rdx jne .LBB1_6 testq %rdx, %rdx je .LBB1_6 cmpq $1, %rdx jne .LBB1_7 xorl %r10d, %r10d jmp .LBB1_4 .LBB1_7: movq %rdx, %r9 andq $-2, %r9 movl $32, %eax xorl %r10d, %r10d .p2align 4, 0x90 .LBB1_8: vmovaps -32(%rcx,%rax), %ymm0 vmovaps -32(%r8,%rax), %ymm1 vmovaps %ymm1, -32(%rcx,%rax) vmovaps %ymm0, -32(%r8,%rax) vmovaps (%rcx,%rax), %ymm0 vmovaps (%r8,%rax), %ymm1 vmovaps %ymm1, (%rcx,%rax) vmovaps %ymm0, (%r8,%rax) addq $2, %r10 addq $64, %rax cmpq %r10, %r9 jne .LBB1_8 .LBB1_4: testb $1, %dl je .LBB1_6 shlq $5, %r10 vmovaps (%rcx,%r10), %ymm0 vmovaps (%r8,%r10), %ymm1 vmovaps %ymm1, (%rcx,%r10) vmovaps %ymm0, (%r8,%r10) .LBB1_6: vzeroupper retq ``` </details> This does all its copying operations as either the original type or as `MaybeUninit`s, so as far as I know there should be no potential abstract machine issues with reading padding bytes as integers. <details> <summary>Perf is essentially unchanged</summary> Though perhaps with more target features this would help more, if it could pick bigger chunks ## Before ``` running 10 tests test slice::swap_with_slice_4x_usize_30 ... bench: 894 ns/iter (+/- 11) test slice::swap_with_slice_4x_usize_3000 ... bench: 99,476 ns/iter (+/- 2,784) test slice::swap_with_slice_5x_usize_30 ... bench: 1,257 ns/iter (+/- 7) test slice::swap_with_slice_5x_usize_3000 ... bench: 139,922 ns/iter (+/- 959) test slice::swap_with_slice_rgb_30 ... bench: 328 ns/iter (+/- 27) test slice::swap_with_slice_rgb_3000 ... bench: 16,215 ns/iter (+/- 176) test slice::swap_with_slice_u8_30 ... bench: 312 ns/iter (+/- 9) test slice::swap_with_slice_u8_3000 ... bench: 5,401 ns/iter (+/- 123) test slice::swap_with_slice_usize_30 ... bench: 368 ns/iter (+/- 3) test slice::swap_with_slice_usize_3000 ... bench: 28,472 ns/iter (+/- 3,913) ``` ## After ``` running 10 tests test slice::swap_with_slice_4x_usize_30 ... bench: 868 ns/iter (+/- 36) test slice::swap_with_slice_4x_usize_3000 ... bench: 99,642 ns/iter (+/- 1,507) test slice::swap_with_slice_5x_usize_30 ... bench: 1,194 ns/iter (+/- 11) test slice::swap_with_slice_5x_usize_3000 ... bench: 139,761 ns/iter (+/- 5,018) test slice::swap_with_slice_rgb_30 ... bench: 324 ns/iter (+/- 6) test slice::swap_with_slice_rgb_3000 ... bench: 15,962 ns/iter (+/- 287) test slice::swap_with_slice_u8_30 ... bench: 281 ns/iter (+/- 5) test slice::swap_with_slice_u8_3000 ... bench: 5,324 ns/iter (+/- 40) test slice::swap_with_slice_usize_30 ... bench: 275 ns/iter (+/- 5) test slice::swap_with_slice_usize_3000 ... bench: 28,277 ns/iter (+/- 277) ``` </detail>
2 parents 000e38d + 8ca47d7 commit 7fb55b4

File tree

6 files changed

+263
-97
lines changed

6 files changed

+263
-97
lines changed

library/core/benches/slice.rs

+39-4
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,15 @@ fn binary_search_l3_worst_case(b: &mut Bencher) {
8989
binary_search_worst_case(b, Cache::L3);
9090
}
9191

92+
#[derive(Clone)]
93+
struct Rgb(u8, u8, u8);
94+
95+
impl Rgb {
96+
fn gen(i: usize) -> Self {
97+
Rgb(i as u8, (i as u8).wrapping_add(7), (i as u8).wrapping_add(42))
98+
}
99+
}
100+
92101
macro_rules! rotate {
93102
($fn:ident, $n:expr, $mapper:expr) => {
94103
#[bench]
@@ -104,17 +113,43 @@ macro_rules! rotate {
104113
};
105114
}
106115

107-
#[derive(Clone)]
108-
struct Rgb(u8, u8, u8);
109-
110116
rotate!(rotate_u8, 32, |i| i as u8);
111-
rotate!(rotate_rgb, 32, |i| Rgb(i as u8, (i as u8).wrapping_add(7), (i as u8).wrapping_add(42)));
117+
rotate!(rotate_rgb, 32, Rgb::gen);
112118
rotate!(rotate_usize, 32, |i| i);
113119
rotate!(rotate_16_usize_4, 16, |i| [i; 4]);
114120
rotate!(rotate_16_usize_5, 16, |i| [i; 5]);
115121
rotate!(rotate_64_usize_4, 64, |i| [i; 4]);
116122
rotate!(rotate_64_usize_5, 64, |i| [i; 5]);
117123

124+
macro_rules! swap_with_slice {
125+
($fn:ident, $n:expr, $mapper:expr) => {
126+
#[bench]
127+
fn $fn(b: &mut Bencher) {
128+
let mut x = (0usize..$n).map(&$mapper).collect::<Vec<_>>();
129+
let mut y = ($n..($n * 2)).map(&$mapper).collect::<Vec<_>>();
130+
let mut skip = 0;
131+
b.iter(|| {
132+
for _ in 0..32 {
133+
x[skip..].swap_with_slice(&mut y[..($n - skip)]);
134+
skip = black_box(skip + 1) % 8;
135+
}
136+
black_box((x[$n / 3].clone(), y[$n * 2 / 3].clone()))
137+
})
138+
}
139+
};
140+
}
141+
142+
swap_with_slice!(swap_with_slice_u8_30, 30, |i| i as u8);
143+
swap_with_slice!(swap_with_slice_u8_3000, 3000, |i| i as u8);
144+
swap_with_slice!(swap_with_slice_rgb_30, 30, Rgb::gen);
145+
swap_with_slice!(swap_with_slice_rgb_3000, 3000, Rgb::gen);
146+
swap_with_slice!(swap_with_slice_usize_30, 30, |i| i);
147+
swap_with_slice!(swap_with_slice_usize_3000, 3000, |i| i);
148+
swap_with_slice!(swap_with_slice_4x_usize_30, 30, |i| [i; 4]);
149+
swap_with_slice!(swap_with_slice_4x_usize_3000, 3000, |i| [i; 4]);
150+
swap_with_slice!(swap_with_slice_5x_usize_30, 30, |i| [i; 5]);
151+
swap_with_slice!(swap_with_slice_5x_usize_3000, 3000, |i| [i; 5]);
152+
118153
#[bench]
119154
fn fill_byte_sized(b: &mut Bencher) {
120155
#[derive(Copy, Clone)]

library/core/src/mem/mod.rs

+42-3
Original file line numberDiff line numberDiff line change
@@ -700,10 +700,49 @@ pub unsafe fn uninitialized<T>() -> T {
700700
#[stable(feature = "rust1", since = "1.0.0")]
701701
#[rustc_const_unstable(feature = "const_swap", issue = "83163")]
702702
pub const fn swap<T>(x: &mut T, y: &mut T) {
703-
// SAFETY: the raw pointers have been created from safe mutable references satisfying all the
704-
// constraints on `ptr::swap_nonoverlapping_one`
703+
// NOTE(eddyb) SPIR-V's Logical addressing model doesn't allow for arbitrary
704+
// reinterpretation of values as (chunkable) byte arrays, and the loop in the
705+
// block optimization in `swap_slice` is hard to rewrite back
706+
// into the (unoptimized) direct swapping implementation, so we disable it.
707+
// FIXME(eddyb) the block optimization also prevents MIR optimizations from
708+
// understanding `mem::replace`, `Option::take`, etc. - a better overall
709+
// solution might be to make `ptr::swap_nonoverlapping` into an intrinsic, which
710+
// a backend can choose to implement using the block optimization, or not.
711+
#[cfg(not(target_arch = "spirv"))]
712+
{
713+
// For types that are larger multiples of their alignment, the simple way
714+
// tends to copy the whole thing to stack rather than doing it one part
715+
// at a time, so instead treat them as one-element slices and piggy-back
716+
// the slice optimizations that will split up the swaps.
717+
if size_of::<T>() / align_of::<T>() > 4 {
718+
// SAFETY: exclusive references always point to one non-overlapping
719+
// element and are non-null and properly aligned.
720+
return unsafe { ptr::swap_nonoverlapping(x, y, 1) };
721+
}
722+
}
723+
724+
// If a scalar consists of just a small number of alignment units, let
725+
// the codegen just swap those pieces directly, as it's likely just a
726+
// few instructions and anything else is probably overcomplicated.
727+
//
728+
// Most importantly, this covers primitives and simd types that tend to
729+
// have size=align where doing anything else can be a pessimization.
730+
// (This will also be used for ZSTs, though any solution works for them.)
731+
swap_simple(x, y);
732+
}
733+
734+
/// Same as [`swap`] semantically, but always uses the simple implementation.
735+
///
736+
/// Used elsewhere in `mem` and `ptr` at the bottom layer of calls.
737+
#[rustc_const_unstable(feature = "const_swap", issue = "83163")]
738+
#[inline]
739+
pub(crate) const fn swap_simple<T>(x: &mut T, y: &mut T) {
740+
// SAFETY: exclusive references are always valid to read/write,
741+
// are non-overlapping, and nothing here panics so it's drop-safe.
705742
unsafe {
706-
ptr::swap_nonoverlapping_one(x, y);
743+
let z = ptr::read(x);
744+
ptr::copy_nonoverlapping(y, x, 1);
745+
ptr::write(y, z);
707746
}
708747
}
709748

library/core/src/ptr/mod.rs

+42-90
Original file line numberDiff line numberDiff line change
@@ -419,106 +419,58 @@ pub const unsafe fn swap<T>(x: *mut T, y: *mut T) {
419419
#[stable(feature = "swap_nonoverlapping", since = "1.27.0")]
420420
#[rustc_const_unstable(feature = "const_swap", issue = "83163")]
421421
pub const unsafe fn swap_nonoverlapping<T>(x: *mut T, y: *mut T, count: usize) {
422-
let x = x as *mut u8;
423-
let y = y as *mut u8;
424-
let len = mem::size_of::<T>() * count;
425-
// SAFETY: the caller must guarantee that `x` and `y` are
426-
// valid for writes and properly aligned.
427-
unsafe { swap_nonoverlapping_bytes(x, y, len) }
428-
}
422+
macro_rules! attempt_swap_as_chunks {
423+
($ChunkTy:ty) => {
424+
if mem::align_of::<T>() >= mem::align_of::<$ChunkTy>()
425+
&& mem::size_of::<T>() % mem::size_of::<$ChunkTy>() == 0
426+
{
427+
let x: *mut MaybeUninit<$ChunkTy> = x.cast();
428+
let y: *mut MaybeUninit<$ChunkTy> = y.cast();
429+
let count = count * (mem::size_of::<T>() / mem::size_of::<$ChunkTy>());
430+
// SAFETY: these are the same bytes that the caller promised were
431+
// ok, just typed as `MaybeUninit<ChunkTy>`s instead of as `T`s.
432+
// The `if` condition above ensures that we're not violating
433+
// alignment requirements, and that the division is exact so
434+
// that we don't lose any bytes off the end.
435+
return unsafe { swap_nonoverlapping_simple(x, y, count) };
436+
}
437+
};
438+
}
429439

430-
#[inline]
431-
#[rustc_const_unstable(feature = "const_swap", issue = "83163")]
432-
pub(crate) const unsafe fn swap_nonoverlapping_one<T>(x: *mut T, y: *mut T) {
433-
// NOTE(eddyb) SPIR-V's Logical addressing model doesn't allow for arbitrary
434-
// reinterpretation of values as (chunkable) byte arrays, and the loop in the
435-
// block optimization in `swap_nonoverlapping_bytes` is hard to rewrite back
436-
// into the (unoptimized) direct swapping implementation, so we disable it.
437-
// FIXME(eddyb) the block optimization also prevents MIR optimizations from
438-
// understanding `mem::replace`, `Option::take`, etc. - a better overall
439-
// solution might be to make `swap_nonoverlapping` into an intrinsic, which
440-
// a backend can choose to implement using the block optimization, or not.
441-
#[cfg(not(target_arch = "spirv"))]
440+
// Split up the slice into small power-of-two-sized chunks that LLVM is able
441+
// to vectorize (unless it's a special type with more-than-pointer alignment,
442+
// because we don't want to pessimize things like slices of SIMD vectors.)
443+
if mem::align_of::<T>() <= mem::size_of::<usize>()
444+
&& (!mem::size_of::<T>().is_power_of_two()
445+
|| mem::size_of::<T>() > mem::size_of::<usize>() * 2)
442446
{
443-
// Only apply the block optimization in `swap_nonoverlapping_bytes` for types
444-
// at least as large as the block size, to avoid pessimizing codegen.
445-
if mem::size_of::<T>() >= 32 {
446-
// SAFETY: the caller must uphold the safety contract for `swap_nonoverlapping`.
447-
unsafe { swap_nonoverlapping(x, y, 1) };
448-
return;
449-
}
447+
attempt_swap_as_chunks!(usize);
448+
attempt_swap_as_chunks!(u8);
450449
}
451450

452-
// Direct swapping, for the cases not going through the block optimization.
453-
// SAFETY: the caller must guarantee that `x` and `y` are valid
454-
// for writes, properly aligned, and non-overlapping.
455-
unsafe {
456-
let z = read(x);
457-
copy_nonoverlapping(y, x, 1);
458-
write(y, z);
459-
}
451+
// SAFETY: Same preconditions as this function
452+
unsafe { swap_nonoverlapping_simple(x, y, count) }
460453
}
461454

455+
/// Same behaviour and safety conditions as [`swap_nonoverlapping`]
456+
///
457+
/// LLVM can vectorize this (at least it can for the power-of-two-sized types
458+
/// `swap_nonoverlapping` tries to use) so no need to manually SIMD it.
462459
#[inline]
463460
#[rustc_const_unstable(feature = "const_swap", issue = "83163")]
464-
const unsafe fn swap_nonoverlapping_bytes(x: *mut u8, y: *mut u8, len: usize) {
465-
// The approach here is to utilize simd to swap x & y efficiently. Testing reveals
466-
// that swapping either 32 bytes or 64 bytes at a time is most efficient for Intel
467-
// Haswell E processors. LLVM is more able to optimize if we give a struct a
468-
// #[repr(simd)], even if we don't actually use this struct directly.
469-
//
470-
// FIXME repr(simd) broken on emscripten and redox
471-
#[cfg_attr(not(any(target_os = "emscripten", target_os = "redox")), repr(simd))]
472-
struct Block(u64, u64, u64, u64);
473-
struct UnalignedBlock(u64, u64, u64, u64);
474-
475-
let block_size = mem::size_of::<Block>();
476-
477-
// Loop through x & y, copying them `Block` at a time
478-
// The optimizer should unroll the loop fully for most types
479-
// N.B. We can't use a for loop as the `range` impl calls `mem::swap` recursively
461+
const unsafe fn swap_nonoverlapping_simple<T>(x: *mut T, y: *mut T, count: usize) {
480462
let mut i = 0;
481-
while i + block_size <= len {
482-
// Create some uninitialized memory as scratch space
483-
// Declaring `t` here avoids aligning the stack when this loop is unused
484-
let mut t = mem::MaybeUninit::<Block>::uninit();
485-
let t = t.as_mut_ptr() as *mut u8;
486-
487-
// SAFETY: As `i < len`, and as the caller must guarantee that `x` and `y` are valid
488-
// for `len` bytes, `x + i` and `y + i` must be valid addresses, which fulfills the
489-
// safety contract for `add`.
490-
//
491-
// Also, the caller must guarantee that `x` and `y` are valid for writes, properly aligned,
492-
// and non-overlapping, which fulfills the safety contract for `copy_nonoverlapping`.
493-
unsafe {
494-
let x = x.add(i);
495-
let y = y.add(i);
463+
while i < count {
464+
let x: &mut T =
465+
// SAFETY: By precondition, `i` is in-bounds because it's below `n`
466+
unsafe { &mut *x.add(i) };
467+
let y: &mut T =
468+
// SAFETY: By precondition, `i` is in-bounds because it's below `n`
469+
// and it's distinct from `x` since the ranges are non-overlapping
470+
unsafe { &mut *y.add(i) };
471+
mem::swap_simple(x, y);
496472

497-
// Swap a block of bytes of x & y, using t as a temporary buffer
498-
// This should be optimized into efficient SIMD operations where available
499-
copy_nonoverlapping(x, t, block_size);
500-
copy_nonoverlapping(y, x, block_size);
501-
copy_nonoverlapping(t, y, block_size);
502-
}
503-
i += block_size;
504-
}
505-
506-
if i < len {
507-
// Swap any remaining bytes
508-
let mut t = mem::MaybeUninit::<UnalignedBlock>::uninit();
509-
let rem = len - i;
510-
511-
let t = t.as_mut_ptr() as *mut u8;
512-
513-
// SAFETY: see previous safety comment.
514-
unsafe {
515-
let x = x.add(i);
516-
let y = y.add(i);
517-
518-
copy_nonoverlapping(x, t, rem);
519-
copy_nonoverlapping(y, x, rem);
520-
copy_nonoverlapping(t, y, rem);
521-
}
473+
i += 1;
522474
}
523475
}
524476

src/test/codegen/swap-large-types.rs

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// compile-flags: -O
2+
// only-x86_64
3+
// ignore-debug: the debug assertions get in the way
4+
5+
#![crate_type = "lib"]
6+
7+
use std::mem::swap;
8+
use std::ptr::{read, copy_nonoverlapping, write};
9+
10+
type KeccakBuffer = [[u64; 5]; 5];
11+
12+
// A basic read+copy+write swap implementation ends up copying one of the values
13+
// to stack for large types, which is completely unnecessary as the lack of
14+
// overlap means we can just do whatever fits in registers at a time.
15+
16+
// CHECK-LABEL: @swap_basic
17+
#[no_mangle]
18+
pub fn swap_basic(x: &mut KeccakBuffer, y: &mut KeccakBuffer) {
19+
// CHECK: alloca [5 x [5 x i64]]
20+
21+
// SAFETY: exclusive references are always valid to read/write,
22+
// are non-overlapping, and nothing here panics so it's drop-safe.
23+
unsafe {
24+
let z = read(x);
25+
copy_nonoverlapping(y, x, 1);
26+
write(y, z);
27+
}
28+
}
29+
30+
// This test verifies that the library does something smarter, and thus
31+
// doesn't need any scratch space on the stack.
32+
33+
// CHECK-LABEL: @swap_std
34+
#[no_mangle]
35+
pub fn swap_std(x: &mut KeccakBuffer, y: &mut KeccakBuffer) {
36+
// CHECK-NOT: alloca
37+
// CHECK: load <{{[0-9]+}} x i64>
38+
// CHECK: store <{{[0-9]+}} x i64>
39+
swap(x, y)
40+
}
41+
42+
// CHECK-LABEL: @swap_slice
43+
#[no_mangle]
44+
pub fn swap_slice(x: &mut [KeccakBuffer], y: &mut [KeccakBuffer]) {
45+
// CHECK-NOT: alloca
46+
// CHECK: load <{{[0-9]+}} x i64>
47+
// CHECK: store <{{[0-9]+}} x i64>
48+
if x.len() == y.len() {
49+
x.swap_with_slice(y);
50+
}
51+
}
52+
53+
type OneKilobyteBuffer = [u8; 1024];
54+
55+
// CHECK-LABEL: @swap_1kb_slices
56+
#[no_mangle]
57+
pub fn swap_1kb_slices(x: &mut [OneKilobyteBuffer], y: &mut [OneKilobyteBuffer]) {
58+
// CHECK-NOT: alloca
59+
// CHECK: load <{{[0-9]+}} x i8>
60+
// CHECK: store <{{[0-9]+}} x i8>
61+
if x.len() == y.len() {
62+
x.swap_with_slice(y);
63+
}
64+
}

src/test/codegen/swap-simd-types.rs

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// compile-flags: -O -C target-feature=+avx
2+
// only-x86_64
3+
// ignore-debug: the debug assertions get in the way
4+
5+
#![crate_type = "lib"]
6+
7+
use std::mem::swap;
8+
9+
// SIMD types are highly-aligned already, so make sure the swap code leaves their
10+
// types alone and doesn't pessimize them (such as by swapping them as `usize`s).
11+
extern crate core;
12+
use core::arch::x86_64::__m256;
13+
14+
// CHECK-LABEL: @swap_single_m256
15+
#[no_mangle]
16+
pub fn swap_single_m256(x: &mut __m256, y: &mut __m256) {
17+
// CHECK-NOT: alloca
18+
// CHECK: load <8 x float>{{.+}}align 32
19+
// CHECK: store <8 x float>{{.+}}align 32
20+
swap(x, y)
21+
}
22+
23+
// CHECK-LABEL: @swap_m256_slice
24+
#[no_mangle]
25+
pub fn swap_m256_slice(x: &mut [__m256], y: &mut [__m256]) {
26+
// CHECK-NOT: alloca
27+
// CHECK: load <8 x float>{{.+}}align 32
28+
// CHECK: store <8 x float>{{.+}}align 32
29+
if x.len() == y.len() {
30+
x.swap_with_slice(y);
31+
}
32+
}

0 commit comments

Comments
 (0)