Skip to content

Commit

Permalink
feat: extern implementations for zkvm target
Browse files Browse the repository at this point in the history
fix: lints

fix: nightly fmt

feat: make extern functions take *u8 instead of *u64

chore: make extern function names more unique
  • Loading branch information
arayikhalatyan authored and jonathanpwang committed Dec 12, 2024
1 parent b91e6f7 commit f75557a
Show file tree
Hide file tree
Showing 7 changed files with 328 additions and 4 deletions.
49 changes: 49 additions & 0 deletions src/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,25 +143,74 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
}

/// Computes `self + rhs`, wrapping around at the boundary of the type.
#[cfg(not(target_os = "zkvm"))]
#[inline(always)]
#[must_use]
pub const fn wrapping_add(self, rhs: Self) -> Self {
self.overflowing_add(rhs).0
}

/// Computes `self + rhs`, wrapping around at the boundary of the type.
#[cfg(target_os = "zkvm")]
#[inline(always)]
#[must_use]
pub fn wrapping_add(mut self, rhs: Self) -> Self {
use crate::support::zkvm::zkvm_u256_wrapping_add_impl;
if BITS == 256 {
unsafe {
zkvm_u256_wrapping_add_impl(
self.limbs.as_mut_ptr() as *mut u8,
self.limbs.as_ptr() as *const u8,
rhs.limbs.as_ptr() as *const u8,
);
}
return self;
}
self.overflowing_add(rhs).0
}

/// Computes `-self`, wrapping around at the boundary of the type.
#[cfg(not(target_os = "zkvm"))]
#[inline(always)]
#[must_use]
pub const fn wrapping_neg(self) -> Self {
self.overflowing_neg().0
}

/// Computes `-self`, wrapping around at the boundary of the type.
#[cfg(target_os = "zkvm")]
#[inline(always)]
#[must_use]
pub fn wrapping_neg(self) -> Self {
Self::ZERO.wrapping_sub(self)
}

/// Computes `self - rhs`, wrapping around at the boundary of the type.
#[cfg(not(target_os = "zkvm"))]
#[inline(always)]
#[must_use]
pub const fn wrapping_sub(self, rhs: Self) -> Self {
self.overflowing_sub(rhs).0
}

/// Computes `self - rhs`, wrapping around at the boundary of the type.
#[cfg(target_os = "zkvm")]
#[inline(always)]
#[must_use]
pub fn wrapping_sub(mut self, rhs: Self) -> Self {
use crate::support::zkvm::zkvm_u256_wrapping_sub_impl;
if BITS == 256 {
unsafe {
zkvm_u256_wrapping_sub_impl(
self.limbs.as_mut_ptr() as *mut u8,
self.limbs.as_ptr() as *const u8,
rhs.limbs.as_ptr() as *const u8,
);
}
return self;
}
self.overflowing_sub(rhs).0
}
}

impl<const BITS: usize, const LIMBS: usize> Neg for Uint<BITS, LIMBS> {
Expand Down
160 changes: 156 additions & 4 deletions src/bits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,42 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
///
/// Note: This differs from [`u64::wrapping_shl`] which first reduces `rhs`
/// by `BITS` (which is IMHO not very useful).
#[cfg(not(target_os = "zkvm"))]
#[inline(always)]
#[must_use]
pub fn wrapping_shl(self, rhs: usize) -> Self {
self.overflowing_shl(rhs).0
}

/// Left shift by `rhs` bits.
///
/// Returns $\mod{\mathtt{value} ⋅ 2^{\mathtt{rhs}}}_{2^{\mathtt{BITS}}}$.
///
/// Note: This differs from [`u64::wrapping_shl`] which first reduces `rhs`
/// by `BITS` (which is IMHO not very useful).
#[cfg(target_os = "zkvm")]
#[inline(always)]
#[must_use]
pub fn wrapping_shl(mut self, rhs: usize) -> Self {
if BITS == 256 {
if rhs >= 256 {
return Self::ZERO;
}
use crate::support::zkvm::zkvm_u256_wrapping_shl_impl;
let rhs = rhs as u64;
unsafe {
zkvm_u256_wrapping_shl_impl(
self.limbs.as_mut_ptr() as *mut u8,
self.limbs.as_ptr() as *const u8,
[rhs].as_ptr() as *const u8,
);
}
self
} else {
self.overflowing_shl(rhs).0
}
}

/// Checked right shift by `rhs` bits.
///
/// $$
Expand Down Expand Up @@ -344,13 +374,47 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
///
/// Note: This differs from [`u64::wrapping_shr`] which first reduces `rhs`
/// by `BITS` (which is IMHO not very useful).
#[cfg(not(target_os = "zkvm"))]
#[inline(always)]
#[must_use]
pub fn wrapping_shr(self, rhs: usize) -> Self {
self.overflowing_shr(rhs).0
}

/// Right shift by `rhs` bits.
///
/// $$
/// \mathtt{wrapping\\_shr}(\mathtt{self}, \mathtt{rhs}) =
/// \floor{\frac{\mathtt{self}}{2^{\mathtt{rhs}}}}
/// $$
///
/// Note: This differs from [`u64::wrapping_shr`] which first reduces `rhs`
/// by `BITS` (which is IMHO not very useful).
#[cfg(target_os = "zkvm")]
#[inline(always)]
#[must_use]
pub fn wrapping_shr(mut self, rhs: usize) -> Self {
if BITS == 256 {
if rhs >= 256 {
return Self::ZERO;
}
use crate::support::zkvm::zkvm_u256_wrapping_shr_impl;
let rhs = rhs as u64;
unsafe {
zkvm_u256_wrapping_shr_impl(
self.limbs.as_mut_ptr() as *mut u8,
self.limbs.as_ptr() as *const u8,
[rhs].as_ptr() as *const u8,
);
}
self
} else {
self.overflowing_shr(rhs).0
}
}

/// Arithmetic shift right by `rhs` bits.
#[cfg(not(target_os = "zkvm"))]
#[inline]
#[must_use]
pub fn arithmetic_shr(self, rhs: usize) -> Self {
Expand All @@ -365,6 +429,36 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
r
}

/// Arithmetic shift right by `rhs` bits.
#[cfg(target_os = "zkvm")]
#[inline]
#[must_use]
pub fn arithmetic_shr(mut self, rhs: usize) -> Self {
if BITS == 256 {
let rhs = if rhs >= 256 { 255 } else { rhs };
use crate::support::zkvm::zkvm_u256_arithmetic_shr_impl;
let rhs = rhs as u64;
unsafe {
zkvm_u256_arithmetic_shr_impl(
self.limbs.as_mut_ptr() as *mut u8,
self.limbs.as_ptr() as *const u8,
[rhs].as_ptr() as *const u8,
);
}
self
} else {
if BITS == 0 {
return Self::ZERO;
}
let sign = self.bit(BITS - 1);
let mut r = self >> rhs;
if sign {
r |= Self::MAX << BITS.saturating_sub(rhs);
}
r
}
}

/// Shifts the bits to the left by a specified amount, `rhs`, wrapping the
/// truncated bits to the end of the resulting integer.
#[inline]
Expand Down Expand Up @@ -394,6 +488,7 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
impl<const BITS: usize, const LIMBS: usize> Not for Uint<BITS, LIMBS> {
type Output = Self;

#[cfg(not(target_os = "zkvm"))]
#[inline]
fn not(mut self) -> Self::Output {
if BITS == 0 {
Expand All @@ -405,6 +500,31 @@ impl<const BITS: usize, const LIMBS: usize> Not for Uint<BITS, LIMBS> {
self.limbs[LIMBS - 1] &= Self::MASK;
self
}

#[cfg(target_os = "zkvm")]
#[inline(always)]
fn not(mut self) -> Self::Output {
use crate::support::zkvm::zkvm_u256_wrapping_sub_impl;
if BITS == 256 {
unsafe {
zkvm_u256_wrapping_sub_impl(
self.limbs.as_mut_ptr() as *mut u8,
Self::MAX.limbs.as_ptr() as *const u8,
self.limbs.as_ptr() as *const u8,
);
}
self
} else {
if BITS == 0 {
return Self::ZERO;
}
for limb in &mut self.limbs {
*limb = u64::not(*limb);
}
self.limbs[LIMBS - 1] &= Self::MASK;
self
}
}
}

impl<const BITS: usize, const LIMBS: usize> Not for &Uint<BITS, LIMBS> {
Expand All @@ -417,7 +537,7 @@ impl<const BITS: usize, const LIMBS: usize> Not for &Uint<BITS, LIMBS> {
}

macro_rules! impl_bit_op {
($trait:ident, $fn:ident, $trait_assign:ident, $fn_assign:ident) => {
($trait:ident, $fn:ident, $trait_assign:ident, $fn_assign:ident, $fn_zkvm_impl:ident) => {
impl<const BITS: usize, const LIMBS: usize> $trait_assign<Uint<BITS, LIMBS>>
for Uint<BITS, LIMBS>
{
Expand All @@ -430,12 +550,26 @@ macro_rules! impl_bit_op {
impl<const BITS: usize, const LIMBS: usize> $trait_assign<&Uint<BITS, LIMBS>>
for Uint<BITS, LIMBS>
{
#[cfg(not(target_os = "zkvm"))]
#[inline]
fn $fn_assign(&mut self, rhs: &Uint<BITS, LIMBS>) {
for i in 0..LIMBS {
u64::$fn_assign(&mut self.limbs[i], rhs.limbs[i]);
}
}

#[cfg(target_os = "zkvm")]
#[inline(always)]
fn $fn_assign(&mut self, rhs: &Uint<BITS, LIMBS>) {
use crate::support::zkvm::$fn_zkvm_impl;
unsafe {
$fn_zkvm_impl(
self.limbs.as_mut_ptr() as *mut u8,
self.limbs.as_ptr() as *const u8,
rhs.limbs.as_ptr() as *const u8,
);
}
}
}

impl<const BITS: usize, const LIMBS: usize> $trait<Uint<BITS, LIMBS>>
Expand Down Expand Up @@ -487,9 +621,27 @@ macro_rules! impl_bit_op {
};
}

impl_bit_op!(BitOr, bitor, BitOrAssign, bitor_assign);
impl_bit_op!(BitAnd, bitand, BitAndAssign, bitand_assign);
impl_bit_op!(BitXor, bitxor, BitXorAssign, bitxor_assign);
impl_bit_op!(
BitOr,
bitor,
BitOrAssign,
bitor_assign,
zkvm_u256_bitor_impl
);
impl_bit_op!(
BitAnd,
bitand,
BitAndAssign,
bitand_assign,
zkvm_u256_bitand_impl
);
impl_bit_op!(
BitXor,
bitxor,
BitXorAssign,
bitxor_assign,
zkvm_u256_bitxor_impl
);

impl<const BITS: usize, const LIMBS: usize> Shl<Self> for Uint<BITS, LIMBS> {
type Output = Self;
Expand Down
16 changes: 16 additions & 0 deletions src/cmp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,26 @@ impl<const BITS: usize, const LIMBS: usize> PartialOrd for Uint<BITS, LIMBS> {
}

impl<const BITS: usize, const LIMBS: usize> Ord for Uint<BITS, LIMBS> {
#[cfg(not(target_os = "zkvm"))]
#[inline]
fn cmp(&self, rhs: &Self) -> Ordering {
crate::algorithms::cmp(self.as_limbs(), rhs.as_limbs())
}

#[cfg(target_os = "zkvm")]
#[inline]
fn cmp(&self, rhs: &Self) -> Ordering {
use crate::support::zkvm::zkvm_u256_cmp_impl;
if BITS == 256 {
return unsafe {
zkvm_u256_cmp_impl(
self.limbs.as_ptr() as *const u8,
rhs.limbs.as_ptr() as *const u8,
)
};
}
self.cmp(rhs)
}
}

impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
Expand Down
9 changes: 9 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,21 @@ pub mod nightly {
/// requires same-sized arguments and returns a pair of lower and higher bits.
///
/// [std-overflow]: https://doc.rust-lang.org/reference/expressions/operator-expr.html#overflow
#[cfg(not(target_os = "zkvm"))]
#[derive(Clone, Copy, Eq, PartialEq, Hash)]
#[repr(transparent)]
pub struct Uint<const BITS: usize, const LIMBS: usize> {
limbs: [u64; LIMBS],
}

/// In case of zkvm, we use the native implementations of `Clone` and `Eq`
#[cfg(target_os = "zkvm")]
#[derive(Hash)]
#[repr(transparent)]
pub struct Uint<const BITS: usize, const LIMBS: usize> {
limbs: [u64; LIMBS],
}

impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
/// The size of this integer type in 64-bit limbs.
pub const LIMBS: usize = {
Expand Down
20 changes: 20 additions & 0 deletions src/mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
}

/// Computes `self * rhs`, wrapping around at the boundary of the type.
#[cfg(not(target_os = "zkvm"))]
#[inline(always)]
#[must_use]
pub fn wrapping_mul(self, rhs: Self) -> Self {
Expand All @@ -69,6 +70,25 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
result
}

/// Computes `self * rhs`, wrapping around at the boundary of the type.
#[cfg(target_os = "zkvm")]
#[inline(always)]
#[must_use]
pub fn wrapping_mul(mut self, rhs: Self) -> Self {
use crate::support::zkvm::zkvm_u256_wrapping_mul_impl;
if BITS == 256 {
unsafe {
zkvm_u256_wrapping_mul_impl(
self.limbs.as_mut_ptr() as *mut u8,
self.limbs.as_ptr() as *const u8,
rhs.limbs.as_ptr() as *const u8,
);
}
return self;
}
self.overflowing_mul(rhs).0
}

/// Computes the inverse modulo $2^{\mathtt{BITS}}$ of `self`, returning
/// [`None`] if the inverse does not exist.
#[inline]
Expand Down
Loading

0 comments on commit f75557a

Please sign in to comment.