From cc9cabd30fe76d4340941fa3bbc23ec39eda70d4 Mon Sep 17 00:00:00 2001 From: Owen Leung Date: Sat, 21 Dec 2024 14:43:58 +0800 Subject: [PATCH 1/3] Optimize nth and nth_back for BoundListIterator. Add unit test and benchmarks --- pyo3-benches/benches/bench_list.rs | 30 ++++++- src/types/list.rs | 137 ++++++++++++++++++++++++++++- 2 files changed, 165 insertions(+), 2 deletions(-) diff --git a/pyo3-benches/benches/bench_list.rs b/pyo3-benches/benches/bench_list.rs index cc790db37bf..7a19452455e 100644 --- a/pyo3-benches/benches/bench_list.rs +++ b/pyo3-benches/benches/bench_list.rs @@ -39,7 +39,33 @@ fn list_get_item(b: &mut Bencher<'_>) { }); } -#[cfg(not(any(Py_LIMITED_API, Py_GIL_DISABLED)))] +fn list_nth(b: &mut Bencher<'_>) { + Python::with_gil(|py| { + const LEN: usize = 50; + let list = PyList::new_bound(py, 0..LEN); + let mut sum = 0; + b.iter(|| { + for i in 0..LEN { + sum += list.iter().nth(i).unwrap().extract::().unwrap(); + } + }); + }); +} + +fn list_nth_back(b: &mut Bencher<'_>) { + Python::with_gil(|py| { + const LEN: usize = 50; + let list = PyList::new_bound(py, 0..LEN); + let mut sum = 0; + b.iter(|| { + for i in 0..LEN { + sum += list.iter().nth_back(i).unwrap().extract::().unwrap(); + } + }); + }); +} + +#[cfg(not(Py_LIMITED_API))] fn list_get_item_unchecked(b: &mut Bencher<'_>) { Python::with_gil(|py| { const LEN: usize = 50_000; @@ -66,6 +92,8 @@ fn sequence_from_list(b: &mut Bencher<'_>) { fn criterion_benchmark(c: &mut Criterion) { c.bench_function("iter_list", iter_list); c.bench_function("list_new", list_new); + c.bench_function("list_nth", list_nth); + c.bench_function("list_nth_back", list_nth_back); c.bench_function("list_get_item", list_get_item); #[cfg(not(any(Py_LIMITED_API, Py_GIL_DISABLED)))] c.bench_function("list_get_item_unchecked", list_get_item_unchecked); diff --git a/src/types/list.rs b/src/types/list.rs index af2b557cba9..07089872bae 100644 --- a/src/types/list.rs +++ b/src/types/list.rs @@ -494,7 +494,6 @@ impl<'py> Iterator for BoundListIterator<'py> { #[inline] fn next(&mut self) -> Option { let length = self.length.min(self.list.len()); - if self.index < length { let item = unsafe { self.get_item(self.index) }; self.index += 1; @@ -509,6 +508,20 @@ impl<'py> Iterator for BoundListIterator<'py> { let len = self.len(); (len, Some(len)) } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + let length = self.length.min(self.list.len()); + let target_index = self.index + n; + if self.index + n < length { + let item = unsafe { self.get_item(target_index) }; + self.index = target_index + 1; + Some(item) + } else { + self.index = self.list.len(); + None + } + } } impl DoubleEndedIterator for BoundListIterator<'_> { @@ -524,6 +537,20 @@ impl DoubleEndedIterator for BoundListIterator<'_> { None } } + + #[inline] + fn nth_back(&mut self, n: usize) -> Option { + let length = self.length.min(self.list.len()); + if self.index + n < length { + let target_index = length - n - 1; + let item = unsafe { self.get_item(target_index) }; + self.length = target_index; + Some(item) + } else { + self.length = length; + None + } + } } impl ExactSizeIterator for BoundListIterator<'_> { @@ -720,6 +747,114 @@ mod tests { }); } + #[test] + fn test_iter_nth() { + Python::with_gil(|py| { + let v = vec![6, 7, 8, 9, 10]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + assert_eq!(iter.nth(0).unwrap().extract::().unwrap(), 6); + assert_eq!(iter.nth(1).unwrap().extract::().unwrap(), 8); + assert_eq!(iter.nth(1).unwrap().extract::().unwrap(), 10); + assert!(iter.nth(1).is_none()); + + let v: Vec = vec![]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + assert!(iter.nth(0).is_none()); + assert!(iter.nth(1).is_none()); + + let v = vec![1, 2, 3]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + assert!(iter.nth(10).is_none()); + + let v = vec![10]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + assert_eq!(iter.nth(0).unwrap().extract::().unwrap(), 10); + assert!(iter.nth(0).is_none()); + + let v = vec![6, 7, 8, 9, 10]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + let mut iter = list.iter(); + assert_eq!(iter.next().unwrap().extract::().unwrap(), 6); + assert_eq!(iter.nth(2).unwrap().extract::().unwrap(), 9); + assert_eq!(iter.next().unwrap().extract::().unwrap(), 10); + + let mut iter = list.iter(); + iter.nth_back(1); + assert_eq!(iter.nth(2).unwrap().extract::().unwrap(), 8); + assert!(iter.nth(0).is_none()); + }); + } + + #[test] + fn test_iter_nth_back() { + Python::with_gil(|py| { + let v = vec![1, 2, 3, 4, 5]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + assert_eq!(iter.nth_back(0).unwrap().extract::().unwrap(), 5); + assert_eq!(iter.nth_back(1).unwrap().extract::().unwrap(), 3); + assert!(iter.nth_back(2).is_none()); + + let v: Vec = vec![]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + assert!(iter.nth_back(0).is_none()); + assert!(iter.nth_back(1).is_none()); + + let v = vec![1, 2, 3]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + assert!(iter.nth_back(5).is_none()); + + let v = vec![1, 2, 3, 4, 5]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + iter.next_back(); // Consume the last element + assert_eq!(iter.nth_back(1).unwrap().extract::().unwrap(), 3); + assert_eq!(iter.next_back().unwrap().extract::().unwrap(), 2); + assert_eq!(iter.nth_back(0).unwrap().extract::().unwrap(), 1); + + let v = vec![1,2,3,4,5]; + let ob = (&v).into_pyobject(py).unwrap(); + let list = ob.downcast::().unwrap(); + + let mut iter = list.iter(); + assert_eq!(iter.nth_back(1).unwrap().extract::().unwrap(), 4); + assert_eq!(iter.nth_back(2).unwrap().extract::().unwrap(), 1); + + let mut iter2 = list.iter(); + iter2.next_back(); + assert_eq!(iter2.nth_back(1).unwrap().extract::().unwrap(), 3); + assert_eq!(iter2.next_back().unwrap().extract::().unwrap(), 2); + + let mut iter3 = list.iter(); + iter3.nth(1); + assert_eq!(iter3.nth_back(2).unwrap().extract::().unwrap(), 3); + assert!(iter3.nth_back(0).is_none()); + }); + } + #[test] fn test_iter_rev() { Python::with_gil(|py| { From 3a0c19665c368e190ada74376d89ad86ba087dc7 Mon Sep 17 00:00:00 2001 From: Owen Leung Date: Sat, 21 Dec 2024 16:09:23 +0800 Subject: [PATCH 2/3] Fix fmt and newsfragment CI --- newsfragments/4787.added.md | 1 + src/types/list.rs | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 newsfragments/4787.added.md diff --git a/newsfragments/4787.added.md b/newsfragments/4787.added.md new file mode 100644 index 00000000000..e89e22e544d --- /dev/null +++ b/newsfragments/4787.added.md @@ -0,0 +1 @@ +Optimizes `nth` and `nth_back` for `BoundListIterator` \ No newline at end of file diff --git a/src/types/list.rs b/src/types/list.rs index 07089872bae..9e74d7ad1d6 100644 --- a/src/types/list.rs +++ b/src/types/list.rs @@ -835,7 +835,7 @@ mod tests { assert_eq!(iter.next_back().unwrap().extract::().unwrap(), 2); assert_eq!(iter.nth_back(0).unwrap().extract::().unwrap(), 1); - let v = vec![1,2,3,4,5]; + let v = vec![1, 2, 3, 4, 5]; let ob = (&v).into_pyobject(py).unwrap(); let list = ob.downcast::().unwrap(); From 40d38f3c66fa0488f8d814c22a303354eb24a2c9 Mon Sep 17 00:00:00 2001 From: Owen Leung Date: Sat, 21 Dec 2024 16:55:29 +0800 Subject: [PATCH 3/3] Fix clippy and changelog CI --- newsfragments/{4787.added.md => 4810.added.md} | 0 src/types/list.rs | 14 +++----------- 2 files changed, 3 insertions(+), 11 deletions(-) rename newsfragments/{4787.added.md => 4810.added.md} (100%) diff --git a/newsfragments/4787.added.md b/newsfragments/4810.added.md similarity index 100% rename from newsfragments/4787.added.md rename to newsfragments/4810.added.md diff --git a/src/types/list.rs b/src/types/list.rs index 9e74d7ad1d6..3db72d07fed 100644 --- a/src/types/list.rs +++ b/src/types/list.rs @@ -755,7 +755,7 @@ mod tests { let list = ob.downcast::().unwrap(); let mut iter = list.iter(); - assert_eq!(iter.nth(0).unwrap().extract::().unwrap(), 6); + iter.next(); assert_eq!(iter.nth(1).unwrap().extract::().unwrap(), 8); assert_eq!(iter.nth(1).unwrap().extract::().unwrap(), 10); assert!(iter.nth(1).is_none()); @@ -765,7 +765,7 @@ mod tests { let list = ob.downcast::().unwrap(); let mut iter = list.iter(); - assert!(iter.nth(0).is_none()); + iter.next(); assert!(iter.nth(1).is_none()); let v = vec![1, 2, 3]; @@ -775,14 +775,6 @@ mod tests { let mut iter = list.iter(); assert!(iter.nth(10).is_none()); - let v = vec![10]; - let ob = (&v).into_pyobject(py).unwrap(); - let list = ob.downcast::().unwrap(); - - let mut iter = list.iter(); - assert_eq!(iter.nth(0).unwrap().extract::().unwrap(), 10); - assert!(iter.nth(0).is_none()); - let v = vec![6, 7, 8, 9, 10]; let ob = (&v).into_pyobject(py).unwrap(); let list = ob.downcast::().unwrap(); @@ -794,7 +786,7 @@ mod tests { let mut iter = list.iter(); iter.nth_back(1); assert_eq!(iter.nth(2).unwrap().extract::().unwrap(), 8); - assert!(iter.nth(0).is_none()); + assert!(iter.next().is_none()); }); }