Skip to main content

pyo3/types/
iterator.rs

1use crate::ffi_ptr_ext::FfiPtrExt;
2use crate::py_result_ext::PyResultExt;
3use crate::sync::PyOnceLock;
4#[cfg(Py_LIMITED_API)]
5use crate::types::PyAnyMethods;
6use crate::types::{PyType, PyTypeMethods};
7use crate::{ffi, Bound, Py, PyAny, PyErr, PyResult};
8
9/// A Python iterator object.
10///
11/// Values of this type are accessed via PyO3's smart pointers, e.g. as
12/// [`Py<PyIterator>`][crate::Py] or [`Bound<'py, PyIterator>`][Bound].
13///
14/// # Examples
15///
16/// ```rust
17/// use pyo3::prelude::*;
18/// use pyo3::ffi::c_str;
19///
20/// # fn main() -> PyResult<()> {
21/// Python::attach(|py| -> PyResult<()> {
22///     let list = py.eval(c"iter([1, 2, 3, 4])", None, None)?;
23///     let numbers: PyResult<Vec<usize>> = list
24///         .try_iter()?
25///         .map(|i| i.and_then(|i|i.extract::<usize>()))
26///         .collect();
27///     let sum: usize = numbers?.iter().sum();
28///     assert_eq!(sum, 10);
29///     Ok(())
30/// })
31/// # }
32/// ```
33#[repr(transparent)]
34pub struct PyIterator(PyAny);
35
36pyobject_native_type_core!(
37    PyIterator,
38    |py| {
39        static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
40        TYPE.import(py, "collections.abc", "Iterator")
41            .unwrap()
42            .as_type_ptr()
43    },
44    "collections.abc",
45    "Iterator",
46    #module=Some("collections.abc"),
47    #checkfunction=ffi::PyIter_Check
48);
49
50impl PyIterator {
51    /// Builds an iterator for an iterable Python object; the equivalent of calling `iter(obj)` in Python.
52    ///
53    /// Usually it is more convenient to write [`obj.try_iter()`][crate::types::any::PyAnyMethods::try_iter],
54    /// which is a more concise way of calling this function.
55    pub fn from_object<'py>(obj: &Bound<'py, PyAny>) -> PyResult<Bound<'py, PyIterator>> {
56        unsafe {
57            ffi::PyObject_GetIter(obj.as_ptr())
58                .assume_owned_or_err(obj.py())
59                .cast_into_unchecked()
60        }
61    }
62}
63
64/// Outcomes from sending a value into a python generator
65#[derive(Debug)]
66#[cfg(all(not(PyPy), Py_3_10))]
67pub enum PySendResult<'py> {
68    /// The generator yielded a new value
69    Next(Bound<'py, PyAny>),
70    /// The generator completed, returning a (possibly None) final value
71    Return(Bound<'py, PyAny>),
72}
73
74#[cfg(all(not(PyPy), Py_3_10))]
75impl<'py> Bound<'py, PyIterator> {
76    /// Sends a value into a python generator. This is the equivalent of calling
77    /// `generator.send(value)` in Python. This resumes the generator and continues its execution
78    /// until the next `yield` or `return` statement. When the generator completes, the (optional)
79    /// return value will be returned as `PySendResult::Return`. All subsequent calls will return
80    /// `PySendResult::Return(None)`. The first call to `send` must be made with `None` as the
81    /// argument to start the generator, failing to do so will raise a `TypeError`.
82    #[inline]
83    pub fn send(&self, value: &Bound<'py, PyAny>) -> PyResult<PySendResult<'py>> {
84        let py = self.py();
85        let mut result = std::ptr::null_mut();
86        match unsafe { ffi::PyIter_Send(self.as_ptr(), value.as_ptr(), &mut result) } {
87            ffi::PySendResult::PYGEN_ERROR => Err(PyErr::fetch(py)),
88            ffi::PySendResult::PYGEN_RETURN => Ok(PySendResult::Return(unsafe {
89                result.assume_owned_unchecked(py)
90            })),
91            ffi::PySendResult::PYGEN_NEXT => Ok(PySendResult::Next(unsafe {
92                result.assume_owned_unchecked(py)
93            })),
94        }
95    }
96}
97
98impl<'py> Iterator for Bound<'py, PyIterator> {
99    type Item = PyResult<Bound<'py, PyAny>>;
100
101    /// Retrieves the next item from an iterator.
102    ///
103    /// Returns `None` when the iterator is exhausted.
104    /// If an exception occurs, returns `Some(Err(..))`.
105    /// Further `next()` calls after an exception occurs are likely
106    /// to repeatedly result in the same exception.
107    fn next(&mut self) -> Option<Self::Item> {
108        let py = self.py();
109        let mut item = std::ptr::null_mut();
110
111        // SAFETY: `self` is a valid iterator object, `item` is a valid pointer to receive the next item
112        match unsafe { ffi::compat::PyIter_NextItem(self.as_ptr(), &mut item) } {
113            std::ffi::c_int::MIN..=-1 => Some(Err(PyErr::fetch(py))),
114            0 => None,
115            // SAFETY: `item` is guaranteed to be a non-null strong reference
116            1..=std::ffi::c_int::MAX => Some(Ok(unsafe { item.assume_owned_unchecked(py) })),
117        }
118    }
119
120    fn size_hint(&self) -> (usize, Option<usize>) {
121        match length_hint(self) {
122            Ok(hint) => (hint, None),
123            Err(e) => {
124                e.write_unraisable(self.py(), Some(self));
125                (0, None)
126            }
127        }
128    }
129}
130
131#[cfg(not(Py_LIMITED_API))]
132fn length_hint(iter: &Bound<'_, PyIterator>) -> PyResult<usize> {
133    // SAFETY: `iter` is a valid iterator object
134    let hint = unsafe { ffi::PyObject_LengthHint(iter.as_ptr(), 0) };
135    if hint < 0 {
136        Err(PyErr::fetch(iter.py()))
137    } else {
138        Ok(hint as usize)
139    }
140}
141
142/// On the limited API, we cannot use `PyObject_LengthHint`, so we fall back to calling
143/// `operator.length_hint()`, which is documented equivalent to calling `PyObject_LengthHint`.
144#[cfg(Py_LIMITED_API)]
145fn length_hint(iter: &Bound<'_, PyIterator>) -> PyResult<usize> {
146    static LENGTH_HINT: PyOnceLock<Py<PyAny>> = PyOnceLock::new();
147    let length_hint = LENGTH_HINT.import(iter.py(), "operator", "length_hint")?;
148    length_hint.call1((iter, 0))?.extract()
149}
150
151impl<'py> IntoIterator for &Bound<'py, PyIterator> {
152    type Item = PyResult<Bound<'py, PyAny>>;
153    type IntoIter = Bound<'py, PyIterator>;
154
155    fn into_iter(self) -> Self::IntoIter {
156        self.clone()
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::PyIterator;
163    #[cfg(all(not(PyPy), Py_3_10))]
164    use super::PySendResult;
165    use crate::exceptions::PyTypeError;
166    #[cfg(all(not(PyPy), Py_3_10))]
167    use crate::types::PyNone;
168    use crate::types::{PyAnyMethods, PyDict, PyList, PyListMethods};
169    #[cfg(feature = "macros")]
170    use crate::PyErr;
171    use crate::{IntoPyObject, PyTypeInfo, Python};
172
173    #[test]
174    fn vec_iter() {
175        Python::attach(|py| {
176            let inst = vec![10, 20].into_pyobject(py).unwrap();
177            let mut it = inst.try_iter().unwrap();
178            assert_eq!(
179                10_i32,
180                it.next().unwrap().unwrap().extract::<'_, i32>().unwrap()
181            );
182            assert_eq!(
183                20_i32,
184                it.next().unwrap().unwrap().extract::<'_, i32>().unwrap()
185            );
186            assert!(it.next().is_none());
187        });
188    }
189
190    #[test]
191    fn iter_refcnt() {
192        let (obj, count) = Python::attach(|py| {
193            let obj = vec![10, 20].into_pyobject(py).unwrap();
194            let count = obj._get_refcnt();
195            (obj.unbind(), count)
196        });
197
198        Python::attach(|py| {
199            let inst = obj.bind(py);
200            let mut it = inst.try_iter().unwrap();
201
202            assert_eq!(
203                10_i32,
204                it.next().unwrap().unwrap().extract::<'_, i32>().unwrap()
205            );
206        });
207
208        Python::attach(move |py| {
209            assert_eq!(count, obj._get_refcnt(py));
210        });
211    }
212
213    #[test]
214    fn iter_item_refcnt() {
215        Python::attach(|py| {
216            let count;
217            let obj = py.eval(c"object()", None, None).unwrap();
218            let list = {
219                let list = PyList::empty(py);
220                list.append(10).unwrap();
221                list.append(&obj).unwrap();
222                count = obj._get_refcnt();
223                list
224            };
225
226            {
227                let mut it = list.iter();
228
229                assert_eq!(10_i32, it.next().unwrap().extract::<'_, i32>().unwrap());
230                assert!(it.next().unwrap().is(&obj));
231                assert!(it.next().is_none());
232            }
233            assert_eq!(count, obj._get_refcnt());
234        });
235    }
236
237    #[test]
238    fn fibonacci_generator() {
239        let fibonacci_generator = cr#"
240def fibonacci(target):
241    a = 1
242    b = 1
243    for _ in range(target):
244        yield a
245        a, b = b, a + b
246"#;
247
248        Python::attach(|py| {
249            let context = PyDict::new(py);
250            py.run(fibonacci_generator, None, Some(&context)).unwrap();
251
252            let generator = py.eval(c"fibonacci(5)", None, Some(&context)).unwrap();
253            for (actual, expected) in generator.try_iter().unwrap().zip(&[1, 1, 2, 3, 5]) {
254                let actual = actual.unwrap().extract::<usize>().unwrap();
255                assert_eq!(actual, *expected)
256            }
257        });
258    }
259
260    #[test]
261    #[cfg(all(not(PyPy), Py_3_10))]
262    fn send_generator() {
263        let generator = cr#"
264def gen():
265    value = None
266    while(True):
267        value = yield value
268        if value is None:
269            return
270"#;
271
272        Python::attach(|py| {
273            let context = PyDict::new(py);
274            py.run(generator, None, Some(&context)).unwrap();
275
276            let generator = py.eval(c"gen()", None, Some(&context)).unwrap();
277
278            let one = 1i32.into_pyobject(py).unwrap();
279            assert!(matches!(
280                generator.try_iter().unwrap().send(&PyNone::get(py)).unwrap(),
281                PySendResult::Next(value) if value.is_none()
282            ));
283            assert!(matches!(
284                generator.try_iter().unwrap().send(&one).unwrap(),
285                PySendResult::Next(value) if value.is(&one)
286            ));
287            assert!(matches!(
288                generator.try_iter().unwrap().send(&PyNone::get(py)).unwrap(),
289                PySendResult::Return(value) if value.is_none()
290            ));
291        });
292    }
293
294    #[test]
295    fn fibonacci_generator_bound() {
296        use crate::types::any::PyAnyMethods;
297        use crate::Bound;
298
299        let fibonacci_generator = cr#"
300def fibonacci(target):
301    a = 1
302    b = 1
303    for _ in range(target):
304        yield a
305        a, b = b, a + b
306"#;
307
308        Python::attach(|py| {
309            let context = PyDict::new(py);
310            py.run(fibonacci_generator, None, Some(&context)).unwrap();
311
312            let generator: Bound<'_, PyIterator> = py
313                .eval(c"fibonacci(5)", None, Some(&context))
314                .unwrap()
315                .cast_into()
316                .unwrap();
317            let mut items = vec![];
318            for actual in &generator {
319                let actual = actual.unwrap().extract::<usize>().unwrap();
320                items.push(actual);
321            }
322            assert_eq!(items, [1, 1, 2, 3, 5]);
323        });
324    }
325
326    #[test]
327    fn int_not_iterable() {
328        Python::attach(|py| {
329            let x = 5i32.into_pyobject(py).unwrap();
330            let err = PyIterator::from_object(&x).unwrap_err();
331
332            assert!(err.is_instance_of::<PyTypeError>(py));
333        });
334    }
335
336    #[test]
337    #[cfg(feature = "macros")]
338    fn python_class_not_iterator() {
339        use crate::PyErr;
340
341        #[crate::pyclass(crate = "crate")]
342        struct Downcaster {
343            failed: Option<PyErr>,
344        }
345
346        #[crate::pymethods(crate = "crate")]
347        impl Downcaster {
348            fn downcast_iterator(&mut self, obj: &crate::Bound<'_, crate::PyAny>) {
349                self.failed = Some(obj.cast::<PyIterator>().unwrap_err().into());
350            }
351        }
352
353        // Regression test for 2913
354        Python::attach(|py| {
355            let downcaster = crate::Py::new(py, Downcaster { failed: None }).unwrap();
356            crate::py_run!(
357                py,
358                downcaster,
359                r#"
360                    from collections.abc import Sequence
361
362                    class MySequence(Sequence):
363                        def __init__(self):
364                            self._data = [1, 2, 3]
365
366                        def __getitem__(self, index):
367                            return self._data[index]
368
369                        def __len__(self):
370                            return len(self._data)
371
372                    downcaster.downcast_iterator(MySequence())
373                "#
374            );
375
376            assert_eq!(
377                downcaster.borrow_mut(py).failed.take().unwrap().to_string(),
378                "TypeError: 'MySequence' object is not an instance of 'Iterator'"
379            );
380        });
381    }
382
383    #[test]
384    #[cfg(feature = "macros")]
385    fn python_class_iterator() {
386        #[crate::pyfunction(crate = "crate")]
387        fn assert_iterator(obj: &crate::Bound<'_, crate::PyAny>) {
388            assert!(obj.cast::<PyIterator>().is_ok())
389        }
390
391        // Regression test for 2913
392        Python::attach(|py| {
393            let assert_iterator = crate::wrap_pyfunction!(assert_iterator, py).unwrap();
394            crate::py_run!(
395                py,
396                assert_iterator,
397                r#"
398                    class MyIter:
399                        def __next__(self):
400                            raise StopIteration
401
402                    assert_iterator(MyIter())
403                "#
404            );
405        });
406    }
407
408    #[test]
409    fn length_hint_becomes_size_hint_lower_bound() {
410        Python::attach(|py| {
411            let list = py.eval(c"[1, 2, 3]", None, None).unwrap();
412            let iter = list.try_iter().unwrap();
413            let hint = iter.size_hint();
414            assert_eq!(hint, (3, None));
415        });
416    }
417
418    #[test]
419    #[cfg(feature = "macros")]
420    fn length_hint_error() {
421        #[crate::pyfunction(crate = "crate")]
422        fn test_size_hint(obj: &crate::Bound<'_, crate::PyAny>, should_error: bool) {
423            let iter = obj.cast::<PyIterator>().unwrap();
424            crate::test_utils::UnraisableCapture::enter(obj.py(), |capture| {
425                assert_eq!((0, None), iter.size_hint());
426                assert_eq!(should_error, capture.take_capture().is_some());
427            });
428            assert!(PyErr::take(obj.py()).is_none());
429        }
430
431        Python::attach(|py| {
432            let test_size_hint = crate::wrap_pyfunction!(test_size_hint, py).unwrap();
433            crate::py_run!(
434                py,
435                test_size_hint,
436                r#"
437                    class NoHintIter:
438                        def __next__(self):
439                            raise StopIteration
440
441                        def __length_hint__(self):
442                            return NotImplemented
443
444                    class ErrorHintIter:
445                        def __next__(self):
446                            raise StopIteration
447
448                        def __length_hint__(self):
449                            raise ValueError("bad hint impl")
450
451                    test_size_hint(NoHintIter(), False)
452                    test_size_hint(ErrorHintIter(), True)
453                "#
454            );
455        });
456    }
457
458    #[test]
459    fn test_type_object() {
460        Python::attach(|py| {
461            let abc = PyIterator::type_object(py);
462            let iter = py.eval(c"iter(())", None, None).unwrap();
463            assert!(iter.is_instance(&abc).unwrap());
464        })
465    }
466}