Skip to content

Commit

Permalink
feat: parallel cooley-tukey
Browse files Browse the repository at this point in the history
  • Loading branch information
0xWOLAND committed Nov 17, 2023
1 parent b2bf517 commit b25f6ba
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 27 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ hex = "0.4.3"
itertools = "0.11.0"
mod_exp = "1.0.1"
rand = "0.8.5"
rayon = "1.8.0"

[dev-dependencies]
criterion = { version = "0.5.1", features = ["html_reports"] }
Expand Down
52 changes: 27 additions & 25 deletions src/ntt.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::{numbers::BigInt, prime::is_prime};
use itertools::Itertools;
use itertools::{iterate, Itertools};
use rayon::prelude::*;

#[derive(Debug, Clone)]
pub struct Constants {
Expand All @@ -11,15 +12,17 @@ fn extended_gcd(a: BigInt, b: BigInt) -> BigInt {
let mut a = a;
let mut b = b;
let n = b;
let mut q = BigInt::from(0);
let mut r = BigInt::from(1);
let mut s1 = BigInt::from(1);
let mut s2 = BigInt::from(0);
let mut s3 = BigInt::from(1);
let mut t1 = BigInt::from(0);
let mut t2 = BigInt::from(1);
let mut t3 = BigInt::from(0);
let ZERO = BigInt::from(0);
let ONE = BigInt::from(1);

let mut q = ZERO;
let mut r = ONE;
let mut s1 = ONE;
let mut s2 = ZERO;
let mut s3 = ONE;
let mut t1 = ZERO;
let mut t2 = ONE;
let mut t3 = ZERO;

while r > ZERO {
q = b / a;
Expand Down Expand Up @@ -109,23 +112,22 @@ fn fft(inp: Vec<BigInt>, c: &Constants, w: BigInt) -> Vec<BigInt> {
(1..N / 2).for_each(|i| pre[i] = (pre[i - 1] * w).rem(MOD));
order_reverse(&mut inp);

let mut len = 2;

while len <= N {
let half = len / 2;
let pre_step = N / len;
(0..N).step_by(len).for_each(|i| {
let mut k = 0;
(i..i + half).for_each(|j| {
let l = j + half;
let left = inp[j];
let right = inp[l] * pre[k];
inp[j] = left.add_mod(right, MOD);
inp[l] = left.sub_mod(right, MOD);
k += pre_step;
})
let mut gap = inp.len() / 2;

while gap > 0 {
let nchunks = inp.len() / (2 * gap);
inp.par_chunks_mut(2 * gap).for_each(|cxi| {
let (lo, hi) = cxi.split_at_mut(gap);
lo.par_iter_mut()
.zip(hi)
.enumerate()
.for_each(|(idx, (lo, hi))| {
let neg = (*lo).sub_mod(*hi, MOD);
*lo = (*lo).add_mod(*hi, MOD);
*hi = neg.mul_mod(pre[nchunks * idx], MOD);
});
});
len <<= 1;
gap /= 2;
}
inp
}
Expand Down
66 changes: 64 additions & 2 deletions src/numbers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{
fmt::Display,
num::NonZeroU128,
ops::{
Add, AddAssign, BitAnd, Div, DivAssign, Mul, MulAssign, Neg, Shl, ShlAssign, Shr,
Add, AddAssign, BitAnd, BitOr, Div, DivAssign, Mul, MulAssign, Neg, Shl, ShlAssign, Shr,
ShrAssign, Sub, SubAssign,
},
};
Expand All @@ -19,7 +19,7 @@ pub enum BigIntType {

#[derive(Debug, Clone, Copy)]
pub struct BigInt {
v: U256,
pub v: U256,
}

impl BigInt {
Expand Down Expand Up @@ -69,6 +69,10 @@ impl BigInt {
(*self + rhs).rem(M)
}

pub fn mul_mod(&self, rhs: BigInt, M: BigInt) -> BigInt {
(*self * rhs).rem(M)
}

pub fn sub_mod(&self, rhs: BigInt, M: BigInt) -> BigInt {
if rhs > *self {
M - (rhs - *self).rem(M)
Expand Down Expand Up @@ -503,6 +507,64 @@ impl BitAnd<u128> for BigInt {
}
}

impl BitOr for BigInt {
type Output = BigInt;

fn bitor(self, rhs: Self) -> Self::Output {
BigInt { v: self.v | rhs.v }
}
}

impl BitOr<u16> for BigInt {
type Output = BigInt;

fn bitor(self, rhs: u16) -> Self::Output {
BigInt {
v: self.v | BigInt::from(rhs).v,
}
}
}

impl BitOr<i32> for BigInt {
type Output = BigInt;

fn bitor(self, rhs: i32) -> Self::Output {
BigInt {
v: self.v | BigInt::from(rhs).v,
}
}
}

impl BitOr<u32> for BigInt {
type Output = BigInt;

fn bitor(self, rhs: u32) -> Self::Output {
BigInt {
v: self.v | BigInt::from(rhs).v,
}
}
}

impl BitOr<u64> for BigInt {
type Output = BigInt;

fn bitor(self, rhs: u64) -> Self::Output {
BigInt {
v: self.v | BigInt::from(rhs).v,
}
}
}

impl BitOr<u128> for BigInt {
type Output = BigInt;

fn bitor(self, rhs: u128) -> Self::Output {
BigInt {
v: self.v | BigInt::from(rhs).v,
}
}
}

impl Shl<usize> for BigInt {
type Output = BigInt;

Expand Down

0 comments on commit b25f6ba

Please sign in to comment.