Skip to content

Commit 9ee55a3

Browse files
ngoldbaumepilys
authored andcommittedJan 11, 2025
Implement locked iteration for PyList (#4789)
* implement locked iteration for PyList * fix limited API and PyPy support * fix formatting of safety docstrings * only define fold and rfold on not(feature = "nightly") * add missing try_fold implementation on nightly * Use split borrows for locked iteration for PyList Inline ListIterImpl implementations by using split borrows and destructuring let Self { .. } = self destructuring inside BoundListIterator impls. Signed-off-by: Manos Pitsidianakis <[email protected]> * use a function to do the split borrow * add changelog entries * fix clippy on limited API and PyPy * use a macro for the split borrow * add a test that mutates the list during a fold * enable next_unchecked on PyPy * fix incorrect docstring for locked_for_each * simplify borrows by adding BoundListIterator::with_critical_section * fix build on GIL-enabled and limited API builds * fix docs build on MSRV --------- Signed-off-by: Manos Pitsidianakis <[email protected]> Co-authored-by: Manos Pitsidianakis <[email protected]>
1 parent 4b04bb3 commit 9ee55a3

File tree

3 files changed

+454
-33
lines changed

3 files changed

+454
-33
lines changed
 

Diff for: ‎newsfragments/4789.added.md

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
* Added `PyList::locked_for_each`, which is equivalent to `PyList::for_each` on
2+
the GIL-enabled build and uses a critical section to lock the list on the
3+
free-threaded build, similar to `PyDict::locked_for_each`.

Diff for: ‎newsfragments/4789.changed.md

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
* Operations that process a PyList via an iterator now use a critical section
2+
on the free-threaded build to amortize synchronization cost and prevent race conditions.

Diff for: ‎src/types/list.rs

+449-33
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,9 @@ pub trait PyListMethods<'py>: crate::sealed::Sealed {
179179
/// # Safety
180180
///
181181
/// Caller must verify that the index is within the bounds of the list.
182-
#[cfg(not(any(Py_LIMITED_API, Py_GIL_DISABLED)))]
182+
/// On the free-threaded build, caller must verify they have exclusive access to the list
183+
/// via a lock or by holding the innermost critical section on the list.
184+
#[cfg(not(Py_LIMITED_API))]
183185
unsafe fn get_item_unchecked(&self, index: usize) -> Bound<'py, PyAny>;
184186

185187
/// Takes the slice `self[low:high]` and returns it as a new list.
@@ -239,6 +241,17 @@ pub trait PyListMethods<'py>: crate::sealed::Sealed {
239241
/// Returns an iterator over this list's items.
240242
fn iter(&self) -> BoundListIterator<'py>;
241243

244+
/// Iterates over the contents of this list while holding a critical section on the list.
245+
/// This is useful when the GIL is disabled and the list is shared between threads.
246+
/// It is not guaranteed that the list will not be modified during iteration when the
247+
/// closure calls arbitrary Python code that releases the critical section held by the
248+
/// iterator. Otherwise, the list will not be modified during iteration.
249+
///
250+
/// This is equivalent to for_each if the GIL is enabled.
251+
fn locked_for_each<F>(&self, closure: F) -> PyResult<()>
252+
where
253+
F: Fn(Bound<'py, PyAny>) -> PyResult<()>;
254+
242255
/// Sorts the list in-place. Equivalent to the Python expression `l.sort()`.
243256
fn sort(&self) -> PyResult<()>;
244257

@@ -302,7 +315,7 @@ impl<'py> PyListMethods<'py> for Bound<'py, PyList> {
302315
/// # Safety
303316
///
304317
/// Caller must verify that the index is within the bounds of the list.
305-
#[cfg(not(any(Py_LIMITED_API, Py_GIL_DISABLED)))]
318+
#[cfg(not(Py_LIMITED_API))]
306319
unsafe fn get_item_unchecked(&self, index: usize) -> Bound<'py, PyAny> {
307320
// PyList_GET_ITEM return borrowed ptr; must make owned for safety (see #890).
308321
ffi::PyList_GET_ITEM(self.as_ptr(), index as Py_ssize_t)
@@ -440,6 +453,14 @@ impl<'py> PyListMethods<'py> for Bound<'py, PyList> {
440453
BoundListIterator::new(self.clone())
441454
}
442455

456+
/// Iterates over a list while holding a critical section, calling a closure on each item
457+
fn locked_for_each<F>(&self, closure: F) -> PyResult<()>
458+
where
459+
F: Fn(Bound<'py, PyAny>) -> PyResult<()>,
460+
{
461+
crate::sync::with_critical_section(self, || self.iter().try_for_each(closure))
462+
}
463+
443464
/// Sorts the list in-place. Equivalent to the Python expression `l.sort()`.
444465
fn sort(&self) -> PyResult<()> {
445466
err::error_on_minusone(self.py(), unsafe { ffi::PyList_Sort(self.as_ptr()) })
@@ -462,73 +483,332 @@ impl<'py> PyListMethods<'py> for Bound<'py, PyList> {
462483
}
463484
}
464485

486+
// New types for type checking when using BoundListIterator associated methods, like
487+
// BoundListIterator::next_unchecked.
488+
struct Index(usize);
489+
struct Length(usize);
490+
465491
/// Used by `PyList::iter()`.
466492
pub struct BoundListIterator<'py> {
467493
list: Bound<'py, PyList>,
468-
index: usize,
469-
length: usize,
494+
index: Index,
495+
length: Length,
470496
}
471497

472498
impl<'py> BoundListIterator<'py> {
473499
fn new(list: Bound<'py, PyList>) -> Self {
474-
let length: usize = list.len();
475-
BoundListIterator {
500+
Self {
501+
index: Index(0),
502+
length: Length(list.len()),
476503
list,
477-
index: 0,
478-
length,
479504
}
480505
}
481506

482-
unsafe fn get_item(&self, index: usize) -> Bound<'py, PyAny> {
483-
#[cfg(any(Py_LIMITED_API, PyPy, Py_GIL_DISABLED))]
484-
let item = self.list.get_item(index).expect("list.get failed");
485-
#[cfg(not(any(Py_LIMITED_API, PyPy, Py_GIL_DISABLED)))]
486-
let item = self.list.get_item_unchecked(index);
487-
item
507+
/// # Safety
508+
///
509+
/// On the free-threaded build, caller must verify they have exclusive
510+
/// access to the list by holding a lock or by holding the innermost
511+
/// critical section on the list.
512+
#[inline]
513+
#[cfg(not(Py_LIMITED_API))]
514+
#[deny(unsafe_op_in_unsafe_fn)]
515+
unsafe fn next_unchecked(
516+
index: &mut Index,
517+
length: &mut Length,
518+
list: &Bound<'py, PyList>,
519+
) -> Option<Bound<'py, PyAny>> {
520+
let length = length.0.min(list.len());
521+
let my_index = index.0;
522+
523+
if index.0 < length {
524+
let item = unsafe { list.get_item_unchecked(my_index) };
525+
index.0 += 1;
526+
Some(item)
527+
} else {
528+
None
529+
}
488530
}
489-
}
490531

491-
impl<'py> Iterator for BoundListIterator<'py> {
492-
type Item = Bound<'py, PyAny>;
532+
#[cfg(Py_LIMITED_API)]
533+
fn next(
534+
index: &mut Index,
535+
length: &mut Length,
536+
list: &Bound<'py, PyList>,
537+
) -> Option<Bound<'py, PyAny>> {
538+
let length = length.0.min(list.len());
539+
let my_index = index.0;
493540

541+
if index.0 < length {
542+
let item = list.get_item(my_index).expect("get-item failed");
543+
index.0 += 1;
544+
Some(item)
545+
} else {
546+
None
547+
}
548+
}
549+
550+
/// # Safety
551+
///
552+
/// On the free-threaded build, caller must verify they have exclusive
553+
/// access to the list by holding a lock or by holding the innermost
554+
/// critical section on the list.
494555
#[inline]
495-
fn next(&mut self) -> Option<Self::Item> {
496-
let length = self.length.min(self.list.len());
556+
#[cfg(not(Py_LIMITED_API))]
557+
#[deny(unsafe_op_in_unsafe_fn)]
558+
unsafe fn next_back_unchecked(
559+
index: &mut Index,
560+
length: &mut Length,
561+
list: &Bound<'py, PyList>,
562+
) -> Option<Bound<'py, PyAny>> {
563+
let current_length = length.0.min(list.len());
564+
565+
if index.0 < current_length {
566+
let item = unsafe { list.get_item_unchecked(current_length - 1) };
567+
length.0 = current_length - 1;
568+
Some(item)
569+
} else {
570+
None
571+
}
572+
}
497573

498-
if self.index < length {
499-
let item = unsafe { self.get_item(self.index) };
500-
self.index += 1;
574+
#[inline]
575+
#[cfg(Py_LIMITED_API)]
576+
fn next_back(
577+
index: &mut Index,
578+
length: &mut Length,
579+
list: &Bound<'py, PyList>,
580+
) -> Option<Bound<'py, PyAny>> {
581+
let current_length = (length.0).min(list.len());
582+
583+
if index.0 < current_length {
584+
let item = list.get_item(current_length - 1).expect("get-item failed");
585+
length.0 = current_length - 1;
501586
Some(item)
502587
} else {
503588
None
504589
}
505590
}
506591

592+
#[cfg(not(Py_LIMITED_API))]
593+
fn with_critical_section<R>(
594+
&mut self,
595+
f: impl FnOnce(&mut Index, &mut Length, &Bound<'py, PyList>) -> R,
596+
) -> R {
597+
let Self {
598+
index,
599+
length,
600+
list,
601+
} = self;
602+
crate::sync::with_critical_section(list, || f(index, length, list))
603+
}
604+
}
605+
606+
impl<'py> Iterator for BoundListIterator<'py> {
607+
type Item = Bound<'py, PyAny>;
608+
609+
#[inline]
610+
fn next(&mut self) -> Option<Self::Item> {
611+
#[cfg(not(Py_LIMITED_API))]
612+
{
613+
self.with_critical_section(|index, length, list| unsafe {
614+
Self::next_unchecked(index, length, list)
615+
})
616+
}
617+
#[cfg(Py_LIMITED_API)]
618+
{
619+
let Self {
620+
index,
621+
length,
622+
list,
623+
} = self;
624+
Self::next(index, length, list)
625+
}
626+
}
627+
507628
#[inline]
508629
fn size_hint(&self) -> (usize, Option<usize>) {
509630
let len = self.len();
510631
(len, Some(len))
511632
}
633+
634+
#[inline]
635+
#[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))]
636+
fn fold<B, F>(mut self, init: B, mut f: F) -> B
637+
where
638+
Self: Sized,
639+
F: FnMut(B, Self::Item) -> B,
640+
{
641+
self.with_critical_section(|index, length, list| {
642+
let mut accum = init;
643+
while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } {
644+
accum = f(accum, x);
645+
}
646+
accum
647+
})
648+
}
649+
650+
#[inline]
651+
#[cfg(all(Py_GIL_DISABLED, feature = "nightly"))]
652+
fn try_fold<B, F, R>(&mut self, init: B, mut f: F) -> R
653+
where
654+
Self: Sized,
655+
F: FnMut(B, Self::Item) -> R,
656+
R: std::ops::Try<Output = B>,
657+
{
658+
self.with_critical_section(|index, length, list| {
659+
let mut accum = init;
660+
while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } {
661+
accum = f(accum, x)?
662+
}
663+
R::from_output(accum)
664+
})
665+
}
666+
667+
#[inline]
668+
#[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))]
669+
fn all<F>(&mut self, mut f: F) -> bool
670+
where
671+
Self: Sized,
672+
F: FnMut(Self::Item) -> bool,
673+
{
674+
self.with_critical_section(|index, length, list| {
675+
while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } {
676+
if !f(x) {
677+
return false;
678+
}
679+
}
680+
true
681+
})
682+
}
683+
684+
#[inline]
685+
#[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))]
686+
fn any<F>(&mut self, mut f: F) -> bool
687+
where
688+
Self: Sized,
689+
F: FnMut(Self::Item) -> bool,
690+
{
691+
self.with_critical_section(|index, length, list| {
692+
while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } {
693+
if f(x) {
694+
return true;
695+
}
696+
}
697+
false
698+
})
699+
}
700+
701+
#[inline]
702+
#[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))]
703+
fn find<P>(&mut self, mut predicate: P) -> Option<Self::Item>
704+
where
705+
Self: Sized,
706+
P: FnMut(&Self::Item) -> bool,
707+
{
708+
self.with_critical_section(|index, length, list| {
709+
while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } {
710+
if predicate(&x) {
711+
return Some(x);
712+
}
713+
}
714+
None
715+
})
716+
}
717+
718+
#[inline]
719+
#[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))]
720+
fn find_map<B, F>(&mut self, mut f: F) -> Option<B>
721+
where
722+
Self: Sized,
723+
F: FnMut(Self::Item) -> Option<B>,
724+
{
725+
self.with_critical_section(|index, length, list| {
726+
while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } {
727+
if let found @ Some(_) = f(x) {
728+
return found;
729+
}
730+
}
731+
None
732+
})
733+
}
734+
735+
#[inline]
736+
#[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))]
737+
fn position<P>(&mut self, mut predicate: P) -> Option<usize>
738+
where
739+
Self: Sized,
740+
P: FnMut(Self::Item) -> bool,
741+
{
742+
self.with_critical_section(|index, length, list| {
743+
let mut acc = 0;
744+
while let Some(x) = unsafe { Self::next_unchecked(index, length, list) } {
745+
if predicate(x) {
746+
return Some(acc);
747+
}
748+
acc += 1;
749+
}
750+
None
751+
})
752+
}
512753
}
513754

514755
impl DoubleEndedIterator for BoundListIterator<'_> {
515756
#[inline]
516757
fn next_back(&mut self) -> Option<Self::Item> {
517-
let length = self.length.min(self.list.len());
518-
519-
if self.index < length {
520-
let item = unsafe { self.get_item(length - 1) };
521-
self.length = length - 1;
522-
Some(item)
523-
} else {
524-
None
758+
#[cfg(not(Py_LIMITED_API))]
759+
{
760+
self.with_critical_section(|index, length, list| unsafe {
761+
Self::next_back_unchecked(index, length, list)
762+
})
763+
}
764+
#[cfg(Py_LIMITED_API)]
765+
{
766+
let Self {
767+
index,
768+
length,
769+
list,
770+
} = self;
771+
Self::next_back(index, length, list)
525772
}
526773
}
774+
775+
#[inline]
776+
#[cfg(all(Py_GIL_DISABLED, not(feature = "nightly")))]
777+
fn rfold<B, F>(mut self, init: B, mut f: F) -> B
778+
where
779+
Self: Sized,
780+
F: FnMut(B, Self::Item) -> B,
781+
{
782+
self.with_critical_section(|index, length, list| {
783+
let mut accum = init;
784+
while let Some(x) = unsafe { Self::next_back_unchecked(index, length, list) } {
785+
accum = f(accum, x);
786+
}
787+
accum
788+
})
789+
}
790+
791+
#[inline]
792+
#[cfg(all(Py_GIL_DISABLED, feature = "nightly"))]
793+
fn try_rfold<B, F, R>(&mut self, init: B, mut f: F) -> R
794+
where
795+
Self: Sized,
796+
F: FnMut(B, Self::Item) -> R,
797+
R: std::ops::Try<Output = B>,
798+
{
799+
self.with_critical_section(|index, length, list| {
800+
let mut accum = init;
801+
while let Some(x) = unsafe { Self::next_back_unchecked(index, length, list) } {
802+
accum = f(accum, x)?
803+
}
804+
R::from_output(accum)
805+
})
806+
}
527807
}
528808

529809
impl ExactSizeIterator for BoundListIterator<'_> {
530810
fn len(&self) -> usize {
531-
self.length.saturating_sub(self.index)
811+
self.length.0.saturating_sub(self.index.0)
532812
}
533813
}
534814

@@ -558,7 +838,7 @@ mod tests {
558838
use crate::types::list::PyListMethods;
559839
use crate::types::sequence::PySequenceMethods;
560840
use crate::types::{PyList, PyTuple};
561-
use crate::{ffi, IntoPyObject, Python};
841+
use crate::{ffi, IntoPyObject, PyResult, Python};
562842

563843
#[test]
564844
fn test_new() {
@@ -748,6 +1028,142 @@ mod tests {
7481028
});
7491029
}
7501030

1031+
#[test]
1032+
fn test_iter_all() {
1033+
Python::with_gil(|py| {
1034+
let list = PyList::new(py, [true, true, true]).unwrap();
1035+
assert!(list.iter().all(|x| x.extract::<bool>().unwrap()));
1036+
1037+
let list = PyList::new(py, [true, false, true]).unwrap();
1038+
assert!(!list.iter().all(|x| x.extract::<bool>().unwrap()));
1039+
});
1040+
}
1041+
1042+
#[test]
1043+
fn test_iter_any() {
1044+
Python::with_gil(|py| {
1045+
let list = PyList::new(py, [true, true, true]).unwrap();
1046+
assert!(list.iter().any(|x| x.extract::<bool>().unwrap()));
1047+
1048+
let list = PyList::new(py, [true, false, true]).unwrap();
1049+
assert!(list.iter().any(|x| x.extract::<bool>().unwrap()));
1050+
1051+
let list = PyList::new(py, [false, false, false]).unwrap();
1052+
assert!(!list.iter().any(|x| x.extract::<bool>().unwrap()));
1053+
});
1054+
}
1055+
1056+
#[test]
1057+
fn test_iter_find() {
1058+
Python::with_gil(|py: Python<'_>| {
1059+
let list = PyList::new(py, ["hello", "world"]).unwrap();
1060+
assert_eq!(
1061+
Some("world".to_string()),
1062+
list.iter()
1063+
.find(|v| v.extract::<String>().unwrap() == "world")
1064+
.map(|v| v.extract::<String>().unwrap())
1065+
);
1066+
assert_eq!(
1067+
None,
1068+
list.iter()
1069+
.find(|v| v.extract::<String>().unwrap() == "foobar")
1070+
.map(|v| v.extract::<String>().unwrap())
1071+
);
1072+
});
1073+
}
1074+
1075+
#[test]
1076+
fn test_iter_position() {
1077+
Python::with_gil(|py: Python<'_>| {
1078+
let list = PyList::new(py, ["hello", "world"]).unwrap();
1079+
assert_eq!(
1080+
Some(1),
1081+
list.iter()
1082+
.position(|v| v.extract::<String>().unwrap() == "world")
1083+
);
1084+
assert_eq!(
1085+
None,
1086+
list.iter()
1087+
.position(|v| v.extract::<String>().unwrap() == "foobar")
1088+
);
1089+
});
1090+
}
1091+
1092+
#[test]
1093+
fn test_iter_fold() {
1094+
Python::with_gil(|py: Python<'_>| {
1095+
let list = PyList::new(py, [1, 2, 3]).unwrap();
1096+
let sum = list
1097+
.iter()
1098+
.fold(0, |acc, v| acc + v.extract::<usize>().unwrap());
1099+
assert_eq!(sum, 6);
1100+
});
1101+
}
1102+
1103+
#[test]
1104+
fn test_iter_fold_out_of_bounds() {
1105+
Python::with_gil(|py: Python<'_>| {
1106+
let list = PyList::new(py, [1, 2, 3]).unwrap();
1107+
let sum = list.iter().fold(0, |_, _| {
1108+
// clear the list to create a pathological fold operation
1109+
// that mutates the list as it processes it
1110+
for _ in 0..3 {
1111+
list.del_item(0).unwrap();
1112+
}
1113+
-5
1114+
});
1115+
assert_eq!(sum, -5);
1116+
assert!(list.len() == 0);
1117+
});
1118+
}
1119+
1120+
#[test]
1121+
fn test_iter_rfold() {
1122+
Python::with_gil(|py: Python<'_>| {
1123+
let list = PyList::new(py, [1, 2, 3]).unwrap();
1124+
let sum = list
1125+
.iter()
1126+
.rfold(0, |acc, v| acc + v.extract::<usize>().unwrap());
1127+
assert_eq!(sum, 6);
1128+
});
1129+
}
1130+
1131+
#[test]
1132+
fn test_iter_try_fold() {
1133+
Python::with_gil(|py: Python<'_>| {
1134+
let list = PyList::new(py, [1, 2, 3]).unwrap();
1135+
let sum = list
1136+
.iter()
1137+
.try_fold(0, |acc, v| PyResult::Ok(acc + v.extract::<usize>()?))
1138+
.unwrap();
1139+
assert_eq!(sum, 6);
1140+
1141+
let list = PyList::new(py, ["foo", "bar"]).unwrap();
1142+
assert!(list
1143+
.iter()
1144+
.try_fold(0, |acc, v| PyResult::Ok(acc + v.extract::<usize>()?))
1145+
.is_err());
1146+
});
1147+
}
1148+
1149+
#[test]
1150+
fn test_iter_try_rfold() {
1151+
Python::with_gil(|py: Python<'_>| {
1152+
let list = PyList::new(py, [1, 2, 3]).unwrap();
1153+
let sum = list
1154+
.iter()
1155+
.try_rfold(0, |acc, v| PyResult::Ok(acc + v.extract::<usize>()?))
1156+
.unwrap();
1157+
assert_eq!(sum, 6);
1158+
1159+
let list = PyList::new(py, ["foo", "bar"]).unwrap();
1160+
assert!(list
1161+
.iter()
1162+
.try_rfold(0, |acc, v| PyResult::Ok(acc + v.extract::<usize>()?))
1163+
.is_err());
1164+
});
1165+
}
1166+
7511167
#[test]
7521168
fn test_into_iter() {
7531169
Python::with_gil(|py| {
@@ -877,7 +1293,7 @@ mod tests {
8771293
});
8781294
}
8791295

880-
#[cfg(not(any(Py_LIMITED_API, PyPy, Py_GIL_DISABLED)))]
1296+
#[cfg(not(Py_LIMITED_API))]
8811297
#[test]
8821298
fn test_list_get_item_unchecked_sanity() {
8831299
Python::with_gil(|py| {

0 commit comments

Comments
 (0)
Please sign in to comment.