Skip to content

Commit

Permalink
Shifting
Browse files Browse the repository at this point in the history
matthiasgoergens committed Dec 12, 2024
1 parent b0ac76d commit 69117ca
Showing 3 changed files with 59 additions and 9 deletions.
58 changes: 57 additions & 1 deletion ceno_zkvm/src/expression.rs
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@ use std::{
fmt::Display,
iter::{Product, Sum},
mem::MaybeUninit,
ops::{Add, AddAssign, Deref, Mul, MulAssign, Neg, Shl, ShlAssign, Sub, SubAssign},
ops::{Add, AddAssign, Deref, Div, Mul, MulAssign, Neg, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign},
};

use ceno_emul::InsnKind;
@@ -361,6 +361,35 @@ impl<E: ExtensionField> ShlAssign<usize> for Expression<E> {
}
}

//

impl<E: ExtensionField> Shr<usize> for Expression<E> {
type Output = Expression<E>;
fn shr(self, rhs: usize) -> Expression<E> {
self / (1_usize << rhs)
}
}

impl<E: ExtensionField> Shr<usize> for &Expression<E> {
type Output = Expression<E>;
fn shr(self, rhs: usize) -> Expression<E> {
self.clone() >> rhs
}
}

impl<E: ExtensionField> Shr<usize> for &mut Expression<E> {
type Output = Expression<E>;
fn shr(self, rhs: usize) -> Expression<E> {
self.clone() >> rhs
}
}

impl<E: ExtensionField> ShrAssign<usize> for Expression<E> {
fn shr_assign(&mut self, rhs: usize) {
*self = self.clone() >> rhs;
}
}

impl<E: ExtensionField> Sum for Expression<E> {
fn sum<I: Iterator<Item = Expression<E>>>(iter: I) -> Expression<E> {
iter.fold(Expression::ZERO, |acc, x| acc + x)
@@ -730,6 +759,33 @@ impl<E: ExtensionField> Mul for Expression<E> {
}
}


macro_rules! div_instances {
(($($t:ty),*)) => {
$(

impl<E: ExtensionField> Div<$t> for Expression<E> {
type Output = Expression<E>;
#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, rhs: $t) -> Expression<E> {
let reduced = (rhs as i128).rem_euclid(E::BaseField::MODULUS_U64 as i128) as u64;
self * E::BaseField::from(reduced).invert().unwrap().to_canonical_u64()
}
}

impl<E: ExtensionField> Div<$t> for &Expression<E> {
type Output = Expression<E>;
#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, rhs: $t) -> Expression<E> {
let reduced = (rhs as i128).rem_euclid(E::BaseField::MODULUS_U64 as i128) as u64;
self * E::BaseField::from(reduced).invert().unwrap().to_canonical_u64()
}
}
)*
};
}
div_instances!((u8, u16, u32, u64, usize, i8, i16, i32, i64, isize));

#[derive(Clone, Debug, Copy)]
pub struct WitIn {
pub id: WitnessId,
6 changes: 1 addition & 5 deletions ceno_zkvm/src/instructions/riscv/insn_base.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use ceno_emul::{StepRecord, Word};
use ff::Field;
use ff_ext::ExtensionField;
use itertools::Itertools;

@@ -420,10 +419,7 @@ impl<E: ExtensionField> MemAddr<E> {
.sum();

// Range check the middle bits, that is the low limb excluding the low bits.
// TODO(Matthias): division here seems suspicious from a soundness perspective.
let shift_right =
Expression::Constant(E::BaseField::from(1 << Self::N_LOW_BITS).invert().unwrap());
let mid_u14 = (&limbs[0] - low_sum) * shift_right;
let mid_u14 = (&limbs[0] - low_sum) >> Self::N_LOW_BITS;
cb.assert_ux::<_, _, 14>(|| "mid_u14", mid_u14)?;

// Range check the high limb.
4 changes: 1 addition & 3 deletions ceno_zkvm/src/instructions/riscv/memory/gadget.rs
Original file line number Diff line number Diff line change
@@ -8,7 +8,6 @@ use crate::{
witness::LkMultiplicity,
};
use ceno_emul::StepRecord;
use ff::Field;
use ff_ext::ExtensionField;
use itertools::izip;
use std::mem::MaybeUninit;
@@ -77,10 +76,9 @@ impl<const N_ZEROS: usize> MemWordChange<N_ZEROS> {

// extract the least significant byte from u16 limb
let rs2_limb_bytes = alloc_bytes(cb, "rs2_limb[0]", 1)?;
let u8_base_inv = E::BaseField::from(1 << 8).invert().unwrap();
cb.assert_ux::<_, _, 8>(
|| "rs2_limb[0].le_bytes[1]",
Expression::Constant(u8_base_inv) * (&rs2_limbs[0] - rs2_limb_bytes[0].expr()),
(&rs2_limbs[0] - rs2_limb_bytes[0].expr()) >> 8,
)?;

// alloc a new witIn to cache degree 2 expression

0 comments on commit 69117ca

Please sign in to comment.