diff --git a/src/fields/u64/fp/fiat.rs b/src/fields/u64/fp/fiat.rs index ce12d08..862111f 100644 --- a/src/fields/u64/fp/fiat.rs +++ b/src/fields/u64/fp/fiat.rs @@ -3356,17 +3356,17 @@ pub fn fp_divstep_precomp(out1: &mut [u64; 6]) { /// out4: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]] /// out5: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]] pub fn fp_divstep( - out1: &mut u64, - out2: &mut [u64; 7], - out3: &mut [u64; 7], - out4: &mut [u64; 6], - out5: &mut [u64; 6], arg1: u64, arg2: &[u64; 7], arg3: &[u64; 7], arg4: &[u64; 6], arg5: &[u64; 6], -) { +) -> (u64, [u64; 7], [u64; 7], [u64; 6], [u64; 6]) { + let mut out1: u64 = 0; + let mut out2: [u64; 7] = [0; 7]; + let mut out3: [u64; 7] = [0; 7]; + let mut out4: [u64; 6] = [0; 6]; + let mut out5: [u64; 6] = [0; 6]; let mut x1: u64 = 0; let mut x2: FpU1 = 0; fp_addcarryx_u64(&mut x1, &mut x2, 0x0, (!arg1), (0x1 as u64)); @@ -3653,7 +3653,7 @@ pub fn fp_divstep( fp_cmovznz_u64(&mut x177, x157, x152, x140); let mut x178: u64 = 0; fp_cmovznz_u64(&mut x178, x157, x154, x142); - *out1 = x158; + out1 = x158; out2[0] = x7; out2[1] = x8; out2[2] = x9; @@ -3680,4 +3680,5 @@ pub fn fp_divstep( out5[3] = x176; out5[4] = x177; out5[5] = x178; -} + (out1, out2, out3, out4, out5) +} \ No newline at end of file diff --git a/src/fields/u64/fp/wrapper.rs b/src/fields/u64/fp/wrapper.rs index 41b39c5..15ea5c0 100644 --- a/src/fields/u64/fp/wrapper.rs +++ b/src/fields/u64/fp/wrapper.rs @@ -1,5 +1,5 @@ -use core::ops::{Add, Mul, Neg, Sub}; +use core::ops::{Add, Mul, Neg, Sub}; use super::fiat; #[derive(Copy, Clone)] @@ -58,6 +58,59 @@ impl Fp { fiat::fp_square(&mut result, &self.0); Self(result) } + + fn div(&self) -> Fp { + const LEN_PRIME: usize = 377; + const ITERATIONS: usize = (49 * LEN_PRIME + 57) / 17; + + let mut a = fiat::FpNonMontgomeryDomainFieldElement([0; 6]); + fiat::fp_from_montgomery(&mut a, &self.0); + let mut d = 1; + let mut f: [u64; 7] = [0u64; 7]; + fiat::fp_msat(&mut f); + let mut g = [0u64; 7]; + let mut v = [0u64; 6]; + let mut r: [u64; 6] = Fp::one().0.0; + let mut i = 0; + let mut j = 0; + + while j < 6 { + g[j] = a[j]; + j += 1; + } + + while i < ITERATIONS - ITERATIONS % 2 { + let (out1, out2, out3, out4, out5) = fiat::fp_divstep(d, &f, &g, &v, &r); + let (out1, out2, out3, out4, out5) = fiat::fp_divstep(out1, &out2, &out3, &out4, &out5); + d = out1; + f = out2; + g = out3; + v = out4; + r = out5; + i += 2; + } + + if ITERATIONS % 2 != 0 { + let (_out1, out2, _out3, out4, _out5) = fiat::fp_divstep(d, &f, &g, &v, &r); + v = out4; + f = out2; + } + + let s = ((f[f.len() - 1] >> 64 - 1) & 1) as u8; + let mut neg = fiat::FpMontgomeryDomainFieldElement([0; 6]); + fiat::fp_opp(&mut neg, &fiat::FpMontgomeryDomainFieldElement(v)); + + let mut v_prime: [u64; 6] = [0u64; 6]; + fiat::fp_selectznz(&mut v_prime, s, &v, &neg.0); + + let mut pre_comp: [u64; 6] = [0u64; 6]; + fiat::fp_divstep_precomp(&mut pre_comp); + + let mut result = fiat::FpMontgomeryDomainFieldElement([0; 6]); + fiat::fp_mul(&mut result, &fiat::FpMontgomeryDomainFieldElement(v_prime), &fiat::FpMontgomeryDomainFieldElement(pre_comp)); + + Fp(result) + } } impl Add for Fp { @@ -99,3 +152,20 @@ impl Neg for Fp { Fp(result) } } + +#[cfg(test)] +mod tests { + use super::*; + use ark_std::println; + + #[test] + fn inversion_test() { + let one = Fp(fiat::FpMontgomeryDomainFieldElement([202099033278250856, 5854854902718660529, 11492539364873682930, 8885205928937022213, 5545221690922665192, 39800542322357402])); + let one_invert = one.div(); + assert_eq!(one_invert, one); + + let three = Fp::one().add(Fp::one().add(Fp::one())); + let three_invert = three.div(); + assert_eq!(three.mul(three_invert), Fp::one()); + } +} \ No newline at end of file