Skip to content

Commit

Permalink
Fix bit operations. (#2249)
Browse files Browse the repository at this point in the history
This fixes a bug in witjitgen where the bit operations used for masking
in bit-decomposition would use field elements as types for the masks.
The problem is that we sometimes mask using masks outside the field and
the proper type for masks is FieldElement::Integer.
  • Loading branch information
chriseth authored Dec 18, 2024
1 parent bdfa4e1 commit d98b7eb
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 76 deletions.
11 changes: 5 additions & 6 deletions executor/src/witgen/jit/affine_symbolic_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ impl<T: FieldElement, V: Ord + Clone + Display> AffineSymbolicExpression<T, V> {
} else {
covered_bits |= mask;
}
let masked = -&self.offset & T::from(mask).into();
let masked = -&self.offset & mask;
effects.push(Effect::Assignment(
var.clone(),
masked.integer_div(&coeff.into()),
Expand All @@ -289,11 +289,10 @@ impl<T: FieldElement, V: Ord + Clone + Display> AffineSymbolicExpression<T, V> {

// We need to assert that the masks cover "-offset",
// otherwise the equation is not solvable.
// We assert -offset & !masks == 0 <=> -offset == -offset | masks.
// We use the latter since we cannot properly bit-negate inside the field.
// We assert -offset & !masks == 0
effects.push(Assertion::assert_eq(
-&self.offset,
-&self.offset | T::from(covered_bits).into(),
-&self.offset & !covered_bits,
T::from(0).into(),
));

ProcessResult::complete(effects)
Expand Down Expand Up @@ -566,7 +565,7 @@ mod test {
"a = ((-(10 + Z) & 65280) // 256);
b = ((-(10 + Z) & 16711680) // 65536);
c = ((-(10 + Z) & 4278190080) // 16777216);
assert -(10 + Z) == (-(10 + Z) | 4294967040);
assert (-(10 + Z) & 18446744069414584575) == 0;
"
);
}
Expand Down
27 changes: 24 additions & 3 deletions executor/src/witgen/jit/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::witgen::{

use super::{
affine_symbolic_expression::{Assertion, Effect},
symbolic_expression::{BinaryOperator, SymbolicExpression, UnaryOperator},
symbolic_expression::{BinaryOperator, BitOperator, SymbolicExpression, UnaryOperator},
variable::{Cell, Variable},
};

Expand Down Expand Up @@ -269,8 +269,6 @@ fn format_expression<T: FieldElement>(e: &SymbolicExpression<T, Variable>) -> St
BinaryOperator::Mul => format!("({left} * {right})"),
BinaryOperator::Div => format!("({left} / {right})"),
BinaryOperator::IntegerDiv => format!("integer_div({left}, {right})"),
BinaryOperator::BitAnd => format!("({left} & {right})"),
BinaryOperator::BitOr => format!("({left} | {right})"),
}
}
SymbolicExpression::UnaryOperation(op, inner, _) => {
Expand All @@ -279,6 +277,12 @@ fn format_expression<T: FieldElement>(e: &SymbolicExpression<T, Variable>) -> St
UnaryOperator::Neg => format!("-{inner}"),
}
}
SymbolicExpression::BitOperation(left, op, right, _) => {
let left = format_expression(left);
match op {
BitOperator::And => format!("({left} & {right})"),
}
}
}
}

Expand Down Expand Up @@ -817,4 +821,21 @@ extern \"C\" fn witgen(
(f.function)(params);
assert_eq!(y_val, GoldilocksField::from(7 * 2));
}

#[test]
fn bit_ops() {
let a = cell("a", 0, 0);
let x = cell("x", 1, 0);
// Test that the operators & and | work with numbers larger than the modulus.
let large_num =
<powdr_number::GoldilocksField as powdr_number::FieldElement>::Integer::from(
0xffffffffffffffff_u64,
);
assert!(large_num.to_string().parse::<u64>().unwrap() == 0xffffffffffffffff_u64);
assert!(large_num > GoldilocksField::modulus());
let effects = vec![assignment(&x, symbol(&a) & large_num)];
let known_inputs = vec![a.clone()];
let code = witgen_code(&known_inputs, &effects);
assert!(code.contains(&format!("let c_x_1_0 = (c_a_0_0 & {large_num});")));
}
}
102 changes: 41 additions & 61 deletions executor/src/witgen/jit/symbolic_expression.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::{
fmt::{self, Display, Formatter},
ops::{Add, BitAnd, BitOr, Mul, Neg},
ops::{Add, BitAnd, Mul, Neg},
rc::Rc,
};

use num_traits::Zero;
use powdr_number::FieldElement;

use crate::witgen::range_constraints::RangeConstraint;
Expand All @@ -25,6 +26,12 @@ pub enum SymbolicExpression<T: FieldElement, S> {
Option<RangeConstraint<T>>,
),
UnaryOperation(UnaryOperator, Rc<Self>, Option<RangeConstraint<T>>),
BitOperation(
Rc<Self>,
BitOperator,
T::Integer,
Option<RangeConstraint<T>>,
),
}

#[derive(Debug, Clone)]
Expand All @@ -36,8 +43,11 @@ pub enum BinaryOperator {
Div,
/// Integer division, i.e. convert field elements to unsigned integer and divide.
IntegerDiv,
BitAnd,
BitOr,
}

#[derive(Debug, Clone)]
pub enum BitOperator {
And,
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -78,7 +88,8 @@ impl<T: FieldElement, S> SymbolicExpression<T, S> {
SymbolicExpression::Concrete(v) => Some(RangeConstraint::from_value(*v)),
SymbolicExpression::Symbol(.., rc)
| SymbolicExpression::BinaryOperation(.., rc)
| SymbolicExpression::UnaryOperation(.., rc) => rc.clone(),
| SymbolicExpression::UnaryOperation(.., rc)
| SymbolicExpression::BitOperation(.., rc) => rc.clone(),
}
}

Expand All @@ -87,7 +98,8 @@ impl<T: FieldElement, S> SymbolicExpression<T, S> {
SymbolicExpression::Concrete(n) => Some(*n),
SymbolicExpression::Symbol(..)
| SymbolicExpression::BinaryOperation(..)
| SymbolicExpression::UnaryOperation(..) => None,
| SymbolicExpression::UnaryOperation(..)
| SymbolicExpression::BitOperation(..) => None,
}
}
}
Expand All @@ -108,6 +120,9 @@ impl<T: FieldElement, V: Display> Display for SymbolicExpression<T, V> {
write!(f, "({lhs} {op} {rhs})")
}
SymbolicExpression::UnaryOperation(op, expr, _) => write!(f, "{op}{expr}"),
SymbolicExpression::BitOperation(expr, op, n, _) => {
write!(f, "({expr} {op} {n})")
}
}
}
}
Expand All @@ -120,8 +135,6 @@ impl Display for BinaryOperator {
BinaryOperator::Mul => write!(f, "*"),
BinaryOperator::Div => write!(f, "/"),
BinaryOperator::IntegerDiv => write!(f, "//"),
BinaryOperator::BitAnd => write!(f, "&"),
BinaryOperator::BitOr => write!(f, "|"),
}
}
}
Expand All @@ -134,6 +147,14 @@ impl Display for UnaryOperator {
}
}

impl Display for BitOperator {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
BitOperator::And => write!(f, "&"),
}
}
}

impl<T: FieldElement, V> From<T> for SymbolicExpression<T, V> {
fn from(n: T) -> Self {
SymbolicExpression::Concrete(n)
Expand Down Expand Up @@ -271,64 +292,23 @@ impl<T: FieldElement, V: Clone> SymbolicExpression<T, V> {
}
}

impl<T: FieldElement, V: Clone> BitAnd for &SymbolicExpression<T, V> {
impl<T: FieldElement, V: Clone> BitAnd<T::Integer> for SymbolicExpression<T, V> {
type Output = SymbolicExpression<T, V>;

fn bitand(self, rhs: Self) -> Self::Output {
if let (SymbolicExpression::Concrete(a), SymbolicExpression::Concrete(b)) = (self, rhs) {
SymbolicExpression::Concrete(T::from(a.to_integer() & b.to_integer()))
} else if self.is_known_zero() || rhs.is_known_zero() {
fn bitand(self, rhs: T::Integer) -> Self::Output {
if let SymbolicExpression::Concrete(a) = self {
SymbolicExpression::Concrete(T::from(a.to_integer() & rhs))
} else if self.is_known_zero() || rhs.is_zero() {
SymbolicExpression::Concrete(T::from(0))
} else {
SymbolicExpression::BinaryOperation(
Rc::new(self.clone()),
BinaryOperator::BitAnd,
Rc::new(rhs.clone()),
self.range_constraint()
.zip(rhs.range_constraint())
.map(|(a, b)| RangeConstraint::from_mask(*a.mask() & *b.mask())),
)
}
}
}

impl<T: FieldElement, V: Clone> BitAnd for SymbolicExpression<T, V> {
type Output = SymbolicExpression<T, V>;

fn bitand(self, rhs: Self) -> Self::Output {
&self & &rhs
}
}

impl<T: FieldElement, V: Clone> BitOr for &SymbolicExpression<T, V> {
type Output = SymbolicExpression<T, V>;

fn bitor(self, rhs: Self) -> Self::Output {
if let (SymbolicExpression::Concrete(a), SymbolicExpression::Concrete(b)) = (self, rhs) {
let v = a.to_integer() | b.to_integer();
assert!(v < T::modulus());
SymbolicExpression::Concrete(T::from(v))
} else if self.is_known_zero() {
rhs.clone()
} else if rhs.is_known_zero() {
self.clone()
} else {
SymbolicExpression::BinaryOperation(
Rc::new(self.clone()),
BinaryOperator::BitOr,
Rc::new(rhs.clone()),
self.range_constraint()
.zip(rhs.range_constraint())
.map(|(a, b)| RangeConstraint::from_mask(*a.mask() | *b.mask())),
)
let rc = Some(RangeConstraint::from_mask(
if let Some(rc) = self.range_constraint() {
*rc.mask() & rhs
} else {
rhs
},
));
SymbolicExpression::BitOperation(Rc::new(self), BitOperator::And, rhs, rc)
}
}
}

impl<T: FieldElement, V: Clone> BitOr for SymbolicExpression<T, V> {
type Output = SymbolicExpression<T, V>;

fn bitor(self, rhs: Self) -> Self::Output {
&self | &rhs
}
}
12 changes: 6 additions & 6 deletions executor/src/witgen/jit/witgen_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -489,23 +489,23 @@ namespace Xor(256 * 256);
"\
Xor::A_byte[6] = ((Xor::A[7] & 4278190080) // 16777216);
Xor::A[6] = (Xor::A[7] & 16777215);
assert Xor::A[7] == (Xor::A[7] | 4294967295);
assert (Xor::A[7] & 18446744069414584320) == 0;
Xor::C_byte[6] = ((Xor::C[7] & 4278190080) // 16777216);
Xor::C[6] = (Xor::C[7] & 16777215);
assert Xor::C[7] == (Xor::C[7] | 4294967295);
assert (Xor::C[7] & 18446744069414584320) == 0;
Xor::A_byte[5] = ((Xor::A[6] & 16711680) // 65536);
Xor::A[5] = (Xor::A[6] & 65535);
assert Xor::A[6] == (Xor::A[6] | 16777215);
assert (Xor::A[6] & 18446744073692774400) == 0;
Xor::C_byte[5] = ((Xor::C[6] & 16711680) // 65536);
Xor::C[5] = (Xor::C[6] & 65535);
assert Xor::C[6] == (Xor::C[6] | 16777215);
assert (Xor::C[6] & 18446744073692774400) == 0;
machine_call(0, [Known(Xor::A_byte[6]), Unknown(Xor::B_byte[6]), Known(Xor::C_byte[6])]);
Xor::A_byte[4] = ((Xor::A[5] & 65280) // 256);
Xor::A[4] = (Xor::A[5] & 255);
assert Xor::A[5] == (Xor::A[5] | 65535);
assert (Xor::A[5] & 18446744073709486080) == 0;
Xor::C_byte[4] = ((Xor::C[5] & 65280) // 256);
Xor::C[4] = (Xor::C[5] & 255);
assert Xor::C[5] == (Xor::C[5] | 65535);
assert (Xor::C[5] & 18446744073709486080) == 0;
machine_call(0, [Known(Xor::A_byte[5]), Unknown(Xor::B_byte[5]), Known(Xor::C_byte[5])]);
Xor::A_byte[3] = Xor::A[4];
Xor::C_byte[3] = Xor::C[4];
Expand Down

0 comments on commit d98b7eb

Please sign in to comment.