From f75557a0df45f65b076f32c783dca9e1d0f46734 Mon Sep 17 00:00:00 2001 From: Arayi Date: Tue, 10 Dec 2024 13:04:08 -0500 Subject: [PATCH] feat: extern implementations for zkvm target fix: lints fix: nightly fmt feat: make extern functions take *u8 instead of *u64 chore: make extern function names more unique --- src/add.rs | 49 ++++++++++++++ src/bits.rs | 160 ++++++++++++++++++++++++++++++++++++++++++-- src/cmp.rs | 16 +++++ src/lib.rs | 9 +++ src/mul.rs | 20 ++++++ src/support/mod.rs | 2 + src/support/zkvm.rs | 76 +++++++++++++++++++++ 7 files changed, 328 insertions(+), 4 deletions(-) create mode 100644 src/support/zkvm.rs diff --git a/src/add.rs b/src/add.rs index 77269b5..d4f352b 100644 --- a/src/add.rs +++ b/src/add.rs @@ -143,25 +143,74 @@ impl Uint { } /// Computes `self + rhs`, wrapping around at the boundary of the type. + #[cfg(not(target_os = "zkvm"))] #[inline(always)] #[must_use] pub const fn wrapping_add(self, rhs: Self) -> Self { self.overflowing_add(rhs).0 } + /// Computes `self + rhs`, wrapping around at the boundary of the type. + #[cfg(target_os = "zkvm")] + #[inline(always)] + #[must_use] + pub fn wrapping_add(mut self, rhs: Self) -> Self { + use crate::support::zkvm::zkvm_u256_wrapping_add_impl; + if BITS == 256 { + unsafe { + zkvm_u256_wrapping_add_impl( + self.limbs.as_mut_ptr() as *mut u8, + self.limbs.as_ptr() as *const u8, + rhs.limbs.as_ptr() as *const u8, + ); + } + return self; + } + self.overflowing_add(rhs).0 + } + /// Computes `-self`, wrapping around at the boundary of the type. + #[cfg(not(target_os = "zkvm"))] #[inline(always)] #[must_use] pub const fn wrapping_neg(self) -> Self { self.overflowing_neg().0 } + /// Computes `-self`, wrapping around at the boundary of the type. + #[cfg(target_os = "zkvm")] + #[inline(always)] + #[must_use] + pub fn wrapping_neg(self) -> Self { + Self::ZERO.wrapping_sub(self) + } + /// Computes `self - rhs`, wrapping around at the boundary of the type. + #[cfg(not(target_os = "zkvm"))] #[inline(always)] #[must_use] pub const fn wrapping_sub(self, rhs: Self) -> Self { self.overflowing_sub(rhs).0 } + + /// Computes `self - rhs`, wrapping around at the boundary of the type. + #[cfg(target_os = "zkvm")] + #[inline(always)] + #[must_use] + pub fn wrapping_sub(mut self, rhs: Self) -> Self { + use crate::support::zkvm::zkvm_u256_wrapping_sub_impl; + if BITS == 256 { + unsafe { + zkvm_u256_wrapping_sub_impl( + self.limbs.as_mut_ptr() as *mut u8, + self.limbs.as_ptr() as *const u8, + rhs.limbs.as_ptr() as *const u8, + ); + } + return self; + } + self.overflowing_sub(rhs).0 + } } impl Neg for Uint { diff --git a/src/bits.rs b/src/bits.rs index 94d10c7..001eece 100644 --- a/src/bits.rs +++ b/src/bits.rs @@ -278,12 +278,42 @@ impl Uint { /// /// Note: This differs from [`u64::wrapping_shl`] which first reduces `rhs` /// by `BITS` (which is IMHO not very useful). + #[cfg(not(target_os = "zkvm"))] #[inline(always)] #[must_use] pub fn wrapping_shl(self, rhs: usize) -> Self { self.overflowing_shl(rhs).0 } + /// Left shift by `rhs` bits. + /// + /// Returns $\mod{\mathtt{value} ⋅ 2^{\mathtt{rhs}}}_{2^{\mathtt{BITS}}}$. + /// + /// Note: This differs from [`u64::wrapping_shl`] which first reduces `rhs` + /// by `BITS` (which is IMHO not very useful). + #[cfg(target_os = "zkvm")] + #[inline(always)] + #[must_use] + pub fn wrapping_shl(mut self, rhs: usize) -> Self { + if BITS == 256 { + if rhs >= 256 { + return Self::ZERO; + } + use crate::support::zkvm::zkvm_u256_wrapping_shl_impl; + let rhs = rhs as u64; + unsafe { + zkvm_u256_wrapping_shl_impl( + self.limbs.as_mut_ptr() as *mut u8, + self.limbs.as_ptr() as *const u8, + [rhs].as_ptr() as *const u8, + ); + } + self + } else { + self.overflowing_shl(rhs).0 + } + } + /// Checked right shift by `rhs` bits. /// /// $$ @@ -344,13 +374,47 @@ impl Uint { /// /// Note: This differs from [`u64::wrapping_shr`] which first reduces `rhs` /// by `BITS` (which is IMHO not very useful). + #[cfg(not(target_os = "zkvm"))] #[inline(always)] #[must_use] pub fn wrapping_shr(self, rhs: usize) -> Self { self.overflowing_shr(rhs).0 } + /// Right shift by `rhs` bits. + /// + /// $$ + /// \mathtt{wrapping\\_shr}(\mathtt{self}, \mathtt{rhs}) = + /// \floor{\frac{\mathtt{self}}{2^{\mathtt{rhs}}}} + /// $$ + /// + /// Note: This differs from [`u64::wrapping_shr`] which first reduces `rhs` + /// by `BITS` (which is IMHO not very useful). + #[cfg(target_os = "zkvm")] + #[inline(always)] + #[must_use] + pub fn wrapping_shr(mut self, rhs: usize) -> Self { + if BITS == 256 { + if rhs >= 256 { + return Self::ZERO; + } + use crate::support::zkvm::zkvm_u256_wrapping_shr_impl; + let rhs = rhs as u64; + unsafe { + zkvm_u256_wrapping_shr_impl( + self.limbs.as_mut_ptr() as *mut u8, + self.limbs.as_ptr() as *const u8, + [rhs].as_ptr() as *const u8, + ); + } + self + } else { + self.overflowing_shr(rhs).0 + } + } + /// Arithmetic shift right by `rhs` bits. + #[cfg(not(target_os = "zkvm"))] #[inline] #[must_use] pub fn arithmetic_shr(self, rhs: usize) -> Self { @@ -365,6 +429,36 @@ impl Uint { r } + /// Arithmetic shift right by `rhs` bits. + #[cfg(target_os = "zkvm")] + #[inline] + #[must_use] + pub fn arithmetic_shr(mut self, rhs: usize) -> Self { + if BITS == 256 { + let rhs = if rhs >= 256 { 255 } else { rhs }; + use crate::support::zkvm::zkvm_u256_arithmetic_shr_impl; + let rhs = rhs as u64; + unsafe { + zkvm_u256_arithmetic_shr_impl( + self.limbs.as_mut_ptr() as *mut u8, + self.limbs.as_ptr() as *const u8, + [rhs].as_ptr() as *const u8, + ); + } + self + } else { + if BITS == 0 { + return Self::ZERO; + } + let sign = self.bit(BITS - 1); + let mut r = self >> rhs; + if sign { + r |= Self::MAX << BITS.saturating_sub(rhs); + } + r + } + } + /// Shifts the bits to the left by a specified amount, `rhs`, wrapping the /// truncated bits to the end of the resulting integer. #[inline] @@ -394,6 +488,7 @@ impl Uint { impl Not for Uint { type Output = Self; + #[cfg(not(target_os = "zkvm"))] #[inline] fn not(mut self) -> Self::Output { if BITS == 0 { @@ -405,6 +500,31 @@ impl Not for Uint { self.limbs[LIMBS - 1] &= Self::MASK; self } + + #[cfg(target_os = "zkvm")] + #[inline(always)] + fn not(mut self) -> Self::Output { + use crate::support::zkvm::zkvm_u256_wrapping_sub_impl; + if BITS == 256 { + unsafe { + zkvm_u256_wrapping_sub_impl( + self.limbs.as_mut_ptr() as *mut u8, + Self::MAX.limbs.as_ptr() as *const u8, + self.limbs.as_ptr() as *const u8, + ); + } + self + } else { + if BITS == 0 { + return Self::ZERO; + } + for limb in &mut self.limbs { + *limb = u64::not(*limb); + } + self.limbs[LIMBS - 1] &= Self::MASK; + self + } + } } impl Not for &Uint { @@ -417,7 +537,7 @@ impl Not for &Uint { } macro_rules! impl_bit_op { - ($trait:ident, $fn:ident, $trait_assign:ident, $fn_assign:ident) => { + ($trait:ident, $fn:ident, $trait_assign:ident, $fn_assign:ident, $fn_zkvm_impl:ident) => { impl $trait_assign> for Uint { @@ -430,12 +550,26 @@ macro_rules! impl_bit_op { impl $trait_assign<&Uint> for Uint { + #[cfg(not(target_os = "zkvm"))] #[inline] fn $fn_assign(&mut self, rhs: &Uint) { for i in 0..LIMBS { u64::$fn_assign(&mut self.limbs[i], rhs.limbs[i]); } } + + #[cfg(target_os = "zkvm")] + #[inline(always)] + fn $fn_assign(&mut self, rhs: &Uint) { + use crate::support::zkvm::$fn_zkvm_impl; + unsafe { + $fn_zkvm_impl( + self.limbs.as_mut_ptr() as *mut u8, + self.limbs.as_ptr() as *const u8, + rhs.limbs.as_ptr() as *const u8, + ); + } + } } impl $trait> @@ -487,9 +621,27 @@ macro_rules! impl_bit_op { }; } -impl_bit_op!(BitOr, bitor, BitOrAssign, bitor_assign); -impl_bit_op!(BitAnd, bitand, BitAndAssign, bitand_assign); -impl_bit_op!(BitXor, bitxor, BitXorAssign, bitxor_assign); +impl_bit_op!( + BitOr, + bitor, + BitOrAssign, + bitor_assign, + zkvm_u256_bitor_impl +); +impl_bit_op!( + BitAnd, + bitand, + BitAndAssign, + bitand_assign, + zkvm_u256_bitand_impl +); +impl_bit_op!( + BitXor, + bitxor, + BitXorAssign, + bitxor_assign, + zkvm_u256_bitxor_impl +); impl Shl for Uint { type Output = Self; diff --git a/src/cmp.rs b/src/cmp.rs index 130ab43..c1b6940 100644 --- a/src/cmp.rs +++ b/src/cmp.rs @@ -9,10 +9,26 @@ impl PartialOrd for Uint { } impl Ord for Uint { + #[cfg(not(target_os = "zkvm"))] #[inline] fn cmp(&self, rhs: &Self) -> Ordering { crate::algorithms::cmp(self.as_limbs(), rhs.as_limbs()) } + + #[cfg(target_os = "zkvm")] + #[inline] + fn cmp(&self, rhs: &Self) -> Ordering { + use crate::support::zkvm::zkvm_u256_cmp_impl; + if BITS == 256 { + return unsafe { + zkvm_u256_cmp_impl( + self.limbs.as_ptr() as *const u8, + rhs.limbs.as_ptr() as *const u8, + ) + }; + } + self.cmp(rhs) + } } impl Uint { diff --git a/src/lib.rs b/src/lib.rs index f40111a..9c6ae08 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -143,12 +143,21 @@ pub mod nightly { /// requires same-sized arguments and returns a pair of lower and higher bits. /// /// [std-overflow]: https://doc.rust-lang.org/reference/expressions/operator-expr.html#overflow +#[cfg(not(target_os = "zkvm"))] #[derive(Clone, Copy, Eq, PartialEq, Hash)] #[repr(transparent)] pub struct Uint { limbs: [u64; LIMBS], } +/// In case of zkvm, we use the native implementations of `Clone` and `Eq` +#[cfg(target_os = "zkvm")] +#[derive(Hash)] +#[repr(transparent)] +pub struct Uint { + limbs: [u64; LIMBS], +} + impl Uint { /// The size of this integer type in 64-bit limbs. pub const LIMBS: usize = { diff --git a/src/mul.rs b/src/mul.rs index 5105603..de0f4a2 100644 --- a/src/mul.rs +++ b/src/mul.rs @@ -58,6 +58,7 @@ impl Uint { } /// Computes `self * rhs`, wrapping around at the boundary of the type. + #[cfg(not(target_os = "zkvm"))] #[inline(always)] #[must_use] pub fn wrapping_mul(self, rhs: Self) -> Self { @@ -69,6 +70,25 @@ impl Uint { result } + /// Computes `self * rhs`, wrapping around at the boundary of the type. + #[cfg(target_os = "zkvm")] + #[inline(always)] + #[must_use] + pub fn wrapping_mul(mut self, rhs: Self) -> Self { + use crate::support::zkvm::zkvm_u256_wrapping_mul_impl; + if BITS == 256 { + unsafe { + zkvm_u256_wrapping_mul_impl( + self.limbs.as_mut_ptr() as *mut u8, + self.limbs.as_ptr() as *const u8, + rhs.limbs.as_ptr() as *const u8, + ); + } + return self; + } + self.overflowing_mul(rhs).0 + } + /// Computes the inverse modulo $2^{\mathtt{BITS}}$ of `self`, returning /// [`None`] if the inverse does not exist. #[inline] diff --git a/src/support/mod.rs b/src/support/mod.rs index 2894c35..2c09297 100644 --- a/src/support/mod.rs +++ b/src/support/mod.rs @@ -29,6 +29,8 @@ pub mod ssz; mod subtle; mod valuable; mod zeroize; +#[cfg(target_os = "zkvm")] +pub mod zkvm; // FEATURE: Support for many more traits and crates. // * https://crates.io/crates/der diff --git a/src/support/zkvm.rs b/src/support/zkvm.rs new file mode 100644 index 0000000..27833f7 --- /dev/null +++ b/src/support/zkvm.rs @@ -0,0 +1,76 @@ +/// This file allows users to define more efficient native implementations for +/// the zkvm target which can be used to speed up the operations on [Uint]'s. +/// +/// The functions defined here are not meant to be used by the user, but rather +/// to be used by the library to define more efficient native implementations +/// for the zkvm target. +/// +/// Currently these functions are specified to support only 256 bit [Uint]'s and +/// take pointers to their limbs as arguments. Providing other sizes +/// will result in an undefined behavior. +use core::{cmp::Ordering, mem::MaybeUninit}; + +use crate::Uint; + +extern "C" { + /// Add two 256-bit numbers and store in `result`. + pub fn zkvm_u256_wrapping_add_impl(result: *mut u8, a: *const u8, b: *const u8); + /// Subtract two 256-bit numbers and store in `result`. + pub fn zkvm_u256_wrapping_sub_impl(result: *mut u8, a: *const u8, b: *const u8); + /// Multiply two 256-bit numbers and store in `result`. + pub fn zkvm_u256_wrapping_mul_impl(result: *mut u8, a: *const u8, b: *const u8); + /// Bitwise XOR two 256-bit numbers and store in `result`. + pub fn zkvm_u256_bitxor_impl(result: *mut u8, a: *const u8, b: *const u8); + /// Bitwise AND two 256-bit numbers and store in `result`. + pub fn zkvm_u256_bitand_impl(result: *mut u8, a: *const u8, b: *const u8); + /// Bitwise OR two 256-bit numbers and store in `result`. + pub fn zkvm_u256_bitor_impl(result: *mut u8, a: *const u8, b: *const u8); + /// Shift left two 256-bit numbers and store in `result`. + pub fn zkvm_u256_wrapping_shl_impl(result: *mut u8, a: *const u8, b: *const u8); + /// Shift right two 256-bit numbers and store in `result`. + pub fn zkvm_u256_wrapping_shr_impl(result: *mut u8, a: *const u8, b: *const u8); + /// Arithmetic shift right two 256-bit numbers and store in `result`. + pub fn zkvm_u256_arithmetic_shr_impl(result: *mut u8, a: *const u8, b: *const u8); + /// Check if two 256-bit numbers are equal. + pub fn zkvm_u256_eq_impl(a: *const u8, b: *const u8) -> bool; + /// Compare two 256-bit numbers. + pub fn zkvm_u256_cmp_impl(a: *const u8, b: *const u8) -> Ordering; + /// Clone a 256-bit number into `result`. `zero` has to + pub fn zkvm_u256_clone_impl(result: *mut u8, a: *const u8, zero: *const u8); +} + +impl Copy for Uint {} + +impl Clone for Uint { + fn clone(&self) -> Self { + if BITS == 256 { + let mut uninit: MaybeUninit = MaybeUninit::uninit(); + unsafe { + zkvm_u256_clone_impl( + (*uninit.as_mut_ptr()).limbs.as_mut_ptr() as *mut u8, + self.limbs.as_ptr() as *const u8, + Self::ZERO.limbs.as_ptr() as *const u8, + ); + } + return unsafe { uninit.assume_init() }; + } + Self { limbs: self.limbs } + } +} + +impl PartialEq for Uint { + fn eq(&self, other: &Self) -> bool { + if BITS == 256 { + unsafe { + zkvm_u256_eq_impl( + self.limbs.as_ptr() as *const u8, + other.limbs.as_ptr() as *const u8, + ) + } + } else { + self.limbs == other.limbs + } + } +} + +impl Eq for Uint {}