Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solidify Notion of Evaluation Context #210

Merged
merged 2 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions pkg/air/expr.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package air

import (
"math"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/schema"
sc "github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
)

Expand All @@ -17,7 +15,7 @@ import (
// trace expansion).
type Expr interface {
util.Boundable
schema.Evaluable
sc.Evaluable

// String produces a string representing this as an S-Expression.
String() string
Expand All @@ -44,7 +42,7 @@ type Add struct{ Args []Expr }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Add) Context(schema sc.Schema) (uint, uint, bool) {
func (p *Add) Context(schema sc.Schema) trace.Context {
return sc.JoinContexts[Expr](p.Args, schema)
}

Expand Down Expand Up @@ -73,7 +71,7 @@ type Sub struct{ Args []Expr }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Sub) Context(schema sc.Schema) (uint, uint, bool) {
func (p *Sub) Context(schema sc.Schema) trace.Context {
return sc.JoinContexts[Expr](p.Args, schema)
}

Expand Down Expand Up @@ -102,7 +100,7 @@ type Mul struct{ Args []Expr }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Mul) Context(schema sc.Schema) (uint, uint, bool) {
func (p *Mul) Context(schema sc.Schema) trace.Context {
return sc.JoinContexts[Expr](p.Args, schema)
}

Expand Down Expand Up @@ -154,8 +152,8 @@ func NewConstCopy(val *fr.Element) Expr {

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Constant) Context(schema sc.Schema) (uint, uint, bool) {
return math.MaxUint, math.MaxUint, true
func (p *Constant) Context(schema sc.Schema) trace.Context {
return trace.VoidContext()
}

// Add two expressions together, producing a third.
Expand Down Expand Up @@ -193,9 +191,9 @@ func NewColumnAccess(column uint, shift int) Expr {

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *ColumnAccess) Context(schema sc.Schema) (uint, uint, bool) {
func (p *ColumnAccess) Context(schema sc.Schema) trace.Context {
col := schema.Columns().Nth(p.Column)
return col.Module(), col.LengthMultiplier(), true
return col.Context()
}

// Add two expressions together, producing a third.
Expand Down
6 changes: 3 additions & 3 deletions pkg/air/gadgets/bits.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func ApplyBinaryGadget(col uint, schema *air.Schema) {
// Construct X * (X-1)
X_X_m1 := X.Mul(X_m1)
// Done!
schema.AddVanishingConstraint(fmt.Sprintf("%s:u1", name), column.Module(), column.LengthMultiplier(), nil, X_X_m1)
schema.AddVanishingConstraint(fmt.Sprintf("%s:u1", name), column.Context(), nil, X_X_m1)
}

// ApplyBitwidthGadget ensures all values in a given column fit within a given
Expand All @@ -45,7 +45,7 @@ func ApplyBitwidthGadget(col uint, nbits uint, schema *air.Schema) {
coefficient := fr.NewElement(1)
// Add decomposition assignment
index := schema.AddAssignment(
assignment.NewByteDecomposition(name, column.Module(), column.LengthMultiplier(), col, n))
assignment.NewByteDecomposition(name, column.Context(), col, n))
// Construct Columns
for i := uint(0); i < n; i++ {
// Create Column + Constraint
Expand All @@ -61,5 +61,5 @@ func ApplyBitwidthGadget(col uint, nbits uint, schema *air.Schema) {
X := air.NewColumnAccess(col, 0)
eq := X.Equate(sum)
// Construct column name
schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", name, nbits), column.Module(), column.LengthMultiplier(), nil, eq)
schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", name, nbits), column.Context(), nil, eq)
}
4 changes: 2 additions & 2 deletions pkg/air/gadgets/column_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ func ApplyColumnSortGadget(col uint, sign bool, bitwidth uint, schema *air.Schem
}
// Add delta assignment
deltaIndex := schema.AddAssignment(
assignment.NewComputedColumn(column.Module(), deltaName, column.LengthMultiplier(), Xdiff))
assignment.NewComputedColumn(column.Context(), deltaName, Xdiff))
// Add necessary bitwidth constraints
ApplyBitwidthGadget(deltaIndex, bitwidth, schema)
// Configure constraint: Delta[k] = X[k] - X[k-1]
Dk := air.NewColumnAccess(deltaIndex, 0)
schema.AddVanishingConstraint(deltaName, column.Module(), column.LengthMultiplier(), nil, Dk.Equate(Xdiff))
schema.AddVanishingConstraint(deltaName, column.Context(), nil, Dk.Equate(Xdiff))
}
8 changes: 4 additions & 4 deletions pkg/air/gadgets/expand.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,23 @@ func Expand(e air.Expr, schema *air.Schema) uint {
return ca.Column
}
// No optimisation, therefore expand using a computedcolumn
module, multiplier := sc.DetermineEnclosingModuleOfExpression(e, schema)
ctx := e.Context(schema)
// Determine computed column name
name := e.String()
// Look up column
index, ok := sc.ColumnIndexOf(schema, module, name)
index, ok := sc.ColumnIndexOf(schema, ctx.Module(), name)
// Add new column (if it does not already exist)
if !ok {
// Add computed column
index = schema.AddAssignment(assignment.NewComputedColumn(module, name, multiplier, e))
index = schema.AddAssignment(assignment.NewComputedColumn(ctx, name, e))
}
// Construct v == [e]
v := air.NewColumnAccess(index, 0)
// Construct 1 == e/e
eq_e_v := v.Equate(e)
// Ensure (e - v) == 0, where v is value of computed column.
c_name := fmt.Sprintf("[%s]", e.String())
schema.AddVanishingConstraint(c_name, module, multiplier, nil, eq_e_v)
schema.AddVanishingConstraint(c_name, ctx, nil, eq_e_v)
//
return index
}
17 changes: 9 additions & 8 deletions pkg/air/gadgets/lexicographic_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/consensys/go-corset/pkg/air"
sc "github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/schema/assignment"
"github.com/consensys/go-corset/pkg/trace"
)

// ApplyLexicographicSortingGadget Add sorting constraints for a sequence of one
Expand All @@ -33,19 +34,19 @@ func ApplyLexicographicSortingGadget(columns []uint, signs []bool, bitwidth uint
panic("Inconsistent number of columns and signs for lexicographic sort.")
}
// Determine enclosing module for this gadget.
module, multiplier := sc.DetermineEnclosingModuleOfColumns(columns, schema)
ctx := sc.ContextOfColumns(columns, schema)
// Construct a unique prefix for this sort.
prefix := constructLexicographicSortingPrefix(columns, signs, schema)
// Add trace computation
deltaIndex := schema.AddAssignment(
assignment.NewLexicographicSort(prefix, module, multiplier, columns, signs, bitwidth))
assignment.NewLexicographicSort(prefix, ctx, columns, signs, bitwidth))
// Construct selecto bits.
addLexicographicSelectorBits(prefix, module, multiplier, deltaIndex, columns, schema)
addLexicographicSelectorBits(prefix, ctx, deltaIndex, columns, schema)
// Construct delta terms
constraint := constructLexicographicDeltaConstraint(deltaIndex, columns, signs)
// Add delta constraint
deltaName := fmt.Sprintf("%s:delta", prefix)
schema.AddVanishingConstraint(deltaName, module, multiplier, nil, constraint)
schema.AddVanishingConstraint(deltaName, ctx, nil, constraint)
// Add necessary bitwidth constraints
ApplyBitwidthGadget(deltaIndex, bitwidth, schema)
}
Expand Down Expand Up @@ -77,7 +78,7 @@ func constructLexicographicSortingPrefix(columns []uint, signs []bool, schema *a
//
// NOTE: this implementation differs from the original corset which used an
// additional "Eq" bit to help ensure at most one selector bit was enabled.
func addLexicographicSelectorBits(prefix string, module uint, multiplier uint,
func addLexicographicSelectorBits(prefix string, context trace.Context,
deltaIndex uint, columns []uint, schema *air.Schema) {
ncols := uint(len(columns))
// Calculate column index of first selector bit
Expand All @@ -102,7 +103,7 @@ func addLexicographicSelectorBits(prefix string, module uint, multiplier uint,
pterms[i] = air.NewColumnAccess(bitIndex+i, 0)
pDiff := air.NewColumnAccess(columns[i], 0).Sub(air.NewColumnAccess(columns[i], -1))
pName := fmt.Sprintf("%s:%d:a", prefix, i)
schema.AddVanishingConstraint(pName, module, multiplier,
schema.AddVanishingConstraint(pName, context,
nil, air.NewConst64(1).Sub(&air.Add{Args: pterms}).Mul(pDiff))
// (∀j<i.Bj=0) ∧ Bi=1 ==> C[k]≠C[k-1]
qDiff := Normalise(air.NewColumnAccess(columns[i], 0).Sub(air.NewColumnAccess(columns[i], -1)), schema)
Expand All @@ -115,14 +116,14 @@ func addLexicographicSelectorBits(prefix string, module uint, multiplier uint,
constraint = air.NewConst64(1).Sub(&air.Add{Args: qterms}).Mul(constraint)
}

schema.AddVanishingConstraint(qName, module, multiplier, nil, constraint)
schema.AddVanishingConstraint(qName, context, nil, constraint)
}

sum := &air.Add{Args: terms}
// (sum = 0) ∨ (sum = 1)
constraint := sum.Mul(sum.Equate(air.NewConst64(1)))
name := fmt.Sprintf("%s:xor", prefix)
schema.AddVanishingConstraint(name, module, multiplier, nil, constraint)
schema.AddVanishingConstraint(name, context, nil, constraint)
}

// Construct the lexicographic delta constraint. This states that the delta
Expand Down
12 changes: 6 additions & 6 deletions pkg/air/gadgets/normalisation.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,17 @@ func Normalise(e air.Expr, schema *air.Schema) air.Expr {
// ensure it really holds the inverted value.
func ApplyPseudoInverseGadget(e air.Expr, schema *air.Schema) air.Expr {
// Determine enclosing module.
module, multiplier := sc.DetermineEnclosingModuleOfExpression(e, schema)
ctx := e.Context(schema)
// Construct inverse computation
ie := &Inverse{Expr: e}
// Determine computed column name
name := ie.String()
// Look up column
index, ok := sc.ColumnIndexOf(schema, module, name)
index, ok := sc.ColumnIndexOf(schema, ctx.Module(), name)
// Add new column (if it does not already exist)
if !ok {
// Add computed column
index = schema.AddAssignment(assignment.NewComputedColumn(module, name, multiplier, ie))
index = schema.AddAssignment(assignment.NewComputedColumn(ctx, name, ie))
}

// Construct 1/e
Expand All @@ -54,10 +54,10 @@ func ApplyPseudoInverseGadget(e air.Expr, schema *air.Schema) air.Expr {
inv_e_implies_one_e_e := inv_e.Mul(one_e_e)
// Ensure (e != 0) ==> (1 == e/e)
l_name := fmt.Sprintf("[%s <=]", ie.String())
schema.AddVanishingConstraint(l_name, module, multiplier, nil, e_implies_one_e_e)
schema.AddVanishingConstraint(l_name, ctx, nil, e_implies_one_e_e)
// Ensure (e/e != 0) ==> (1 == e/e)
r_name := fmt.Sprintf("[%s =>]", ie.String())
schema.AddVanishingConstraint(r_name, module, multiplier, nil, inv_e_implies_one_e_e)
schema.AddVanishingConstraint(r_name, ctx, nil, inv_e_implies_one_e_e)
// Done
return air.NewColumnAccess(index, 0)
}
Expand All @@ -81,7 +81,7 @@ func (e *Inverse) Bounds() util.Bounds { return e.Expr.Bounds() }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (e *Inverse) Context(schema sc.Schema) (uint, uint, bool) {
func (e *Inverse) Context(schema sc.Schema) tr.Context {
return e.Expr.Context(schema)
}

Expand Down
23 changes: 12 additions & 11 deletions pkg/air/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/schema/assignment"
"github.com/consensys/go-corset/pkg/schema/constraint"
"github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
)

Expand Down Expand Up @@ -64,14 +65,14 @@ func (p *Schema) AddModule(name string) uint {

// AddColumn appends a new data column whose values must be provided by the
// user.
func (p *Schema) AddColumn(module uint, name string, datatype schema.Type) uint {
if module >= uint(len(p.modules)) {
panic(fmt.Sprintf("invalid module index (%d)", module))
func (p *Schema) AddColumn(context trace.Context, name string, datatype schema.Type) uint {
if context.Module() >= uint(len(p.modules)) {
panic(fmt.Sprintf("invalid module index (%d)", context.Module()))
}

// NOTE: the air level has no ability to enforce the type specified for a
// given column.
p.inputs = append(p.inputs, assignment.NewDataColumn(module, name, datatype))
p.inputs = append(p.inputs, assignment.NewDataColumn(context, name, datatype))
// Calculate column index
return uint(len(p.inputs) - 1)
}
Expand All @@ -87,8 +88,8 @@ func (p *Schema) AddAssignment(c schema.Assignment) uint {
}

// AddLookupConstraint appends a new lookup constraint.
func (p *Schema) AddLookupConstraint(handle string, source uint, source_multiplier uint,
target uint, target_multiplier uint, sources []uint, targets []uint) {
func (p *Schema) AddLookupConstraint(handle string, source trace.Context,
target trace.Context, sources []uint, targets []uint) {
if len(targets) != len(sources) {
panic("differeng number of target / source lookup columns")
}
Expand All @@ -103,7 +104,7 @@ func (p *Schema) AddLookupConstraint(handle string, source uint, source_multipli
}
//
p.constraints = append(p.constraints,
constraint.NewLookupConstraint(handle, source, source_multiplier, target, target_multiplier, from, into))
constraint.NewLookupConstraint(handle, source, target, from, into))
}

// AddPermutationConstraint appends a new permutation constraint which
Expand All @@ -114,13 +115,13 @@ func (p *Schema) AddPermutationConstraint(targets []uint, sources []uint) {
}

// AddVanishingConstraint appends a new vanishing constraint.
func (p *Schema) AddVanishingConstraint(handle string, module uint, multiplier uint, domain *int, expr Expr) {
if module >= uint(len(p.modules)) {
panic(fmt.Sprintf("invalid module index (%d)", module))
func (p *Schema) AddVanishingConstraint(handle string, context trace.Context, domain *int, expr Expr) {
if context.Module() >= uint(len(p.modules)) {
panic(fmt.Sprintf("invalid module index (%d)", context.Module()))
}
// TODO: sanity check expression enclosed by module
p.constraints = append(p.constraints,
constraint.NewVanishingConstraint(handle, module, multiplier, domain, constraint.ZeroTest[Expr]{Expr: expr}))
constraint.NewVanishingConstraint(handle, context, domain, constraint.ZeroTest[Expr]{Expr: expr}))
}

// AddRangeConstraint appends a new range constraint.
Expand Down
17 changes: 10 additions & 7 deletions pkg/binfile/computation.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/consensys/go-corset/pkg/hir"
sc "github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/schema/assignment"
"github.com/consensys/go-corset/pkg/trace"
)

type jsonComputationSet struct {
Expand All @@ -27,7 +28,6 @@ type jsonSortedComputation struct {
// =============================================================================

func (e jsonComputationSet) addToSchema(schema *hir.Schema) {
var multiplier uint
//
for _, c := range e.Computations {
if c.Sorted != nil {
Expand All @@ -44,6 +44,8 @@ func (e jsonComputationSet) addToSchema(schema *hir.Schema) {
// Convert target refs into columns
targets := make([]sc.Column, len(targetRefs))
//
ctx := trace.VoidContext()
//
for i, targetRef := range targetRefs {
src_cid, src_mid := sourceRefs[i].resolve(schema)
_, dst_mid := targetRef.resolve(schema)
Expand All @@ -53,20 +55,21 @@ func (e jsonComputationSet) addToSchema(schema *hir.Schema) {
}
// Determine type of source column
ith := schema.Columns().Nth(src_cid)
ctx = ctx.Join(ith.Context())
// Sanity check we have a sensible type here.
if ith.Type().AsUint() == nil {
panic(fmt.Sprintf("source column %s has field type", sourceRefs[i]))
} else if i == 0 {
multiplier = ith.LengthMultiplier()
} else if multiplier != ith.LengthMultiplier() {
panic(fmt.Sprintf("source column %s has inconsistent length multiplier", sourceRefs[i]))
} else if ctx.IsConflicted() {
panic(fmt.Sprintf("source column %s has conflicted evaluation context", sourceRefs[i]))
} else if ctx.IsVoid() {
panic(fmt.Sprintf("source column %s has void evaluation context", sourceRefs[i]))
}

sources[i] = src_cid
targets[i] = sc.NewColumn(ith.Module(), targetRef.column, multiplier, ith.Type())
targets[i] = sc.NewColumn(ctx, targetRef.column, ith.Type())
}
// Finally, add the sorted permutation assignment
schema.AddAssignment(assignment.NewSortedPermutation(module, multiplier, targets, c.Sorted.Signs, sources))
schema.AddAssignment(assignment.NewSortedPermutation(ctx, targets, c.Sorted.Signs, sources))
}
}
}
Loading