From cf5ca25826807cc1b6e440a0ff70be5c88d70b1b Mon Sep 17 00:00:00 2001 From: John-John Tedro Date: Fri, 25 Aug 2023 04:05:55 +0200 Subject: [PATCH] Implement HashSet using hashbrown::raw --- crates/rune-macros/src/any.rs | 14 +- crates/rune/src/hashbrown.rs | 3 + crates/rune/src/hashbrown/table.rs | 271 ++++++++++++++ crates/rune/src/lib.rs | 3 + .../rune/src/modules/collections/hash_map.rs | 260 +++---------- .../rune/src/modules/collections/hash_set.rs | 354 +++++++++++------- crates/rune/src/modules/iter.rs | 9 +- crates/rune/src/runtime/hasher.rs | 3 +- crates/rune/src/runtime/iterator.rs | 10 - crates/rune/src/runtime/static_type.rs | 5 - 10 files changed, 551 insertions(+), 381 deletions(-) create mode 100644 crates/rune/src/hashbrown.rs create mode 100644 crates/rune/src/hashbrown/table.rs diff --git a/crates/rune-macros/src/any.rs b/crates/rune-macros/src/any.rs index 1479262fb..2ee7aa13a 100644 --- a/crates/rune-macros/src/any.rs +++ b/crates/rune-macros/src/any.rs @@ -89,7 +89,7 @@ impl Derive { let mut installers = Vec::new(); - expand_install_with(&cx, &self.input, &tokens, &attr, &mut installers)?; + expand_install_with(cx, &self.input, &tokens, &attr, &mut installers)?; let name = match &attr.name { Some(name) => name, @@ -717,7 +717,7 @@ where impl #from_value for #ty { fn from_value(value: Value) -> #vm_result { let value = #vm_try!(#path(value)); - let value = #vm_try!(value.take()); + let value = #vm_try!(#shared::take(value)); #vm_result::Ok(value) } } @@ -725,11 +725,9 @@ where impl #unsafe_to_ref for #ty { type Guard = #raw_ref; - unsafe fn unsafe_to_ref<'a>( - value: #value, - ) -> #vm_result<(&'a Self, Self::Guard)> { + unsafe fn unsafe_to_ref<'a>(value: #value) -> #vm_result<(&'a Self, Self::Guard)> { let value = #vm_try!(#path(value)); - let value = #vm_try!(value.into_ref()); + let value = #vm_try!(#shared::into_ref(value)); let (value, guard) = #ref_::into_raw(value); #vm_result::Ok((value.as_ref(), guard)) } @@ -738,9 +736,7 @@ where impl #unsafe_to_mut for #ty { type Guard = #raw_mut; - unsafe fn unsafe_to_mut<'a>( - value: #value, - ) -> #vm_result<(&'a mut Self, Self::Guard)> { + unsafe fn unsafe_to_mut<'a>(value: #value) -> #vm_result<(&'a mut Self, Self::Guard)> { let value = #vm_try!(#path(value)); let value = #vm_try!(#shared::into_mut(value)); let (mut value, guard) = #mut_::into_raw(value); diff --git a/crates/rune/src/hashbrown.rs b/crates/rune/src/hashbrown.rs new file mode 100644 index 000000000..64c374fa2 --- /dev/null +++ b/crates/rune/src/hashbrown.rs @@ -0,0 +1,3 @@ +mod table; +pub(crate) use self::table::{IterRef, Table}; +pub(crate) use ::hashbrown::raw::RawIter; diff --git a/crates/rune/src/hashbrown/table.rs b/crates/rune/src/hashbrown/table.rs new file mode 100644 index 000000000..41352a598 --- /dev/null +++ b/crates/rune/src/hashbrown/table.rs @@ -0,0 +1,271 @@ +use core::hash::BuildHasher; +use core::iter; +use core::marker::PhantomData; +use core::mem; +use core::ptr; + +use hashbrown::raw::{RawIter, RawTable}; +use std::collections::hash_map::{DefaultHasher, RandomState}; + +use crate::runtime::{Hasher, ProtocolCaller, RawRef, Ref, Value, VmError, VmResult}; + +#[derive(Clone)] +pub(crate) struct Table { + table: RawTable<(Value, V)>, + state: RandomState, +} + +impl Table { + #[inline(always)] + pub(crate) fn new() -> Self { + Self { + table: RawTable::new(), + state: RandomState::new(), + } + } + + #[inline(always)] + pub(crate) fn with_capacity(capacity: usize) -> Self { + Self { + table: RawTable::with_capacity(capacity), + state: RandomState::new(), + } + } + + #[inline(always)] + pub(crate) fn len(&self) -> usize { + self.table.len() + } + + #[inline(always)] + pub(crate) fn capacity(&self) -> usize { + self.table.capacity() + } + + #[inline(always)] + pub(crate) fn is_empty(&self) -> bool { + self.table.is_empty() + } + + #[inline(always)] + pub(crate) fn insert_with

( + &mut self, + key: Value, + value: V, + caller: &mut P, + ) -> VmResult> + where + P: ?Sized + ProtocolCaller, + { + let hash = vm_try!(hash(&self.state, &key, caller)); + + let result = + match self + .table + .find_or_find_insert_slot2(caller, hash, eq(&key), hasher(&self.state)) + { + Ok(result) => result, + Err(error) => return VmResult::Err(error), + }; + + let existing = match result { + Ok(bucket) => Some(mem::replace(unsafe { &mut bucket.as_mut().1 }, value)), + Err(slot) => { + unsafe { + self.table.insert_in_slot(hash, slot, (key, value)); + } + None + } + }; + + VmResult::Ok(existing) + } + + pub(crate) fn get

(&self, key: &Value, caller: &mut P) -> VmResult> + where + P: ?Sized + ProtocolCaller, + { + if self.table.is_empty() { + return VmResult::Ok(None); + } + + let hash = vm_try!(hash(&self.state, key, caller)); + VmResult::Ok(vm_try!(self.table.get2(caller, hash, eq(key)))) + } + + #[inline(always)] + pub(crate) fn remove_with

(&mut self, key: &Value, caller: &mut P) -> VmResult> + where + P: ?Sized + ProtocolCaller, + { + let hash = vm_try!(hash(&self.state, key, caller)); + + match self.table.remove_entry2(caller, hash, eq(key)) { + Ok(value) => VmResult::Ok(value.map(|(_, value)| value)), + Err(error) => VmResult::Err(error), + } + } + + #[inline(always)] + pub(crate) fn clear(&mut self) { + self.table.clear() + } + + pub(crate) fn iter(&self) -> Iter<'_, V> { + // SAFETY: lifetime is held by returned iterator. + let iter = unsafe { self.table.iter() }; + + Iter { + iter, + _marker: PhantomData, + } + } + + #[inline(always)] + pub(crate) fn iter_ref(this: Ref) -> IterRef { + let (this, _guard) = Ref::into_raw(this); + // SAFETY: Table will be alive and a reference to it held for as long as + // `RawRef` is alive. + let iter = unsafe { this.as_ref().table.iter() }; + IterRef { iter, _guard } + } + + #[inline(always)] + pub(crate) unsafe fn iter_ref_raw(this: ptr::NonNull>) -> RawIter<(Value, V)> { + this.as_ref().table.iter() + } + + #[inline(always)] + pub(crate) fn keys_ref(this: Ref) -> KeysRef { + let (this, _guard) = Ref::into_raw(this); + // SAFETY: Table will be alive and a reference to it held for as long as + // `RawRef` is alive. + let iter = unsafe { this.as_ref().table.iter() }; + KeysRef { iter, _guard } + } + + #[inline(always)] + pub(crate) fn values_ref(this: Ref) -> ValuesRef { + let (this, _guard) = Ref::into_raw(this); + // SAFETY: Table will be alive and a reference to it held for as long as + // `RawRef` is alive. + let iter = unsafe { this.as_ref().table.iter() }; + ValuesRef { iter, _guard } + } +} + +pub(crate) struct Iter<'a, V> { + iter: RawIter<(Value, V)>, + _marker: PhantomData<&'a V>, +} + +impl<'a, V> iter::Iterator for Iter<'a, V> { + type Item = &'a (Value, V); + + #[inline] + fn next(&mut self) -> Option { + // SAFETY: we're still holding onto the `RawRef` guard. + unsafe { Some(self.iter.next()?.as_ref().clone()) } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +pub(crate) struct IterRef { + iter: RawIter<(Value, V)>, + _guard: RawRef, +} + +impl iter::Iterator for IterRef +where + V: Clone, +{ + type Item = (Value, V); + + #[inline] + fn next(&mut self) -> Option { + // SAFETY: we're still holding onto the `RawRef` guard. + unsafe { Some(self.iter.next()?.as_ref().clone()) } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +pub(crate) struct KeysRef { + iter: RawIter<(Value, V)>, + _guard: RawRef, +} + +impl iter::Iterator for KeysRef { + type Item = Value; + + #[inline] + fn next(&mut self) -> Option { + // SAFETY: we're still holding onto the `RawRef` guard. + unsafe { Some(self.iter.next()?.as_ref().0.clone()) } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +pub(crate) struct ValuesRef { + iter: RawIter<(Value, V)>, + _guard: RawRef, +} + +impl iter::Iterator for ValuesRef +where + V: Clone, +{ + type Item = V; + + #[inline] + fn next(&mut self) -> Option { + // SAFETY: we're still holding onto the `RawRef` guard. + unsafe { Some(self.iter.next()?.as_ref().1.clone()) } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +/// Convenience function to hash a value. +fn hash(state: &S, value: &Value, caller: &mut impl ProtocolCaller) -> VmResult +where + S: BuildHasher, +{ + let mut hasher = Hasher::new_with(state); + vm_try!(value.hash_with(&mut hasher, caller)); + VmResult::Ok(hasher.finish()) +} + +/// Construct a hasher for a value in the table. +fn hasher(state: &S) -> impl Fn(&mut P, &(Value, V)) -> Result + '_ +where + P: ?Sized + ProtocolCaller, + S: BuildHasher, +{ + move |caller, (key, _): &(Value, V)| hash(state, key, caller).into_result() +} + +/// Construct an equality function for a value in the table that will compare an +/// entry with the current key. +fn eq(key: &Value) -> impl Fn(&mut P, &(Value, V)) -> Result + '_ +where + P: ?Sized + ProtocolCaller, +{ + move |caller: &mut P, (other, _): &(Value, V)| -> Result { + key.eq_with(other, caller).into_result() + } +} diff --git a/crates/rune/src/lib.rs b/crates/rune/src/lib.rs index 7a89db126..dea901da5 100644 --- a/crates/rune/src/lib.rs +++ b/crates/rune/src/lib.rs @@ -269,6 +269,9 @@ cfg_workspace! { pub mod workspace; } +#[cfg(feature = "std")] +mod hashbrown; + // Macros used internally and re-exported. pub(crate) use rune_macros::__internal_impl_any; diff --git a/crates/rune/src/modules/collections/hash_map.rs b/crates/rune/src/modules/collections/hash_map.rs index 959e62d24..2f7c84b87 100644 --- a/crates/rune/src/modules/collections/hash_map.rs +++ b/crates/rune/src/modules/collections/hash_map.rs @@ -1,15 +1,10 @@ use core::fmt::{self, Write}; -use core::hash::BuildHasher; -use core::iter; -use core::mem; - -use hashbrown::raw::{RawIter, RawTable}; -use std::collections::hash_map::{DefaultHasher, RandomState}; use crate as rune; +use crate::hashbrown::Table; use crate::runtime::{ - EnvProtocolCaller, Formatter, FromValue, Hasher, Iterator, ProtocolCaller, RawRef, Ref, Value, - VmError, VmErrorKind, VmResult, + EnvProtocolCaller, Formatter, FromValue, Iterator, ProtocolCaller, Ref, Value, VmErrorKind, + VmResult, }; use crate::{Any, ContextError, Module}; @@ -18,6 +13,7 @@ pub(super) fn setup(module: &mut Module) -> Result<(), ContextError> { module.function_meta(HashMap::new__meta)?; module.function_meta(HashMap::with_capacity__meta)?; module.function_meta(HashMap::len__meta)?; + module.function_meta(HashMap::capacity__meta)?; module.function_meta(HashMap::insert__meta)?; module.function_meta(HashMap::get__meta)?; module.function_meta(HashMap::contains_key__meta)?; @@ -39,41 +35,10 @@ pub(super) fn setup(module: &mut Module) -> Result<(), ContextError> { Ok(()) } -/// Convenience function to hash a value. -fn hash(state: &S, value: &Value, caller: &mut impl ProtocolCaller) -> VmResult -where - S: BuildHasher, -{ - let mut hasher = Hasher::new_with(state); - vm_try!(value.hash_with(&mut hasher, caller)); - VmResult::Ok(hasher.finish()) -} - -/// Construct a hasher for a value in the table. -fn hasher(state: &S) -> impl Fn(&mut P, &(Value, Value)) -> Result + '_ -where - S: BuildHasher, - P: ProtocolCaller, -{ - move |caller, (key, _): &(Value, Value)| hash(state, key, caller).into_result() -} - -/// Construct an equality function for a value in the table that will compare an -/// entry with the current key. -fn eq(key: &Value) -> impl Fn(&mut P, &(Value, Value)) -> Result + '_ -where - P: ProtocolCaller, -{ - move |caller: &mut P, (other, _): &(Value, Value)| -> Result { - key.eq_with(other, caller).into_result() - } -} - #[derive(Any, Clone)] #[rune(item = ::std::collections)] pub(crate) struct HashMap { - table: RawTable<(Value, Value)>, - state: RandomState, + table: Table, } impl HashMap { @@ -91,8 +56,7 @@ impl HashMap { #[rune::function(keep, path = Self::new)] fn new() -> Self { Self { - table: RawTable::new(), - state: RandomState::new(), + table: Table::new(), } } @@ -111,8 +75,7 @@ impl HashMap { #[rune::function(keep, path = Self::with_capacity)] fn with_capacity(capacity: usize) -> Self { Self { - table: RawTable::with_capacity(capacity), - state: RandomState::new(), + table: Table::with_capacity(capacity), } } @@ -194,30 +157,7 @@ impl HashMap { #[rune::function(keep)] fn insert(&mut self, key: Value, value: Value) -> VmResult> { let mut caller = EnvProtocolCaller; - - let hash = vm_try!(hash(&self.state, &key, &mut caller)); - - let result = match self.table.find_or_find_insert_slot2( - &mut caller, - hash, - eq(&key), - hasher(&self.state), - ) { - Ok(result) => result, - Err(error) => return VmResult::Err(error), - }; - - let existing = match result { - Ok(bucket) => Some(mem::replace(unsafe { &mut bucket.as_mut().1 }, value)), - Err(slot) => { - unsafe { - self.table.insert_in_slot(hash, slot, (key, value)); - } - None - } - }; - - VmResult::Ok(existing) + self.table.insert_with(key, value, &mut caller) } /// Returns the value corresponding to the [`Key`]. @@ -234,17 +174,8 @@ impl HashMap { /// ``` #[rune::function(keep)] fn get(&self, key: Value) -> VmResult> { - VmResult::Ok(vm_try!(self.get_inner(&key)).map(|(_, value)| value.clone())) - } - - fn get_inner(&self, key: &Value) -> VmResult> { - if self.table.is_empty() { - return VmResult::Ok(None); - } - let mut caller = EnvProtocolCaller; - let hash = vm_try!(hash(&self.state, key, &mut caller)); - VmResult::Ok(vm_try!(self.table.get2(&mut caller, hash, eq(key)))) + VmResult::Ok(vm_try!(self.table.get(&key, &mut caller)).map(|(_, v)| v.clone())) } /// Returns `true` if the map contains a value for the specified [`Key`]. @@ -261,13 +192,8 @@ impl HashMap { /// ``` #[rune::function(keep)] fn contains_key(&self, key: Value) -> VmResult { - if self.table.is_empty() { - return VmResult::Ok(false); - } - let mut caller = EnvProtocolCaller; - let hash = vm_try!(hash(&self.state, &key, &mut caller)); - VmResult::Ok(vm_try!(self.table.get2(&mut caller, hash, eq(&key))).is_some()) + VmResult::Ok(vm_try!(self.table.get(&key, &mut caller)).is_some()) } /// Removes a key from the map, returning the value at the [`Key`] if the @@ -286,12 +212,7 @@ impl HashMap { #[rune::function(keep)] fn remove(&mut self, key: Value) -> VmResult> { let mut caller = EnvProtocolCaller; - let hash = vm_try!(hash(&self.state, &key, &mut caller)); - - match self.table.remove_entry2(&mut caller, hash, eq(&key)) { - Ok(value) => VmResult::Ok(value.map(|(_, value)| value)), - Err(error) => VmResult::Err(error), - } + self.table.remove_with(&key, &mut caller) } /// Clears the map, removing all key-value pairs. Keeps the allocated memory @@ -336,32 +257,8 @@ impl HashMap { /// instead of O(len) because it internally visits empty buckets too. #[rune::function(keep, instance, path = Self::iter)] fn iter(this: Ref) -> Iterator { - struct Iter { - iter: RawIter<(Value, Value)>, - _guard: RawRef, - } - - impl iter::Iterator for Iter { - type Item = (Value, Value); - - #[inline] - fn next(&mut self) -> Option { - // SAFETY: we're still holding onto the `RawRef` guard. - unsafe { Some(self.iter.next()?.as_ref().clone()) } - } - - #[inline] - fn size_hint(&self) -> (usize, Option) { - self.iter.size_hint() - } - } - - let (this, _guard) = Ref::into_raw(this); - - // SAFETY: Table will be alive and a reference to it held for as long as - // `RawRef` is alive. - let iter = unsafe { this.as_ref().table.iter() }; - Iterator::from("std::collections::hash_map::Iter", Iter { iter, _guard }) + let iter = Table::iter_ref(Ref::map(this, |this| &this.table)); + Iterator::from("std::collections::hash_map::Iter", iter) } /// An iterator visiting all keys in arbitrary order. @@ -388,33 +285,8 @@ impl HashMap { /// time instead of O(len) because it internally visits empty buckets too. #[rune::function(keep, instance, path = Self::keys)] fn keys(this: Ref) -> Iterator { - struct Keys { - iter: RawIter<(Value, Value)>, - _guard: RawRef, - } - - impl iter::Iterator for Keys { - type Item = Value; - - #[inline] - fn next(&mut self) -> Option { - // SAFETY: we're still holding onto the `RawRef` guard. - unsafe { Some(self.iter.next()?.as_ref().0.clone()) } - } - - #[inline] - fn size_hint(&self) -> (usize, Option) { - self.iter.size_hint() - } - } - - let (this, _guard) = Ref::into_raw(this); - - // SAFETY: Table will be alive and a reference to it held for as long as - // `RawRef` is alive. - let iter = unsafe { this.as_ref().table.iter() }; - - Iterator::from("std::collections::hash_map::Keys", Keys { iter, _guard }) + let iter = Table::keys_ref(Ref::map(this, |this| &this.table)); + Iterator::from("std::collections::hash_map::Keys", iter) } /// An iterator visiting all values in arbitrary order. @@ -441,36 +313,9 @@ impl HashMap { /// time instead of O(len) because it internally visits empty buckets too. #[rune::function(keep, instance, path = Self::values)] fn values(this: Ref) -> Iterator { - struct Values { - iter: RawIter<(Value, Value)>, - _guard: RawRef, - } - - impl iter::Iterator for Values { - type Item = Value; + let iter = Table::values_ref(Ref::map(this, |this| &this.table)); - #[inline] - fn next(&mut self) -> Option { - // SAFETY: we're still holding onto the `RawRef` guard. - unsafe { Some(self.iter.next()?.as_ref().1.clone()) } - } - - #[inline] - fn size_hint(&self) -> (usize, Option) { - self.iter.size_hint() - } - } - - let (this, _guard) = Ref::into_raw(this); - - // SAFETY: Table will be alive and a reference to it held for as long as - // `RawRef` is alive. - let iter = unsafe { this.as_ref().table.iter() }; - - Iterator::from( - "std::collections::hash_map::Values", - Values { iter, _guard }, - ) + Iterator::from("std::collections::hash_map::Values", iter) } /// Extend this map from an iterator. @@ -506,7 +351,8 @@ impl HashMap { /// protocol, and each item produces should be a tuple pair. #[rune::function(keep, path = Self::from)] fn from(value: Value) -> VmResult { - HashMap::from_iter(vm_try!(value.into_iter())) + let mut caller = EnvProtocolCaller; + HashMap::from_iter(vm_try!(value.into_iter()), &mut caller) } /// Clone the map. @@ -529,12 +375,15 @@ impl HashMap { Clone::clone(this) } - pub(crate) fn from_iter(mut it: Iterator) -> VmResult { + pub(crate) fn from_iter

(mut it: Iterator, caller: &mut P) -> VmResult + where + P: ?Sized + ProtocolCaller, + { let mut map = Self::new(); while let Some(value) = vm_try!(it.next()) { let (key, value) = vm_try!(<(Value, Value)>::from_value(value)); - vm_try!(map.insert(key, value)); + vm_try!(map.table.insert_with(key, value, caller)); } VmResult::Ok(map) @@ -590,7 +439,9 @@ impl HashMap { fn index_get(&self, key: Value) -> VmResult { use crate::runtime::TypeOf; - let Some((_, value)) = vm_try!(self.get_inner(&key)) else { + let mut caller = EnvProtocolCaller; + + let Some((_, value)) = vm_try!(self.table.get(&key, &mut caller)) else { return VmResult::err(VmErrorKind::MissingIndexKey { target: Self::type_info(), }); @@ -623,22 +474,21 @@ impl HashMap { ) -> VmResult { vm_write!(f, "{{"); - // SAFETY: we're holding onto the map for the duration of the iteration. - unsafe { - let mut it = self.table.iter().peekable(); + let mut it = self.table.iter().peekable(); - while let Some(bucket) = it.next() { - let (key, value) = bucket.as_ref(); + while let Some((key, value)) = it.next() { + if let Err(fmt::Error) = vm_try!(key.string_debug_with(f, caller)) { + return VmResult::Ok(Err(fmt::Error)); + } - vm_write!(f, "{:?}: ", key); + vm_write!(f, ": "); - if let Err(fmt::Error) = vm_try!(value.string_debug_with(f, caller)) { - return VmResult::Ok(Err(fmt::Error)); - } + if let Err(fmt::Error) = vm_try!(value.string_debug_with(f, caller)) { + return VmResult::Ok(Err(fmt::Error)); + } - if it.peek().is_some() { - vm_write!(f, ", "); - } + if it.peek().is_some() { + vm_write!(f, ", "); } } @@ -682,18 +532,13 @@ impl HashMap { return VmResult::Ok(false); } - // SAFETY: we're holding onto the map for the duration of the iteration. - unsafe { - for bucket in self.table.iter() { - let (k, v1) = bucket.as_ref(); - - let Some((_, v2)) = vm_try!(other.get_inner(k)) else { - return VmResult::Ok(false); - }; + for (k, v1) in self.table.iter() { + let Some((_, v2)) = vm_try!(other.table.get(k, caller)) else { + return VmResult::Ok(false); + }; - if !vm_try!(Value::partial_eq_with(v1, v2, caller)) { - return VmResult::Ok(false); - } + if !vm_try!(Value::partial_eq_with(v1, v2, caller)) { + return VmResult::Ok(false); } } @@ -732,18 +577,13 @@ impl HashMap { return VmResult::Ok(false); } - // SAFETY: we're holding onto the map for the duration of the iteration. - unsafe { - for bucket in self.table.iter() { - let (k, v1) = bucket.as_ref(); - - let Some((_, v2)) = vm_try!(other.get_inner(k)) else { - return VmResult::Ok(false); - }; + for (k, v1) in self.table.iter() { + let Some((_, v2)) = vm_try!(other.table.get(k, caller)) else { + return VmResult::Ok(false); + }; - if !vm_try!(Value::eq_with(v1, v2, caller)) { - return VmResult::Ok(false); - } + if !vm_try!(Value::eq_with(v1, v2, caller)) { + return VmResult::Ok(false); } } diff --git a/crates/rune/src/modules/collections/hash_set.rs b/crates/rune/src/modules/collections/hash_set.rs index 9c1d567b9..6d6e5c932 100644 --- a/crates/rune/src/modules/collections/hash_set.rs +++ b/crates/rune/src/modules/collections/hash_set.rs @@ -1,44 +1,44 @@ use core::fmt::{self, Write}; use core::iter; +use core::ptr; use crate as rune; -use crate::no_std::collections; +use crate::hashbrown::{IterRef, RawIter, Table}; use crate::runtime::{ - EnvProtocolCaller, Formatter, Iterator, IteratorTrait, Key, ProtocolCaller, Ref, Value, - VmResult, + EnvProtocolCaller, Formatter, Iterator, ProtocolCaller, RawRef, Ref, Value, VmResult, }; use crate::{Any, ContextError, Module}; pub(super) fn setup(module: &mut Module) -> Result<(), ContextError> { module.ty::()?; - module.function_meta(HashSet::new)?; - module.function_meta(HashSet::with_capacity)?; - module.function_meta(HashSet::len)?; - module.function_meta(HashSet::is_empty)?; - module.function_meta(HashSet::capacity)?; - module.function_meta(HashSet::insert)?; - module.function_meta(HashSet::remove)?; - module.function_meta(HashSet::contains)?; - module.function_meta(HashSet::clear)?; - module.function_meta(HashSet::difference)?; - module.function_meta(HashSet::extend)?; - module.function_meta(HashSet::intersection)?; - module.function_meta(HashSet::union)?; - module.function_meta(HashSet::iter)?; - module.function_meta(clone)?; - module.function_meta(from)?; - module.function_meta(HashSet::into_iter)?; - module.function_meta(HashSet::string_debug)?; - module.function_meta(HashSet::partial_eq)?; - module.function_meta(HashSet::eq)?; + module.function_meta(HashSet::new__meta)?; + module.function_meta(HashSet::with_capacity__meta)?; + module.function_meta(HashSet::len__meta)?; + module.function_meta(HashSet::is_empty__meta)?; + module.function_meta(HashSet::capacity__meta)?; + module.function_meta(HashSet::insert__meta)?; + module.function_meta(HashSet::remove__meta)?; + module.function_meta(HashSet::contains__meta)?; + module.function_meta(HashSet::clear__meta)?; + module.function_meta(HashSet::difference__meta)?; + module.function_meta(HashSet::extend__meta)?; + module.function_meta(HashSet::intersection__meta)?; + module.function_meta(HashSet::union__meta)?; + module.function_meta(HashSet::iter__meta)?; + module.function_meta(HashSet::into_iter__meta)?; + module.function_meta(HashSet::string_debug__meta)?; + module.function_meta(HashSet::partial_eq__meta)?; + module.function_meta(HashSet::eq__meta)?; + module.function_meta(HashSet::clone__meta)?; + module.function_meta(HashSet::from__meta)?; Ok(()) } #[derive(Any, Clone)] #[rune(module = crate, item = ::std::collections)] pub(crate) struct HashSet { - set: collections::HashSet, + table: Table<()>, } impl HashSet { @@ -54,10 +54,10 @@ impl HashSet { /// /// let set = HashSet::new(); /// ``` - #[rune::function(path = Self::new)] + #[rune::function(keep, path = Self::new)] fn new() -> Self { Self { - set: collections::HashSet::new(), + table: Table::new(), } } @@ -75,10 +75,10 @@ impl HashSet { /// let set = HashSet::with_capacity(10); /// assert!(set.capacity() >= 10); /// ``` - #[rune::function(path = Self::with_capacity)] + #[rune::function(keep, path = Self::with_capacity)] fn with_capacity(capacity: usize) -> Self { Self { - set: collections::HashSet::with_capacity(capacity), + table: Table::with_capacity(capacity), } } @@ -94,9 +94,9 @@ impl HashSet { /// v.insert(1); /// assert_eq!(v.len(), 1); /// ``` - #[rune::function] + #[rune::function(keep)] fn len(&self) -> usize { - self.set.len() + self.table.len() } /// Returns `true` if the set contains no elements. @@ -111,9 +111,9 @@ impl HashSet { /// v.insert(1); /// assert!(!v.is_empty()); /// ``` - #[rune::function] + #[rune::function(keep)] fn is_empty(&self) -> bool { - self.set.is_empty() + self.table.is_empty() } /// Returns the number of elements the set can hold without reallocating. @@ -126,9 +126,9 @@ impl HashSet { /// let set = HashSet::with_capacity(100); /// assert!(set.capacity() >= 100); /// ``` - #[rune::function] + #[rune::function(keep)] fn capacity(&self) -> usize { - self.set.capacity() + self.table.capacity() } /// Adds a value to the set. @@ -149,9 +149,10 @@ impl HashSet { /// assert_eq!(set.insert(2), false); /// assert_eq!(set.len(), 1); /// ``` - #[rune::function] - fn insert(&mut self, key: Key) -> bool { - self.set.insert(key) + #[rune::function(keep)] + fn insert(&mut self, key: Value) -> VmResult { + let mut caller = EnvProtocolCaller; + VmResult::Ok(vm_try!(self.table.insert_with(key, (), &mut caller)).is_none()) } /// Removes a value from the set. Returns whether the value was present in @@ -168,9 +169,10 @@ impl HashSet { /// assert_eq!(set.remove(2), true); /// assert_eq!(set.remove(2), false); /// ``` - #[rune::function] - fn remove(&mut self, key: Key) -> bool { - self.set.remove(&key) + #[rune::function(keep)] + fn remove(&mut self, key: Value) -> VmResult { + let mut caller = EnvProtocolCaller; + VmResult::Ok(vm_try!(self.table.remove_with(&key, &mut caller)).is_some()) } /// Returns `true` if the set contains a value. @@ -184,9 +186,10 @@ impl HashSet { /// assert_eq!(set.contains(1), true); /// assert_eq!(set.contains(4), false); /// ``` - #[rune::function] - fn contains(&self, key: Key) -> bool { - self.set.contains(&key) + #[rune::function(keep)] + fn contains(&self, key: Value) -> VmResult { + let mut caller = EnvProtocolCaller; + VmResult::Ok(vm_try!(self.table.get(&key, &mut caller)).is_some()) } /// Clears the set, removing all values. @@ -201,9 +204,9 @@ impl HashSet { /// v.clear(); /// assert!(v.is_empty()); /// ``` - #[rune::function] + #[rune::function(keep)] fn clear(&mut self) { - self.set.clear() + self.table.clear() } /// Visits the values representing the difference, i.e., the values that are @@ -225,15 +228,17 @@ impl HashSet { /// let diff = b.difference(a).collect::(); /// assert_eq!(diff, [4].iter().collect::()); /// ``` - #[rune::function(instance, path = Self::difference)] + #[rune::function(keep, instance, path = Self::difference)] fn difference(this: Ref, other: Ref) -> Iterator { - Iterator::from( - "std::collections::set::Difference", - Difference { - this: this.set.clone().into_iter(), - other: Some(other), - }, - ) + let iter = Self::difference_inner(this, other); + Iterator::from("std::collections::hash_set::Difference", iter) + } + + fn difference_inner(this: Ref, other: Ref) -> Difference { + Difference { + this: Table::iter_ref(Ref::map(this, |this| &this.table)), + other: Some(other), + } } /// Visits the values representing the intersection, i.e., the values that @@ -253,22 +258,22 @@ impl HashSet { /// let values = a.intersection(b).collect::(); /// assert_eq!(values, [2, 3].iter().collect::()); /// ``` - #[rune::function(instance, path = Self::intersection)] + #[rune::function(keep, instance, path = Self::intersection)] fn intersection(this: Ref, other: Ref) -> Iterator { // use shortest iterator as driver for intersections - let intersection = if this.set.len() <= other.set.len() { + let iter = if this.table.len() <= other.table.len() { Intersection { - this: this.set.clone().into_iter(), + this: Table::iter_ref(Ref::map(this, |this| &this.table)), other: Some(other), } } else { Intersection { - this: other.set.clone().into_iter(), + this: Table::iter_ref(Ref::map(other, |this| &this.table)), other: Some(this), } }; - Iterator::from("std::collections::set::Intersection", intersection) + Iterator::from("std::collections::hash_set::Intersection", iter) } /// Visits the values representing the union, i.e., all the values in `self` @@ -284,21 +289,41 @@ impl HashSet { /// /// let union = a.union(b).collect::(); /// assert_eq!(union, HashSet::from([1, 2, 3, 4])); + /// + /// let union = b.union(a).collect::(); + /// assert_eq!(union, HashSet::from([1, 2, 3, 4])); /// ``` - #[rune::function(instance, path = Self::union)] + #[rune::function(keep, instance, path = Self::union)] fn union(this: Ref, other: Ref) -> VmResult { - // use longest as lead and then append any missing that are in second - let iter = Union { - iter: if this.set.len() >= other.set.len() { - vm_try!(HashSet::__rune_fn__iter(&this) - .chain_raw(HashSet::__rune_fn__difference(other, this))) + unsafe { + let (this, this_guard) = Ref::into_raw(Ref::map(this, |this| &this.table)); + let (other, other_guard) = Ref::into_raw(Ref::map(other, |this| &this.table)); + + // use longest as lead and then append any missing that are in second + let iter = if this.as_ref().len() >= other.as_ref().len() { + let this_iter = Table::iter_ref_raw(this); + let other_iter = Table::iter_ref_raw(other); + + Union { + this, + this_iter, + other_iter, + _guards: (this_guard, other_guard), + } } else { - vm_try!(HashSet::__rune_fn__iter(&other) - .chain_raw(HashSet::__rune_fn__difference(this, other))) - }, - }; - - VmResult::Ok(Iterator::from("std::collections::set::Union", iter)) + let this_iter = Table::iter_ref_raw(other); + let other_iter = Table::iter_ref_raw(this); + + Union { + this: other, + this_iter, + other_iter, + _guards: (other_guard, this_guard), + } + }; + + VmResult::Ok(Iterator::from("std::collections::hash_set::Union", iter)) + } } /// Iterate over the hash set. @@ -313,20 +338,24 @@ impl HashSet { /// vec.sort(); /// assert_eq!(vec, [1, 2, 3]); /// ``` - #[rune::function] - fn iter(&self) -> Iterator { - let iter = self.set.clone().into_iter(); - Iterator::from("std::collections::set::Iter", iter) + #[rune::function(keep, instance, path = Self::iter)] + fn iter(this: Ref) -> VmResult { + let iter = Self::iter_inner(this); + VmResult::Ok(Iterator::from("std::collections::hash_set::Iter", iter)) + } + + fn iter_inner(this: Ref) -> impl iter::Iterator { + Table::iter_ref(Ref::map(this, |this| &this.table)).map(|(key, ())| key) } /// Extend this set from an iterator. - #[rune::function] + #[rune::function(keep)] fn extend(&mut self, value: Value) -> VmResult<()> { + let mut caller = EnvProtocolCaller; let mut it = vm_try!(value.into_iter()); - while let Some(value) = vm_try!(it.next()) { - let key = vm_try!(Key::from_value(&value)); - self.set.insert(key); + while let Some(key) = vm_try!(it.next()) { + vm_try!(self.table.insert_with(key, (), &mut caller)); } VmResult::Ok(()) @@ -349,9 +378,9 @@ impl HashSet { /// vec.sort(); /// assert_eq!(vec, [1, 2, 3]); /// ``` - #[rune::function(protocol = INTO_ITER)] - fn into_iter(&self) -> Iterator { - self.__rune_fn__iter() + #[rune::function(keep, instance, protocol = INTO_ITER, path = Self)] + fn into_iter(this: Ref) -> VmResult { + Self::iter(this) } /// Write a debug representation to a string. @@ -367,7 +396,7 @@ impl HashSet { /// let set = HashSet::from([1, 2, 3]); /// println!("{:?}", set); /// ``` - #[rune::function(protocol = STRING_DEBUG)] + #[rune::function(keep, protocol = STRING_DEBUG)] fn string_debug(&self, f: &mut Formatter) -> VmResult { self.string_debug_with(f, &mut EnvProtocolCaller) } @@ -379,7 +408,7 @@ impl HashSet { ) -> VmResult { vm_write!(f, "{{"); - let mut it = self.set.iter().peekable(); + let mut it = self.table.iter().peekable(); while let Some(value) = it.next() { vm_write!(f, "{:?}", value); @@ -393,14 +422,17 @@ impl HashSet { VmResult::Ok(Ok(())) } - pub(crate) fn from_iter(mut it: Iterator) -> VmResult { - let mut set = collections::HashSet::with_capacity(it.size_hint().0); + pub(crate) fn from_iter

(mut it: Iterator, caller: &mut P) -> VmResult + where + P: ?Sized + ProtocolCaller, + { + let mut set = Table::with_capacity(it.size_hint().0); - while let Some(value) = vm_try!(it.next()) { - set.insert(vm_try!(Key::from_value(&value))); + while let Some(key) = vm_try!(it.next()) { + vm_try!(set.insert_with(key, (), caller)); } - VmResult::Ok(HashSet { set }) + VmResult::Ok(HashSet { table: set }) } /// Perform a partial equality test between two sets. @@ -416,9 +448,9 @@ impl HashSet { /// assert_eq!(set, HashSet::from([1, 2, 3])); /// assert_ne!(set, HashSet::from([2, 3, 4])); /// ``` - #[rune::function(protocol = PARTIAL_EQ)] - fn partial_eq(&self, other: &Self) -> bool { - self.set == other.set + #[rune::function(keep, protocol = PARTIAL_EQ)] + fn partial_eq(&self, other: &Self) -> VmResult { + self.eq_with(other, &mut EnvProtocolCaller) } /// Perform a total equality test between two sets. @@ -433,36 +465,62 @@ impl HashSet { /// assert!(eq(set, HashSet::from([1, 2, 3]))); /// assert!(!eq(set, HashSet::from([2, 3, 4]))); /// ``` - #[rune::function(protocol = EQ)] - fn eq(&self, other: &Self) -> bool { - self.set == other.set + #[rune::function(keep, protocol = EQ)] + fn eq(&self, other: &Self) -> VmResult { + self.eq_with(other, &mut EnvProtocolCaller) + } + + fn eq_with(&self, other: &Self, caller: &mut EnvProtocolCaller) -> VmResult { + if self.table.len() != other.table.len() { + return VmResult::Ok(false); + } + + for (key, ()) in self.table.iter() { + if vm_try!(other.table.get(key, caller)).is_none() { + return VmResult::Ok(false); + } + } + + VmResult::Ok(true) + } + + #[rune::function(keep, path = Self::from)] + fn from(value: Value) -> VmResult { + let mut caller = EnvProtocolCaller; + HashSet::from_iter(vm_try!(value.into_iter()), &mut caller) + } + + #[rune::function(keep, instance, path = Self::clone)] + fn clone(this: &HashSet) -> HashSet { + this.clone() } } -struct Intersection -where - I: iter::Iterator, -{ - this: I, +struct Intersection { + this: IterRef<()>, other: Option>, } -impl iter::Iterator for Intersection -where - I: iter::Iterator, -{ - type Item = Key; +impl iter::Iterator for Intersection { + type Item = VmResult; + fn next(&mut self) -> Option { - let other = self.other.take()?; + let mut caller = EnvProtocolCaller; + let other = self.other.as_ref()?; - loop { - let item = self.this.next()?; + for (key, ()) in self.this.by_ref() { + let c = match other.table.get(&key, &mut caller) { + VmResult::Ok(c) => c.is_some(), + VmResult::Err(e) => return Some(VmResult::Err(e)), + }; - if other.set.contains(&item) { - self.other = Some(other); - return Some(item); + if c { + return Some(VmResult::Ok(key)); } } + + self.other = None; + None } #[inline] @@ -472,31 +530,31 @@ where } } -struct Difference -where - I: iter::Iterator, -{ - this: I, +struct Difference { + this: IterRef<()>, other: Option>, } -impl iter::Iterator for Difference -where - I: iter::Iterator, -{ - type Item = Key; +impl iter::Iterator for Difference { + type Item = VmResult; fn next(&mut self) -> Option { - let other = self.other.take()?; + let mut caller = EnvProtocolCaller; + let other = self.other.as_ref()?; - loop { - let item = self.this.next()?; + for (key, ()) in self.this.by_ref() { + let c = match other.table.get(&key, &mut caller) { + VmResult::Ok(c) => c.is_some(), + VmResult::Err(e) => return Some(VmResult::Err(e)), + }; - if !other.set.contains(&item) { - self.other = Some(other); - return Some(item); + if !c { + return Some(VmResult::Ok(key)); } } + + self.other = None; + None } #[inline] @@ -507,25 +565,37 @@ where } struct Union { - iter: Iterator, + this: ptr::NonNull>, + this_iter: RawIter<(Value, ())>, + other_iter: RawIter<(Value, ())>, + _guards: (RawRef, RawRef), } -impl IteratorTrait for Union { - fn next(&mut self) -> VmResult> { - self.iter.next() - } +impl iter::Iterator for Union { + type Item = VmResult; - fn size_hint(&self) -> (usize, Option) { - self.iter.size_hint() - } -} + fn next(&mut self) -> Option { + // SAFETY: we're holding onto the ref guards for both collections during + // iteration, so this is valid for the lifetime of the iterator. + unsafe { + if let Some(bucket) = self.this_iter.next() { + let (value, ()) = bucket.as_ref(); + return Some(VmResult::Ok(value.clone())); + } -#[rune::function(free, path = HashSet::from)] -fn from(value: Value) -> VmResult { - HashSet::from_iter(vm_try!(value.into_iter())) -} + let mut caller = EnvProtocolCaller; + + for bucket in self.other_iter.by_ref() { + let (key, ()) = bucket.as_ref(); -#[rune::function(instance)] -fn clone(this: &HashSet) -> HashSet { - this.clone() + match self.this.as_ref().get(key, &mut caller) { + VmResult::Ok(None) => return Some(VmResult::Ok(key.clone())), + VmResult::Ok(..) => {} + VmResult::Err(e) => return Some(VmResult::Err(e)), + } + } + + None + } + } } diff --git a/crates/rune/src/modules/iter.rs b/crates/rune/src/modules/iter.rs index 8ed18dc4c..6704fa651 100644 --- a/crates/rune/src/modules/iter.rs +++ b/crates/rune/src/modules/iter.rs @@ -5,7 +5,8 @@ use crate::no_std::prelude::*; use crate as rune; use crate::modules::collections::{HashMap, HashSet, VecDeque}; use crate::runtime::{ - FromValue, Function, Iterator, Object, OwnedTuple, Protocol, Value, Vec, VmResult, + EnvProtocolCaller, FromValue, Function, Iterator, Object, OwnedTuple, Protocol, Value, Vec, + VmResult, }; use crate::{ContextError, Module}; @@ -1070,7 +1071,8 @@ fn collect_vec_deque(it: Iterator) -> VmResult { /// ``` #[rune::function(instance, path = collect::)] fn collect_hash_set(it: Iterator) -> VmResult { - HashSet::from_iter(it) + let mut caller = EnvProtocolCaller; + HashSet::from_iter(it, &mut caller) } /// Collect the iterator as a [`HashMap`]. @@ -1086,7 +1088,8 @@ fn collect_hash_set(it: Iterator) -> VmResult { /// ``` #[rune::function(instance, path = collect::)] fn collect_hash_map(it: Iterator) -> VmResult { - HashMap::from_iter(it) + let mut caller = EnvProtocolCaller; + HashMap::from_iter(it, &mut caller) } /// Collect the iterator as a [`Tuple`]. diff --git a/crates/rune/src/runtime/hasher.rs b/crates/rune/src/runtime/hasher.rs index 5184f1973..f90c4478f 100644 --- a/crates/rune/src/runtime/hasher.rs +++ b/crates/rune/src/runtime/hasher.rs @@ -1,14 +1,13 @@ use core::hash::{BuildHasher, Hasher as _}; use crate::no_std::collections::hash_map::DefaultHasher; -use crate::runtime::Value; use crate as rune; use crate::Any; /// The default hasher used in Rune. #[derive(Any)] -#[rune(builtin, from_value = Value::into_hasher, static_type = HASHER_TYPE)] +#[rune(item = ::std::hash)] pub struct Hasher { hasher: DefaultHasher, } diff --git a/crates/rune/src/runtime/iterator.rs b/crates/rune/src/runtime/iterator.rs index c0603d736..ca56f486c 100644 --- a/crates/rune/src/runtime/iterator.rs +++ b/crates/rune/src/runtime/iterator.rs @@ -224,16 +224,6 @@ impl Iterator { }) } - #[inline] - pub(crate) fn chain_raw(self, other: Self) -> VmResult { - VmResult::Ok(Self { - iter: IterRepr::Chain(Box::new(Chain { - a: Some(self.iter), - b: Some(other.iter), - })), - }) - } - #[inline] pub(crate) fn rev(self) -> VmResult { if !self.iter.is_double_ended() { diff --git a/crates/rune/src/runtime/static_type.rs b/crates/rune/src/runtime/static_type.rs index 30b90a259..36222fa07 100644 --- a/crates/rune/src/runtime/static_type.rs +++ b/crates/rune/src/runtime/static_type.rs @@ -233,11 +233,6 @@ pub(crate) static ORDERING_TYPE: &StaticType = &StaticType { impl_static_type!(Ordering => ORDERING_TYPE); -pub(crate) static HASHER_TYPE: &StaticType = &StaticType { - name: RawStr::from_str("Hasher"), - hash: ::rune_macros::hash!(::std::hash::Hasher), -}; - pub(crate) static TYPE: &StaticType = &StaticType { name: RawStr::from_str("Type"), hash: ::rune_macros::hash!(::std::any::Type),