From c1025455894d98e2322da6529b83341c76c34f4c Mon Sep 17 00:00:00 2001 From: bhargav Date: Tue, 21 Nov 2023 20:17:30 -0800 Subject: [PATCH 1/3] feat: dyn residue algebra --- BENCHMARKS.md | 58 --------------------------- Cargo.toml | 3 ++ benches/BENCHMARKS.md | 0 src/numbers.rs | 93 ++++++++++++++++++++++++++++++++++--------- 4 files changed, 78 insertions(+), 76 deletions(-) create mode 100644 benches/BENCHMARKS.md diff --git a/BENCHMARKS.md b/BENCHMARKS.md index efc8365..e69de29 100644 --- a/BENCHMARKS.md +++ b/BENCHMARKS.md @@ -1,58 +0,0 @@ -# Benchmarks - -## Table of Contents - -- [Overview](#overview) -- [Benchmark Results](#benchmark-results) - - [Number-Theoretic Transform Benchmarks](#number-theoretic-transform-benchmarks) - - [Polynomial Multiplication Benchmarks](#polynomial-multiplication-benchmarks) - -## Overview - -This benchmark comparison report shows the difference in performance between parallel, NTT-based and serial, brute-force -polynomial multiplication algorithms. Each row entry in the first table is an n-degree forward NTT and each row entry in the second table represents an n-degree polynomial multiplication. - -Computer Stats: - -``` -CPU(s): 16 -Thread(s) per core: 2 -Core(s) per socket: 8 -Socket(s): 1 -``` - -## Benchmark Results - -### Number-Theoretic Transform Benchmarks - -| | `NTT` | -|:------------|:-------------------------- | -| **`64`** | `202.26 us` (✅ **1.00x**) | -| **`128`** | `354.08 us` (✅ **1.00x**) | -| **`256`** | `665.54 us` (✅ **1.00x**) | -| **`512`** | `1.12 ms` (✅ **1.00x**) | -| **`1024`** | `2.00 ms` (✅ **1.00x**) | -| **`2048`** | `3.94 ms` (✅ **1.00x**) | -| **`4096`** | `7.69 ms` (✅ **1.00x**) | -| **`8192`** | `16.13 ms` (✅ **1.00x**) | -| **`16384`** | `34.01 ms` (✅ **1.00x**) | -| **`32768`** | `74.65 ms` (✅ **1.00x**) | - -### Polynomial Multiplication Benchmarks - -| | `NTT-Based` | `Brute-Force` | -|:------------|:--------------------------|:---------------------------------- | -| **`64`** | `1.18 ms` (✅ **1.00x**) | `48.62 us` (🚀 **24.21x faster**) | -| **`128`** | `2.30 ms` (✅ **1.00x**) | `198.30 us` (🚀 **11.59x faster**) | -| **`256`** | `3.54 ms` (✅ **1.00x**) | `766.71 us` (🚀 **4.62x faster**) | -| **`512`** | `6.50 ms` (✅ **1.00x**) | `3.11 ms` (🚀 **2.09x faster**) | -| **`1024`** | `12.43 ms` (✅ **1.00x**) | `12.34 ms` (✅ **1.01x faster**) | -| **`2048`** | `24.68 ms` (✅ **1.00x**) | `49.90 ms` (❌ *2.02x slower*) | -| **`4096`** | `51.36 ms` (✅ **1.00x**) | `200.91 ms` (❌ *3.91x slower*) | -| **`8192`** | `106.21 ms` (✅ **1.00x**) | `803.87 ms` (❌ *7.57x slower*) | -| **`16384`** | `226.19 ms` (✅ **1.00x**) | `3.24 s` (❌ *14.31x slower*) | -| **`32768`** | `467.75 ms` (✅ **1.00x**) | `12.75 s` (❌ *27.25x slower*) | - ---- -Made with [criterion-table](https://github.com/nu11ptr/criterion-table) - diff --git a/Cargo.toml b/Cargo.toml index 2fcbd60..1bf54c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,9 @@ name = "fast-ntt" version = "0.1.0" edition = "2021" +license = "MIT OR Apache-2.0" +keywords = ["ntt", "number-theoretic-transform", "fft"] +categories = ["cryptography", "data-structures"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/benches/BENCHMARKS.md b/benches/BENCHMARKS.md new file mode 100644 index 0000000..e69de29 diff --git a/src/numbers.rs b/src/numbers.rs index 859c90a..6738a38 100644 --- a/src/numbers.rs +++ b/src/numbers.rs @@ -7,8 +7,13 @@ use std::{ }, }; -use crypto_bigint::{rand_core::OsRng, Invert, NonZero, Random, RandomMod, Wrapping, U256}; +use crypto_bigint::{ + modular::runtime_mod::{DynResidue, DynResidueParams}, + rand_core::OsRng, + Invert, NonZero, Random, RandomMod, Wrapping, U256, +}; use itertools::Itertools; +use rand::{thread_rng, Rng}; pub enum BigIntType { U16(u16), @@ -20,10 +25,14 @@ pub enum BigIntType { #[derive(Debug, Clone, Copy)] pub struct BigInt { pub v: U256, + params: DynResidueParams<4>, } impl BigInt { pub fn new(_v: BigIntType) -> Self { + let params = DynResidueParams::new(&U256::from_be_hex( + "ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551", + )); Self { v: match _v { BigIntType::U16(x) => U256::from(x), @@ -32,12 +41,14 @@ impl BigInt { BigIntType::U128(x) => U256::from(x), _ => panic!("received invalid `BigIntType`"), }, + params, } } pub fn rem(&self, num: BigInt) -> Self { BigInt { v: self.v.const_rem(&num.v).0, + params: self.params, } } @@ -62,6 +73,7 @@ impl BigInt { pub fn sqrt(&self) -> BigInt { BigInt { v: self.v.sqrt_vartime(), + params: self.params, } } @@ -82,14 +94,16 @@ impl BigInt { } pub fn random() -> BigInt { - BigInt { - v: U256::random(&mut OsRng), - } + let x = rand::thread_rng().gen::(); + BigInt::new(BigIntType::U128(x)) } pub fn reverse(&self) -> BigInt { let mut v = self.v; - BigInt { v } + BigInt { + v, + params: self.params, + } } } @@ -133,9 +147,11 @@ impl Add for BigInt { type Output = BigInt; fn add(self, rhs: Self) -> Self::Output { - Self { - v: (Wrapping(self.v) + Wrapping(rhs.v)).0, - } + let params = self.params; + let x = DynResidue::new(&self.v, params); + let y = DynResidue::new(&rhs.v, params); + let v = U256::from((x + y).retrieve()); + Self { v, params } } } @@ -145,6 +161,7 @@ impl Add for BigInt { fn add(self, rhs: u16) -> Self::Output { Self { v: (Wrapping(self.v) + Wrapping(BigInt::from(rhs).v)).0, + params: self.params, } } } @@ -155,6 +172,7 @@ impl Add for BigInt { fn add(self, rhs: i32) -> Self::Output { Self { v: (Wrapping(self.v) + Wrapping(BigInt::from(rhs).v)).0, + params: self.params, } } } @@ -165,6 +183,7 @@ impl Add for BigInt { fn add(self, rhs: u32) -> Self::Output { Self { v: (-Wrapping(BigInt::from(rhs).v)).0, + params: self.params, } } } @@ -175,6 +194,7 @@ impl Add for BigInt { fn add(self, rhs: u64) -> Self::Output { Self { v: (Wrapping(self.v) + Wrapping(BigInt::from(rhs).v)).0, + params: self.params, } } } @@ -185,6 +205,7 @@ impl Add for BigInt { fn add(self, rhs: u128) -> Self::Output { Self { v: (Wrapping(self.v) + Wrapping(BigInt::from(rhs).v)).0, + params: self.params, } } } @@ -229,9 +250,11 @@ impl Sub for BigInt { type Output = BigInt; fn sub(self, rhs: Self) -> Self::Output { - BigInt { - v: (Wrapping(self.v) - Wrapping(rhs.v)).0, - } + let params = self.params; + let x = DynResidue::new(&self.v, params); + let y = DynResidue::new(&rhs.v, params); + let v = U256::from((x - y).retrieve()); + Self { v, params } } } @@ -325,6 +348,7 @@ impl Neg for BigInt { fn neg(self) -> Self::Output { Self { v: (Wrapping(U256::MAX) - Wrapping(self.v)).0, + params: self.params, } } } @@ -333,9 +357,11 @@ impl Mul for BigInt { type Output = BigInt; fn mul(self, rhs: Self) -> Self::Output { - Self { - v: (Wrapping(self.v) * Wrapping(rhs.v)).0, - } + let params = self.params; + let x = DynResidue::new(&self.v, params); + let y = DynResidue::new(&rhs.v, params); + let v = U256::from((x * y).retrieve()); + Self { v, params } } } @@ -354,6 +380,7 @@ impl Div for BigInt { let half = half + (lower as u128); BigInt { v: (Wrapping(self.v) / NonZero::from(NonZeroU128::new(half).unwrap())).0, + params: self.params, } } } @@ -364,6 +391,7 @@ impl Invert for BigInt { fn invert(&self) -> Self::Output { BigInt { v: self.v.inv_mod(&U256::MAX).0, + params: self.params, } } } @@ -458,7 +486,10 @@ impl BitAnd for BigInt { type Output = BigInt; fn bitand(self, rhs: Self) -> Self::Output { - BigInt { v: self.v & rhs.v } + BigInt { + v: self.v & rhs.v, + params: self.params, + } } } @@ -468,6 +499,7 @@ impl BitAnd for BigInt { fn bitand(self, rhs: u16) -> Self::Output { BigInt { v: self.v & BigInt::from(rhs).v, + params: self.params, } } } @@ -478,6 +510,7 @@ impl BitAnd for BigInt { fn bitand(self, rhs: i32) -> Self::Output { BigInt { v: self.v & BigInt::from(rhs).v, + params: self.params, } } } @@ -488,6 +521,7 @@ impl BitAnd for BigInt { fn bitand(self, rhs: u32) -> Self::Output { BigInt { v: self.v & BigInt::from(rhs).v, + params: self.params, } } } @@ -498,6 +532,7 @@ impl BitAnd for BigInt { fn bitand(self, rhs: u64) -> Self::Output { BigInt { v: self.v & BigInt::from(rhs).v, + params: self.params, } } } @@ -508,6 +543,7 @@ impl BitAnd for BigInt { fn bitand(self, rhs: u128) -> Self::Output { BigInt { v: self.v & BigInt::from(rhs).v, + params: self.params, } } } @@ -516,7 +552,10 @@ impl BitOr for BigInt { type Output = BigInt; fn bitor(self, rhs: Self) -> Self::Output { - BigInt { v: self.v | rhs.v } + BigInt { + v: self.v | rhs.v, + params: self.params, + } } } @@ -526,6 +565,7 @@ impl BitOr for BigInt { fn bitor(self, rhs: u16) -> Self::Output { BigInt { v: self.v | BigInt::from(rhs).v, + params: self.params, } } } @@ -536,6 +576,7 @@ impl BitOr for BigInt { fn bitor(self, rhs: i32) -> Self::Output { BigInt { v: self.v | BigInt::from(rhs).v, + params: self.params, } } } @@ -546,6 +587,7 @@ impl BitOr for BigInt { fn bitor(self, rhs: u32) -> Self::Output { BigInt { v: self.v | BigInt::from(rhs).v, + params: self.params, } } } @@ -556,6 +598,7 @@ impl BitOr for BigInt { fn bitor(self, rhs: u64) -> Self::Output { BigInt { v: self.v | BigInt::from(rhs).v, + params: self.params, } } } @@ -566,6 +609,7 @@ impl BitOr for BigInt { fn bitor(self, rhs: u128) -> Self::Output { BigInt { v: self.v | BigInt::from(rhs).v, + params: self.params, } } } @@ -574,7 +618,10 @@ impl Shl for BigInt { type Output = BigInt; fn shl(self, rhs: usize) -> Self::Output { - BigInt { v: self.v << rhs } + BigInt { + v: self.v << rhs, + params: self.params, + } } } @@ -582,7 +629,10 @@ impl Shr for BigInt { type Output = BigInt; fn shr(self, rhs: usize) -> Self::Output { - BigInt { v: self.v >> rhs } + BigInt { + v: self.v >> rhs, + params: self.params, + } } } @@ -638,6 +688,13 @@ mod tests { }) } + #[test] + fn test_mul() { + let a = BigInt::from(8); + let b = BigInt::from(10); + println!("{}", a * b); + } + #[test] fn test_division() { let a = BigInt::from(8); From 37b065dccefa007537c394511d9d350a78bc054d Mon Sep 17 00:00:00 2001 From: bhargav Date: Wed, 22 Nov 2023 17:09:22 -0800 Subject: [PATCH 2/3] fix: dyn residue fft --- src/ntt.rs | 84 ++++---------- src/numbers.rs | 298 +++++++++++++++++++++++++++---------------------- src/prime.rs | 20 +++- 3 files changed, 201 insertions(+), 201 deletions(-) diff --git a/src/ntt.rs b/src/ntt.rs index 110cdd5..cc3e937 100644 --- a/src/ntt.rs +++ b/src/ntt.rs @@ -1,6 +1,7 @@ use std::ops::Add; use crate::{numbers::BigInt, prime::is_prime}; +use crypto_bigint::Invert; use itertools::Itertools; use rayon::prelude::*; @@ -10,44 +11,10 @@ pub struct Constants { pub w: BigInt, } -fn extended_gcd(a: BigInt, b: BigInt) -> BigInt { - let mut a = a; - let mut b = b; - let n = b; - let ZERO = BigInt::from(0); - let ONE = BigInt::from(1); - - let mut q = ZERO; - let mut r = ONE; - let mut s1 = ONE; - let mut s2 = ZERO; - let mut s3 = ONE; - let mut t1 = ZERO; - let mut t2 = ONE; - let mut t3 = ZERO; - - while r > ZERO { - q = b / a; - r = b - q * a; - s3 = s1 - q * s2; - t3 = t1 - q * t2; - - if r > ZERO { - b = a; - a = r; - s1 = s2; - s2 = s3; - t1 = t2; - t2 = t3; - } - } - (t2 + n).rem(n) -} - fn prime_factors(a: BigInt) -> Vec { let mut ans: Vec = Vec::new(); let mut x = BigInt::from(2); - while x <= a.sqrt() { + while x * x <= a { if a.rem(x) == 0 { ans.push(x); } @@ -57,11 +24,13 @@ fn prime_factors(a: BigInt) -> Vec { } fn is_primitive_root(a: BigInt, deg: BigInt, N: BigInt) -> bool { - a.mod_exp(deg, N) == 1 - && prime_factors(deg) - .iter() - .map(|&x| a.mod_exp(deg / x, N) != 1) - .all(|x| x) + let lhs = a.mod_exp(deg, N); + let lhs = lhs == 1; + let rhs = prime_factors(deg) + .iter() + .map(|&x| a.mod_exp(deg / x, N) != 1) + .all(|x| x); + lhs && rhs } pub fn working_modulus(n: BigInt, M: BigInt) -> Constants { @@ -135,7 +104,7 @@ fn fft(inp: Vec, c: &Constants, w: BigInt) -> Vec { .zip(hi) .enumerate() .for_each(|(idx, (lo, hi))| { - *hi = (*hi).mul_mod(pre[nchunks * idx], MOD); + *hi = (*hi * pre[nchunks * idx]).rem(MOD); let neg = if *lo < *hi { (MOD + *lo) - *hi } else { @@ -159,9 +128,10 @@ pub fn forward(inp: Vec, c: &Constants) -> Vec { } pub fn inverse(inp: Vec, c: &Constants) -> Vec { - let inv = extended_gcd(BigInt::from(inp.len()), BigInt::from(c.N)); - let w = extended_gcd(c.w, c.N); - + let mut inv = BigInt::from(inp.len()); + let _ = inv.set_mod(c.N); + let inv = inv.invert(); + let w = c.w.invert(); let mut res = fft(inp, c, w); res.par_iter_mut().for_each(|x| *x = (inv * (*x)).rem(c.N)); res @@ -173,34 +143,26 @@ mod tests { use rayon::{iter::ParallelIterator, slice::ParallelSliceMut}; use crate::{ - ntt::{extended_gcd, forward, inverse, working_modulus}, + ntt::{forward, inverse, working_modulus}, numbers::BigInt, }; #[test] fn test_forward() { - let n = 1 << rand::thread_rng().gen::() % 8; - let v: Vec = (0..n) - .map(|_| BigInt::from(rand::thread_rng().gen::() % (1 << 6))) - .collect(); - let M = (*v.iter().max().unwrap() << 1) * BigInt::from(n) + 1; + // let n = 1 << rand::thread_rng().gen::() % 8; + // let v: Vec = (0..n) + // .map(|_| BigInt::from(rand::thread_rng().gen::() % (1 << 6))) + // .collect(); + // let M = (*v.iter().max().unwrap() << 1) * BigInt::from(n) + 1; + let n = 8; + let v: Vec = (0..n).map(|x| BigInt::from(x)).collect(); + let M = BigInt::from(n) * BigInt::from(n) + 1; let c = working_modulus(BigInt::from(n), BigInt::from(M)); let forward = forward(v.clone(), &c); let inverse = inverse(forward, &c); v.iter().zip(inverse).for_each(|(&a, b)| assert_eq!(a, b)); } - #[test] - fn test_extended_gcd() { - (2..11).for_each(|x: u64| { - let inv = extended_gcd(BigInt::from(x), BigInt::from(11)); - assert_eq!( - (BigInt::from(x) * inv).rem(BigInt::from(11)), - BigInt::from(1) - ); - }); - } - #[test] fn test_roots_of_unity() { let N = 10; diff --git a/src/numbers.rs b/src/numbers.rs index 6738a38..750dad1 100644 --- a/src/numbers.rs +++ b/src/numbers.rs @@ -1,4 +1,5 @@ use std::{ + cmp::Ordering, fmt::Display, num::NonZeroU128, ops::{ @@ -8,12 +9,14 @@ use std::{ }; use crypto_bigint::{ - modular::runtime_mod::{DynResidue, DynResidueParams}, - rand_core::OsRng, - Invert, NonZero, Random, RandomMod, Wrapping, U256, + modular::{ + runtime_mod::{DynResidue, DynResidueParams}, + Retrieve, + }, + Invert, NonZero, Uint, U128, U256, }; use itertools::Itertools; -use rand::{thread_rng, Rng}; +use rand::{thread_rng, Error, Rng}; pub enum BigIntType { U16(u16), @@ -24,8 +27,7 @@ pub enum BigIntType { #[derive(Debug, Clone, Copy)] pub struct BigInt { - pub v: U256, - params: DynResidueParams<4>, + pub v: DynResidue<4>, } impl BigInt { @@ -35,75 +37,82 @@ impl BigInt { )); Self { v: match _v { - BigIntType::U16(x) => U256::from(x), - BigIntType::U32(x) => U256::from(x), - BigIntType::U64(x) => U256::from(x), - BigIntType::U128(x) => U256::from(x), + BigIntType::U16(x) => DynResidue::new(&U256::from(x), params), + BigIntType::U32(x) => DynResidue::new(&U256::from(x), params), + BigIntType::U64(x) => DynResidue::new(&U256::from(x), params), + BigIntType::U128(x) => DynResidue::new(&U256::from(x), params), _ => panic!("received invalid `BigIntType`"), }, - params, } } - pub fn rem(&self, num: BigInt) -> Self { - BigInt { - v: self.v.const_rem(&num.v).0, - params: self.params, + pub fn set_mod(&mut self, M: BigInt) -> Result<(), String> { + if M.is_even() { + return Err("modulus must be odd".to_string()); } + let params = DynResidueParams::new(&(U256::from(M.v.retrieve()))); + self.v = DynResidue::new(&self.v.retrieve(), params); + Ok(()) + } + + pub fn set_mod_from_residue(&mut self, params: DynResidueParams<4>) { + self.v = DynResidue::new(&self.v.retrieve(), params); + } + + pub fn rem(&self, M: BigInt) -> BigInt { + let mut res = self.clone(); + if res < M { + return res; + } + res.v = DynResidue::new( + &res.v.retrieve().rem(&NonZero::from_uint(M.v.retrieve())), + res.params(), + ); + res + } + + pub fn params(&self) -> DynResidueParams<4> { + *self.v.params() } pub fn mod_exp(&self, exp: BigInt, M: BigInt) -> BigInt { - let mut res: BigInt = if exp & 1 > 0 { + let mut res: BigInt = if !exp.is_even() { self.clone() } else { BigInt::from(1) }; let mut b = self.clone(); let mut e = exp.clone(); + res.set_mod(M); + b.set_mod(M); while e > 0 { e >>= 1; - b = (b * b).rem(M); - if e & 1 > 0 { - res = (res * b).rem(M); + b = b * b; + if M.is_even() { + b = b.rem(M); + } + if !e.is_even() && !e.is_zero() { + res = b * res; + if M.is_even() { + res = res.rem(M); + } } } res } - pub fn sqrt(&self) -> BigInt { - BigInt { - v: self.v.sqrt_vartime(), - params: self.params, - } - } - - pub fn add_mod(&self, rhs: BigInt, M: BigInt) -> BigInt { - (*self + rhs).rem(M) - } - - pub fn mul_mod(&self, rhs: BigInt, M: BigInt) -> BigInt { - (*self * rhs).rem(M) - } - - pub fn sub_mod(&self, rhs: BigInt, M: BigInt) -> BigInt { - if rhs > *self { - M - (rhs - *self).rem(M) - } else { - (*self - rhs).rem(M) - } - } - pub fn random() -> BigInt { let x = rand::thread_rng().gen::(); BigInt::new(BigIntType::U128(x)) } - pub fn reverse(&self) -> BigInt { - let mut v = self.v; - BigInt { - v, - params: self.params, - } + pub fn is_zero(&self) -> bool { + self.v.retrieve().bits() == 0 + } + + pub fn is_even(&self) -> bool { + let is_odd: bool = self.v.retrieve().bit(0).into(); + !is_odd } } @@ -147,11 +156,12 @@ impl Add for BigInt { type Output = BigInt; fn add(self, rhs: Self) -> Self::Output { - let params = self.params; - let x = DynResidue::new(&self.v, params); - let y = DynResidue::new(&rhs.v, params); - let v = U256::from((x + y).retrieve()); - Self { v, params } + if rhs.v.params() != self.v.params() { + let mut rhs = rhs.clone(); + rhs.set_mod_from_residue(self.params()); + return Self { v: self.v + rhs.v }; + } + Self { v: self.v + rhs.v } } } @@ -160,8 +170,7 @@ impl Add for BigInt { fn add(self, rhs: u16) -> Self::Output { Self { - v: (Wrapping(self.v) + Wrapping(BigInt::from(rhs).v)).0, - params: self.params, + v: self.v + BigInt::from(rhs).v, } } } @@ -171,8 +180,7 @@ impl Add for BigInt { fn add(self, rhs: i32) -> Self::Output { Self { - v: (Wrapping(self.v) + Wrapping(BigInt::from(rhs).v)).0, - params: self.params, + v: self.v + BigInt::from(rhs).v, } } } @@ -182,8 +190,7 @@ impl Add for BigInt { fn add(self, rhs: u32) -> Self::Output { Self { - v: (-Wrapping(BigInt::from(rhs).v)).0, - params: self.params, + v: self.v + BigInt::from(rhs).v, } } } @@ -193,8 +200,7 @@ impl Add for BigInt { fn add(self, rhs: u64) -> Self::Output { Self { - v: (Wrapping(self.v) + Wrapping(BigInt::from(rhs).v)).0, - params: self.params, + v: self.v + BigInt::from(rhs).v, } } } @@ -204,8 +210,7 @@ impl Add for BigInt { fn add(self, rhs: u128) -> Self::Output { Self { - v: (Wrapping(self.v) + Wrapping(BigInt::from(rhs).v)).0, - params: self.params, + v: self.v + BigInt::from(rhs).v, } } } @@ -250,11 +255,12 @@ impl Sub for BigInt { type Output = BigInt; fn sub(self, rhs: Self) -> Self::Output { - let params = self.params; - let x = DynResidue::new(&self.v, params); - let y = DynResidue::new(&rhs.v, params); - let v = U256::from((x - y).retrieve()); - Self { v, params } + if rhs.v.params() != self.v.params() { + let mut rhs = rhs.clone(); + rhs.set_mod_from_residue(self.params()); + return Self { v: self.v - rhs.v }; + } + Self { v: self.v - rhs.v } } } @@ -346,10 +352,7 @@ impl Neg for BigInt { type Output = BigInt; fn neg(self) -> Self::Output { - Self { - v: (Wrapping(U256::MAX) - Wrapping(self.v)).0, - params: self.params, - } + Self { v: self.v.neg() } } } @@ -357,11 +360,12 @@ impl Mul for BigInt { type Output = BigInt; fn mul(self, rhs: Self) -> Self::Output { - let params = self.params; - let x = DynResidue::new(&self.v, params); - let y = DynResidue::new(&rhs.v, params); - let v = U256::from((x * y).retrieve()); - Self { v, params } + if rhs.v.params() != self.v.params() { + let mut rhs = rhs.clone(); + rhs.set_mod_from_residue(self.params()); + return Self { v: self.v * rhs.v }; + } + Self { v: self.v * rhs.v } } } @@ -375,12 +379,15 @@ impl Div for BigInt { type Output = BigInt; fn div(self, rhs: Self) -> Self::Output { - let [lower, upper, _, _] = rhs.v.to_words(); - let half = (upper as u128) << 64; - let half = half + (lower as u128); BigInt { - v: (Wrapping(self.v) / NonZero::from(NonZeroU128::new(half).unwrap())).0, - params: self.params, + v: DynResidue::new( + &(self + .v + .retrieve() + .div_rem(&NonZero::from_uint(rhs.v.retrieve())) + .0), + self.params(), + ), } } } @@ -390,8 +397,7 @@ impl Invert for BigInt { fn invert(&self) -> Self::Output { BigInt { - v: self.v.inv_mod(&U256::MAX).0, - params: self.params, + v: self.v.invert().0, } } } @@ -406,79 +412,83 @@ impl Eq for BigInt {} impl Ord for BigInt { fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.v.cmp(&other.v) + let half = self + .params() + .modulus() + .div(NonZero::from(NonZeroU128::new(2).unwrap())); + (self.v - other.v).retrieve().cmp(&half) } } impl PartialEq for BigInt { fn eq(&self, other: &Self) -> bool { - self.v == other.v + self.v.retrieve() == other.v.retrieve() } } impl PartialEq for BigInt { fn eq(&self, other: &u16) -> bool { - self.v == BigInt::from(*other).v + self.v.retrieve() == BigInt::from(*other).v.retrieve() } } impl PartialEq for BigInt { fn eq(&self, other: &i32) -> bool { - self.v == BigInt::from(*other).v + self.v.retrieve() == BigInt::from(*other).v.retrieve() } } impl PartialEq for BigInt { fn eq(&self, other: &u32) -> bool { - self.v == BigInt::from(*other).v + self.v.retrieve() == BigInt::from(*other).v.retrieve() } } impl PartialEq for BigInt { fn eq(&self, other: &u64) -> bool { - self.v == BigInt::from(*other).v + self.v.retrieve() == BigInt::from(*other).v.retrieve() } } impl PartialEq for BigInt { fn eq(&self, other: &u128) -> bool { - self.v == BigInt::from(*other).v + self.v.retrieve() == BigInt::from(*other).v.retrieve() } } impl PartialOrd for BigInt { fn partial_cmp(&self, other: &Self) -> Option { - self.v.partial_cmp(&other.v) + (self.v.retrieve()).partial_cmp(&(other.v.retrieve())) } } impl PartialOrd for BigInt { fn partial_cmp(&self, other: &u16) -> Option { - self.v.partial_cmp(&BigInt::from(*other).v) + (self.v.retrieve()).partial_cmp(&BigInt::from(*other).v.retrieve()) } } impl PartialOrd for BigInt { fn partial_cmp(&self, other: &i32) -> Option { - self.v.partial_cmp(&BigInt::from(*other).v) + (self.v.retrieve()).partial_cmp(&BigInt::from(*other).v.retrieve()) } } impl PartialOrd for BigInt { fn partial_cmp(&self, other: &u32) -> Option { - self.v.partial_cmp(&BigInt::from(*other).v) + (self.v.retrieve()).partial_cmp(&BigInt::from(*other).v.retrieve()) } } impl PartialOrd for BigInt { fn partial_cmp(&self, other: &u64) -> Option { - self.v.partial_cmp(&BigInt::from(*other).v) + (self.v.retrieve()).partial_cmp(&BigInt::from(*other).v.retrieve()) } } impl PartialOrd for BigInt { fn partial_cmp(&self, other: &u128) -> Option { - self.v.partial_cmp(&BigInt::from(*other).v) + (self.v.retrieve()).partial_cmp(&BigInt::from(*other).v.retrieve()) } } @@ -487,8 +497,7 @@ impl BitAnd for BigInt { fn bitand(self, rhs: Self) -> Self::Output { BigInt { - v: self.v & rhs.v, - params: self.params, + v: DynResidue::new(&(self.v.retrieve() & rhs.v.retrieve()), self.params()), } } } @@ -498,8 +507,10 @@ impl BitAnd for BigInt { fn bitand(self, rhs: u16) -> Self::Output { BigInt { - v: self.v & BigInt::from(rhs).v, - params: self.params, + v: DynResidue::new( + &(self.v.retrieve() & BigInt::from(rhs).v.retrieve()), + self.params(), + ), } } } @@ -509,8 +520,10 @@ impl BitAnd for BigInt { fn bitand(self, rhs: i32) -> Self::Output { BigInt { - v: self.v & BigInt::from(rhs).v, - params: self.params, + v: DynResidue::new( + &(self.v.retrieve() & BigInt::from(rhs).v.retrieve()), + self.params(), + ), } } } @@ -520,8 +533,10 @@ impl BitAnd for BigInt { fn bitand(self, rhs: u32) -> Self::Output { BigInt { - v: self.v & BigInt::from(rhs).v, - params: self.params, + v: DynResidue::new( + &(self.v.retrieve() & BigInt::from(rhs).v.retrieve()), + self.params(), + ), } } } @@ -531,8 +546,10 @@ impl BitAnd for BigInt { fn bitand(self, rhs: u64) -> Self::Output { BigInt { - v: self.v & BigInt::from(rhs).v, - params: self.params, + v: DynResidue::new( + &(self.v.retrieve() & BigInt::from(rhs).v.retrieve()), + self.params(), + ), } } } @@ -542,8 +559,10 @@ impl BitAnd for BigInt { fn bitand(self, rhs: u128) -> Self::Output { BigInt { - v: self.v & BigInt::from(rhs).v, - params: self.params, + v: DynResidue::new( + &(self.v.retrieve() & BigInt::from(rhs).v.retrieve()), + self.params(), + ), } } } @@ -553,8 +572,7 @@ impl BitOr for BigInt { fn bitor(self, rhs: Self) -> Self::Output { BigInt { - v: self.v | rhs.v, - params: self.params, + v: DynResidue::new(&(self.v.retrieve() | rhs.v.retrieve()), self.params()), } } } @@ -564,8 +582,10 @@ impl BitOr for BigInt { fn bitor(self, rhs: u16) -> Self::Output { BigInt { - v: self.v | BigInt::from(rhs).v, - params: self.params, + v: DynResidue::new( + &(self.v.retrieve() | BigInt::from(rhs).v.retrieve()), + self.params(), + ), } } } @@ -575,8 +595,10 @@ impl BitOr for BigInt { fn bitor(self, rhs: i32) -> Self::Output { BigInt { - v: self.v | BigInt::from(rhs).v, - params: self.params, + v: DynResidue::new( + &(self.v.retrieve() | BigInt::from(rhs).v.retrieve()), + self.params(), + ), } } } @@ -586,8 +608,10 @@ impl BitOr for BigInt { fn bitor(self, rhs: u32) -> Self::Output { BigInt { - v: self.v | BigInt::from(rhs).v, - params: self.params, + v: DynResidue::new( + &(self.v.retrieve() | BigInt::from(rhs).v.retrieve()), + self.params(), + ), } } } @@ -597,8 +621,10 @@ impl BitOr for BigInt { fn bitor(self, rhs: u64) -> Self::Output { BigInt { - v: self.v | BigInt::from(rhs).v, - params: self.params, + v: DynResidue::new( + &(self.v.retrieve() | BigInt::from(rhs).v.retrieve()), + self.params(), + ), } } } @@ -608,8 +634,10 @@ impl BitOr for BigInt { fn bitor(self, rhs: u128) -> Self::Output { BigInt { - v: self.v | BigInt::from(rhs).v, - params: self.params, + v: DynResidue::new( + &(self.v.retrieve() | BigInt::from(rhs).v.retrieve()), + self.params(), + ), } } } @@ -619,8 +647,7 @@ impl Shl for BigInt { fn shl(self, rhs: usize) -> Self::Output { BigInt { - v: self.v << rhs, - params: self.params, + v: DynResidue::new(&self.v.retrieve().shl_vartime(rhs), self.params()), } } } @@ -630,8 +657,7 @@ impl Shr for BigInt { fn shr(self, rhs: usize) -> Self::Output { BigInt { - v: self.v >> rhs, - params: self.params, + v: DynResidue::new(&self.v.retrieve().shr_vartime(rhs), self.params()), } } } @@ -653,6 +679,7 @@ impl Display for BigInt { // concatenate bytes to string representation let str: String = self .v + .retrieve() .to_words() .iter() .rev() @@ -685,7 +712,7 @@ mod tests { BigInt::from(x).mod_exp(BigInt::from(y), BigInt::from(N)) ); }) - }) + }); } #[test] @@ -695,17 +722,23 @@ mod tests { println!("{}", a * b); } + #[test] + fn test_is_even() { + let a = BigInt::from(1 << 12); + assert!(a.is_even()); + } + #[test] fn test_division() { let a = BigInt::from(8); let b = BigInt::from(10); - println!("{}", a / b); + assert_eq!(a / b, BigInt::from(0)) } #[test] fn test_rem() { let a = BigInt::from(10); - println!("{}", a.rem(BigInt::from(4))); + assert_eq!(a.rem(BigInt::from(4)), BigInt::from(2)); } #[test] @@ -715,11 +748,8 @@ mod tests { } #[test] - fn test_sub_mod() { - let a = BigInt::from(72); - let b = BigInt::from(73); - let N = BigInt::from(1890); - - println!("{}", a.sub_mod(b, N)); + fn test_shr() { + let a = BigInt::from(1); + println!("{}", a >> 1); } } diff --git a/src/prime.rs b/src/prime.rs index f6c2d4f..cf3a2a7 100644 --- a/src/prime.rs +++ b/src/prime.rs @@ -6,19 +6,29 @@ fn miller_test(mut d: BigInt, n: BigInt, x: BigInt) -> bool { let a = BigInt::from(2) + x; let mut x = a.mod_exp(d, n); + match x.set_mod(n) { + Ok(()) => (), + Err(_) => return false, + }; + match d.set_mod(n) { + Ok(()) => (), + Err(_) => return false, + }; if x == one || x == n - one { return true; } - while d != n - one { - x = (x * x).rem(n); + // (d + 1) mod n = 0 + while !(d + one).is_zero() { + // x = x * x mod n + x = x * x; d *= two; if x == one { return false; } - if x == n - one { + if (x + one).is_zero() { return true; } } @@ -27,9 +37,7 @@ fn miller_test(mut d: BigInt, n: BigInt, x: BigInt) -> bool { } pub fn is_prime(num: BigInt) -> bool { - let zero = BigInt::from(0); let one = BigInt::from(1); - let two = BigInt::from(2); if num <= one || num == BigInt::from(4) { return false; } @@ -38,7 +46,7 @@ pub fn is_prime(num: BigInt) -> bool { } let mut d = num - one; - while d.rem(two) == zero { + while d.is_even() && !d.is_zero() { d >>= 1; } From 7bfc92f58e2528031290c96f32b88c4d5affe1d2 Mon Sep 17 00:00:00 2001 From: bhargav Date: Wed, 22 Nov 2023 20:59:31 -0800 Subject: [PATCH 3/3] bench: residue math --- BENCHMARKS.md | 58 +++++++++++++++++++++++++++++++++++++++++++ benches/BENCHMARKS.md | 0 2 files changed, 58 insertions(+) delete mode 100644 benches/BENCHMARKS.md diff --git a/BENCHMARKS.md b/BENCHMARKS.md index e69de29..cf51876 100644 --- a/BENCHMARKS.md +++ b/BENCHMARKS.md @@ -0,0 +1,58 @@ +# Benchmarks + +## Table of Contents + +- [Overview](#overview) +- [Benchmark Results](#benchmark-results) + - [Number-Theoretic Transform Benchmarks](#number-theoretic-transform-benchmarks) + - [Polynomial Multiplication Benchmarks](#polynomial-multiplication-benchmarks) + +## Overview + +This benchmark comparison report shows the difference in performance between parallel, NTT-based and serial, brute-force +polynomial multiplication algorithms. Each row entry in the first table is an n-degree forward NTT and each row entry in the second table represents an n-degree polynomial multiplication. + +Computer Stats: + +``` +CPU(s): 16 +Thread(s) per core: 2 +Core(s) per socket: 8 +Socket(s): 1 +``` + +## Benchmark Results + +### Number-Theoretic Transform Benchmarks + +| | `NTT` | +|:------------|:-------------------------- | +| **`64`** | `187.17 us` (✅ **1.00x**) | +| **`128`** | `231.50 us` (✅ **1.00x**) | +| **`256`** | `333.26 us` (✅ **1.00x**) | +| **`512`** | `623.88 us` (✅ **1.00x**) | +| **`1024`** | `951.62 us` (✅ **1.00x**) | +| **`2048`** | `1.48 ms` (✅ **1.00x**) | +| **`4096`** | `2.78 ms` (✅ **1.00x**) | +| **`8192`** | `5.48 ms` (✅ **1.00x**) | +| **`16384`** | `11.09 ms` (✅ **1.00x**) | +| **`32768`** | `23.08 ms` (✅ **1.00x**) | + +### Polynomial Multiplication Benchmarks + +| | `NTT-Based` | `Brute-Force` | +|:------------|:--------------------------|:---------------------------------- | +| **`64`** | `818.69 us` (✅ **1.00x**) | `494.52 us` (✅ **1.66x faster**) | +| **`128`** | `1.12 ms` (✅ **1.00x**) | `1.93 ms` (❌ *1.72x slower*) | +| **`256`** | `1.74 ms` (✅ **1.00x**) | `7.78 ms` (❌ *4.48x slower*) | +| **`512`** | `2.69 ms` (✅ **1.00x**) | `30.35 ms` (❌ *11.30x slower*) | +| **`1024`** | `4.33 ms` (✅ **1.00x**) | `121.49 ms` (❌ *28.05x slower*) | +| **`2048`** | `7.47 ms` (✅ **1.00x**) | `493.59 ms` (❌ *66.07x slower*) | +| **`4096`** | `14.23 ms` (✅ **1.00x**) | `1.98 s` (❌ *139.11x slower*) | +| **`8192`** | `31.60 ms` (✅ **1.00x**) | `7.88 s` (❌ *249.28x slower*) | +| **`16384`** | `65.51 ms` (✅ **1.00x**) | `31.46 s` (❌ *480.32x slower*) | +| **`32768`** | `141.24 ms` (✅ **1.00x**) | `126.02 s` (❌ *892.30x slower*) | + +--- +Made with [criterion-table](https://github.com/nu11ptr/criterion-table) + diff --git a/benches/BENCHMARKS.md b/benches/BENCHMARKS.md deleted file mode 100644 index e69de29..0000000