From d98b7ebb58f0f2b88f3f9e68d562b63d57cda185 Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 18 Dec 2024 14:56:29 +0100 Subject: [PATCH] Fix bit operations. (#2249) This fixes a bug in witjitgen where the bit operations used for masking in bit-decomposition would use field elements as types for the masks. The problem is that we sometimes mask using masks outside the field and the proper type for masks is FieldElement::Integer. --- .../witgen/jit/affine_symbolic_expression.rs | 11 +- executor/src/witgen/jit/compiler.rs | 27 ++++- .../src/witgen/jit/symbolic_expression.rs | 102 +++++++----------- executor/src/witgen/jit/witgen_inference.rs | 12 +-- 4 files changed, 76 insertions(+), 76 deletions(-) diff --git a/executor/src/witgen/jit/affine_symbolic_expression.rs b/executor/src/witgen/jit/affine_symbolic_expression.rs index d4b3c9357..d2e5398b8 100644 --- a/executor/src/witgen/jit/affine_symbolic_expression.rs +++ b/executor/src/witgen/jit/affine_symbolic_expression.rs @@ -276,7 +276,7 @@ impl AffineSymbolicExpression { } else { covered_bits |= mask; } - let masked = -&self.offset & T::from(mask).into(); + let masked = -&self.offset & mask; effects.push(Effect::Assignment( var.clone(), masked.integer_div(&coeff.into()), @@ -289,11 +289,10 @@ impl AffineSymbolicExpression { // We need to assert that the masks cover "-offset", // otherwise the equation is not solvable. - // We assert -offset & !masks == 0 <=> -offset == -offset | masks. - // We use the latter since we cannot properly bit-negate inside the field. + // We assert -offset & !masks == 0 effects.push(Assertion::assert_eq( - -&self.offset, - -&self.offset | T::from(covered_bits).into(), + -&self.offset & !covered_bits, + T::from(0).into(), )); ProcessResult::complete(effects) @@ -566,7 +565,7 @@ mod test { "a = ((-(10 + Z) & 65280) // 256); b = ((-(10 + Z) & 16711680) // 65536); c = ((-(10 + Z) & 4278190080) // 16777216); -assert -(10 + Z) == (-(10 + Z) | 4294967040); +assert (-(10 + Z) & 18446744069414584575) == 0; " ); } diff --git a/executor/src/witgen/jit/compiler.rs b/executor/src/witgen/jit/compiler.rs index 6e7b2e6db..e92fbbb3f 100644 --- a/executor/src/witgen/jit/compiler.rs +++ b/executor/src/witgen/jit/compiler.rs @@ -18,7 +18,7 @@ use crate::witgen::{ use super::{ affine_symbolic_expression::{Assertion, Effect}, - symbolic_expression::{BinaryOperator, SymbolicExpression, UnaryOperator}, + symbolic_expression::{BinaryOperator, BitOperator, SymbolicExpression, UnaryOperator}, variable::{Cell, Variable}, }; @@ -269,8 +269,6 @@ fn format_expression(e: &SymbolicExpression) -> St BinaryOperator::Mul => format!("({left} * {right})"), BinaryOperator::Div => format!("({left} / {right})"), BinaryOperator::IntegerDiv => format!("integer_div({left}, {right})"), - BinaryOperator::BitAnd => format!("({left} & {right})"), - BinaryOperator::BitOr => format!("({left} | {right})"), } } SymbolicExpression::UnaryOperation(op, inner, _) => { @@ -279,6 +277,12 @@ fn format_expression(e: &SymbolicExpression) -> St UnaryOperator::Neg => format!("-{inner}"), } } + SymbolicExpression::BitOperation(left, op, right, _) => { + let left = format_expression(left); + match op { + BitOperator::And => format!("({left} & {right})"), + } + } } } @@ -817,4 +821,21 @@ extern \"C\" fn witgen( (f.function)(params); assert_eq!(y_val, GoldilocksField::from(7 * 2)); } + + #[test] + fn bit_ops() { + let a = cell("a", 0, 0); + let x = cell("x", 1, 0); + // Test that the operators & and | work with numbers larger than the modulus. + let large_num = + ::Integer::from( + 0xffffffffffffffff_u64, + ); + assert!(large_num.to_string().parse::().unwrap() == 0xffffffffffffffff_u64); + assert!(large_num > GoldilocksField::modulus()); + let effects = vec![assignment(&x, symbol(&a) & large_num)]; + let known_inputs = vec![a.clone()]; + let code = witgen_code(&known_inputs, &effects); + assert!(code.contains(&format!("let c_x_1_0 = (c_a_0_0 & {large_num});"))); + } } diff --git a/executor/src/witgen/jit/symbolic_expression.rs b/executor/src/witgen/jit/symbolic_expression.rs index 900135c7b..aa27e50c2 100644 --- a/executor/src/witgen/jit/symbolic_expression.rs +++ b/executor/src/witgen/jit/symbolic_expression.rs @@ -1,9 +1,10 @@ use std::{ fmt::{self, Display, Formatter}, - ops::{Add, BitAnd, BitOr, Mul, Neg}, + ops::{Add, BitAnd, Mul, Neg}, rc::Rc, }; +use num_traits::Zero; use powdr_number::FieldElement; use crate::witgen::range_constraints::RangeConstraint; @@ -25,6 +26,12 @@ pub enum SymbolicExpression { Option>, ), UnaryOperation(UnaryOperator, Rc, Option>), + BitOperation( + Rc, + BitOperator, + T::Integer, + Option>, + ), } #[derive(Debug, Clone)] @@ -36,8 +43,11 @@ pub enum BinaryOperator { Div, /// Integer division, i.e. convert field elements to unsigned integer and divide. IntegerDiv, - BitAnd, - BitOr, +} + +#[derive(Debug, Clone)] +pub enum BitOperator { + And, } #[derive(Debug, Clone)] @@ -78,7 +88,8 @@ impl SymbolicExpression { SymbolicExpression::Concrete(v) => Some(RangeConstraint::from_value(*v)), SymbolicExpression::Symbol(.., rc) | SymbolicExpression::BinaryOperation(.., rc) - | SymbolicExpression::UnaryOperation(.., rc) => rc.clone(), + | SymbolicExpression::UnaryOperation(.., rc) + | SymbolicExpression::BitOperation(.., rc) => rc.clone(), } } @@ -87,7 +98,8 @@ impl SymbolicExpression { SymbolicExpression::Concrete(n) => Some(*n), SymbolicExpression::Symbol(..) | SymbolicExpression::BinaryOperation(..) - | SymbolicExpression::UnaryOperation(..) => None, + | SymbolicExpression::UnaryOperation(..) + | SymbolicExpression::BitOperation(..) => None, } } } @@ -108,6 +120,9 @@ impl Display for SymbolicExpression { write!(f, "({lhs} {op} {rhs})") } SymbolicExpression::UnaryOperation(op, expr, _) => write!(f, "{op}{expr}"), + SymbolicExpression::BitOperation(expr, op, n, _) => { + write!(f, "({expr} {op} {n})") + } } } } @@ -120,8 +135,6 @@ impl Display for BinaryOperator { BinaryOperator::Mul => write!(f, "*"), BinaryOperator::Div => write!(f, "/"), BinaryOperator::IntegerDiv => write!(f, "//"), - BinaryOperator::BitAnd => write!(f, "&"), - BinaryOperator::BitOr => write!(f, "|"), } } } @@ -134,6 +147,14 @@ impl Display for UnaryOperator { } } +impl Display for BitOperator { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + BitOperator::And => write!(f, "&"), + } + } +} + impl From for SymbolicExpression { fn from(n: T) -> Self { SymbolicExpression::Concrete(n) @@ -271,64 +292,23 @@ impl SymbolicExpression { } } -impl BitAnd for &SymbolicExpression { +impl BitAnd for SymbolicExpression { type Output = SymbolicExpression; - fn bitand(self, rhs: Self) -> Self::Output { - if let (SymbolicExpression::Concrete(a), SymbolicExpression::Concrete(b)) = (self, rhs) { - SymbolicExpression::Concrete(T::from(a.to_integer() & b.to_integer())) - } else if self.is_known_zero() || rhs.is_known_zero() { + fn bitand(self, rhs: T::Integer) -> Self::Output { + if let SymbolicExpression::Concrete(a) = self { + SymbolicExpression::Concrete(T::from(a.to_integer() & rhs)) + } else if self.is_known_zero() || rhs.is_zero() { SymbolicExpression::Concrete(T::from(0)) } else { - SymbolicExpression::BinaryOperation( - Rc::new(self.clone()), - BinaryOperator::BitAnd, - Rc::new(rhs.clone()), - self.range_constraint() - .zip(rhs.range_constraint()) - .map(|(a, b)| RangeConstraint::from_mask(*a.mask() & *b.mask())), - ) - } - } -} - -impl BitAnd for SymbolicExpression { - type Output = SymbolicExpression; - - fn bitand(self, rhs: Self) -> Self::Output { - &self & &rhs - } -} - -impl BitOr for &SymbolicExpression { - type Output = SymbolicExpression; - - fn bitor(self, rhs: Self) -> Self::Output { - if let (SymbolicExpression::Concrete(a), SymbolicExpression::Concrete(b)) = (self, rhs) { - let v = a.to_integer() | b.to_integer(); - assert!(v < T::modulus()); - SymbolicExpression::Concrete(T::from(v)) - } else if self.is_known_zero() { - rhs.clone() - } else if rhs.is_known_zero() { - self.clone() - } else { - SymbolicExpression::BinaryOperation( - Rc::new(self.clone()), - BinaryOperator::BitOr, - Rc::new(rhs.clone()), - self.range_constraint() - .zip(rhs.range_constraint()) - .map(|(a, b)| RangeConstraint::from_mask(*a.mask() | *b.mask())), - ) + let rc = Some(RangeConstraint::from_mask( + if let Some(rc) = self.range_constraint() { + *rc.mask() & rhs + } else { + rhs + }, + )); + SymbolicExpression::BitOperation(Rc::new(self), BitOperator::And, rhs, rc) } } } - -impl BitOr for SymbolicExpression { - type Output = SymbolicExpression; - - fn bitor(self, rhs: Self) -> Self::Output { - &self | &rhs - } -} diff --git a/executor/src/witgen/jit/witgen_inference.rs b/executor/src/witgen/jit/witgen_inference.rs index 81223339d..f126a1f6a 100644 --- a/executor/src/witgen/jit/witgen_inference.rs +++ b/executor/src/witgen/jit/witgen_inference.rs @@ -489,23 +489,23 @@ namespace Xor(256 * 256); "\ Xor::A_byte[6] = ((Xor::A[7] & 4278190080) // 16777216); Xor::A[6] = (Xor::A[7] & 16777215); -assert Xor::A[7] == (Xor::A[7] | 4294967295); +assert (Xor::A[7] & 18446744069414584320) == 0; Xor::C_byte[6] = ((Xor::C[7] & 4278190080) // 16777216); Xor::C[6] = (Xor::C[7] & 16777215); -assert Xor::C[7] == (Xor::C[7] | 4294967295); +assert (Xor::C[7] & 18446744069414584320) == 0; Xor::A_byte[5] = ((Xor::A[6] & 16711680) // 65536); Xor::A[5] = (Xor::A[6] & 65535); -assert Xor::A[6] == (Xor::A[6] | 16777215); +assert (Xor::A[6] & 18446744073692774400) == 0; Xor::C_byte[5] = ((Xor::C[6] & 16711680) // 65536); Xor::C[5] = (Xor::C[6] & 65535); -assert Xor::C[6] == (Xor::C[6] | 16777215); +assert (Xor::C[6] & 18446744073692774400) == 0; machine_call(0, [Known(Xor::A_byte[6]), Unknown(Xor::B_byte[6]), Known(Xor::C_byte[6])]); Xor::A_byte[4] = ((Xor::A[5] & 65280) // 256); Xor::A[4] = (Xor::A[5] & 255); -assert Xor::A[5] == (Xor::A[5] | 65535); +assert (Xor::A[5] & 18446744073709486080) == 0; Xor::C_byte[4] = ((Xor::C[5] & 65280) // 256); Xor::C[4] = (Xor::C[5] & 255); -assert Xor::C[5] == (Xor::C[5] | 65535); +assert (Xor::C[5] & 18446744073709486080) == 0; machine_call(0, [Known(Xor::A_byte[5]), Unknown(Xor::B_byte[5]), Known(Xor::C_byte[5])]); Xor::A_byte[3] = Xor::A[4]; Xor::C_byte[3] = Xor::C[4];