|
| 1 | +use crate::cell::UnsafeCell; |
1 | 2 | use crate::sync::atomic::{
|
2 |
| - AtomicI32, |
| 3 | + AtomicI32, AtomicUsize, |
3 | 4 | Ordering::{Acquire, Relaxed, Release},
|
4 | 5 | };
|
5 | 6 | use crate::sys::futex::{futex_wait, futex_wake, futex_wake_all};
|
@@ -162,3 +163,98 @@ impl Condvar {
|
162 | 163 | r
|
163 | 164 | }
|
164 | 165 | }
|
| 166 | + |
| 167 | +/// A reentrant mutex. Used by stdout().lock() and friends. |
| 168 | +/// |
| 169 | +/// The 'owner' field tracks which thread has locked the mutex. |
| 170 | +/// |
| 171 | +/// We use current_thread_unique_ptr() as the thread identifier, |
| 172 | +/// which is just the address of a thread local variable. |
| 173 | +/// |
| 174 | +/// If `owner` is set to the identifier of the current thread, |
| 175 | +/// we assume the mutex is already locked and instead of locking it again, |
| 176 | +/// we increment `lock_count`. |
| 177 | +/// |
| 178 | +/// When unlocking, we decrement `lock_count`, and only unlock the mutex when |
| 179 | +/// it reaches zero. |
| 180 | +/// |
| 181 | +/// `lock_count` is protected by the mutex and only accessed by the thread that has |
| 182 | +/// locked the mutex, so needs no synchronization. |
| 183 | +/// |
| 184 | +/// `owner` can be checked by other threads that want to see if they already |
| 185 | +/// hold the lock, so needs to be atomic. If it compares equal, we're on the |
| 186 | +/// same thread that holds the mutex and memory access can use relaxed ordering |
| 187 | +/// since we're not dealing with multiple threads. If it compares unequal, |
| 188 | +/// synchronization is left to the mutex, making relaxed memory ordering for |
| 189 | +/// the `owner` field fine in all cases. |
| 190 | +pub struct ReentrantMutex { |
| 191 | + mutex: Mutex, |
| 192 | + owner: AtomicUsize, |
| 193 | + lock_count: UnsafeCell<u32>, |
| 194 | +} |
| 195 | + |
| 196 | +unsafe impl Send for ReentrantMutex {} |
| 197 | +unsafe impl Sync for ReentrantMutex {} |
| 198 | + |
| 199 | +impl ReentrantMutex { |
| 200 | + #[inline] |
| 201 | + pub const unsafe fn uninitialized() -> Self { |
| 202 | + Self { mutex: Mutex::new(), owner: AtomicUsize::new(0), lock_count: UnsafeCell::new(0) } |
| 203 | + } |
| 204 | + |
| 205 | + #[inline] |
| 206 | + pub unsafe fn init(&self) {} |
| 207 | + |
| 208 | + #[inline] |
| 209 | + pub unsafe fn destroy(&self) {} |
| 210 | + |
| 211 | + pub unsafe fn try_lock(&self) -> bool { |
| 212 | + let this_thread = current_thread_unique_ptr(); |
| 213 | + if self.owner.load(Relaxed) == this_thread { |
| 214 | + self.increment_lock_count(); |
| 215 | + true |
| 216 | + } else if self.mutex.try_lock() { |
| 217 | + self.owner.store(this_thread, Relaxed); |
| 218 | + debug_assert_eq!(*self.lock_count.get(), 0); |
| 219 | + *self.lock_count.get() = 1; |
| 220 | + true |
| 221 | + } else { |
| 222 | + false |
| 223 | + } |
| 224 | + } |
| 225 | + |
| 226 | + pub unsafe fn lock(&self) { |
| 227 | + let this_thread = current_thread_unique_ptr(); |
| 228 | + if self.owner.load(Relaxed) == this_thread { |
| 229 | + self.increment_lock_count(); |
| 230 | + } else { |
| 231 | + self.mutex.lock(); |
| 232 | + self.owner.store(this_thread, Relaxed); |
| 233 | + debug_assert_eq!(*self.lock_count.get(), 0); |
| 234 | + *self.lock_count.get() = 1; |
| 235 | + } |
| 236 | + } |
| 237 | + |
| 238 | + unsafe fn increment_lock_count(&self) { |
| 239 | + *self.lock_count.get() = (*self.lock_count.get()) |
| 240 | + .checked_add(1) |
| 241 | + .expect("lock count overflow in reentrant mutex"); |
| 242 | + } |
| 243 | + |
| 244 | + pub unsafe fn unlock(&self) { |
| 245 | + *self.lock_count.get() -= 1; |
| 246 | + if *self.lock_count.get() == 0 { |
| 247 | + self.owner.store(0, Relaxed); |
| 248 | + self.mutex.unlock(); |
| 249 | + } |
| 250 | + } |
| 251 | +} |
| 252 | + |
| 253 | +/// Get an address that is unique per running thread. |
| 254 | +/// |
| 255 | +/// This can be used as a non-null usize-sized ID. |
| 256 | +pub fn current_thread_unique_ptr() -> usize { |
| 257 | + // Use a non-drop type to make sure it's still available during thread destruction. |
| 258 | + thread_local! { static X: u8 = const { 0 } } |
| 259 | + X.with(|x| <*const _>::addr(x)) |
| 260 | +} |
0 commit comments