From c33ce2db71de4e65082cde68f540a08463ec3961 Mon Sep 17 00:00:00 2001 From: Amanieu d'Antras Date: Fri, 22 Sep 2023 12:28:10 +0800 Subject: [PATCH] Add the ability to recover the original HashTable from an entry --- src/external_trait_impls/rayon/table.rs | 8 +- src/table.rs | 102 ++++++++++++++---------- 2 files changed, 63 insertions(+), 47 deletions(-) diff --git a/src/external_trait_impls/rayon/table.rs b/src/external_trait_impls/rayon/table.rs index 9ece26d41..5ed5849d1 100644 --- a/src/external_trait_impls/rayon/table.rs +++ b/src/external_trait_impls/rayon/table.rs @@ -165,7 +165,7 @@ impl HashTable { #[cfg_attr(feature = "inline-more", inline)] pub fn par_drain(&mut self) -> ParDrain<'_, T, A> { ParDrain { - inner: self.table.par_drain(), + inner: self.raw.par_drain(), } } } @@ -177,7 +177,7 @@ impl IntoParallelIterator for HashTable { #[cfg_attr(feature = "inline-more", inline)] fn into_par_iter(self) -> Self::Iter { IntoParIter { - inner: self.table.into_par_iter(), + inner: self.raw.into_par_iter(), } } } @@ -189,7 +189,7 @@ impl<'a, T: Sync, A: Allocator> IntoParallelIterator for &'a HashTable { #[cfg_attr(feature = "inline-more", inline)] fn into_par_iter(self) -> Self::Iter { ParIter { - inner: unsafe { self.table.par_iter() }, + inner: unsafe { self.raw.par_iter() }, marker: PhantomData, } } @@ -202,7 +202,7 @@ impl<'a, T: Send, A: Allocator> IntoParallelIterator for &'a mut HashTable #[cfg_attr(feature = "inline-more", inline)] fn into_par_iter(self) -> Self::Iter { ParIterMut { - inner: unsafe { self.table.par_iter() }, + inner: unsafe { self.raw.par_iter() }, marker: PhantomData, } } diff --git a/src/table.rs b/src/table.rs index 37ec55693..24d5bf1af 100644 --- a/src/table.rs +++ b/src/table.rs @@ -46,7 +46,7 @@ pub struct HashTable where A: Allocator, { - pub(crate) table: RawTable, + pub(crate) raw: RawTable, } impl HashTable { @@ -65,7 +65,7 @@ impl HashTable { /// ``` pub const fn new() -> Self { Self { - table: RawTable::new(), + raw: RawTable::new(), } } @@ -84,7 +84,7 @@ impl HashTable { /// ``` pub fn with_capacity(capacity: usize) -> Self { Self { - table: RawTable::with_capacity(capacity), + raw: RawTable::with_capacity(capacity), } } } @@ -133,7 +133,7 @@ where /// ``` pub const fn new_in(alloc: A) -> Self { Self { - table: RawTable::new_in(alloc), + raw: RawTable::new_in(alloc), } } @@ -182,13 +182,13 @@ where /// ``` pub fn with_capacity_in(capacity: usize, alloc: A) -> Self { Self { - table: RawTable::with_capacity_in(capacity, alloc), + raw: RawTable::with_capacity_in(capacity, alloc), } } /// Returns a reference to the underlying allocator. pub fn allocator(&self) -> &A { - self.table.allocator() + self.raw.allocator() } /// Returns a reference to an entry in the table with the given hash and @@ -222,7 +222,7 @@ where /// # } /// ``` pub fn find(&self, hash: u64, eq: impl FnMut(&T) -> bool) -> Option<&T> { - self.table + self.raw .find(hash, eq) .map(|bucket| unsafe { bucket.as_ref() }) } @@ -263,7 +263,7 @@ where /// # } /// ``` pub fn find_mut(&mut self, hash: u64, eq: impl FnMut(&T) -> bool) -> Option<&mut T> { - self.table + self.raw .find(hash, eq) .map(|bucket| unsafe { bucket.as_mut() }) } @@ -292,7 +292,7 @@ where /// let hasher = BuildHasherDefault::::default(); /// let hasher = |val: &_| hasher.hash_one(val); /// table.insert_unchecked(hasher(&1), (1, "a"), |val| hasher(&val.0)); - /// if let Some(entry) = table.find_entry(hasher(&1), |val| val.0 == 1) { + /// if let Ok(entry) = table.find_entry(hasher(&1), |val| val.0 == 1) { /// entry.remove(); /// } /// assert_eq!(table.find(hasher(&1), |val| val.0 == 1), None); @@ -306,12 +306,15 @@ where &mut self, hash: u64, eq: impl FnMut(&T) -> bool, - ) -> Option> { - self.table.find(hash, eq).map(|bucket| OccupiedEntry { - hash, - bucket, - table: &mut self.table, - }) + ) -> Result, &mut Self> { + match self.raw.find(hash, eq) { + Some(bucket) => Ok(OccupiedEntry { + hash, + bucket, + table: self, + }), + None => Err(self), + } } /// Returns an `Entry` for an entry in the table with the given hash @@ -365,16 +368,16 @@ where eq: impl FnMut(&T) -> bool, hasher: impl Fn(&T) -> u64, ) -> Entry<'_, T, A> { - match self.table.find_or_find_insert_slot(hash, eq, hasher) { + match self.raw.find_or_find_insert_slot(hash, eq, hasher) { Ok(bucket) => Entry::Occupied(OccupiedEntry { hash, bucket, - table: &mut self.table, + table: self, }), Err(insert_slot) => Entry::Vacant(VacantEntry { hash, insert_slot, - table: &mut self.table, + table: self, }), } } @@ -393,11 +396,11 @@ where value: T, hasher: impl Fn(&T) -> u64, ) -> OccupiedEntry<'_, T, A> { - let bucket = self.table.insert(hash, value, hasher); + let bucket = self.raw.insert(hash, value, hasher); OccupiedEntry { hash, bucket, - table: &mut self.table, + table: self, } } @@ -425,7 +428,7 @@ where /// # } /// ``` pub fn clear(&mut self) { - self.table.clear(); + self.raw.clear(); } /// Shrinks the capacity of the table as much as possible. It will drop @@ -459,7 +462,7 @@ where /// # } /// ``` pub fn shrink_to_fit(&mut self, hasher: impl Fn(&T) -> u64) { - self.table.shrink_to(self.len(), hasher) + self.raw.shrink_to(self.len(), hasher) } /// Shrinks the capacity of the table with a lower limit. It will drop @@ -498,7 +501,7 @@ where /// # } /// ``` pub fn shrink_to(&mut self, min_capacity: usize, hasher: impl Fn(&T) -> u64) { - self.table.shrink_to(min_capacity, hasher); + self.raw.shrink_to(min_capacity, hasher); } /// Reserves capacity for at least `additional` more elements to be inserted @@ -538,7 +541,7 @@ where /// # } /// ``` pub fn reserve(&mut self, additional: usize, hasher: impl Fn(&T) -> u64) { - self.table.reserve(additional, hasher) + self.raw.reserve(additional, hasher) } /// Tries to reserve capacity for at least `additional` more elements to be inserted @@ -579,7 +582,7 @@ where additional: usize, hasher: impl Fn(&T) -> u64, ) -> Result<(), TryReserveError> { - self.table.try_reserve(additional, hasher) + self.raw.try_reserve(additional, hasher) } /// Returns the number of elements the table can hold without reallocating. @@ -592,7 +595,7 @@ where /// assert!(table.capacity() >= 100); /// ``` pub fn capacity(&self) -> usize { - self.table.capacity() + self.raw.capacity() } /// Returns the number of elements in the table. @@ -619,7 +622,7 @@ where /// # } /// ``` pub fn len(&self) -> usize { - self.table.len() + self.raw.len() } /// Returns `true` if the set contains no elements. @@ -646,7 +649,7 @@ where /// # } /// ``` pub fn is_empty(&self) -> bool { - self.table.is_empty() + self.raw.is_empty() } /// An iterator visiting all elements in arbitrary order. @@ -679,7 +682,7 @@ where /// ``` pub fn iter(&self) -> Iter<'_, T> { Iter { - inner: unsafe { self.table.iter() }, + inner: unsafe { self.raw.iter() }, marker: PhantomData, } } @@ -731,7 +734,7 @@ where /// ``` pub fn iter_mut(&mut self) -> IterMut<'_, T> { IterMut { - inner: unsafe { self.table.iter() }, + inner: unsafe { self.raw.iter() }, marker: PhantomData, } } @@ -766,9 +769,9 @@ where pub fn retain(&mut self, mut f: impl FnMut(&mut T) -> bool) { // Here we only use `iter` as a temporary, preventing use-after-free unsafe { - for item in self.table.iter() { + for item in self.raw.iter() { if !f(item.as_mut()) { - self.table.erase(item); + self.raw.erase(item); } } } @@ -807,7 +810,7 @@ where /// ``` pub fn drain(&mut self) -> Drain<'_, T, A> { Drain { - inner: self.table.drain(), + inner: self.raw.drain(), } } @@ -858,8 +861,8 @@ where ExtractIf { f, inner: RawExtractIf { - iter: unsafe { self.table.iter() }, - table: &mut self.table, + iter: unsafe { self.raw.iter() }, + table: &mut self.raw, }, } } @@ -922,7 +925,7 @@ where hashes: [u64; N], eq: impl FnMut(usize, &T) -> bool, ) -> Option<[&'_ mut T; N]> { - self.table.get_many_mut(hashes, eq) + self.raw.get_many_mut(hashes, eq) } /// Attempts to get mutable references to `N` values in the map at once, without validating that @@ -992,7 +995,7 @@ where hashes: [u64; N], eq: impl FnMut(usize, &T) -> bool, ) -> Option<[&'_ mut T; N]> { - self.table.get_many_unchecked_mut(hashes, eq) + self.raw.get_many_unchecked_mut(hashes, eq) } } @@ -1005,7 +1008,7 @@ where fn into_iter(self) -> IntoIter { IntoIter { - inner: self.table.into_iter(), + inner: self.raw.into_iter(), } } } @@ -1040,7 +1043,7 @@ where { fn default() -> Self { Self { - table: Default::default(), + raw: Default::default(), } } } @@ -1052,7 +1055,7 @@ where { fn clone(&self) -> Self { Self { - table: self.table.clone(), + raw: self.raw.clone(), } } } @@ -1429,7 +1432,7 @@ where { hash: u64, bucket: Bucket, - table: &'a mut RawTable, + table: &'a mut HashTable, } unsafe impl Send for OccupiedEntry<'_, T, A> @@ -1496,7 +1499,7 @@ where /// # } /// ``` pub fn remove(self) -> (T, VacantEntry<'a, T, A>) { - let (val, slot) = unsafe { self.table.remove(self.bucket) }; + let (val, slot) = unsafe { self.table.raw.remove(self.bucket) }; ( val, VacantEntry { @@ -1642,6 +1645,12 @@ where pub fn into_mut(self) -> &'a mut T { unsafe { self.bucket.as_mut() } } + + /// Converts the OccupiedEntry into a mutable reference to the underlying + /// table. + pub fn into_table(self) -> &'a mut HashTable { + self.table + } } /// A view into a vacant entry in a `HashTable`. @@ -1689,7 +1698,7 @@ where { hash: u64, insert_slot: InsertSlot, - table: &'a mut RawTable, + table: &'a mut HashTable, } impl fmt::Debug for VacantEntry<'_, T, A> { @@ -1737,6 +1746,7 @@ where pub fn insert(self, value: T) -> OccupiedEntry<'a, T, A> { let bucket = unsafe { self.table + .raw .insert_in_slot(self.hash, self.insert_slot, value) }; OccupiedEntry { @@ -1745,6 +1755,12 @@ where table: self.table, } } + + /// Converts the OccupiedEntry into a mutable reference to the underlying + /// table. + pub fn into_table(self) -> &'a mut HashTable { + self.table + } } /// An iterator over the entries of a `HashTable` in arbitrary order.