Skip to content

Commit 03afb61

Browse files
Optimize live point computation
This is just replicating the previous algorithm, but taking advantage of the bitset structures to optimize into tighter and better optimized loops. Particularly advantageous on enormous MIR blocks, which are relatively rare in practice.
1 parent ff0e148 commit 03afb61

File tree

5 files changed

+278
-29
lines changed

5 files changed

+278
-29
lines changed

compiler/rustc_borrowck/src/region_infer/values.rs

+5-23
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ impl RegionValueElements {
6060
PointIndex::new(start_index)
6161
}
6262

63+
/// Return the PointIndex for the block start of this index.
64+
crate fn to_block_start(&self, index: PointIndex) -> PointIndex {
65+
PointIndex::new(self.statements_before_block[self.basic_blocks[index]])
66+
}
67+
6368
/// Converts a `PointIndex` back to a location. O(1).
6469
crate fn to_location(&self, index: PointIndex) -> Location {
6570
assert!(index.index() < self.num_points);
@@ -76,29 +81,6 @@ impl RegionValueElements {
7681
crate fn point_in_range(&self, index: PointIndex) -> bool {
7782
index.index() < self.num_points
7883
}
79-
80-
/// Pushes all predecessors of `index` onto `stack`.
81-
crate fn push_predecessors(
82-
&self,
83-
body: &Body<'_>,
84-
index: PointIndex,
85-
stack: &mut Vec<PointIndex>,
86-
) {
87-
let Location { block, statement_index } = self.to_location(index);
88-
if statement_index == 0 {
89-
// If this is a basic block head, then the predecessors are
90-
// the terminators of other basic blocks
91-
stack.extend(
92-
body.predecessors()[block]
93-
.iter()
94-
.map(|&pred_bb| body.terminator_loc(pred_bb))
95-
.map(|pred_loc| self.point_from_location(pred_loc)),
96-
);
97-
} else {
98-
// Otherwise, the pred is just the previous statement
99-
stack.push(PointIndex::new(index.index() - 1));
100-
}
101-
}
10284
}
10385

10486
rustc_index::newtype_index! {

compiler/rustc_borrowck/src/type_check/liveness/trace.rs

+35-5
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,42 @@ impl LivenessResults<'me, 'typeck, 'flow, 'tcx> {
205205

206206
self.stack.extend(self.cx.local_use_map.uses(local));
207207
while let Some(p) = self.stack.pop() {
208-
if self.defs.contains(p) {
208+
// We are live in this block from the closest to us of:
209+
//
210+
// * Inclusively, the block start
211+
// * Exclusively, the previous definition (if it's in this block)
212+
// * Exclusively, the previous live_at setting (an optimization)
213+
let block_start = self.cx.elements.to_block_start(p);
214+
let previous_defs = self.defs.last_set_in(block_start..=p);
215+
let previous_live_at = self.use_live_at.last_set_in(block_start..=p);
216+
217+
let exclusive_start = match (previous_defs, previous_live_at) {
218+
(Some(a), Some(b)) => Some(std::cmp::max(a, b)),
219+
(Some(a), None) | (None, Some(a)) => Some(a),
220+
(None, None) => None,
221+
};
222+
223+
if let Some(exclusive) = exclusive_start {
224+
self.use_live_at.insert_range(exclusive + 1..=p);
225+
226+
// If we have a bound after the start of the block, we should
227+
// not add the predecessors for this block.
209228
continue;
210-
}
211-
212-
if self.use_live_at.insert(p) {
213-
self.cx.elements.push_predecessors(self.cx.body, p, &mut self.stack)
229+
} else {
230+
// Add all the elements of this block.
231+
self.use_live_at.insert_range(block_start..=p);
232+
233+
// Then add the predecessors for this block, which are the
234+
// terminators of predecessor basic blocks. Push those onto the
235+
// stack so that the next iteration(s) will process them.
236+
237+
let block = self.cx.elements.to_location(block_start).block;
238+
self.stack.extend(
239+
self.cx.body.predecessors()[block]
240+
.iter()
241+
.map(|&pred_bb| self.cx.body.terminator_loc(pred_bb))
242+
.map(|pred_loc| self.cx.elements.point_from_location(pred_loc)),
243+
);
214244
}
215245
}
216246
}

compiler/rustc_index/src/bit_set.rs

+141-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::fmt;
44
use std::iter;
55
use std::marker::PhantomData;
66
use std::mem;
7-
use std::ops::{BitAnd, BitAndAssign, BitOrAssign, Not, Range, Shl};
7+
use std::ops::{BitAnd, BitAndAssign, BitOrAssign, Bound, Not, Range, RangeBounds, Shl};
88
use std::slice;
99

1010
use rustc_macros::{Decodable, Encodable};
@@ -22,6 +22,29 @@ pub trait BitRelations<Rhs> {
2222
fn intersect(&mut self, other: &Rhs) -> bool;
2323
}
2424

25+
#[inline]
26+
fn inclusive_start_end<T: Idx>(
27+
range: impl RangeBounds<T>,
28+
domain: usize,
29+
) -> Option<(usize, usize)> {
30+
// Both start and end are inclusive.
31+
let start = match range.start_bound().cloned() {
32+
Bound::Included(start) => start.index(),
33+
Bound::Excluded(start) => start.index() + 1,
34+
Bound::Unbounded => 0,
35+
};
36+
let end = match range.end_bound().cloned() {
37+
Bound::Included(end) => end.index(),
38+
Bound::Excluded(end) => end.index().checked_sub(1)?,
39+
Bound::Unbounded => domain - 1,
40+
};
41+
assert!(end < domain);
42+
if start > end {
43+
return None;
44+
}
45+
Some((start, end))
46+
}
47+
2548
macro_rules! bit_relations_inherent_impls {
2649
() => {
2750
/// Sets `self = self | other` and returns `true` if `self` changed
@@ -151,6 +174,33 @@ impl<T: Idx> BitSet<T> {
151174
new_word != word
152175
}
153176

177+
#[inline]
178+
pub fn insert_range(&mut self, elems: impl RangeBounds<T>) {
179+
let Some((start, end)) = inclusive_start_end(elems, self.domain_size) else {
180+
return;
181+
};
182+
183+
let (start_word_index, start_mask) = word_index_and_mask(start);
184+
let (end_word_index, end_mask) = word_index_and_mask(end);
185+
186+
// Set all words in between start and end (exclusively of both).
187+
for word_index in (start_word_index + 1)..end_word_index {
188+
self.words[word_index] = !0;
189+
}
190+
191+
if start_word_index != end_word_index {
192+
// Start and end are in different words, so we handle each in turn.
193+
//
194+
// We set all leading bits. This includes the start_mask bit.
195+
self.words[start_word_index] |= !(start_mask - 1);
196+
// And all trailing bits (i.e. from 0..=end) in the end word,
197+
// including the end.
198+
self.words[end_word_index] |= end_mask | end_mask - 1;
199+
} else {
200+
self.words[start_word_index] |= end_mask | (end_mask - start_mask);
201+
}
202+
}
203+
154204
/// Sets all bits to true.
155205
pub fn insert_all(&mut self) {
156206
for word in &mut self.words {
@@ -227,6 +277,36 @@ impl<T: Idx> BitSet<T> {
227277
not_already
228278
}
229279

280+
fn last_set_in(&self, range: impl RangeBounds<T>) -> Option<T> {
281+
let (start, end) = inclusive_start_end(range, self.domain_size)?;
282+
let (start_word_index, _) = word_index_and_mask(start);
283+
let (end_word_index, end_mask) = word_index_and_mask(end);
284+
285+
let end_word = self.words[end_word_index] & (end_mask | (end_mask - 1));
286+
if end_word != 0 {
287+
let pos = max_bit(end_word) + WORD_BITS * end_word_index;
288+
if start <= pos {
289+
return Some(T::new(pos));
290+
}
291+
}
292+
293+
// We exclude end_word_index from the range here, because we don't want
294+
// to limit ourselves to *just* the last word: the bits set it in may be
295+
// after `end`, so it may not work out.
296+
if let Some(offset) =
297+
self.words[start_word_index..end_word_index].iter().rposition(|&w| w != 0)
298+
{
299+
let word_idx = start_word_index + offset;
300+
let start_word = self.words[word_idx];
301+
let pos = max_bit(start_word) + WORD_BITS * word_idx;
302+
if start <= pos {
303+
return Some(T::new(pos));
304+
}
305+
}
306+
307+
None
308+
}
309+
230310
bit_relations_inherent_impls! {}
231311
}
232312

@@ -635,6 +715,16 @@ impl<T: Idx> SparseBitSet<T> {
635715
self.elems.iter()
636716
}
637717

718+
fn last_set_in(&self, range: impl RangeBounds<T>) -> Option<T> {
719+
let mut last_leq = None;
720+
for e in self.iter() {
721+
if range.contains(e) {
722+
last_leq = Some(*e);
723+
}
724+
}
725+
last_leq
726+
}
727+
638728
bit_relations_inherent_impls! {}
639729
}
640730

@@ -709,6 +799,16 @@ impl<T: Idx> HybridBitSet<T> {
709799
}
710800
}
711801

802+
/// Returns the previous element present in the bitset from `elem`,
803+
/// inclusively of elem. That is, will return `Some(elem)` if elem is in the
804+
/// bitset.
805+
pub fn last_set_in(&self, range: impl RangeBounds<T>) -> Option<T> {
806+
match self {
807+
HybridBitSet::Sparse(sparse) => sparse.last_set_in(range),
808+
HybridBitSet::Dense(dense) => dense.last_set_in(range),
809+
}
810+
}
811+
712812
pub fn insert(&mut self, elem: T) -> bool {
713813
// No need to check `elem` against `self.domain_size` here because all
714814
// the match cases check it, one way or another.
@@ -734,6 +834,41 @@ impl<T: Idx> HybridBitSet<T> {
734834
}
735835
}
736836

837+
pub fn insert_range(&mut self, elems: impl RangeBounds<T>) {
838+
// No need to check `elem` against `self.domain_size` here because all
839+
// the match cases check it, one way or another.
840+
let start = match elems.start_bound().cloned() {
841+
Bound::Included(start) => start.index(),
842+
Bound::Excluded(start) => start.index() + 1,
843+
Bound::Unbounded => 0,
844+
};
845+
let end = match elems.end_bound().cloned() {
846+
Bound::Included(end) => end.index() + 1,
847+
Bound::Excluded(end) => end.index(),
848+
Bound::Unbounded => self.domain_size() - 1,
849+
};
850+
let len = if let Some(l) = end.checked_sub(start) {
851+
l
852+
} else {
853+
return;
854+
};
855+
match self {
856+
HybridBitSet::Sparse(sparse) if sparse.len() + len < SPARSE_MAX => {
857+
// The set is sparse and has space for `elems`.
858+
for elem in start..end {
859+
sparse.insert(T::new(elem));
860+
}
861+
}
862+
HybridBitSet::Sparse(sparse) => {
863+
// The set is sparse and full. Convert to a dense set.
864+
let mut dense = sparse.to_dense();
865+
dense.insert_range(elems);
866+
*self = HybridBitSet::Dense(dense);
867+
}
868+
HybridBitSet::Dense(dense) => dense.insert_range(elems),
869+
}
870+
}
871+
737872
pub fn insert_all(&mut self) {
738873
let domain_size = self.domain_size();
739874
match self {
@@ -1205,6 +1340,11 @@ fn word_index_and_mask<T: Idx>(elem: T) -> (usize, Word) {
12051340
(word_index, mask)
12061341
}
12071342

1343+
#[inline]
1344+
fn max_bit(word: Word) -> usize {
1345+
WORD_BITS - 1 - word.leading_zeros() as usize
1346+
}
1347+
12081348
/// Integral type used to represent the bit set.
12091349
pub trait FiniteBitSetTy:
12101350
BitAnd<Output = Self>

compiler/rustc_index/src/bit_set/tests.rs

+95
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,101 @@ fn sparse_matrix_operations() {
370370
}
371371
}
372372

373+
#[test]
374+
fn dense_insert_range() {
375+
#[track_caller]
376+
fn check<R>(domain: usize, range: R)
377+
where
378+
R: RangeBounds<usize> + Clone + IntoIterator<Item = usize> + std::fmt::Debug,
379+
{
380+
let mut set = BitSet::new_empty(domain);
381+
set.insert_range(range.clone());
382+
for i in set.iter() {
383+
assert!(range.contains(&i));
384+
}
385+
for i in range.clone() {
386+
assert!(set.contains(i), "{} in {:?}, inserted {:?}", i, set, range);
387+
}
388+
}
389+
check(300, 10..10);
390+
check(300, WORD_BITS..WORD_BITS * 2);
391+
check(300, WORD_BITS - 1..WORD_BITS * 2);
392+
check(300, WORD_BITS - 1..WORD_BITS);
393+
check(300, 10..100);
394+
check(300, 10..30);
395+
check(300, 0..5);
396+
check(300, 0..250);
397+
check(300, 200..250);
398+
399+
check(300, 10..=10);
400+
check(300, WORD_BITS..=WORD_BITS * 2);
401+
check(300, WORD_BITS - 1..=WORD_BITS * 2);
402+
check(300, WORD_BITS - 1..=WORD_BITS);
403+
check(300, 10..=100);
404+
check(300, 10..=30);
405+
check(300, 0..=5);
406+
check(300, 0..=250);
407+
check(300, 200..=250);
408+
409+
for i in 0..WORD_BITS * 2 {
410+
for j in i..WORD_BITS * 2 {
411+
check(WORD_BITS * 2, i..j);
412+
check(WORD_BITS * 2, i..=j);
413+
check(300, i..j);
414+
check(300, i..=j);
415+
}
416+
}
417+
}
418+
419+
#[test]
420+
fn dense_last_set_before() {
421+
fn easy(set: &BitSet<usize>, needle: impl RangeBounds<usize>) -> Option<usize> {
422+
let mut last_leq = None;
423+
for e in set.iter() {
424+
if needle.contains(&e) {
425+
last_leq = Some(e);
426+
}
427+
}
428+
last_leq
429+
}
430+
431+
#[track_caller]
432+
fn cmp(set: &BitSet<usize>, needle: impl RangeBounds<usize> + Clone + std::fmt::Debug) {
433+
assert_eq!(
434+
set.last_set_in(needle.clone()),
435+
easy(set, needle.clone()),
436+
"{:?} in {:?}",
437+
needle,
438+
set
439+
);
440+
}
441+
let mut set = BitSet::new_empty(300);
442+
cmp(&set, 50..=50);
443+
set.insert(WORD_BITS);
444+
cmp(&set, WORD_BITS..=WORD_BITS);
445+
set.insert(WORD_BITS - 1);
446+
cmp(&set, 0..=WORD_BITS - 1);
447+
cmp(&set, 0..=5);
448+
cmp(&set, 10..100);
449+
set.insert(100);
450+
cmp(&set, 100..110);
451+
cmp(&set, 99..100);
452+
cmp(&set, 99..=100);
453+
454+
for i in 0..=WORD_BITS * 2 {
455+
for j in i..=WORD_BITS * 2 {
456+
for k in 0..WORD_BITS * 2 {
457+
let mut set = BitSet::new_empty(300);
458+
cmp(&set, i..j);
459+
cmp(&set, i..=j);
460+
set.insert(k);
461+
cmp(&set, i..j);
462+
cmp(&set, i..=j);
463+
}
464+
}
465+
}
466+
}
467+
373468
/// Merge dense hybrid set into empty sparse hybrid set.
374469
#[bench]
375470
fn union_hybrid_sparse_empty_to_dense(b: &mut Bencher) {

compiler/rustc_index/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
#![feature(extend_one)]
44
#![feature(iter_zip)]
55
#![feature(min_specialization)]
6+
#![feature(step_trait)]
67
#![feature(test)]
8+
#![feature(let_else)]
79

810
pub mod bit_set;
911
pub mod vec;

0 commit comments

Comments
 (0)