Skip to content

Commit

Permalink
feat: parallel roots of unity processing
Browse files Browse the repository at this point in the history
  • Loading branch information
0xWOLAND committed Nov 19, 2023
1 parent b25f6ba commit 4d2545d
Showing 1 changed file with 42 additions and 7 deletions.
49 changes: 42 additions & 7 deletions src/ntt.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{numbers::BigInt, prime::is_prime};
use itertools::{iterate, Itertools};
use itertools::Itertools;
use rayon::prelude::*;

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -108,8 +108,17 @@ fn fft(inp: Vec<BigInt>, c: &Constants, w: BigInt) -> Vec<BigInt> {
let MOD = BigInt::from(c.N);
let ONE = BigInt::from(1);
let mut pre: Vec<BigInt> = vec![ONE; N / 2];

(1..N / 2).for_each(|i| pre[i] = (pre[i - 1] * w).rem(MOD));
let CHUNK_COUNT = 128;
let chunk_count = BigInt::from(CHUNK_COUNT);

(1..N / (2 * CHUNK_COUNT))
.for_each(|i| pre[i * CHUNK_COUNT] = w.mod_exp(BigInt::from(i) * chunk_count, MOD));
pre.par_chunks_mut(CHUNK_COUNT).for_each(|x| {
(1..x.len()).for_each(|y| {
let _x = x.to_vec();
x[y] = (w * x[y - 1]).rem(MOD);
})
});
order_reverse(&mut inp);

let mut gap = inp.len() / 2;
Expand Down Expand Up @@ -140,15 +149,15 @@ pub fn inverse(inp: Vec<BigInt>, c: &Constants) -> Vec<BigInt> {
let inv = extended_gcd(BigInt::from(inp.len()), BigInt::from(c.N));
let w = extended_gcd(c.w, c.N);

fft(inp, c, w)
.iter()
.map(|&x| (inv * x).rem(c.N))
.collect_vec()
let mut res = fft(inp, c, w);
res.par_iter_mut().for_each(|x| *x = (inv * (*x)).rem(c.N));
res
}

#[cfg(test)]
mod tests {
use rand::Rng;
use rayon::{iter::ParallelIterator, slice::ParallelSliceMut};

use crate::{
ntt::{extended_gcd, forward, inverse, working_modulus},
Expand Down Expand Up @@ -180,4 +189,30 @@ mod tests {
);
});
}

#[test]
fn test_roots_of_unity() {
let N = 10;
let ONE = BigInt::from(1);
let mut pre: Vec<BigInt> = vec![ONE; N / 2];
let mut pre2 = pre.clone();
let CHUNK_COUNT = 128;
let MOD = BigInt::from(10);
let chunk_count = BigInt::from(CHUNK_COUNT);
let w = BigInt::from(2);

(1..N / 2).for_each(|i| pre[i] = (pre[i - 1] * w).rem(MOD));

(1..N / (2 * CHUNK_COUNT))
.for_each(|i| pre2[i * CHUNK_COUNT] = w.mod_exp(BigInt::from(i) * chunk_count, MOD));
pre2.par_chunks_mut(CHUNK_COUNT).for_each(|x| {
(1..x.len()).for_each(|y| {
let _x = x.to_vec();
x[y] = (w * x[y - 1]).rem(MOD);
})
});
(0..N / 2).for_each(|i| {
assert_eq!(pre[i], pre2[i]);
})
}
}

0 comments on commit 4d2545d

Please sign in to comment.