diff --git a/src/ntt.rs b/src/ntt.rs index d34651d..2ddaf60 100644 --- a/src/ntt.rs +++ b/src/ntt.rs @@ -1,4 +1,5 @@ use crate::prime::is_prime; +use core::panic; use mod_exp::mod_exp; use std::mem::swap; @@ -95,10 +96,21 @@ pub fn working_modulus(n: i64, M: i64) -> Constants { } pub fn forward(inp: Vec, c: &Constants) -> Vec { + let mut pre = vec![-1; inp.len().pow(2)]; + (0..inp.len()).for_each(|col| { + (0..=col).for_each(|row| { + if pre[row * inp.len() + col] == -1 { + pre[row * inp.len() + col] = mod_exp(c.w, (row * col) as i64, c.N) as i64; + } + }) + }); + (0..inp.len()) .map(|k| { inp.iter().enumerate().fold(0, |acc, (i, cur)| { - (acc + cur * mod_exp(c.w, (k * i) as i64, c.N) as i64) % c.N as i64 + let row = k.min(i); + let col = k.max(i); + (acc + cur * pre[row * inp.len() + col]) % c.N as i64 }) % c.N as i64 }) .collect() @@ -108,11 +120,22 @@ pub fn inverse(inp: Vec, c: &Constants) -> Vec { let inv = extended_gcd(inp.len() as i64, c.N); let w = extended_gcd(c.w, c.N); + let mut pre = vec![-1; inp.len().pow(2)]; + (0..inp.len()).for_each(|col| { + (0..=col).for_each(|row| { + if pre[row * inp.len() + col] == -1 { + pre[row * inp.len() + col] = mod_exp(w, (row * col) as i64, c.N) as i64; + } + }) + }); + (0..inp.len()) .map(|k| { inv as i64 * inp.iter().enumerate().fold(0, |acc, (i, cur)| { - (acc + cur * mod_exp(w, (k * i) as i64, c.N) as i64) % c.N as i64 + let row = k.min(i); + let col = k.max(i); + (acc + cur * pre[row * inp.len() + col]) % c.N as i64 }) % c.N as i64 })