Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: zkvm target for the ruint library #7

Merged
merged 4 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions src/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,25 +143,74 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
}

/// Computes `self + rhs`, wrapping around at the boundary of the type.
#[cfg(not(target_os = "zkvm"))]
#[inline(always)]
#[must_use]
pub const fn wrapping_add(self, rhs: Self) -> Self {
self.overflowing_add(rhs).0
}

/// Computes `self + rhs`, wrapping around at the boundary of the type.
#[cfg(target_os = "zkvm")]
#[inline(always)]
#[must_use]
pub fn wrapping_add(mut self, rhs: Self) -> Self {
use crate::support::zkvm::wrapping_add_impl;
if BITS == 256 {
unsafe {
wrapping_add_impl(
self.limbs.as_ptr(),
rhs.limbs.as_ptr(),
self.limbs.as_mut_ptr(),
);
}
return self;
}
self.overflowing_add(rhs).0
}

/// Computes `-self`, wrapping around at the boundary of the type.
#[cfg(not(target_os = "zkvm"))]
#[inline(always)]
#[must_use]
pub const fn wrapping_neg(self) -> Self {
self.overflowing_neg().0
}

/// Computes `-self`, wrapping around at the boundary of the type.
#[cfg(target_os = "zkvm")]
#[inline(always)]
#[must_use]
pub fn wrapping_neg(self) -> Self {
Self::ZERO.wrapping_sub(self)
}

/// Computes `self - rhs`, wrapping around at the boundary of the type.
#[cfg(not(target_os = "zkvm"))]
#[inline(always)]
#[must_use]
pub const fn wrapping_sub(self, rhs: Self) -> Self {
self.overflowing_sub(rhs).0
}

/// Computes `self - rhs`, wrapping around at the boundary of the type.
#[cfg(target_os = "zkvm")]
#[inline(always)]
#[must_use]
pub fn wrapping_sub(mut self, rhs: Self) -> Self {
use crate::support::zkvm::wrapping_sub_impl;
if BITS == 256 {
unsafe {
wrapping_sub_impl(
self.limbs.as_ptr(),
rhs.limbs.as_ptr(),
self.limbs.as_mut_ptr(),
);
}
return self;
}
self.overflowing_sub(rhs).0
}
}

impl<const BITS: usize, const LIMBS: usize> Neg for Uint<BITS, LIMBS> {
Expand Down
130 changes: 126 additions & 4 deletions src/bits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,38 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
///
/// Note: This differs from [`u64::wrapping_shl`] which first reduces `rhs`
/// by `BITS` (which is IMHO not very useful).
#[cfg(not(target_os = "zkvm"))]
#[inline(always)]
#[must_use]
pub fn wrapping_shl(self, rhs: usize) -> Self {
self.overflowing_shl(rhs).0
}

/// Left shift by `rhs` bits.
///
/// Returns $\mod{\mathtt{value} ⋅ 2^{\mathtt{rhs}}}_{2^{\mathtt{BITS}}}$.
///
/// Note: This differs from [`u64::wrapping_shl`] which first reduces `rhs`
/// by `BITS` (which is IMHO not very useful).
#[cfg(target_os = "zkvm")]
#[inline(always)]
#[must_use]
pub fn wrapping_shl(mut self, rhs: usize) -> Self {
if BITS == 256 {
if rhs >= 256 {
return Self::ZERO;
}
use crate::support::zkvm::wrapping_shl_impl;
let rhs = rhs as u64;
unsafe {
wrapping_shl_impl(self.limbs.as_ptr(), [rhs].as_ptr(), self.limbs.as_mut_ptr());
}
self
} else {
self.overflowing_shl(rhs).0
}
}

/// Checked right shift by `rhs` bits.
///
/// $$
Expand Down Expand Up @@ -344,13 +370,43 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
///
/// Note: This differs from [`u64::wrapping_shr`] which first reduces `rhs`
/// by `BITS` (which is IMHO not very useful).
#[cfg(not(target_os = "zkvm"))]
#[inline(always)]
#[must_use]
pub fn wrapping_shr(self, rhs: usize) -> Self {
self.overflowing_shr(rhs).0
}

/// Right shift by `rhs` bits.
///
/// $$
/// \mathtt{wrapping\\_shr}(\mathtt{self}, \mathtt{rhs}) =
/// \floor{\frac{\mathtt{self}}{2^{\mathtt{rhs}}}}
/// $$
///
/// Note: This differs from [`u64::wrapping_shr`] which first reduces `rhs`
/// by `BITS` (which is IMHO not very useful).
#[cfg(target_os = "zkvm")]
#[inline(always)]
#[must_use]
pub fn wrapping_shr(mut self, rhs: usize) -> Self {
if BITS == 256 {
if rhs >= 256 {
return Self::ZERO;
}
use crate::support::zkvm::wrapping_shr_impl;
let rhs = rhs as u64;
unsafe {
wrapping_shr_impl(self.limbs.as_ptr(), [rhs].as_ptr(), self.limbs.as_mut_ptr());
}
self
} else {
self.overflowing_shr(rhs).0
}
}

/// Arithmetic shift right by `rhs` bits.
#[cfg(not(target_os = "zkvm"))]
#[inline]
#[must_use]
pub fn arithmetic_shr(self, rhs: usize) -> Self {
Expand All @@ -365,6 +421,32 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
r
}

/// Arithmetic shift right by `rhs` bits.
#[cfg(target_os = "zkvm")]
#[inline]
#[must_use]
pub fn arithmetic_shr(mut self, rhs: usize) -> Self {
if BITS == 256 {
let rhs = if rhs >= 256 { 255 } else { rhs };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to review arithmetic shifts for this. It is correct.

use crate::support::zkvm::arithmetic_shr_impl;
let rhs = rhs as u64;
unsafe {
arithmetic_shr_impl(self.limbs.as_ptr(), [rhs].as_ptr(), self.limbs.as_mut_ptr());
jonathanpwang marked this conversation as resolved.
Show resolved Hide resolved
}
self
} else {
if BITS == 0 {
return Self::ZERO;
}
let sign = self.bit(BITS - 1);
let mut r = self >> rhs;
if sign {
r |= Self::MAX << BITS.saturating_sub(rhs);
}
r
}
}

/// Shifts the bits to the left by a specified amount, `rhs`, wrapping the
/// truncated bits to the end of the resulting integer.
#[inline]
Expand Down Expand Up @@ -394,6 +476,7 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
impl<const BITS: usize, const LIMBS: usize> Not for Uint<BITS, LIMBS> {
type Output = Self;

#[cfg(not(target_os = "zkvm"))]
#[inline]
fn not(mut self) -> Self::Output {
if BITS == 0 {
Expand All @@ -405,6 +488,31 @@ impl<const BITS: usize, const LIMBS: usize> Not for Uint<BITS, LIMBS> {
self.limbs[LIMBS - 1] &= Self::MASK;
self
}

#[cfg(target_os = "zkvm")]
#[inline(always)]
fn not(mut self) -> Self::Output {
use crate::support::zkvm::wrapping_sub_impl;
if BITS == 256 {
unsafe {
wrapping_sub_impl(
Self::MAX.limbs.as_ptr(),
self.limbs.as_ptr(),
self.limbs.as_mut_ptr(),
);
}
self
} else {
if BITS == 0 {
return Self::ZERO;
}
for limb in &mut self.limbs {
*limb = u64::not(*limb);
}
self.limbs[LIMBS - 1] &= Self::MASK;
self
}
}
}

impl<const BITS: usize, const LIMBS: usize> Not for &Uint<BITS, LIMBS> {
Expand All @@ -417,7 +525,7 @@ impl<const BITS: usize, const LIMBS: usize> Not for &Uint<BITS, LIMBS> {
}

macro_rules! impl_bit_op {
($trait:ident, $fn:ident, $trait_assign:ident, $fn_assign:ident) => {
($trait:ident, $fn:ident, $trait_assign:ident, $fn_assign:ident, $fn_zkvm_impl:ident) => {
impl<const BITS: usize, const LIMBS: usize> $trait_assign<Uint<BITS, LIMBS>>
for Uint<BITS, LIMBS>
{
Expand All @@ -430,12 +538,26 @@ macro_rules! impl_bit_op {
impl<const BITS: usize, const LIMBS: usize> $trait_assign<&Uint<BITS, LIMBS>>
for Uint<BITS, LIMBS>
{
#[cfg(not(target_os = "zkvm"))]
#[inline]
fn $fn_assign(&mut self, rhs: &Uint<BITS, LIMBS>) {
for i in 0..LIMBS {
u64::$fn_assign(&mut self.limbs[i], rhs.limbs[i]);
}
}

#[cfg(target_os = "zkvm")]
#[inline(always)]
fn $fn_assign(&mut self, rhs: &Uint<BITS, LIMBS>) {
use crate::support::zkvm::$fn_zkvm_impl;
unsafe {
$fn_zkvm_impl(
self.limbs.as_ptr(),
rhs.limbs.as_ptr(),
self.limbs.as_mut_ptr(),
);
}
}
}

impl<const BITS: usize, const LIMBS: usize> $trait<Uint<BITS, LIMBS>>
Expand Down Expand Up @@ -487,9 +609,9 @@ macro_rules! impl_bit_op {
};
}

impl_bit_op!(BitOr, bitor, BitOrAssign, bitor_assign);
impl_bit_op!(BitAnd, bitand, BitAndAssign, bitand_assign);
impl_bit_op!(BitXor, bitxor, BitXorAssign, bitxor_assign);
impl_bit_op!(BitOr, bitor, BitOrAssign, bitor_assign, bitor_impl);
impl_bit_op!(BitAnd, bitand, BitAndAssign, bitand_assign, bitand_impl);
impl_bit_op!(BitXor, bitxor, BitXorAssign, bitxor_assign, bitxor_impl);

impl<const BITS: usize, const LIMBS: usize> Shl<Self> for Uint<BITS, LIMBS> {
type Output = Self;
Expand Down
11 changes: 11 additions & 0 deletions src/cmp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,21 @@ impl<const BITS: usize, const LIMBS: usize> PartialOrd for Uint<BITS, LIMBS> {
}

impl<const BITS: usize, const LIMBS: usize> Ord for Uint<BITS, LIMBS> {
#[cfg(not(target_os = "zkvm"))]
#[inline]
fn cmp(&self, rhs: &Self) -> Ordering {
crate::algorithms::cmp(self.as_limbs(), rhs.as_limbs())
}

#[cfg(target_os = "zkvm")]
#[inline]
fn cmp(&self, rhs: &Self) -> Ordering {
use crate::support::zkvm::cmp_impl;
if BITS == 256 {
return unsafe { cmp_impl(self.limbs.as_ptr(), rhs.limbs.as_ptr()) };
}
self.cmp(rhs)
}
}

impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
Expand Down
9 changes: 9 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,21 @@ pub mod nightly {
/// requires same-sized arguments and returns a pair of lower and higher bits.
///
/// [std-overflow]: https://doc.rust-lang.org/reference/expressions/operator-expr.html#overflow
#[cfg(not(target_os = "zkvm"))]
#[derive(Clone, Copy, Eq, PartialEq, Hash)]
#[repr(transparent)]
pub struct Uint<const BITS: usize, const LIMBS: usize> {
limbs: [u64; LIMBS],
}

/// In case of zkvm, we use the native implementations of `Clone` and `Eq`
#[cfg(target_os = "zkvm")]
#[derive(Hash)]
#[repr(transparent)]
pub struct Uint<const BITS: usize, const LIMBS: usize> {
limbs: [u64; LIMBS],
}

impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
/// The size of this integer type in 64-bit limbs.
pub const LIMBS: usize = {
Expand Down
20 changes: 20 additions & 0 deletions src/mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
}

/// Computes `self * rhs`, wrapping around at the boundary of the type.
#[cfg(not(target_os = "zkvm"))]
#[inline(always)]
#[must_use]
pub fn wrapping_mul(self, rhs: Self) -> Self {
Expand All @@ -69,6 +70,25 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
result
}

/// Computes `self * rhs`, wrapping around at the boundary of the type.
#[cfg(target_os = "zkvm")]
#[inline(always)]
#[must_use]
pub fn wrapping_mul(mut self, rhs: Self) -> Self {
use crate::support::zkvm::wrapping_mul_impl;
if BITS == 256 {
unsafe {
wrapping_mul_impl(
self.limbs.as_ptr(),
rhs.limbs.as_ptr(),
self.limbs.as_mut_ptr(),
);
}
return self;
}
self.overflowing_mul(rhs).0
}

/// Computes the inverse modulo $2^{\mathtt{BITS}}$ of `self`, returning
/// [`None`] if the inverse does not exist.
#[inline]
Expand Down
2 changes: 2 additions & 0 deletions src/support/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ pub mod ssz;
mod subtle;
mod valuable;
mod zeroize;
#[cfg(target_os = "zkvm")]
pub mod zkvm;

// FEATURE: Support for many more traits and crates.
// * https://crates.io/crates/der
Expand Down
Loading