Skip to content

Commit

Permalink
fix(uints): constrain valueOf (#1139)
Browse files Browse the repository at this point in the history
* fix(uints): constrain valueOf

* chore: use existing method  instead

* chore: use the Long type restriction

* test: add test case for add

* feat: check carry correctness in Add

* chore: use unified method for Long length

* chore: use range

---------

Co-authored-by: Ivo Kubjas <[email protected]>
  • Loading branch information
bernard-wagner and ivokub authored Jun 12, 2024
1 parent ccf02f9 commit 6abed5a
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 15 deletions.
44 changes: 29 additions & 15 deletions std/math/uints/uint8.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ package uints

import (
"fmt"
"math/bits"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/internal/logderivprecomp"
Expand Down Expand Up @@ -173,15 +174,18 @@ func (bf *BinaryField[T]) ValueOf(a frontend.Variable) T {
if err != nil {
panic(err)
}
// TODO: add constraint which ensures that map back to

for i := range bts {
r[i] = bf.ByteValueOf(bts[i])
}
expectedValue := bf.ToValue(r)
bf.api.AssertIsEqual(a, expectedValue)

return r
}

func (bf *BinaryField[T]) ToValue(a T) frontend.Variable {
v := make([]frontend.Variable, len(a))
v := make([]frontend.Variable, bf.lenBts())
for i := range v {
v[i] = bf.api.Mul(a[i].Val, 1<<(i*8))
}
Expand All @@ -206,17 +210,17 @@ func (bf *BinaryField[T]) PackLSB(a ...U8) T {
}

func (bf *BinaryField[T]) UnpackMSB(a T) []U8 {
ret := make([]U8, len(a))
for i := 0; i < len(a); i++ {
ret := make([]U8, bf.lenBts())
for i := 0; i < len(ret); i++ {
ret[len(a)-i-1] = a[i]
}
return ret
}

func (bf *BinaryField[T]) UnpackLSB(a T) []U8 {
// cannot deduce that a can be cast to []U8
ret := make([]U8, len(a))
for i := 0; i < len(a); i++ {
ret := make([]U8, bf.lenBts())
for i := 0; i < len(ret); i++ {
ret[i] = a[i]
}
return ret
Expand Down Expand Up @@ -255,18 +259,22 @@ func (bf *BinaryField[T]) Not(a T) T {
}

func (bf *BinaryField[T]) Add(a ...T) T {
va := make([]frontend.Variable, len(a))
tLen := bf.lenBts() * 8
inLen := len(a)
va := make([]frontend.Variable, inLen)
for i := range a {
va[i] = bf.ToValue(a[i])
}
vres := bf.api.Add(va[0], va[1], va[2:]...)
res := bf.ValueOf(vres)
// TODO: should also check the that carry we omitted is correct.
maxBitlen := bits.Len(uint(inLen)) + tLen
// bitslice.Partition below checks that the input is less than 2^maxBitlen and that we have omitted carry correctly
vreslow, _ := bitslice.Partition(bf.api, vres, uint(tLen), bitslice.WithNbDigits(maxBitlen), bitslice.WithUnconstrainedOutputs())
res := bf.ValueOf(vreslow)
return res
}

func (bf *BinaryField[T]) Lrot(a T, c int) T {
l := len(a)
l := bf.lenBts()
if c < 0 {
c = l*8 + c
}
Expand All @@ -293,23 +301,24 @@ func (bf *BinaryField[T]) Lrot(a T, c int) T {
}

func (bf *BinaryField[T]) Rshift(a T, c int) T {
lenB := bf.lenBts()
shiftBl := c / 8
shiftBt := c % 8
partitioned := make([][2]frontend.Variable, len(a)-shiftBl)
partitioned := make([][2]frontend.Variable, lenB-shiftBl)
for i := range partitioned {
lower, upper := bitslice.Partition(bf.api, a[i+shiftBl].Val, uint(shiftBt), bitslice.WithNbDigits(8))
partitioned[i] = [2]frontend.Variable{lower, upper}
}
var ret T
for i := 0; i < len(a)-shiftBl-1; i++ {
for i := 0; i < bf.lenBts()-shiftBl-1; i++ {
if shiftBt != 0 {
ret[i].Val = bf.api.Add(partitioned[i][1], bf.api.Mul(1<<(8-shiftBt), partitioned[i+1][0]))
} else {
ret[i].Val = partitioned[i][1]
}
}
ret[len(a)-shiftBl-1].Val = partitioned[len(a)-shiftBl-1][1]
for i := len(a) - shiftBl; i < len(ret); i++ {
ret[lenB-shiftBl-1].Val = partitioned[lenB-shiftBl-1][1]
for i := lenB - shiftBl; i < lenB; i++ {
ret[i] = NewU8(0)
}
return ret
Expand All @@ -320,11 +329,16 @@ func (bf *BinaryField[T]) ByteAssertEq(a, b U8) {
}

func (bf *BinaryField[T]) AssertEq(a, b T) {
for i := 0; i < len(a); i++ {
for i := 0; i < bf.lenBts(); i++ {
bf.ByteAssertEq(a[i], b[i])
}
}

func (bf *BinaryField[T]) lenBts() int {
var a T
return len(a)
}

func reslice[T U32 | U64](in []T) [][]U8 {
if len(in) == 0 {
panic("zero-length input")
Expand Down
47 changes: 47 additions & 0 deletions std/math/uints/uint8_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,50 @@ func TestRshift(t *testing.T) {
err = test.IsSolved(&rshiftCircuit{Shift: 11}, &rshiftCircuit{Shift: 11, In: NewU32(0x12345678), Expected: NewU32(0x12345678 >> 11)}, ecc.BN254.ScalarField())
assert.NoError(err)
}

type valueOfCircuit[T Long] struct {
In frontend.Variable
Expected T
}

func (c *valueOfCircuit[T]) Define(api frontend.API) error {
uapi, err := New[T](api)
if err != nil {
return err
}
res := uapi.ValueOf(c.In)
uapi.AssertEq(res, c.Expected)
return nil
}

func TestValueOf(t *testing.T) {
assert := test.NewAssert(t)
var err error
err = test.IsSolved(&valueOfCircuit[U64]{}, &valueOfCircuit[U64]{In: 0x12345678, Expected: [8]U8{NewU8(0x78), NewU8(0x56), NewU8(0x34), NewU8(0x12), NewU8(0), NewU8(0), NewU8(0), NewU8(0)}}, ecc.BN254.ScalarField())
assert.NoError(err)
err = test.IsSolved(&valueOfCircuit[U32]{}, &valueOfCircuit[U32]{In: 0x12345678, Expected: [4]U8{NewU8(0x78), NewU8(0x56), NewU8(0x34), NewU8(0x12)}}, ecc.BN254.ScalarField())
assert.NoError(err)
err = test.IsSolved(&valueOfCircuit[U32]{}, &valueOfCircuit[U32]{In: 0x1234567812345678, Expected: [4]U8{NewU8(0x78), NewU8(0x56), NewU8(0x34), NewU8(0x12)}}, ecc.BN254.ScalarField())
assert.Error(err)
}

type addCircuit struct {
In [2]U32
Expected U32
}

func (c *addCircuit) Define(api frontend.API) error {
uapi, err := New[U32](api)
if err != nil {
return err
}
res := uapi.Add(c.In[0], c.In[1])
uapi.AssertEq(res, c.Expected)
return nil
}

func TestAdd(t *testing.T) {
assert := test.NewAssert(t)
err := test.IsSolved(&addCircuit{}, &addCircuit{In: [2]U32{NewU32(^uint32(0)), NewU32(2)}, Expected: NewU32(1)}, ecc.BN254.ScalarField())
assert.NoError(err)
}

0 comments on commit 6abed5a

Please sign in to comment.