diff --git a/newsfragments/4810.added.md b/newsfragments/4810.added.md new file mode 100644 index 00000000000..e89e22e544d --- /dev/null +++ b/newsfragments/4810.added.md @@ -0,0 +1 @@ +Optimizes `nth` and `nth_back` for `BoundListIterator` \ No newline at end of file diff --git a/pyo3-benches/benches/bench_list.rs b/pyo3-benches/benches/bench_list.rs index cc790db37bf..7a19452455e 100644 --- a/pyo3-benches/benches/bench_list.rs +++ b/pyo3-benches/benches/bench_list.rs @@ -39,7 +39,33 @@ fn list_get_item(b: &mut Bencher<'_>) { }); } -#[cfg(not(any(Py_LIMITED_API, Py_GIL_DISABLED)))] +fn list_nth(b: &mut Bencher<'_>) { + Python::with_gil(|py| { + const LEN: usize = 50; + let list = PyList::new_bound(py, 0..LEN); + let mut sum = 0; + b.iter(|| { + for i in 0..LEN { + sum += list.iter().nth(i).unwrap().extract::().unwrap(); + } + }); + }); +} + +fn list_nth_back(b: &mut Bencher<'_>) { + Python::with_gil(|py| { + const LEN: usize = 50; + let list = PyList::new_bound(py, 0..LEN); + let mut sum = 0; + b.iter(|| { + for i in 0..LEN { + sum += list.iter().nth_back(i).unwrap().extract::().unwrap(); + } + }); + }); +} + +#[cfg(not(Py_LIMITED_API))] fn list_get_item_unchecked(b: &mut Bencher<'_>) { Python::with_gil(|py| { const LEN: usize = 50_000; @@ -66,6 +92,8 @@ fn sequence_from_list(b: &mut Bencher<'_>) { fn criterion_benchmark(c: &mut Criterion) { c.bench_function("iter_list", iter_list); c.bench_function("list_new", list_new); + c.bench_function("list_nth", list_nth); + c.bench_function("list_nth_back", list_nth_back); c.bench_function("list_get_item", list_get_item); #[cfg(not(any(Py_LIMITED_API, Py_GIL_DISABLED)))] c.bench_function("list_get_item_unchecked", list_get_item_unchecked); diff --git a/src/types/list.rs b/src/types/list.rs index af2b557cba9..3db72d07fed 100644 --- a/src/types/list.rs +++ b/src/types/list.rs @@ -494,7 +494,6 @@ impl<'py> Iterator for BoundListIterator<'py> { #[inline] fn next(&mut self) -> Option { let length = self.length.min(self.list.len()); - if self.index < length { let item = unsafe { self.get_item(self.index) }; self.index += 1; @@ -509,6 +508,20 @@ impl<'py> Iterator for BoundListIterator<'py> { let len = self.len(); (len, Some(len)) } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + let length = self.length.min(self.list.len()); + let target_index = self.index + n; + if self.index + n < length { + let item = unsafe { self.get_item(target_index) }; + self.index = target_index + 1; + Some(item) + } else { + self.index = self.list.len(); + None + } + } } impl DoubleEndedIterator for BoundListIterator<'_> { @@ -524,6 +537,20 @@ impl DoubleEndedIterator for BoundListIterator<'_> { None } } + + #[inline] + fn nth_back(&mut self, n: usize) -> Option { + let length = self.length.min(self.list.len()); + if self.index + n < length { + let target_index = length - n - 1; + let item = unsafe { self.get_item(target_index) }; + self.length = target_index; + Some(item) + } else { + self.length = length; + None + } + } } impl ExactSizeIterator for BoundListIterator<'_> { @@ -720,6 +747,106 @@ mod tests { }); } + #[test] + fn test_iter_nth() { + Python::with_gil(|py| { + let v = vec![6, 7, 8, 9, 10]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + iter.next(); + assert_eq!(iter.nth(1).unwrap().extract::().unwrap(), 8); + assert_eq!(iter.nth(1).unwrap().extract::().unwrap(), 10); + assert!(iter.nth(1).is_none()); + + let v: Vec = vec![]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + iter.next(); + assert!(iter.nth(1).is_none()); + + let v = vec![1, 2, 3]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + assert!(iter.nth(10).is_none()); + + let v = vec![6, 7, 8, 9, 10]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + let mut iter = list.iter(); + assert_eq!(iter.next().unwrap().extract::().unwrap(), 6); + assert_eq!(iter.nth(2).unwrap().extract::().unwrap(), 9); + assert_eq!(iter.next().unwrap().extract::().unwrap(), 10); + + let mut iter = list.iter(); + iter.nth_back(1); + assert_eq!(iter.nth(2).unwrap().extract::().unwrap(), 8); + assert!(iter.next().is_none()); + }); + } + + #[test] + fn test_iter_nth_back() { + Python::with_gil(|py| { + let v = vec![1, 2, 3, 4, 5]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + assert_eq!(iter.nth_back(0).unwrap().extract::().unwrap(), 5); + assert_eq!(iter.nth_back(1).unwrap().extract::().unwrap(), 3); + assert!(iter.nth_back(2).is_none()); + + let v: Vec = vec![]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + assert!(iter.nth_back(0).is_none()); + assert!(iter.nth_back(1).is_none()); + + let v = vec![1, 2, 3]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + assert!(iter.nth_back(5).is_none()); + + let v = vec![1, 2, 3, 4, 5]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + iter.next_back(); // Consume the last element + assert_eq!(iter.nth_back(1).unwrap().extract::().unwrap(), 3); + assert_eq!(iter.next_back().unwrap().extract::().unwrap(), 2); + assert_eq!(iter.nth_back(0).unwrap().extract::().unwrap(), 1); + + let v = vec![1, 2, 3, 4, 5]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + assert_eq!(iter.nth_back(1).unwrap().extract::().unwrap(), 4); + assert_eq!(iter.nth_back(2).unwrap().extract::().unwrap(), 1); + + let mut iter2 = list.iter(); + iter2.next_back(); + assert_eq!(iter2.nth_back(1).unwrap().extract::().unwrap(), 3); + assert_eq!(iter2.next_back().unwrap().extract::().unwrap(), 2); + + let mut iter3 = list.iter(); + iter3.nth(1); + assert_eq!(iter3.nth_back(2).unwrap().extract::().unwrap(), 3); + assert!(iter3.nth_back(0).is_none()); + }); + } + #[test] fn test_iter_rev() { Python::with_gil(|py| {