From cdb36ae2ff99d2138580e415c9412c266cbbf15d Mon Sep 17 00:00:00 2001 From: vincentvbh Date: Mon, 26 Aug 2024 01:16:42 +0800 Subject: [PATCH] simplify rounding of fractions --- ref/poly.c | 61 ++++++++++++++++++++++++++++----------------------- ref/polyvec.c | 42 +++++++++++++++++++---------------- 2 files changed, 57 insertions(+), 46 deletions(-) diff --git a/ref/poly.c b/ref/poly.c index cbd3abf..56a5a81 100644 --- a/ref/poly.c +++ b/ref/poly.c @@ -20,22 +20,23 @@ void poly_compress(uint8_t r[KYBER_POLYCOMPRESSEDBYTES], const poly *a) { unsigned int i,j; int16_t u; - uint32_t d0; uint8_t t[8]; #if (KYBER_POLYCOMPRESSEDBYTES == 128) for(i=0;icoeffs[8*i+j]; - u += (u >> 15) & KYBER_Q; -/* t[j] = ((((uint16_t)u << 4) + KYBER_Q/2)/KYBER_Q) & 15; */ - d0 = u << 4; - d0 += 1665; - d0 *= 80635; - d0 >>= 28; - t[j] = d0 & 0xf; + + // 16-bit precision suffices for round(2^4 x / q) + // inputs are in [-q/2, ..., q/2] + // 315 = round(16 * 2^16 / q) + t[j] = (int16_t)(((int32_t)u * 315 + (1 << 15)) >> 16) & 0xf; + + // this is equivalent to first mapping to positive + // standard representatives followed by + // t[j] = ((((uint16_t)u << 4) + KYBER_Q/2)/KYBER_Q) & 0xf; + } r[0] = t[0] | (t[1] << 4); @@ -47,15 +48,17 @@ void poly_compress(uint8_t r[KYBER_POLYCOMPRESSEDBYTES], const poly *a) #elif (KYBER_POLYCOMPRESSEDBYTES == 160) for(i=0;icoeffs[8*i+j]; - u += (u >> 15) & KYBER_Q; -/* t[j] = ((((uint32_t)u << 5) + KYBER_Q/2)/KYBER_Q) & 31; */ - d0 = u << 5; - d0 += 1664; - d0 *= 40318; - d0 >>= 27; - t[j] = d0 & 0x1f; + + // 15-bit precision suffices for round(2^5 x / q) + // inputs are in [-q/2, ..., q/2] + // 315 = round(32 * 2^15 / q) + t[j] = (int16_t)(((int32_t)u * 315 + (1 << 14)) >> 15) & 0x1f; + + // this is equivalent to first mapping to positive + // standard representatives followed by + // t[j] = ((((uint32_t)u << 5) + KYBER_Q/2)/KYBER_Q) & 0x1f; + } r[0] = (t[0] >> 0) | (t[1] << 5); @@ -192,20 +195,24 @@ void poly_frommsg(poly *r, const uint8_t msg[KYBER_INDCPA_MSGBYTES]) void poly_tomsg(uint8_t msg[KYBER_INDCPA_MSGBYTES], const poly *a) { unsigned int i,j; - uint32_t t; + int16_t u; for(i=0;icoeffs[8*i+j]; - // t += ((int16_t)t >> 15) & KYBER_Q; - // t = (((t << 1) + KYBER_Q/2)/KYBER_Q) & 1; - t <<= 1; - t += 1665; - t *= 80635; - t >>= 28; - t &= 1; - msg[i] |= t << j; + u = a->coeffs[8*i+j]; + + // 19-bit precision suffices for round(2 x / q) + // inputs are in [-q/2, ..., q/2] + // 315 = round(2 * 2^19 / q) + u = (int16_t)(((int32_t)u * 315 + (1 << 18)) >> 19) & 1; + + // this is equivalent to first mapping to positive + // standard representatives followed by + // u = ((((uint16_t)u << 1) + KYBER_Q/2)/KYBER_Q) & 1; + + msg[i] |= u << j; + } } } diff --git a/ref/polyvec.c b/ref/polyvec.c index 669f6a5..fb3092d 100644 --- a/ref/polyvec.c +++ b/ref/polyvec.c @@ -15,22 +15,24 @@ void polyvec_compress(uint8_t r[KYBER_POLYVECCOMPRESSEDBYTES], const polyvec *a) { unsigned int i,j,k; - uint64_t d0; + int16_t u; #if (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 352)) uint16_t t[8]; for(i=0;ivec[i].coeffs[8*j+k]; - t[k] += ((int16_t)t[k] >> 15) & KYBER_Q; -/* t[k] = ((((uint32_t)t[k] << 11) + KYBER_Q/2)/KYBER_Q) & 0x7ff; */ - d0 = t[k]; - d0 <<= 11; - d0 += 1664; - d0 *= 645084; - d0 >>= 31; - t[k] = d0 & 0x7ff; + u = a->vec[i].coeffs[8*j+k]; + + // 21-bit suffices for round(2048 x / q) + // inputs are in [-q/2, ..., q/2] + // 1290167 = round(2048 * 2^21 / q) + t[k] = ((int16_t)(((int32_t)u * 1290167 + (1 << 20)) >> 21)) & 0x7ff; + + // this is equivalent to first mapping to positive + // standard representatives followed by + // t[k] = ((((uint32_t)u << 11) + KYBER_Q/2)/KYBER_Q) & 0x7ff; + } r[ 0] = (t[0] >> 0); @@ -52,15 +54,17 @@ void polyvec_compress(uint8_t r[KYBER_POLYVECCOMPRESSEDBYTES], const polyvec *a) for(i=0;ivec[i].coeffs[4*j+k]; - t[k] += ((int16_t)t[k] >> 15) & KYBER_Q; -/* t[k] = ((((uint32_t)t[k] << 10) + KYBER_Q/2)/ KYBER_Q) & 0x3ff; */ - d0 = t[k]; - d0 <<= 10; - d0 += 1665; - d0 *= 1290167; - d0 >>= 32; - t[k] = d0 & 0x3ff; + u = a->vec[i].coeffs[4*j+k]; + + // 22-bit suffices for round(1024 x / q) + // inputs are in [-q/2, ..., q/2] + // 1290167 = round(1024 * 2^22 / q) + t[k] = ((int16_t)(((int32_t)u * 1290167 + (1 << 21)) >> 22)) & 0x3ff; + + // this is equivalent to first mapping to positive + // standard representatives followed by + // t[k] = ((((uint32_t)u << 10) + KYBER_Q/2)/ KYBER_Q) & 0x3ff; + } r[0] = (t[0] >> 0);