Skip to content

Commit

Permalink
feat: add constants as a parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
0xWOLAND committed Nov 4, 2023
1 parent 11fe6fa commit e49cf05
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 81 deletions.
45 changes: 35 additions & 10 deletions benches/benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,48 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use fast_ntt::{numbers::BigInt, polynomial::Polynomial};
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use fast_ntt::{
ntt::{forward, working_modulus, Constants},
numbers::BigInt,
polynomial::Polynomial,
};
use itertools::Itertools;

fn bench_mul(x: usize, y: usize, k: BigInt) {
let a = Polynomial::new(vec![0; x].iter().map(|_| k.clone()).collect_vec());
let b = Polynomial::new(vec![0; y].iter().map(|_| k.clone()).collect_vec());
let _ = a * b;
fn bench_mul(x: usize, y: usize, c: &Constants) {
let ONE = BigInt::from(1);
let a = Polynomial::new(vec![0; x].iter().map(|_| ONE).collect_vec());
let b = Polynomial::new(vec![0; y].iter().map(|_| ONE).collect_vec());
let _ = a.mul(b, c);
}

fn criterion_benchmark(c: &mut Criterion) {
fn bench_forward(n: usize, c: &Constants) {
let ONE = BigInt::from(1);
let a = Polynomial::new(vec![0; n].iter().map(|_| ONE).collect_vec());
let _ = forward(a.coef, c);
}

fn criterion_forward(c: &mut Criterion) {
let mut group = c.benchmark_group("bench_forward");
let deg = 16;
(1..deg).for_each(|x| {
group.bench_function(BenchmarkId::from_parameter(x), |b| {
let c = working_modulus(BigInt::from(x), BigInt::from(2 * x + 1));
b.iter(|| bench_mul(black_box(1 << x), black_box(1 << x), black_box(&c)))
});
});
}

fn criterion_mul(c: &mut Criterion) {
let mut group = c.benchmark_group("bench_mul");
let deg = 10;
let deg = 16;
(1..deg).for_each(|x| {
group.bench_function(BenchmarkId::from_parameter(x), |b| {
b.iter(|| bench_mul(x, x, BigInt::from(1)))
let N = BigInt::from((2 * x as usize).next_power_of_two());
let M = N << 1 + 1;
let c = working_modulus(N, M);
b.iter(|| bench_mul(black_box(1 << x), black_box(1 << x), black_box(&c)))
});
});
group.finish();
}

criterion_group!(benches, criterion_benchmark);
criterion_group!(benches, criterion_forward);
criterion_main!(benches);
45 changes: 26 additions & 19 deletions src/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use itertools::Itertools;

#[derive(Debug, Clone)]
pub struct Constants {
pub k: BigInt,
pub N: BigInt,
pub w: BigInt,
}
Expand All @@ -20,14 +19,15 @@ fn extended_gcd(a: BigInt, b: BigInt) -> BigInt {
let mut t1 = BigInt::from(0);
let mut t2 = BigInt::from(1);
let mut t3 = BigInt::from(0);
let ZERO = BigInt::from(0);

while r > BigInt::from(0) {
while r > ZERO {
q = b / a;
r = b - q * a;
s3 = s1 - q * s2;
t3 = t1 - q * t2;

if r > BigInt::from(0) {
if r > ZERO {
b = a;
a = r;
s1 = s2;
Expand Down Expand Up @@ -62,22 +62,24 @@ fn is_primitive_root(a: BigInt, deg: BigInt, N: BigInt) -> bool {
pub fn working_modulus(n: BigInt, M: BigInt) -> Constants {
let mut N = n + 1;
let mut k = BigInt::from(1);
while (!is_prime(N)) || N < M {
while !(is_prime(N) && N > M) {
k += 1;
N = k * n + 1;
}
assert!(N > M);
let mut gen = BigInt::from(0);
let ONE = BigInt::from(1);
let mut g = BigInt::from(2);
while g < N {
if is_primitive_root(g, N - 1, N) {
gen = g;
break;
}
g += BigInt::from(1);
g += ONE;
}
assert!(gen > 0);
let w = gen.mod_exp(k, N);
Constants { k, N, w }
Constants { N, w }
}

fn order_reverse(inp: &mut Vec<BigInt>) {
Expand All @@ -100,9 +102,11 @@ fn order_reverse(inp: &mut Vec<BigInt>) {
fn fft(inp: Vec<BigInt>, c: &Constants, w: BigInt) -> Vec<BigInt> {
let mut inp = inp.clone();
let N = inp.len();
let mut pre: Vec<BigInt> = vec![BigInt::from(1); N / 2];
let MOD = BigInt::from(c.N);
let ONE = BigInt::from(1);
let mut pre: Vec<BigInt> = vec![ONE; N / 2];

(1..N / 2).for_each(|i| pre[i] = (pre[i - 1] * w).rem(BigInt::from(c.N)));
(1..N / 2).for_each(|i| pre[i] = (pre[i - 1] * w).rem(MOD));
order_reverse(&mut inp);

let mut len = 2;
Expand All @@ -116,8 +120,8 @@ fn fft(inp: Vec<BigInt>, c: &Constants, w: BigInt) -> Vec<BigInt> {
let l = j + half;
let left = inp[j];
let right = inp[l] * pre[k];
inp[j] = left.add_mod(right, BigInt::from(c.N));
inp[l] = left.sub_mod(right, BigInt::from(c.N));
inp[j] = left.add_mod(right, MOD);
inp[l] = left.sub_mod(right, MOD);
k += pre_step;
})
});
Expand Down Expand Up @@ -151,16 +155,19 @@ mod tests {

#[test]
fn test_forward() {
let n = 1 << rand::thread_rng().gen::<u32>() % 8;
let v: Vec<BigInt> = (0..n)
.map(|_| BigInt::from(rand::thread_rng().gen::<u32>() % (1 << 6)))
.collect();
let M = (*v.iter().max().unwrap() << 1) * BigInt::from(n) + 1;
let c = working_modulus(BigInt::from(n), BigInt::from(M));
let forward = forward(v.clone(), &c);
let inverse = inverse(forward, &c);
v.iter().zip(inverse).for_each(|(&a, b)| assert_eq!(a, b));
(0..100).for_each(|_| {
let n = 1 << rand::thread_rng().gen::<u32>() % 8;
let v: Vec<BigInt> = (0..n)
.map(|_| BigInt::from(rand::thread_rng().gen::<u32>() % (1 << 6)))
.collect();
let M = (*v.iter().max().unwrap() << 1) * BigInt::from(n) + 1;
let c = working_modulus(BigInt::from(n), BigInt::from(M));
let forward = forward(v.clone(), &c);
let inverse = inverse(forward, &c);
v.iter().zip(inverse).for_each(|(&a, b)| assert_eq!(a, b));
})
}

#[test]
fn test_extended_gcd() {
(2..11).for_each(|x: u64| {
Expand Down
98 changes: 48 additions & 50 deletions src/polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,51 @@ pub struct Polynomial {
impl Polynomial {
pub fn new(coef: Vec<BigInt>) -> Self {
let n = coef.len();
let ZERO = BigInt::from(0);

// if is not power of 2
if !(n & (n - 1) == 0) {
let pad = n.next_power_of_two() - n;
return Self {
coef: vec![BigInt::from(0); pad]
coef: vec![ZERO; pad]
.into_iter()
.chain(coef.into_iter())
.collect_vec(),
};
}
Self { coef }
}

pub fn mul(self, rhs: Polynomial, c: &Constants) -> Polynomial {
let v1_deg = self.degree();
let v2_deg = rhs.degree();
let n = (self.len() + rhs.len()).next_power_of_two();

let a_forward = forward(self.coef, &c);
let b_forward = forward(rhs.coef, &c);

let ZERO = BigInt::from(0);

let mut mul = vec![ZERO; n as usize];
a_forward
.iter()
.rev()
.zip_longest(b_forward.iter().rev())
.enumerate()
.for_each(|(i, p)| match p {
Both(&a, &b) => mul[i] = (a * b).rem(c.N),
Left(_) => {}
Right(_) => {}
});
mul.reverse();
let coef = inverse(mul, &c);
let start = coef.iter().position(|&x| x != 0).unwrap();

Polynomial {
coef: coef[start..=(start + v1_deg + v2_deg)].to_vec(),
}
}

pub fn diff(mut self) -> Self {
let N = self.len();
for n in (1..N).rev() {
Expand Down Expand Up @@ -89,53 +121,6 @@ impl Neg for Polynomial {
}
}

impl Mul<Polynomial> for Polynomial {
type Output = Polynomial;

fn mul(self, rhs: Polynomial) -> Self::Output {
let v1_deg = self.degree();
let v2_deg = rhs.degree();
let mut v1 = self.coef;
let mut v2 = rhs.coef;
let n = (v1.len() + v2.len()).next_power_of_two();

v1 = vec![BigInt::from(0); n - v1.len()]
.into_iter()
.chain(v1.into_iter())
.collect();
v2 = vec![BigInt::from(0); n - v2.len()]
.into_iter()
.chain(v2.into_iter())
.collect();

let N = BigInt::from(n);
let M = (*v1.iter().max().unwrap().max(v2.iter().max().unwrap()) << 1) * N + 1;
let c = working_modulus(N, M);

let a_forward = forward(v1, &c);
let b_forward = forward(v2, &c);

let mut mul = vec![BigInt::from(0); n as usize];
a_forward
.iter()
.rev()
.zip_longest(b_forward.iter().rev())
.enumerate()
.for_each(|(i, p)| match p {
Both(&a, &b) => mul[i] = (a * b).rem(c.N),
Left(_) => {}
Right(_) => {}
});
mul.reverse();
let coef = inverse(mul, &c);
let start = coef.iter().position(|&x| x != 0).unwrap();

Polynomial {
coef: coef[start..=(start + v1_deg + v2_deg)].to_vec(),
}
}
}

impl Display for Polynomial {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.coef.iter().map(|&x| write!(f, "{} ", x)).collect()
Expand All @@ -145,7 +130,7 @@ impl Display for Polynomial {
#[cfg(test)]
mod tests {
use super::Polynomial;
use crate::numbers::BigInt;
use crate::{ntt::working_modulus, numbers::BigInt};

#[test]
fn add() {
Expand All @@ -163,7 +148,20 @@ mod tests {
.map(|&x| BigInt::from(x))
.collect(),
);
println!("{}", a * b);

let N = BigInt::from((a.len() + b.len()).next_power_of_two());
let M = (*a
.coef
.iter()
.max()
.unwrap()
.max(b.coef.iter().max().unwrap())
<< 1)
* N
+ 1;
let c = working_modulus(N, M);

println!("{}", a.mul(b, &c));
}

#[test]
Expand Down
4 changes: 2 additions & 2 deletions src/prime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ pub fn is_prime(n: BigInt) -> bool {
s += 1;
}

for _ in 0..5 {
if check_composite(n, BigInt::from(2), d, s) {
for a in 2..7 {
if check_composite(n, BigInt::from(a), d, s) {
return false;
}
}
Expand Down

0 comments on commit e49cf05

Please sign in to comment.