From ef40f30c103d90dced6f4d187ebdd81887925cf3 Mon Sep 17 00:00:00 2001 From: bhargav Date: Tue, 12 Dec 2023 15:43:23 -0800 Subject: [PATCH] feat: optimize working modulus calculation --- src/ntt.rs | 33 +++++++++++++++++++++++++-------- src/polynomial.rs | 37 ++++++++++++++++++++++++++++++++++--- 2 files changed, 59 insertions(+), 11 deletions(-) diff --git a/src/ntt.rs b/src/ntt.rs index c93e5d4..915e512 100644 --- a/src/ntt.rs +++ b/src/ntt.rs @@ -23,6 +23,18 @@ fn prime_factors(a: BigInt) -> Vec { ans } +#[cfg(feature = "parallel")] +fn is_primitive_root(a: BigInt, deg: BigInt, N: BigInt) -> bool { + let lhs = a.mod_exp(deg, N); + let lhs = lhs == 1; + let rhs = prime_factors(deg) + .par_iter() + .map(|&x| a.mod_exp(deg / x, N) != 1) + .all(|x| x); + lhs && rhs +} + +#[cfg(not(feature = "parallel"))] fn is_primitive_root(a: BigInt, deg: BigInt, N: BigInt) -> bool { let lhs = a.mod_exp(deg, N); let lhs = lhs == 1; @@ -34,25 +46,30 @@ fn is_primitive_root(a: BigInt, deg: BigInt, N: BigInt) -> bool { } pub fn working_modulus(n: BigInt, M: BigInt) -> Constants { - let mut N = n + 1; - let mut k = BigInt::from(1); - while !(is_prime(N) && N >= M) { - k += 1; - N = k * n + 1; + let ONE = BigInt::from(1); + let mut N = M; + if N >= ONE { + N = N * n + 1; + while !is_prime(N) { + println!("N -- {}", N); + N += n; + } } + println!("{} is prime", N); + let totient = N - ONE; assert!(N >= M); let mut gen = BigInt::from(0); - let ONE = BigInt::from(1); let mut g = BigInt::from(2); while g < N { - if is_primitive_root(g, N - 1, N) { + if is_primitive_root(g, totient, N) { gen = g; break; } g += ONE; } assert!(gen > 0); - let w = gen.mod_exp(k, N); + println!("g/gen -- {} {}", g, gen); + let w = gen.mod_exp(totient / n, N); Constants { N, w } } diff --git a/src/polynomial.rs b/src/polynomial.rs index 6871221..1119406 100644 --- a/src/polynomial.rs +++ b/src/polynomial.rs @@ -89,6 +89,9 @@ impl Polynomial { let n = (self.len() + rhs.len()).next_power_of_two(); let ZERO = BigInt::from(0); + println!("rhs -- {}", self); + println!("lhs -- {}", rhs); + let v1 = vec![ZERO; n - self.len()] .into_iter() .chain(self.coef.into_iter()) @@ -109,9 +112,11 @@ impl Polynomial { let coef = inverse(mul, &c); // n - polynomial degree - 1 let start = n - (v1_deg + v2_deg + 1) - 1; - Polynomial { + let res = Polynomial { coef: coef[start..=(start + v1_deg + v2_deg)].to_vec(), - } + }; + println!("poly -- {}", res); + res } pub fn diff(mut self) -> Self { @@ -134,6 +139,18 @@ impl Polynomial { let start = self.coef.iter().position(|&x| x != 0).unwrap(); self.len() - start - 1 } + + pub fn max(&self) -> BigInt { + let mut ans = self.coef[0]; + + self.coef[1..].iter().for_each(|&x| { + if ans < x { + ans = x; + } + }); + + ans + } } impl Add for Polynomial { @@ -195,7 +212,10 @@ mod tests { use rand::Rng; use super::Polynomial; - use crate::{ntt::working_modulus, numbers::BigInt}; + use crate::{ + ntt::{working_modulus, Constants}, + numbers::BigInt, + }; #[test] fn add() { @@ -241,4 +261,15 @@ mod tests { let da = a.diff(); println!("{}", da); } + + #[test] + fn test_comparator() { + let a = BigInt::from(550338105); + let b = BigInt::from(1); + assert!(a > b); + let p = Polynomial::new(vec![a, b]); + let hi = p.max(); + + assert_eq!(a, hi); + } }