Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PeekMut::refresh #138161

Merged
merged 1 commit into from
Mar 12, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 78 additions & 4 deletions library/alloc/src/collections/binary_heap/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,74 @@ impl<T: Ord, A: Allocator> DerefMut for PeekMut<'_, T, A> {
}

impl<'a, T: Ord, A: Allocator> PeekMut<'a, T, A> {
/// Sifts the current element to its new position.
///
/// Afterwards refers to the new element. Returns if the element changed.
///
/// ## Examples
///
/// The condition can be used to upper bound all elements in the heap. When only few elements
/// are affected, the heap's sort ensures this is faster than a reconstruction from the raw
/// element list and requires no additional allocation.
///
/// ```
/// #![feature(binary_heap_peek_mut_refresh)]
/// use std::collections::BinaryHeap;
///
/// let mut heap: BinaryHeap<u32> = (0..128).collect();
/// let mut peek = heap.peek_mut().unwrap();
///
/// loop {
/// *peek = 99;
///
/// if !peek.refresh() {
/// break;
/// }
/// }
///
/// // Post condition, this is now an upper bound.
/// assert!(*peek < 100);
/// ```
///
/// When the element remains the maximum after modification, the peek remains unchanged:
///
/// ```
/// #![feature(binary_heap_peek_mut_refresh)]
/// use std::collections::BinaryHeap;
///
/// let mut heap: BinaryHeap<u32> = [1, 2, 3].into();
/// let mut peek = heap.peek_mut().unwrap();
///
/// assert_eq!(*peek, 3);
/// *peek = 42;
///
/// // When we refresh, the peek is updated to the new maximum.
/// assert!(!peek.refresh(), "42 is even larger than 3");
/// assert_eq!(*peek, 42);
/// ```
#[unstable(feature = "binary_heap_peek_mut_refresh", issue = "138355")]
#[must_use = "is equivalent to dropping and getting a new PeekMut except for return information"]
pub fn refresh(&mut self) -> bool {
// The length of the underlying heap is unchanged by sifting down. The value stored for leak
// amplification thus remains accurate. We erase the leak amplification firstly because the
// operation is then equivalent to constructing a new PeekMut and secondly this avoids any
// future complication where original_len being non-empty would be interpreted as the heap
// having been leak amplified instead of checking the heap itself.
if let Some(original_len) = self.original_len.take() {
// SAFETY: This is how many elements were in the Vec at the time of
// the BinaryHeap::peek_mut call.
unsafe { self.heap.data.set_len(original_len.get()) };

// The length of the heap did not change by sifting, upholding our own invariants.

// SAFETY: PeekMut is only instantiated for non-empty heaps.
(unsafe { self.heap.sift_down(0) }) != 0
} else {
// The element was not modified.
false
}
}

/// Removes the peeked value from the heap and returns it.
#[stable(feature = "binary_heap_peek_mut_pop", since = "1.18.0")]
pub fn pop(mut this: PeekMut<'a, T, A>) -> T {
Expand Down Expand Up @@ -670,6 +738,8 @@ impl<T: Ord, A: Allocator> BinaryHeap<T, A> {
/// # Safety
///
/// The caller must guarantee that `pos < self.len()`.
///
/// Returns the new position of the element.
unsafe fn sift_up(&mut self, start: usize, pos: usize) -> usize {
// Take out the value at `pos` and create a hole.
// SAFETY: The caller guarantees that pos < self.len()
Expand All @@ -696,10 +766,12 @@ impl<T: Ord, A: Allocator> BinaryHeap<T, A> {
/// Take an element at `pos` and move it down the heap,
/// while its children are larger.
///
/// Returns the new position of the element.
///
/// # Safety
///
/// The caller must guarantee that `pos < end <= self.len()`.
unsafe fn sift_down_range(&mut self, pos: usize, end: usize) {
unsafe fn sift_down_range(&mut self, pos: usize, end: usize) -> usize {
// SAFETY: The caller guarantees that pos < end <= self.len().
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
let mut child = 2 * hole.pos() + 1;
Expand All @@ -719,7 +791,7 @@ impl<T: Ord, A: Allocator> BinaryHeap<T, A> {
// SAFETY: child is now either the old child or the old child+1
// We already proven that both are < self.len() and != hole.pos()
if hole.element() >= unsafe { hole.get(child) } {
return;
return hole.pos();
}

// SAFETY: same as above.
Expand All @@ -734,16 +806,18 @@ impl<T: Ord, A: Allocator> BinaryHeap<T, A> {
// child == 2 * hole.pos() + 1 != hole.pos().
unsafe { hole.move_to(child) };
}

hole.pos()
}

/// # Safety
///
/// The caller must guarantee that `pos < self.len()`.
unsafe fn sift_down(&mut self, pos: usize) {
unsafe fn sift_down(&mut self, pos: usize) -> usize {
let len = self.len();
// SAFETY: pos < len is guaranteed by the caller and
// obviously len = self.len() <= self.len().
unsafe { self.sift_down_range(pos, len) };
unsafe { self.sift_down_range(pos, len) }
}

/// Take an element at `pos` and move it all the way down the heap,
Expand Down
Loading