Skip to content

Commit 295e67a

Browse files
improve signature of ffi::PyIter_Send & add PyIterator::send (#4746)
Co-authored-by: David Hewitt <[email protected]>
1 parent c068831 commit 295e67a

File tree

4 files changed

+76
-1
lines changed

4 files changed

+76
-1
lines changed

newsfragments/4746.added.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added `PyIterator::send` method to allow sending values into a python generator.

newsfragments/4746.fixed.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed the return value of pyo3-ffi's PyIter_Send() function to return PySendResult.

pyo3-ffi/src/abstract_.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,11 @@ extern "C" {
149149
pub fn PyIter_Next(arg1: *mut PyObject) -> *mut PyObject;
150150
#[cfg(all(not(PyPy), Py_3_10))]
151151
#[cfg_attr(PyPy, link_name = "PyPyIter_Send")]
152-
pub fn PyIter_Send(iter: *mut PyObject, arg: *mut PyObject, presult: *mut *mut PyObject);
152+
pub fn PyIter_Send(
153+
iter: *mut PyObject,
154+
arg: *mut PyObject,
155+
presult: *mut *mut PyObject,
156+
) -> PySendResult;
153157

154158
#[cfg_attr(PyPy, link_name = "PyPyNumber_Check")]
155159
pub fn PyNumber_Check(o: *mut PyObject) -> c_int;

src/types/iterator.rs

+69
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,35 @@ impl PyIterator {
5252
}
5353
}
5454

55+
#[derive(Debug)]
56+
#[cfg(all(not(PyPy), Py_3_10))]
57+
pub enum PySendResult<'py> {
58+
Next(Bound<'py, PyAny>),
59+
Return(Bound<'py, PyAny>),
60+
}
61+
62+
#[cfg(all(not(PyPy), Py_3_10))]
63+
impl<'py> Bound<'py, PyIterator> {
64+
/// Sends a value into a python generator. This is the equivalent of calling `generator.send(value)` in Python.
65+
/// This resumes the generator and continues its execution until the next `yield` or `return` statement.
66+
/// If the generator exits without returning a value, this function returns a `StopException`.
67+
/// The first call to `send` must be made with `None` as the argument to start the generator, failing to do so will raise a `TypeError`.
68+
#[inline]
69+
pub fn send(&self, value: &Bound<'py, PyAny>) -> PyResult<PySendResult<'py>> {
70+
let py = self.py();
71+
let mut result = std::ptr::null_mut();
72+
match unsafe { ffi::PyIter_Send(self.as_ptr(), value.as_ptr(), &mut result) } {
73+
ffi::PySendResult::PYGEN_ERROR => Err(PyErr::fetch(py)),
74+
ffi::PySendResult::PYGEN_RETURN => Ok(PySendResult::Return(unsafe {
75+
result.assume_owned_unchecked(py)
76+
})),
77+
ffi::PySendResult::PYGEN_NEXT => Ok(PySendResult::Next(unsafe {
78+
result.assume_owned_unchecked(py)
79+
})),
80+
}
81+
}
82+
}
83+
5584
impl<'py> Iterator for Bound<'py, PyIterator> {
5685
type Item = PyResult<Bound<'py, PyAny>>;
5786

@@ -106,7 +135,11 @@ impl PyTypeCheck for PyIterator {
106135
#[cfg(test)]
107136
mod tests {
108137
use super::PyIterator;
138+
#[cfg(all(not(PyPy), Py_3_10))]
139+
use super::PySendResult;
109140
use crate::exceptions::PyTypeError;
141+
#[cfg(all(not(PyPy), Py_3_10))]
142+
use crate::types::PyNone;
110143
use crate::types::{PyAnyMethods, PyDict, PyList, PyListMethods};
111144
use crate::{ffi, IntoPyObject, Python};
112145

@@ -201,6 +234,42 @@ def fibonacci(target):
201234
});
202235
}
203236

237+
#[test]
238+
#[cfg(all(not(PyPy), Py_3_10))]
239+
fn send_generator() {
240+
let generator = ffi::c_str!(
241+
r#"
242+
def gen():
243+
value = None
244+
while(True):
245+
value = yield value
246+
if value is None:
247+
return
248+
"#
249+
);
250+
251+
Python::with_gil(|py| {
252+
let context = PyDict::new(py);
253+
py.run(generator, None, Some(&context)).unwrap();
254+
255+
let generator = py.eval(ffi::c_str!("gen()"), None, Some(&context)).unwrap();
256+
257+
let one = 1i32.into_pyobject(py).unwrap();
258+
assert!(matches!(
259+
generator.try_iter().unwrap().send(&PyNone::get(py)).unwrap(),
260+
PySendResult::Next(value) if value.is_none()
261+
));
262+
assert!(matches!(
263+
generator.try_iter().unwrap().send(&one).unwrap(),
264+
PySendResult::Next(value) if value.is(&one)
265+
));
266+
assert!(matches!(
267+
generator.try_iter().unwrap().send(&PyNone::get(py)).unwrap(),
268+
PySendResult::Return(value) if value.is_none()
269+
));
270+
});
271+
}
272+
204273
#[test]
205274
fn fibonacci_generator_bound() {
206275
use crate::types::any::PyAnyMethods;

0 commit comments

Comments
 (0)