From 4d2545d876a2b53106f5183c090151976b582a09 Mon Sep 17 00:00:00 2001 From: bhargav Date: Sat, 18 Nov 2023 18:43:18 -0800 Subject: [PATCH] feat: parallel roots of unity processing --- src/ntt.rs | 49 ++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 42 insertions(+), 7 deletions(-) diff --git a/src/ntt.rs b/src/ntt.rs index 6a0d9ab..86ae043 100644 --- a/src/ntt.rs +++ b/src/ntt.rs @@ -1,5 +1,5 @@ use crate::{numbers::BigInt, prime::is_prime}; -use itertools::{iterate, Itertools}; +use itertools::Itertools; use rayon::prelude::*; #[derive(Debug, Clone)] @@ -108,8 +108,17 @@ fn fft(inp: Vec, c: &Constants, w: BigInt) -> Vec { let MOD = BigInt::from(c.N); let ONE = BigInt::from(1); let mut pre: Vec = vec![ONE; N / 2]; - - (1..N / 2).for_each(|i| pre[i] = (pre[i - 1] * w).rem(MOD)); + let CHUNK_COUNT = 128; + let chunk_count = BigInt::from(CHUNK_COUNT); + + (1..N / (2 * CHUNK_COUNT)) + .for_each(|i| pre[i * CHUNK_COUNT] = w.mod_exp(BigInt::from(i) * chunk_count, MOD)); + pre.par_chunks_mut(CHUNK_COUNT).for_each(|x| { + (1..x.len()).for_each(|y| { + let _x = x.to_vec(); + x[y] = (w * x[y - 1]).rem(MOD); + }) + }); order_reverse(&mut inp); let mut gap = inp.len() / 2; @@ -140,15 +149,15 @@ pub fn inverse(inp: Vec, c: &Constants) -> Vec { let inv = extended_gcd(BigInt::from(inp.len()), BigInt::from(c.N)); let w = extended_gcd(c.w, c.N); - fft(inp, c, w) - .iter() - .map(|&x| (inv * x).rem(c.N)) - .collect_vec() + let mut res = fft(inp, c, w); + res.par_iter_mut().for_each(|x| *x = (inv * (*x)).rem(c.N)); + res } #[cfg(test)] mod tests { use rand::Rng; + use rayon::{iter::ParallelIterator, slice::ParallelSliceMut}; use crate::{ ntt::{extended_gcd, forward, inverse, working_modulus}, @@ -180,4 +189,30 @@ mod tests { ); }); } + + #[test] + fn test_roots_of_unity() { + let N = 10; + let ONE = BigInt::from(1); + let mut pre: Vec = vec![ONE; N / 2]; + let mut pre2 = pre.clone(); + let CHUNK_COUNT = 128; + let MOD = BigInt::from(10); + let chunk_count = BigInt::from(CHUNK_COUNT); + let w = BigInt::from(2); + + (1..N / 2).for_each(|i| pre[i] = (pre[i - 1] * w).rem(MOD)); + + (1..N / (2 * CHUNK_COUNT)) + .for_each(|i| pre2[i * CHUNK_COUNT] = w.mod_exp(BigInt::from(i) * chunk_count, MOD)); + pre2.par_chunks_mut(CHUNK_COUNT).for_each(|x| { + (1..x.len()).for_each(|y| { + let _x = x.to_vec(); + x[y] = (w * x[y - 1]).rem(MOD); + }) + }); + (0..N / 2).for_each(|i| { + assert_eq!(pre[i], pre2[i]); + }) + } }