Skip to content

Commit b03d939

Browse files
committed
Improve autovectorization of to_lowercase / to_uppercase functions
Refactor the code in the `convert_while_ascii` helper function to make it more suitable for auto-vectorization and also process the full ascii prefix of the string. The generic case conversion logic will only be invoked starting from the first non-ascii character. The runtime on microbenchmarks with ascii-only inputs improves between 1.5x for short and 4x for long inputs on x86_64 and aarch64. The new implementation also encapsulates all unsafe inside the `convert_while_ascii` function. Fixes #123712
1 parent eda9d7f commit b03d939

File tree

3 files changed

+87
-50
lines changed

3 files changed

+87
-50
lines changed

library/alloc/benches/str.rs

+2
Original file line numberDiff line numberDiff line change
@@ -347,3 +347,5 @@ make_test!(rsplitn_space_char, s, s.rsplitn(10, ' ').count());
347347

348348
make_test!(split_space_str, s, s.split(" ").count());
349349
make_test!(split_ad_str, s, s.split("ad").count());
350+
351+
make_test!(to_lowercase, s, s.to_lowercase());

library/alloc/src/str.rs

+68-50
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
use core::borrow::{Borrow, BorrowMut};
1111
use core::iter::FusedIterator;
1212
use core::mem;
13+
use core::mem::MaybeUninit;
1314
use core::ptr;
1415
use core::str::pattern::{DoubleEndedSearcher, Pattern, ReverseSearcher, Searcher};
1516
use core::unicode::conversions;
@@ -366,14 +367,9 @@ impl str {
366367
without modifying the original"]
367368
#[stable(feature = "unicode_case_mapping", since = "1.2.0")]
368369
pub fn to_lowercase(&self) -> String {
369-
let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_lowercase);
370+
let (mut s, rest) = convert_while_ascii(self, u8::to_ascii_lowercase);
370371

371-
// Safety: we know this is a valid char boundary since
372-
// out.len() is only progressed if ascii bytes are found
373-
let rest = unsafe { self.get_unchecked(out.len()..) };
374-
375-
// Safety: We have written only valid ASCII to our vec
376-
let mut s = unsafe { String::from_utf8_unchecked(out) };
372+
let prefix_len = s.len();
377373

378374
for (i, c) in rest.char_indices() {
379375
if c == 'Σ' {
@@ -382,8 +378,7 @@ impl str {
382378
// in `SpecialCasing.txt`,
383379
// so hard-code it rather than have a generic "condition" mechanism.
384380
// See https://github.com/rust-lang/rust/issues/26035
385-
let out_len = self.len() - rest.len();
386-
let sigma_lowercase = map_uppercase_sigma(&self, i + out_len);
381+
let sigma_lowercase = map_uppercase_sigma(self, prefix_len + i);
387382
s.push(sigma_lowercase);
388383
} else {
389384
match conversions::to_lower(c) {
@@ -459,14 +454,7 @@ impl str {
459454
without modifying the original"]
460455
#[stable(feature = "unicode_case_mapping", since = "1.2.0")]
461456
pub fn to_uppercase(&self) -> String {
462-
let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_uppercase);
463-
464-
// Safety: we know this is a valid char boundary since
465-
// out.len() is only progressed if ascii bytes are found
466-
let rest = unsafe { self.get_unchecked(out.len()..) };
467-
468-
// Safety: We have written only valid ASCII to our vec
469-
let mut s = unsafe { String::from_utf8_unchecked(out) };
457+
let (mut s, rest) = convert_while_ascii(self, u8::to_ascii_uppercase);
470458

471459
for c in rest.chars() {
472460
match conversions::to_upper(c) {
@@ -615,50 +603,80 @@ pub unsafe fn from_boxed_utf8_unchecked(v: Box<[u8]>) -> Box<str> {
615603
unsafe { Box::from_raw(Box::into_raw(v) as *mut str) }
616604
}
617605

618-
/// Converts the bytes while the bytes are still ascii.
606+
/// Converts leading ascii bytes in `s` by calling the `convert` function.
607+
///
619608
/// For better average performance, this happens in chunks of `2*size_of::<usize>()`.
620-
/// Returns a vec with the converted bytes.
609+
///
610+
/// Returns a tuple of the converted prefix and the remainder starting from
611+
/// the first non-ascii character.
621612
#[inline]
622613
#[cfg(not(test))]
623614
#[cfg(not(no_global_oom_handling))]
624-
fn convert_while_ascii(b: &[u8], convert: fn(&u8) -> u8) -> Vec<u8> {
625-
let mut out = Vec::with_capacity(b.len());
626-
615+
fn convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> (String, &str) {
616+
// process the input in chunks to enable auto-vectorization
627617
const USIZE_SIZE: usize = mem::size_of::<usize>();
628618
const MAGIC_UNROLL: usize = 2;
629619
const N: usize = USIZE_SIZE * MAGIC_UNROLL;
630-
const NONASCII_MASK: usize = usize::from_ne_bytes([0x80; USIZE_SIZE]);
631620

632-
let mut i = 0;
633-
unsafe {
634-
while i + N <= b.len() {
635-
// Safety: we have checks the sizes `b` and `out` to know that our
636-
let in_chunk = b.get_unchecked(i..i + N);
637-
let out_chunk = out.spare_capacity_mut().get_unchecked_mut(i..i + N);
638-
639-
let mut bits = 0;
640-
for j in 0..MAGIC_UNROLL {
641-
// read the bytes 1 usize at a time (unaligned since we haven't checked the alignment)
642-
// safety: in_chunk is valid bytes in the range
643-
bits |= in_chunk.as_ptr().cast::<usize>().add(j).read_unaligned();
644-
}
645-
// if our chunks aren't ascii, then return only the prior bytes as init
646-
if bits & NONASCII_MASK != 0 {
647-
break;
648-
}
621+
let mut slice = s.as_bytes();
622+
let mut out = Vec::with_capacity(slice.len());
623+
let mut out_slice = &mut out.spare_capacity_mut()[..slice.len()];
649624

650-
// perform the case conversions on N bytes (gets heavily autovec'd)
651-
for j in 0..N {
652-
// safety: in_chunk and out_chunk is valid bytes in the range
653-
let out = out_chunk.get_unchecked_mut(j);
654-
out.write(convert(in_chunk.get_unchecked(j)));
655-
}
625+
let mut ascii_prefix_len = 0_usize;
626+
let mut is_ascii = [false; N];
627+
628+
while slice.len() >= N {
629+
// Safety: checked in loop condition
630+
let chunk = unsafe { slice.get_unchecked(..N) };
631+
// Safety: out_slice has same length as input slice and gets sliced with the same offsets
632+
let out_chunk = unsafe { out_slice.get_unchecked_mut(..N) };
633+
634+
for j in 0..N {
635+
is_ascii[j] = chunk[j] <= 127;
636+
}
656637

657-
// mark these bytes as initialised
658-
i += N;
638+
// auto-vectorization for this check is a bit fragile,
639+
// sum and comparing against the chunk size gives the best result,
640+
// specifically a pmovmsk instruction on x86.
641+
if is_ascii.iter().map(|x| *x as u8).sum::<u8>() as usize != N {
642+
break;
659643
}
660-
out.set_len(i);
644+
645+
for j in 0..N {
646+
out_chunk[j] = MaybeUninit::new(convert(&chunk[j]));
647+
}
648+
649+
ascii_prefix_len += N;
650+
slice = unsafe { slice.get_unchecked(N..) };
651+
out_slice = unsafe { out_slice.get_unchecked_mut(N..) };
652+
}
653+
654+
// handle the remainder as individual bytes
655+
while slice.len() > 0 {
656+
let byte = slice[0];
657+
if byte > 127 {
658+
break;
659+
}
660+
// Safety: out_slice has same length as input slice and gets sliced with the same offsets
661+
unsafe {
662+
*out_slice.get_unchecked_mut(0) = MaybeUninit::new(convert(&byte));
663+
}
664+
ascii_prefix_len += 1;
665+
slice = unsafe { slice.get_unchecked(1..) };
666+
out_slice = unsafe { out_slice.get_unchecked_mut(1..) };
661667
}
662668

663-
out
669+
unsafe {
670+
// SAFETY: ascii_prefix_len bytes have been initialized above
671+
out.set_len(ascii_prefix_len);
672+
673+
// SAFETY: We have written only valid ascii to the output vec
674+
let ascii_string = String::from_utf8_unchecked(out);
675+
676+
// SAFETY: we know this is a valid char boundary
677+
// since we only skipped over leading ascii bytes
678+
let rest = core::str::from_utf8_unchecked(slice);
679+
680+
(ascii_string, rest)
681+
}
664682
}

library/alloc/tests/str.rs

+17
Original file line numberDiff line numberDiff line change
@@ -1826,6 +1826,19 @@ fn to_lowercase() {
18261826
assert_eq!("Α'Σ".to_lowercase(), "α'ς");
18271827
assert_eq!("Α''Σ".to_lowercase(), "α''ς");
18281828

1829+
assert_eq!("aΣ".to_lowercase(), "aς");
1830+
assert_eq!("a'Σ".to_lowercase(), "a'ς");
1831+
assert_eq!("a''Σ".to_lowercase(), "a''ς");
1832+
1833+
assert_eq!("ÄΣ".to_lowercase(), "äς");
1834+
assert_eq!("ä'Σ".to_lowercase(), "ä'ς");
1835+
assert_eq!("ä''Σ".to_lowercase(), "ä''ς");
1836+
1837+
// input lengths around the boundary of the chunk size used by the ascii prefix optimization
1838+
assert_eq!("abcdefghijklmnoΣ".to_lowercase(), "abcdefghijklmnoς");
1839+
assert_eq!("abcdefghijklmnopΣ".to_lowercase(), "abcdefghijklmnopς");
1840+
assert_eq!("abcdefghijklmnopqΣ".to_lowercase(), "abcdefghijklmnopqς");
1841+
18291842
assert_eq!("ΑΣ Α".to_lowercase(), "ας α");
18301843
assert_eq!("Α'Σ Α".to_lowercase(), "α'ς α");
18311844
assert_eq!("Α''Σ Α".to_lowercase(), "α''ς α");
@@ -1840,6 +1853,10 @@ fn to_lowercase() {
18401853
assert_eq!("Α 'Σ".to_lowercase(), "α 'σ");
18411854
assert_eq!("Α ''Σ".to_lowercase(), "α ''σ");
18421855

1856+
assert_eq!("Ä Σ".to_lowercase(), "ä σ");
1857+
assert_eq!("Ä 'Σ".to_lowercase(), "ä 'σ");
1858+
assert_eq!("Ä ''Σ".to_lowercase(), "ä ''σ");
1859+
18431860
assert_eq!("Σ".to_lowercase(), "σ");
18441861
assert_eq!("'Σ".to_lowercase(), "'σ");
18451862
assert_eq!("''Σ".to_lowercase(), "''σ");

0 commit comments

Comments
 (0)