Skip to content

Commit

Permalink
Merge pull request #2 from 0xWOLAND/dyn-residue
Browse files Browse the repository at this point in the history
Dynamic Residue
  • Loading branch information
0xWOLAND authored Nov 23, 2023
2 parents 50e4a0a + 7bfc92f commit 3189108
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 182 deletions.
40 changes: 20 additions & 20 deletions BENCHMARKS.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,31 @@ Socket(s): 1

| | `NTT` |
|:------------|:-------------------------- |
| **`64`** | `202.26 us` (✅ **1.00x**) |
| **`128`** | `354.08 us` (✅ **1.00x**) |
| **`256`** | `665.54 us` (✅ **1.00x**) |
| **`512`** | `1.12 ms` (✅ **1.00x**) |
| **`1024`** | `2.00 ms` (✅ **1.00x**) |
| **`2048`** | `3.94 ms` (✅ **1.00x**) |
| **`4096`** | `7.69 ms` (✅ **1.00x**) |
| **`8192`** | `16.13 ms` (✅ **1.00x**) |
| **`16384`** | `34.01 ms` (✅ **1.00x**) |
| **`32768`** | `74.65 ms` (✅ **1.00x**) |
| **`64`** | `187.17 us` (✅ **1.00x**) |
| **`128`** | `231.50 us` (✅ **1.00x**) |
| **`256`** | `333.26 us` (✅ **1.00x**) |
| **`512`** | `623.88 us` (✅ **1.00x**) |
| **`1024`** | `951.62 us` (✅ **1.00x**) |
| **`2048`** | `1.48 ms` (✅ **1.00x**) |
| **`4096`** | `2.78 ms` (✅ **1.00x**) |
| **`8192`** | `5.48 ms` (✅ **1.00x**) |
| **`16384`** | `11.09 ms` (✅ **1.00x**) |
| **`32768`** | `23.08 ms` (✅ **1.00x**) |

### Polynomial Multiplication Benchmarks

| | `NTT-Based` | `Brute-Force` |
|:------------|:--------------------------|:---------------------------------- |
| **`64`** | `1.18 ms` (✅ **1.00x**) | `48.62 us` (🚀 **24.21x faster**) |
| **`128`** | `2.30 ms` (✅ **1.00x**) | `198.30 us` (🚀 **11.59x faster**) |
| **`256`** | `3.54 ms` (✅ **1.00x**) | `766.71 us` (🚀 **4.62x faster**) |
| **`512`** | `6.50 ms` (✅ **1.00x**) | `3.11 ms` (🚀 **2.09x faster**) |
| **`1024`** | `12.43 ms` (✅ **1.00x**) | `12.34 ms` (**1.01x faster**) |
| **`2048`** | `24.68 ms` (✅ **1.00x**) | `49.90 ms` (❌ *2.02x slower*) |
| **`4096`** | `51.36 ms` (✅ **1.00x**) | `200.91 ms` (❌ *3.91x slower*) |
| **`8192`** | `106.21 ms` (✅ **1.00x**) | `803.87 ms` (❌ *7.57x slower*) |
| **`16384`** | `226.19 ms` (✅ **1.00x**) | `3.24 s` (❌ *14.31x slower*) |
| **`32768`** | `467.75 ms` (✅ **1.00x**) | `12.75 s` (❌ *27.25x slower*) |
| **`64`** | `818.69 us` (✅ **1.00x**) | `494.52 us` ( **1.66x faster**) |
| **`128`** | `1.12 ms` (✅ **1.00x**) | `1.93 ms` (*1.72x slower*) |
| **`256`** | `1.74 ms` (✅ **1.00x**) | `7.78 ms` (*4.48x slower*) |
| **`512`** | `2.69 ms` (✅ **1.00x**) | `30.35 ms` (*11.30x slower*) |
| **`1024`** | `4.33 ms` (✅ **1.00x**) | `121.49 ms` (*28.05x slower*) |
| **`2048`** | `7.47 ms` (✅ **1.00x**) | `493.59 ms` (❌ *66.07x slower*) |
| **`4096`** | `14.23 ms` (✅ **1.00x**) | `1.98 s` (❌ *139.11x slower*) |
| **`8192`** | `31.60 ms` (✅ **1.00x**) | `7.88 s` (❌ *249.28x slower*) |
| **`16384`** | `65.51 ms` (✅ **1.00x**) | `31.46 s` (❌ *480.32x slower*) |
| **`32768`** | `141.24 ms` (✅ **1.00x**) | `126.02 s` (❌ *892.30x slower*) |

---
Made with [criterion-table](https://github.com/nu11ptr/criterion-table)
Expand Down
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
name = "fast-ntt"
version = "0.1.0"
edition = "2021"
license = "MIT OR Apache-2.0"
keywords = ["ntt", "number-theoretic-transform", "fft"]
categories = ["cryptography", "data-structures"]

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand Down
84 changes: 23 additions & 61 deletions src/ntt.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::ops::Add;

use crate::{numbers::BigInt, prime::is_prime};
use crypto_bigint::Invert;
use itertools::Itertools;
use rayon::prelude::*;

Expand All @@ -10,44 +11,10 @@ pub struct Constants {
pub w: BigInt,
}

fn extended_gcd(a: BigInt, b: BigInt) -> BigInt {
let mut a = a;
let mut b = b;
let n = b;
let ZERO = BigInt::from(0);
let ONE = BigInt::from(1);

let mut q = ZERO;
let mut r = ONE;
let mut s1 = ONE;
let mut s2 = ZERO;
let mut s3 = ONE;
let mut t1 = ZERO;
let mut t2 = ONE;
let mut t3 = ZERO;

while r > ZERO {
q = b / a;
r = b - q * a;
s3 = s1 - q * s2;
t3 = t1 - q * t2;

if r > ZERO {
b = a;
a = r;
s1 = s2;
s2 = s3;
t1 = t2;
t2 = t3;
}
}
(t2 + n).rem(n)
}

fn prime_factors(a: BigInt) -> Vec<BigInt> {
let mut ans: Vec<BigInt> = Vec::new();
let mut x = BigInt::from(2);
while x <= a.sqrt() {
while x * x <= a {
if a.rem(x) == 0 {
ans.push(x);
}
Expand All @@ -57,11 +24,13 @@ fn prime_factors(a: BigInt) -> Vec<BigInt> {
}

fn is_primitive_root(a: BigInt, deg: BigInt, N: BigInt) -> bool {
a.mod_exp(deg, N) == 1
&& prime_factors(deg)
.iter()
.map(|&x| a.mod_exp(deg / x, N) != 1)
.all(|x| x)
let lhs = a.mod_exp(deg, N);
let lhs = lhs == 1;
let rhs = prime_factors(deg)
.iter()
.map(|&x| a.mod_exp(deg / x, N) != 1)
.all(|x| x);
lhs && rhs
}

pub fn working_modulus(n: BigInt, M: BigInt) -> Constants {
Expand Down Expand Up @@ -135,7 +104,7 @@ fn fft(inp: Vec<BigInt>, c: &Constants, w: BigInt) -> Vec<BigInt> {
.zip(hi)
.enumerate()
.for_each(|(idx, (lo, hi))| {
*hi = (*hi).mul_mod(pre[nchunks * idx], MOD);
*hi = (*hi * pre[nchunks * idx]).rem(MOD);
let neg = if *lo < *hi {
(MOD + *lo) - *hi
} else {
Expand All @@ -159,9 +128,10 @@ pub fn forward(inp: Vec<BigInt>, c: &Constants) -> Vec<BigInt> {
}

pub fn inverse(inp: Vec<BigInt>, c: &Constants) -> Vec<BigInt> {
let inv = extended_gcd(BigInt::from(inp.len()), BigInt::from(c.N));
let w = extended_gcd(c.w, c.N);

let mut inv = BigInt::from(inp.len());
let _ = inv.set_mod(c.N);
let inv = inv.invert();
let w = c.w.invert();
let mut res = fft(inp, c, w);
res.par_iter_mut().for_each(|x| *x = (inv * (*x)).rem(c.N));
res
Expand All @@ -173,34 +143,26 @@ mod tests {
use rayon::{iter::ParallelIterator, slice::ParallelSliceMut};

use crate::{
ntt::{extended_gcd, forward, inverse, working_modulus},
ntt::{forward, inverse, working_modulus},
numbers::BigInt,
};

#[test]
fn test_forward() {
let n = 1 << rand::thread_rng().gen::<u32>() % 8;
let v: Vec<BigInt> = (0..n)
.map(|_| BigInt::from(rand::thread_rng().gen::<u32>() % (1 << 6)))
.collect();
let M = (*v.iter().max().unwrap() << 1) * BigInt::from(n) + 1;
// let n = 1 << rand::thread_rng().gen::<u32>() % 8;
// let v: Vec<BigInt> = (0..n)
// .map(|_| BigInt::from(rand::thread_rng().gen::<u32>() % (1 << 6)))
// .collect();
// let M = (*v.iter().max().unwrap() << 1) * BigInt::from(n) + 1;
let n = 8;
let v: Vec<BigInt> = (0..n).map(|x| BigInt::from(x)).collect();
let M = BigInt::from(n) * BigInt::from(n) + 1;
let c = working_modulus(BigInt::from(n), BigInt::from(M));
let forward = forward(v.clone(), &c);
let inverse = inverse(forward, &c);
v.iter().zip(inverse).for_each(|(&a, b)| assert_eq!(a, b));
}

#[test]
fn test_extended_gcd() {
(2..11).for_each(|x: u64| {
let inv = extended_gcd(BigInt::from(x), BigInt::from(11));
assert_eq!(
(BigInt::from(x) * inv).rem(BigInt::from(11)),
BigInt::from(1)
);
});
}

#[test]
fn test_roots_of_unity() {
let N = 10;
Expand Down
Loading

0 comments on commit 3189108

Please sign in to comment.