From b25f6ba9eeea866a3c8ff937862f7b8b24308002 Mon Sep 17 00:00:00 2001 From: bhargav Date: Fri, 17 Nov 2023 02:16:11 -0800 Subject: [PATCH] feat: parallel cooley-tukey --- Cargo.lock | 1 + Cargo.toml | 1 + src/ntt.rs | 52 ++++++++++++++++++++------------------- src/numbers.rs | 66 ++++++++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 93 insertions(+), 27 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ba617fa..b9e92c9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -210,6 +210,7 @@ dependencies = [ "itertools 0.11.0", "mod_exp", "rand", + "rayon", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 8e88ece..2fcbd60 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ hex = "0.4.3" itertools = "0.11.0" mod_exp = "1.0.1" rand = "0.8.5" +rayon = "1.8.0" [dev-dependencies] criterion = { version = "0.5.1", features = ["html_reports"] } diff --git a/src/ntt.rs b/src/ntt.rs index 0366329..6a0d9ab 100644 --- a/src/ntt.rs +++ b/src/ntt.rs @@ -1,5 +1,6 @@ use crate::{numbers::BigInt, prime::is_prime}; -use itertools::Itertools; +use itertools::{iterate, Itertools}; +use rayon::prelude::*; #[derive(Debug, Clone)] pub struct Constants { @@ -11,15 +12,17 @@ fn extended_gcd(a: BigInt, b: BigInt) -> BigInt { let mut a = a; let mut b = b; let n = b; - let mut q = BigInt::from(0); - let mut r = BigInt::from(1); - let mut s1 = BigInt::from(1); - let mut s2 = BigInt::from(0); - let mut s3 = BigInt::from(1); - let mut t1 = BigInt::from(0); - let mut t2 = BigInt::from(1); - let mut t3 = BigInt::from(0); let ZERO = BigInt::from(0); + let ONE = BigInt::from(1); + + let mut q = ZERO; + let mut r = ONE; + let mut s1 = ONE; + let mut s2 = ZERO; + let mut s3 = ONE; + let mut t1 = ZERO; + let mut t2 = ONE; + let mut t3 = ZERO; while r > ZERO { q = b / a; @@ -109,23 +112,22 @@ fn fft(inp: Vec, c: &Constants, w: BigInt) -> Vec { (1..N / 2).for_each(|i| pre[i] = (pre[i - 1] * w).rem(MOD)); order_reverse(&mut inp); - let mut len = 2; - - while len <= N { - let half = len / 2; - let pre_step = N / len; - (0..N).step_by(len).for_each(|i| { - let mut k = 0; - (i..i + half).for_each(|j| { - let l = j + half; - let left = inp[j]; - let right = inp[l] * pre[k]; - inp[j] = left.add_mod(right, MOD); - inp[l] = left.sub_mod(right, MOD); - k += pre_step; - }) + let mut gap = inp.len() / 2; + + while gap > 0 { + let nchunks = inp.len() / (2 * gap); + inp.par_chunks_mut(2 * gap).for_each(|cxi| { + let (lo, hi) = cxi.split_at_mut(gap); + lo.par_iter_mut() + .zip(hi) + .enumerate() + .for_each(|(idx, (lo, hi))| { + let neg = (*lo).sub_mod(*hi, MOD); + *lo = (*lo).add_mod(*hi, MOD); + *hi = neg.mul_mod(pre[nchunks * idx], MOD); + }); }); - len <<= 1; + gap /= 2; } inp } diff --git a/src/numbers.rs b/src/numbers.rs index de002fd..770cced 100644 --- a/src/numbers.rs +++ b/src/numbers.rs @@ -2,7 +2,7 @@ use std::{ fmt::Display, num::NonZeroU128, ops::{ - Add, AddAssign, BitAnd, Div, DivAssign, Mul, MulAssign, Neg, Shl, ShlAssign, Shr, + Add, AddAssign, BitAnd, BitOr, Div, DivAssign, Mul, MulAssign, Neg, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, }, }; @@ -19,7 +19,7 @@ pub enum BigIntType { #[derive(Debug, Clone, Copy)] pub struct BigInt { - v: U256, + pub v: U256, } impl BigInt { @@ -69,6 +69,10 @@ impl BigInt { (*self + rhs).rem(M) } + pub fn mul_mod(&self, rhs: BigInt, M: BigInt) -> BigInt { + (*self * rhs).rem(M) + } + pub fn sub_mod(&self, rhs: BigInt, M: BigInt) -> BigInt { if rhs > *self { M - (rhs - *self).rem(M) @@ -503,6 +507,64 @@ impl BitAnd for BigInt { } } +impl BitOr for BigInt { + type Output = BigInt; + + fn bitor(self, rhs: Self) -> Self::Output { + BigInt { v: self.v | rhs.v } + } +} + +impl BitOr for BigInt { + type Output = BigInt; + + fn bitor(self, rhs: u16) -> Self::Output { + BigInt { + v: self.v | BigInt::from(rhs).v, + } + } +} + +impl BitOr for BigInt { + type Output = BigInt; + + fn bitor(self, rhs: i32) -> Self::Output { + BigInt { + v: self.v | BigInt::from(rhs).v, + } + } +} + +impl BitOr for BigInt { + type Output = BigInt; + + fn bitor(self, rhs: u32) -> Self::Output { + BigInt { + v: self.v | BigInt::from(rhs).v, + } + } +} + +impl BitOr for BigInt { + type Output = BigInt; + + fn bitor(self, rhs: u64) -> Self::Output { + BigInt { + v: self.v | BigInt::from(rhs).v, + } + } +} + +impl BitOr for BigInt { + type Output = BigInt; + + fn bitor(self, rhs: u128) -> Self::Output { + BigInt { + v: self.v | BigInt::from(rhs).v, + } + } +} + impl Shl for BigInt { type Output = BigInt;