diff --git a/pkg/air/expr.go b/pkg/air/expr.go index ae45acf..8eab2b5 100644 --- a/pkg/air/expr.go +++ b/pkg/air/expr.go @@ -1,11 +1,9 @@ package air import ( - "math" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/go-corset/pkg/schema" sc "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/trace" "github.com/consensys/go-corset/pkg/util" ) @@ -17,7 +15,7 @@ import ( // trace expansion). type Expr interface { util.Boundable - schema.Evaluable + sc.Evaluable // String produces a string representing this as an S-Expression. String() string @@ -44,7 +42,7 @@ type Add struct{ Args []Expr } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Add) Context(schema sc.Schema) (uint, uint, bool) { +func (p *Add) Context(schema sc.Schema) trace.Context { return sc.JoinContexts[Expr](p.Args, schema) } @@ -73,7 +71,7 @@ type Sub struct{ Args []Expr } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Sub) Context(schema sc.Schema) (uint, uint, bool) { +func (p *Sub) Context(schema sc.Schema) trace.Context { return sc.JoinContexts[Expr](p.Args, schema) } @@ -102,7 +100,7 @@ type Mul struct{ Args []Expr } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Mul) Context(schema sc.Schema) (uint, uint, bool) { +func (p *Mul) Context(schema sc.Schema) trace.Context { return sc.JoinContexts[Expr](p.Args, schema) } @@ -154,8 +152,8 @@ func NewConstCopy(val *fr.Element) Expr { // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Constant) Context(schema sc.Schema) (uint, uint, bool) { - return math.MaxUint, math.MaxUint, true +func (p *Constant) Context(schema sc.Schema) trace.Context { + return trace.VoidContext() } // Add two expressions together, producing a third. @@ -193,9 +191,9 @@ func NewColumnAccess(column uint, shift int) Expr { // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *ColumnAccess) Context(schema sc.Schema) (uint, uint, bool) { +func (p *ColumnAccess) Context(schema sc.Schema) trace.Context { col := schema.Columns().Nth(p.Column) - return col.Module(), col.LengthMultiplier(), true + return col.Context() } // Add two expressions together, producing a third. diff --git a/pkg/air/gadgets/bits.go b/pkg/air/gadgets/bits.go index 8d8845d..2c279b7 100644 --- a/pkg/air/gadgets/bits.go +++ b/pkg/air/gadgets/bits.go @@ -23,7 +23,7 @@ func ApplyBinaryGadget(col uint, schema *air.Schema) { // Construct X * (X-1) X_X_m1 := X.Mul(X_m1) // Done! - schema.AddVanishingConstraint(fmt.Sprintf("%s:u1", name), column.Module(), column.LengthMultiplier(), nil, X_X_m1) + schema.AddVanishingConstraint(fmt.Sprintf("%s:u1", name), column.Context(), nil, X_X_m1) } // ApplyBitwidthGadget ensures all values in a given column fit within a given @@ -45,7 +45,7 @@ func ApplyBitwidthGadget(col uint, nbits uint, schema *air.Schema) { coefficient := fr.NewElement(1) // Add decomposition assignment index := schema.AddAssignment( - assignment.NewByteDecomposition(name, column.Module(), column.LengthMultiplier(), col, n)) + assignment.NewByteDecomposition(name, column.Context(), col, n)) // Construct Columns for i := uint(0); i < n; i++ { // Create Column + Constraint @@ -61,5 +61,5 @@ func ApplyBitwidthGadget(col uint, nbits uint, schema *air.Schema) { X := air.NewColumnAccess(col, 0) eq := X.Equate(sum) // Construct column name - schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", name, nbits), column.Module(), column.LengthMultiplier(), nil, eq) + schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", name, nbits), column.Context(), nil, eq) } diff --git a/pkg/air/gadgets/column_sort.go b/pkg/air/gadgets/column_sort.go index de08f8b..6877b2f 100644 --- a/pkg/air/gadgets/column_sort.go +++ b/pkg/air/gadgets/column_sort.go @@ -38,10 +38,10 @@ func ApplyColumnSortGadget(col uint, sign bool, bitwidth uint, schema *air.Schem } // Add delta assignment deltaIndex := schema.AddAssignment( - assignment.NewComputedColumn(column.Module(), deltaName, column.LengthMultiplier(), Xdiff)) + assignment.NewComputedColumn(column.Context(), deltaName, Xdiff)) // Add necessary bitwidth constraints ApplyBitwidthGadget(deltaIndex, bitwidth, schema) // Configure constraint: Delta[k] = X[k] - X[k-1] Dk := air.NewColumnAccess(deltaIndex, 0) - schema.AddVanishingConstraint(deltaName, column.Module(), column.LengthMultiplier(), nil, Dk.Equate(Xdiff)) + schema.AddVanishingConstraint(deltaName, column.Context(), nil, Dk.Equate(Xdiff)) } diff --git a/pkg/air/gadgets/expand.go b/pkg/air/gadgets/expand.go index be975b5..dbca243 100644 --- a/pkg/air/gadgets/expand.go +++ b/pkg/air/gadgets/expand.go @@ -20,15 +20,15 @@ func Expand(e air.Expr, schema *air.Schema) uint { return ca.Column } // No optimisation, therefore expand using a computedcolumn - module, multiplier := sc.DetermineEnclosingModuleOfExpression(e, schema) + ctx := e.Context(schema) // Determine computed column name name := e.String() // Look up column - index, ok := sc.ColumnIndexOf(schema, module, name) + index, ok := sc.ColumnIndexOf(schema, ctx.Module(), name) // Add new column (if it does not already exist) if !ok { // Add computed column - index = schema.AddAssignment(assignment.NewComputedColumn(module, name, multiplier, e)) + index = schema.AddAssignment(assignment.NewComputedColumn(ctx, name, e)) } // Construct v == [e] v := air.NewColumnAccess(index, 0) @@ -36,7 +36,7 @@ func Expand(e air.Expr, schema *air.Schema) uint { eq_e_v := v.Equate(e) // Ensure (e - v) == 0, where v is value of computed column. c_name := fmt.Sprintf("[%s]", e.String()) - schema.AddVanishingConstraint(c_name, module, multiplier, nil, eq_e_v) + schema.AddVanishingConstraint(c_name, ctx, nil, eq_e_v) // return index } diff --git a/pkg/air/gadgets/lexicographic_sort.go b/pkg/air/gadgets/lexicographic_sort.go index f12610b..151c2d9 100644 --- a/pkg/air/gadgets/lexicographic_sort.go +++ b/pkg/air/gadgets/lexicographic_sort.go @@ -7,6 +7,7 @@ import ( "github.com/consensys/go-corset/pkg/air" sc "github.com/consensys/go-corset/pkg/schema" "github.com/consensys/go-corset/pkg/schema/assignment" + "github.com/consensys/go-corset/pkg/trace" ) // ApplyLexicographicSortingGadget Add sorting constraints for a sequence of one @@ -33,19 +34,19 @@ func ApplyLexicographicSortingGadget(columns []uint, signs []bool, bitwidth uint panic("Inconsistent number of columns and signs for lexicographic sort.") } // Determine enclosing module for this gadget. - module, multiplier := sc.DetermineEnclosingModuleOfColumns(columns, schema) + ctx := sc.ContextOfColumns(columns, schema) // Construct a unique prefix for this sort. prefix := constructLexicographicSortingPrefix(columns, signs, schema) // Add trace computation deltaIndex := schema.AddAssignment( - assignment.NewLexicographicSort(prefix, module, multiplier, columns, signs, bitwidth)) + assignment.NewLexicographicSort(prefix, ctx, columns, signs, bitwidth)) // Construct selecto bits. - addLexicographicSelectorBits(prefix, module, multiplier, deltaIndex, columns, schema) + addLexicographicSelectorBits(prefix, ctx, deltaIndex, columns, schema) // Construct delta terms constraint := constructLexicographicDeltaConstraint(deltaIndex, columns, signs) // Add delta constraint deltaName := fmt.Sprintf("%s:delta", prefix) - schema.AddVanishingConstraint(deltaName, module, multiplier, nil, constraint) + schema.AddVanishingConstraint(deltaName, ctx, nil, constraint) // Add necessary bitwidth constraints ApplyBitwidthGadget(deltaIndex, bitwidth, schema) } @@ -77,7 +78,7 @@ func constructLexicographicSortingPrefix(columns []uint, signs []bool, schema *a // // NOTE: this implementation differs from the original corset which used an // additional "Eq" bit to help ensure at most one selector bit was enabled. -func addLexicographicSelectorBits(prefix string, module uint, multiplier uint, +func addLexicographicSelectorBits(prefix string, context trace.Context, deltaIndex uint, columns []uint, schema *air.Schema) { ncols := uint(len(columns)) // Calculate column index of first selector bit @@ -102,7 +103,7 @@ func addLexicographicSelectorBits(prefix string, module uint, multiplier uint, pterms[i] = air.NewColumnAccess(bitIndex+i, 0) pDiff := air.NewColumnAccess(columns[i], 0).Sub(air.NewColumnAccess(columns[i], -1)) pName := fmt.Sprintf("%s:%d:a", prefix, i) - schema.AddVanishingConstraint(pName, module, multiplier, + schema.AddVanishingConstraint(pName, context, nil, air.NewConst64(1).Sub(&air.Add{Args: pterms}).Mul(pDiff)) // (∀j C[k]≠C[k-1] qDiff := Normalise(air.NewColumnAccess(columns[i], 0).Sub(air.NewColumnAccess(columns[i], -1)), schema) @@ -115,14 +116,14 @@ func addLexicographicSelectorBits(prefix string, module uint, multiplier uint, constraint = air.NewConst64(1).Sub(&air.Add{Args: qterms}).Mul(constraint) } - schema.AddVanishingConstraint(qName, module, multiplier, nil, constraint) + schema.AddVanishingConstraint(qName, context, nil, constraint) } sum := &air.Add{Args: terms} // (sum = 0) ∨ (sum = 1) constraint := sum.Mul(sum.Equate(air.NewConst64(1))) name := fmt.Sprintf("%s:xor", prefix) - schema.AddVanishingConstraint(name, module, multiplier, nil, constraint) + schema.AddVanishingConstraint(name, context, nil, constraint) } // Construct the lexicographic delta constraint. This states that the delta diff --git a/pkg/air/gadgets/normalisation.go b/pkg/air/gadgets/normalisation.go index 6f6d489..c00a3fb 100644 --- a/pkg/air/gadgets/normalisation.go +++ b/pkg/air/gadgets/normalisation.go @@ -29,17 +29,17 @@ func Normalise(e air.Expr, schema *air.Schema) air.Expr { // ensure it really holds the inverted value. func ApplyPseudoInverseGadget(e air.Expr, schema *air.Schema) air.Expr { // Determine enclosing module. - module, multiplier := sc.DetermineEnclosingModuleOfExpression(e, schema) + ctx := e.Context(schema) // Construct inverse computation ie := &Inverse{Expr: e} // Determine computed column name name := ie.String() // Look up column - index, ok := sc.ColumnIndexOf(schema, module, name) + index, ok := sc.ColumnIndexOf(schema, ctx.Module(), name) // Add new column (if it does not already exist) if !ok { // Add computed column - index = schema.AddAssignment(assignment.NewComputedColumn(module, name, multiplier, ie)) + index = schema.AddAssignment(assignment.NewComputedColumn(ctx, name, ie)) } // Construct 1/e @@ -54,10 +54,10 @@ func ApplyPseudoInverseGadget(e air.Expr, schema *air.Schema) air.Expr { inv_e_implies_one_e_e := inv_e.Mul(one_e_e) // Ensure (e != 0) ==> (1 == e/e) l_name := fmt.Sprintf("[%s <=]", ie.String()) - schema.AddVanishingConstraint(l_name, module, multiplier, nil, e_implies_one_e_e) + schema.AddVanishingConstraint(l_name, ctx, nil, e_implies_one_e_e) // Ensure (e/e != 0) ==> (1 == e/e) r_name := fmt.Sprintf("[%s =>]", ie.String()) - schema.AddVanishingConstraint(r_name, module, multiplier, nil, inv_e_implies_one_e_e) + schema.AddVanishingConstraint(r_name, ctx, nil, inv_e_implies_one_e_e) // Done return air.NewColumnAccess(index, 0) } @@ -81,7 +81,7 @@ func (e *Inverse) Bounds() util.Bounds { return e.Expr.Bounds() } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (e *Inverse) Context(schema sc.Schema) (uint, uint, bool) { +func (e *Inverse) Context(schema sc.Schema) tr.Context { return e.Expr.Context(schema) } diff --git a/pkg/air/schema.go b/pkg/air/schema.go index c14c958..bcfdefa 100644 --- a/pkg/air/schema.go +++ b/pkg/air/schema.go @@ -7,6 +7,7 @@ import ( "github.com/consensys/go-corset/pkg/schema" "github.com/consensys/go-corset/pkg/schema/assignment" "github.com/consensys/go-corset/pkg/schema/constraint" + "github.com/consensys/go-corset/pkg/trace" "github.com/consensys/go-corset/pkg/util" ) @@ -64,14 +65,14 @@ func (p *Schema) AddModule(name string) uint { // AddColumn appends a new data column whose values must be provided by the // user. -func (p *Schema) AddColumn(module uint, name string, datatype schema.Type) uint { - if module >= uint(len(p.modules)) { - panic(fmt.Sprintf("invalid module index (%d)", module)) +func (p *Schema) AddColumn(context trace.Context, name string, datatype schema.Type) uint { + if context.Module() >= uint(len(p.modules)) { + panic(fmt.Sprintf("invalid module index (%d)", context.Module())) } // NOTE: the air level has no ability to enforce the type specified for a // given column. - p.inputs = append(p.inputs, assignment.NewDataColumn(module, name, datatype)) + p.inputs = append(p.inputs, assignment.NewDataColumn(context, name, datatype)) // Calculate column index return uint(len(p.inputs) - 1) } @@ -87,8 +88,8 @@ func (p *Schema) AddAssignment(c schema.Assignment) uint { } // AddLookupConstraint appends a new lookup constraint. -func (p *Schema) AddLookupConstraint(handle string, source uint, source_multiplier uint, - target uint, target_multiplier uint, sources []uint, targets []uint) { +func (p *Schema) AddLookupConstraint(handle string, source trace.Context, + target trace.Context, sources []uint, targets []uint) { if len(targets) != len(sources) { panic("differeng number of target / source lookup columns") } @@ -103,7 +104,7 @@ func (p *Schema) AddLookupConstraint(handle string, source uint, source_multipli } // p.constraints = append(p.constraints, - constraint.NewLookupConstraint(handle, source, source_multiplier, target, target_multiplier, from, into)) + constraint.NewLookupConstraint(handle, source, target, from, into)) } // AddPermutationConstraint appends a new permutation constraint which @@ -114,13 +115,13 @@ func (p *Schema) AddPermutationConstraint(targets []uint, sources []uint) { } // AddVanishingConstraint appends a new vanishing constraint. -func (p *Schema) AddVanishingConstraint(handle string, module uint, multiplier uint, domain *int, expr Expr) { - if module >= uint(len(p.modules)) { - panic(fmt.Sprintf("invalid module index (%d)", module)) +func (p *Schema) AddVanishingConstraint(handle string, context trace.Context, domain *int, expr Expr) { + if context.Module() >= uint(len(p.modules)) { + panic(fmt.Sprintf("invalid module index (%d)", context.Module())) } // TODO: sanity check expression enclosed by module p.constraints = append(p.constraints, - constraint.NewVanishingConstraint(handle, module, multiplier, domain, constraint.ZeroTest[Expr]{Expr: expr})) + constraint.NewVanishingConstraint(handle, context, domain, constraint.ZeroTest[Expr]{Expr: expr})) } // AddRangeConstraint appends a new range constraint. diff --git a/pkg/binfile/computation.go b/pkg/binfile/computation.go index 2a95b73..30fa768 100644 --- a/pkg/binfile/computation.go +++ b/pkg/binfile/computation.go @@ -6,6 +6,7 @@ import ( "github.com/consensys/go-corset/pkg/hir" sc "github.com/consensys/go-corset/pkg/schema" "github.com/consensys/go-corset/pkg/schema/assignment" + "github.com/consensys/go-corset/pkg/trace" ) type jsonComputationSet struct { @@ -27,7 +28,6 @@ type jsonSortedComputation struct { // ============================================================================= func (e jsonComputationSet) addToSchema(schema *hir.Schema) { - var multiplier uint // for _, c := range e.Computations { if c.Sorted != nil { @@ -44,6 +44,8 @@ func (e jsonComputationSet) addToSchema(schema *hir.Schema) { // Convert target refs into columns targets := make([]sc.Column, len(targetRefs)) // + ctx := trace.VoidContext() + // for i, targetRef := range targetRefs { src_cid, src_mid := sourceRefs[i].resolve(schema) _, dst_mid := targetRef.resolve(schema) @@ -53,20 +55,21 @@ func (e jsonComputationSet) addToSchema(schema *hir.Schema) { } // Determine type of source column ith := schema.Columns().Nth(src_cid) + ctx = ctx.Join(ith.Context()) // Sanity check we have a sensible type here. if ith.Type().AsUint() == nil { panic(fmt.Sprintf("source column %s has field type", sourceRefs[i])) - } else if i == 0 { - multiplier = ith.LengthMultiplier() - } else if multiplier != ith.LengthMultiplier() { - panic(fmt.Sprintf("source column %s has inconsistent length multiplier", sourceRefs[i])) + } else if ctx.IsConflicted() { + panic(fmt.Sprintf("source column %s has conflicted evaluation context", sourceRefs[i])) + } else if ctx.IsVoid() { + panic(fmt.Sprintf("source column %s has void evaluation context", sourceRefs[i])) } sources[i] = src_cid - targets[i] = sc.NewColumn(ith.Module(), targetRef.column, multiplier, ith.Type()) + targets[i] = sc.NewColumn(ctx, targetRef.column, ith.Type()) } // Finally, add the sorted permutation assignment - schema.AddAssignment(assignment.NewSortedPermutation(module, multiplier, targets, c.Sorted.Signs, sources)) + schema.AddAssignment(assignment.NewSortedPermutation(ctx, targets, c.Sorted.Signs, sources)) } } } diff --git a/pkg/binfile/constraint.go b/pkg/binfile/constraint.go index 7a8dcc9..1b71d28 100644 --- a/pkg/binfile/constraint.go +++ b/pkg/binfile/constraint.go @@ -1,6 +1,8 @@ package binfile import ( + "fmt" + "github.com/consensys/go-corset/pkg/hir" sc "github.com/consensys/go-corset/pkg/schema" ) @@ -52,22 +54,22 @@ func (e jsonConstraint) addToSchema(schema *hir.Schema) { // Translate Domain domain := e.Vanishes.Domain.toHir() // Determine enclosing module - module, multiplier := sc.DetermineEnclosingModuleOfExpression(expr, schema) + ctx := expr.Context(schema) // Construct the vanishing constraint - schema.AddVanishingConstraint(e.Vanishes.Handle, module, multiplier, domain, expr) + schema.AddVanishingConstraint(e.Vanishes.Handle, ctx, domain, expr) } else if e.Lookup != nil { sources := jsonExprsToHirUnit(e.Lookup.From, schema) targets := jsonExprsToHirUnit(e.Lookup.To, schema) - source, source_multiplier, err1 := sc.DetermineEnclosingModuleOfExpressions(sources, schema) - target, target_multiplier, err2 := sc.DetermineEnclosingModuleOfExpressions(targets, schema) + sourceCtx := sc.JoinContexts(sources, schema) + targetCtx := sc.JoinContexts(targets, schema) // Error check - if err1 != nil { - panic(err1.Error) - } else if err2 != nil { - panic(err2.Error) + if sourceCtx.IsConflicted() || sourceCtx.IsVoid() { + panic(fmt.Sprintf("lookup %s has conflicting source evaluation context", e.Lookup.Handle)) + } else if targetCtx.IsConflicted() || targetCtx.IsVoid() { + panic(fmt.Sprintf("lookup %s has conflicting target evaluation context", e.Lookup.Handle)) } // Add constraint - schema.AddLookupConstraint(e.Lookup.Handle, source, source_multiplier, target, target_multiplier, sources, targets) + schema.AddLookupConstraint(e.Lookup.Handle, sourceCtx, targetCtx, sources, targets) } else if e.Permutation == nil { // Catch all panic("Unknown JSON constraint encountered") diff --git a/pkg/binfile/constraint_set.go b/pkg/binfile/constraint_set.go index 2c9eaac..a9220b2 100644 --- a/pkg/binfile/constraint_set.go +++ b/pkg/binfile/constraint_set.go @@ -6,6 +6,7 @@ import ( "github.com/consensys/go-corset/pkg/hir" sc "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/trace" ) // This is very much a Work-In-Progress :) @@ -109,9 +110,11 @@ func HirSchemaFromJson(bytes []byte) (schema *hir.Schema, err error) { } else { cref := asColumnRef(c.Handle) mid := registerModule(schema, cref.module) + // NOTE: assumption here that length multiplier is always one. + ctx := trace.NewContext(mid, 1) col_type := c.Type.toHir() // Add column for this - schema.AddDataColumn(mid, cref.column, col_type) + schema.AddDataColumn(ctx, cref.column, col_type) // Check whether a type constraint required or not. if c.MustProve { cid, ok := sc.ColumnIndexOf(schema, mid, cref.column) diff --git a/pkg/cmd/trace.go b/pkg/cmd/trace.go index 32bdd6e..5117db7 100644 --- a/pkg/cmd/trace.go +++ b/pkg/cmd/trace.go @@ -63,7 +63,7 @@ func filterColumns(tr trace.Trace, prefix string) trace.Trace { // across traces. for i := uint(0); i < n; i++ { ith := tr.Columns().Get(i) - name := tr.Modules().Get(ith.Module()).Name() + name := tr.Modules().Get(ith.Context().Module()).Name() if !builder.HasModule(name) { if _, err := builder.Register(name, ith.Height()); err != nil { diff --git a/pkg/cmd/util.go b/pkg/cmd/util.go index 74ee4a4..c7a9774 100644 --- a/pkg/cmd/util.go +++ b/pkg/cmd/util.go @@ -187,7 +187,7 @@ func printSyntaxError(filename string, err *sexp.SyntaxError, text string) { // index. func QualifiedColumnName(cid uint, tr trace.Trace) string { col := tr.Columns().Get(cid) - mod := tr.Modules().Get(col.Module()) + mod := tr.Modules().Get(col.Context().Module()) // Check whether qualification required if mod.Name() != "" { return fmt.Sprintf("%s.%s", mod.Name(), col.Name()) diff --git a/pkg/hir/environment.go b/pkg/hir/environment.go index c3644a4..bdd0968 100644 --- a/pkg/hir/environment.go +++ b/pkg/hir/environment.go @@ -5,6 +5,7 @@ import ( "github.com/consensys/go-corset/pkg/schema" sc "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/trace" ) // =================================================================== @@ -41,7 +42,7 @@ func EmptyEnvironment() *Environment { // RegisterModule registers a new module within this environment. Observe that // this will panic if the module already exists. -func (p *Environment) RegisterModule(module string) uint { +func (p *Environment) RegisterModule(module string) trace.Context { if p.HasModule(module) { panic(fmt.Sprintf("module %s already exists", module)) } @@ -50,20 +51,20 @@ func (p *Environment) RegisterModule(module string) uint { // Update cache p.modules[module] = mid // Done - return mid + return trace.NewContext(mid, 1) } // AddDataColumn registers a new column within a given module. Observe that // this will panic if the column already exists. -func (p *Environment) AddDataColumn(module uint, column string, datatype sc.Type) uint { - if p.HasColumn(module, column) { - panic(fmt.Sprintf("column %d:%s already exists", module, column)) +func (p *Environment) AddDataColumn(context trace.Context, column string, datatype sc.Type) uint { + if p.HasColumn(context, column) { + panic(fmt.Sprintf("column %d:%s already exists", context.Module(), column)) } // Update schema - p.schema.AddDataColumn(module, column, datatype) + p.schema.AddDataColumn(context, column, datatype) // Update cache cid := uint(len(p.columns)) - cref := columnRef{module, column} + cref := columnRef{context.Module(), column} p.columns[cref] = cid // Done return cid @@ -78,7 +79,7 @@ func (p *Environment) AddAssignment(decl schema.Assignment) { // Update cache for i := decl.Columns(); i.HasNext(); { ith := i.Next() - cref := columnRef{ith.Module(), ith.Name()} + cref := columnRef{ith.Context().Module(), ith.Name()} p.columns[cref] = index index++ } @@ -86,15 +87,15 @@ func (p *Environment) AddAssignment(decl schema.Assignment) { // LookupModule determines the module index for a given named module, or return // false if no such module exists. -func (p *Environment) LookupModule(module string) (uint, bool) { +func (p *Environment) LookupModule(module string) (trace.Context, bool) { mid, ok := p.modules[module] - return mid, ok + return trace.NewContext(mid, 1), ok } // LookupColumn determines the column index for a given named column in a given // module, or return false if no such column exists. -func (p *Environment) LookupColumn(module uint, column string) (uint, bool) { - cref := columnRef{module, column} +func (p *Environment) LookupColumn(context trace.Context, column string) (uint, bool) { + cref := columnRef{context.Module(), column} cid, ok := p.columns[cref] return cid, ok @@ -108,8 +109,8 @@ func (p *Environment) HasModule(module string) bool { } // HasColumn checks whether a given module has a given column, or not. -func (p *Environment) HasColumn(module uint, column string) bool { - _, ok := p.LookupColumn(module, column) +func (p *Environment) HasColumn(context trace.Context, column string) bool { + _, ok := p.LookupColumn(context, column) // Discard column index return ok } diff --git a/pkg/hir/expr.go b/pkg/hir/expr.go index 531aef4..d65cc9c 100644 --- a/pkg/hir/expr.go +++ b/pkg/hir/expr.go @@ -1,8 +1,6 @@ package hir import ( - "math" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/mir" sc "github.com/consensys/go-corset/pkg/schema" @@ -50,7 +48,7 @@ func (p *Add) Bounds() util.Bounds { return util.BoundsForArray(p.Args) } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Add) Context(schema sc.Schema) (uint, uint, bool) { +func (p *Add) Context(schema sc.Schema) trace.Context { return sc.JoinContexts[Expr](p.Args, schema) } @@ -67,7 +65,7 @@ func (p *Sub) Bounds() util.Bounds { return util.BoundsForArray(p.Args) } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Sub) Context(schema sc.Schema) (uint, uint, bool) { +func (p *Sub) Context(schema sc.Schema) trace.Context { return sc.JoinContexts[Expr](p.Args, schema) } @@ -84,7 +82,7 @@ func (p *Mul) Bounds() util.Bounds { return util.BoundsForArray(p.Args) } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Mul) Context(schema sc.Schema) (uint, uint, bool) { +func (p *Mul) Context(schema sc.Schema) trace.Context { return sc.JoinContexts[Expr](p.Args, schema) } @@ -101,7 +99,7 @@ func (p *List) Bounds() util.Bounds { return util.BoundsForArray(p.Args) } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *List) Context(schema sc.Schema) (uint, uint, bool) { +func (p *List) Context(schema sc.Schema) trace.Context { return sc.JoinContexts[Expr](p.Args, schema) } @@ -118,8 +116,8 @@ func (p *Constant) Bounds() util.Bounds { return util.EMPTY_BOUND } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Constant) Context(schema sc.Schema) (uint, uint, bool) { - return math.MaxUint, math.MaxUint, true +func (p *Constant) Context(schema sc.Schema) trace.Context { + return trace.VoidContext() } // ============================================================================ @@ -157,7 +155,7 @@ func (p *IfZero) Bounds() util.Bounds { // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *IfZero) Context(schema sc.Schema) (uint, uint, bool) { +func (p *IfZero) Context(schema sc.Schema) trace.Context { if p.TrueBranch != nil && p.FalseBranch != nil { args := []Expr{p.Condition, p.TrueBranch, p.FalseBranch} return sc.JoinContexts[Expr](args, schema) @@ -188,7 +186,7 @@ func (p *Normalise) Bounds() util.Bounds { // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Normalise) Context(schema sc.Schema) (uint, uint, bool) { +func (p *Normalise) Context(schema sc.Schema) trace.Context { return p.Arg.Context(schema) } @@ -220,7 +218,7 @@ func (p *ColumnAccess) Bounds() util.Bounds { // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *ColumnAccess) Context(schema sc.Schema) (uint, uint, bool) { +func (p *ColumnAccess) Context(schema sc.Schema) trace.Context { col := schema.Columns().Nth(p.Column) - return col.Module(), col.LengthMultiplier(), true + return col.Context() } diff --git a/pkg/hir/lower.go b/pkg/hir/lower.go index 09b3091..93f35de 100644 --- a/pkg/hir/lower.go +++ b/pkg/hir/lower.go @@ -21,7 +21,7 @@ func (p *Schema) LowerToMir() *mir.Schema { // Lower columns for _, input := range p.inputs { col := input.(DataColumn) - mirSchema.AddDataColumn(col.Module(), col.Name(), col.Type()) + mirSchema.AddDataColumn(col.Context(), col.Name(), col.Type()) } // Lower assignments (nothing to do here) for _, asn := range p.assignments { @@ -51,7 +51,7 @@ func lowerConstraintToMir(c sc.Constraint, schema *mir.Schema) { mir_exprs := v.Constraint().Expr.LowerTo(schema) // Add individual constraints arising for _, mir_expr := range mir_exprs { - schema.AddVanishingConstraint(v.Handle(), v.Module(), v.LengthMultiplier(), v.Domain(), mir_expr) + schema.AddVanishingConstraint(v.Handle(), v.Context(), v.Domain(), mir_expr) } } else if v, ok := c.(*constraint.TypeConstraint); ok { schema.AddTypeConstraint(v.Target(), v.Type()) @@ -73,9 +73,7 @@ func lowerLookupConstraint(c LookupConstraint, schema *mir.Schema) { into[i] = lowerUnitTo(targets[i], schema) } // - src_mod, src_mul := c.SourceContext() - dst_mod, dst_mul := c.TargetContext() - schema.AddLookupConstraint(c.Handle(), src_mod, src_mul, dst_mod, dst_mul, from, into) + schema.AddLookupConstraint(c.Handle(), c.SourceContext(), c.TargetContext(), from, into) } // Lower an expression which is expected to lower into a single expression. diff --git a/pkg/hir/parser.go b/pkg/hir/parser.go index 9ed0204..ceedd7f 100644 --- a/pkg/hir/parser.go +++ b/pkg/hir/parser.go @@ -8,10 +8,10 @@ import ( "unicode" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/go-corset/pkg/schema" sc "github.com/consensys/go-corset/pkg/schema" "github.com/consensys/go-corset/pkg/schema/assignment" "github.com/consensys/go-corset/pkg/sexp" + "github.com/consensys/go-corset/pkg/trace" ) // =================================================================== @@ -51,7 +51,7 @@ type hirParser struct { // Translator used for recursive expressions. translator *sexp.Translator[Expr] // Current module being parsed. - module uint + module trace.Context // Environment used during parsing to resolve column names into column // indices. env *Environment @@ -164,19 +164,20 @@ func (p *hirParser) parseColumnDeclaration(l *sexp.List) error { // Parse a sorted permutation declaration func (p *hirParser) parseSortedPermutationDeclaration(l *sexp.List) error { - var multiplier uint // Target columns are (sorted) permutations of source columns. sexpTargets := l.Elements[1].AsList() // Source columns. sexpSources := l.Elements[2].AsList() // Convert into appropriate form. - targets := make([]schema.Column, sexpTargets.Len()) + targets := make([]sc.Column, sexpTargets.Len()) sources := make([]uint, sexpSources.Len()) signs := make([]bool, sexpSources.Len()) // if sexpTargets.Len() != sexpSources.Len() { return p.translator.SyntaxError(l, "sorted permutation requires matching number of source and target columns") } + // initialise context + ctx := trace.VoidContext() // for i := 0; i < sexpSources.Len(); i++ { source := sexpSources.Get(i).AsSymbol() @@ -214,23 +215,22 @@ func (p *hirParser) parseSortedPermutationDeclaration(l *sexp.List) error { // No, it doesn't. return p.translator.SyntaxError(sexpTargets.Get(i), fmt.Sprintf("duplicate column %s", targetName)) } - // Check multiplier calculation + // Check source context sourceCol := p.env.schema.Columns().Nth(sourceIndex) - if i == 0 { - // First time around, multiplier is determine by the first source column. - multiplier = sourceCol.LengthMultiplier() - } else if sourceCol.LengthMultiplier() != multiplier { - // In all other cases, multiplier must match that of first source column. - return p.translator.SyntaxError(sexpSources.Get(i), "inconsistent length multiplier") + ctx = ctx.Join(sourceCol.Context()) + // Sanity check we have a sensible type here. + if ctx.IsConflicted() { + panic(fmt.Sprintf("source column %s has conflicted evaluation context", sexpSources.Get(i))) + } else if ctx.IsVoid() { + panic(fmt.Sprintf("source column %s has void evaluation context", sexpSources.Get(i))) } // Copy over column name sources[i] = sourceIndex // FIXME: determine source column type - targets[i] = schema.NewColumn(p.module, targetName, multiplier, &schema.FieldType{}) + targets[i] = sc.NewColumn(ctx, targetName, &sc.FieldType{}) } // - //p.env.AddPermutationColumns(p.module, targets, signs, sources) - p.env.AddAssignment(assignment.NewSortedPermutation(p.module, multiplier, targets, signs, sources)) + p.env.AddAssignment(assignment.NewSortedPermutation(ctx, targets, signs, sources)) // return nil } @@ -279,23 +279,26 @@ func (p *hirParser) parseLookupDeclaration(l *sexp.List) error { sources[i] = UnitExpr{source} } // Sanity check enclosing source and target modules - source, src_multiplier, err1 := schema.DetermineEnclosingModuleOfExpressions(sources, p.env.schema) - target, target_multiplier, err2 := schema.DetermineEnclosingModuleOfExpressions(targets, p.env.schema) + sourceCtx := sc.JoinContexts(sources, p.env.schema) + targetCtx := sc.JoinContexts(targets, p.env.schema) // Propagate errors - if err1 != nil { - return p.translator.SyntaxError(sexpSources.Get(int(source)), err1.Error()) - } else if err2 != nil { - return p.translator.SyntaxError(sexpTargets.Get(int(target)), err2.Error()) + if sourceCtx.IsConflicted() { + return p.translator.SyntaxError(sexpSources, "conflicting evaluation context") + } else if targetCtx.IsConflicted() { + return p.translator.SyntaxError(sexpTargets, "conflicting evaluation context") + } else if sourceCtx.IsVoid() { + return p.translator.SyntaxError(sexpSources, "empty evaluation context") + } else if targetCtx.IsVoid() { + return p.translator.SyntaxError(sexpTargets, "empty evaluation context") } // Finally add constraint - p.env.schema.AddLookupConstraint(handle, source, src_multiplier, target, target_multiplier, sources, targets) + p.env.schema.AddLookupConstraint(handle, sourceCtx, targetCtx, sources, targets) // Done return nil } // Parse am interleaving declaration func (p *hirParser) parseInterleavingDeclaration(l *sexp.List) error { - var multiplier uint // Target columns are (sorted) permutations of source columns. sexpTarget := l.Elements[1].AsSymbol() // Source columns. @@ -308,6 +311,7 @@ func (p *hirParser) parseInterleavingDeclaration(l *sexp.List) error { } // Construct and check source columns sources := make([]uint, sexpSources.Len()) + ctx := trace.VoidContext() for i := 0; i < sexpSources.Len(); i++ { ith := sexpSources.Get(i) @@ -324,18 +328,18 @@ func (p *hirParser) parseInterleavingDeclaration(l *sexp.List) error { } // Check multiplier calculation sourceCol := p.env.schema.Columns().Nth(cid) - if i == 0 { - // First time around, multiplier is determine by the first source column. - multiplier = sourceCol.LengthMultiplier() - } else if sourceCol.LengthMultiplier() != multiplier { - // In all other cases, multiplier must match that of first source column. - return p.translator.SyntaxError(sexpSources.Get(i), "inconsistent length multiplier") + ctx = ctx.Join(sourceCol.Context()) + // Sanity check we have a sensible context here. + if ctx.IsConflicted() { + panic(fmt.Sprintf("source column %s has conflicted evaluation context", sexpSources.Get(i))) + } else if ctx.IsVoid() { + panic(fmt.Sprintf("source column %s has void evaluation context", sexpSources.Get(i))) } // Assign sources[i] = cid } // Add assignment - p.env.AddAssignment(assignment.NewInterleaving(p.module, sexpTarget.Value, multiplier, sources)) + p.env.AddAssignment(assignment.NewInterleaving(ctx, sexpTarget.Value, sources)) // Done return nil } @@ -352,7 +356,7 @@ func (p *hirParser) parseAssertionDeclaration(elements []sexp.SExp) error { return err } // Add assertion. - p.env.schema.AddPropertyAssertion(p.module, handle, expr) + p.env.schema.AddPropertyAssertion(p.module.Module(), handle, expr) return nil } @@ -368,12 +372,16 @@ func (p *hirParser) parseVanishingDeclaration(elements []sexp.SExp, domain *int) if err != nil { return err } - // TODO: improve error reporting here, since the following will just panic if - // the evaluation context is inconsistent (and, since we know the enclosing - // module is consistent, then this should only happen if the length - // multipliers are inconsistent). - _, multiplier := schema.DetermineEnclosingModuleOfExpression(expr, p.env.schema) - p.env.schema.AddVanishingConstraint(handle, p.module, multiplier, domain, expr) + // Determine evaluation context of expression. + ctx := expr.Context(p.env.schema) + // Sanity check we have a sensible context here. + if ctx.IsConflicted() { + panic(fmt.Sprintf("source column %s has conflicted evaluation context", elements[2])) + } else if ctx.IsVoid() { + panic(fmt.Sprintf("source column %s has void evaluation context", elements[2])) + } + + p.env.schema.AddVanishingConstraint(handle, ctx, domain, expr) return nil } @@ -454,13 +462,13 @@ func columnAccessParserRule(parser *hirParser) func(col string) (Expr, bool, err return nil, false, nil } // Handle qualified accesses (where permitted) - module := parser.module + context := parser.module colname := col // Attempt to split column name into module / column pair. split := strings.Split(col, ".") if parser.global && len(split) == 2 { // Lookup module - if module, ok = parser.env.LookupModule(split[0]); !ok { + if context, ok = parser.env.LookupModule(split[0]); !ok { return nil, true, errors.New("unknown module") } @@ -473,7 +481,7 @@ func columnAccessParserRule(parser *hirParser) func(col string) (Expr, bool, err // Now lookup column in the appropriate module. var cid uint // Look up column in the environment using local scope. - cid, ok = parser.env.LookupColumn(module, colname) + cid, ok = parser.env.LookupColumn(context, colname) // Check column was found if !ok { return nil, true, errors.New("unknown column") diff --git a/pkg/hir/schema.go b/pkg/hir/schema.go index 4dc7b19..4b26841 100644 --- a/pkg/hir/schema.go +++ b/pkg/hir/schema.go @@ -3,10 +3,10 @@ package hir import ( "fmt" - "github.com/consensys/go-corset/pkg/schema" sc "github.com/consensys/go-corset/pkg/schema" "github.com/consensys/go-corset/pkg/schema/assignment" "github.com/consensys/go-corset/pkg/schema/constraint" + "github.com/consensys/go-corset/pkg/trace" "github.com/consensys/go-corset/pkg/util" ) @@ -69,17 +69,17 @@ func (p *Schema) AddModule(name string) uint { // AddDataColumn appends a new data column with a given type. Furthermore, the // type is enforced by the system when checking is enabled. -func (p *Schema) AddDataColumn(module uint, name string, base sc.Type) { - if module >= uint(len(p.modules)) { - panic(fmt.Sprintf("invalid module index (%d)", module)) +func (p *Schema) AddDataColumn(context trace.Context, name string, base sc.Type) { + if context.Module() >= uint(len(p.modules)) { + panic(fmt.Sprintf("invalid module index (%d)", context.Module())) } - p.inputs = append(p.inputs, assignment.NewDataColumn(module, name, base)) + p.inputs = append(p.inputs, assignment.NewDataColumn(context, name, base)) } // AddLookupConstraint appends a new lookup constraint. -func (p *Schema) AddLookupConstraint(handle string, source uint, source_multiplier uint, target uint, - target_multiplier uint, sources []UnitExpr, targets []UnitExpr) { +func (p *Schema) AddLookupConstraint(handle string, source trace.Context, target trace.Context, + sources []UnitExpr, targets []UnitExpr) { if len(targets) != len(sources) { panic("differeng number of target / source lookup columns") } @@ -88,13 +88,13 @@ func (p *Schema) AddLookupConstraint(handle string, source uint, source_multipli // Finally add constraint p.constraints = append(p.constraints, - constraint.NewLookupConstraint(handle, source, source_multiplier, target, target_multiplier, sources, targets)) + constraint.NewLookupConstraint(handle, source, target, sources, targets)) } // AddAssignment appends a new assignment (i.e. set of computed columns) to be // used during trace expansion for this schema. Computed columns are introduced // by the process of lowering from HIR / MIR to AIR. -func (p *Schema) AddAssignment(c schema.Assignment) uint { +func (p *Schema) AddAssignment(c sc.Assignment) uint { index := p.Columns().Count() p.assignments = append(p.assignments, c) @@ -102,13 +102,13 @@ func (p *Schema) AddAssignment(c schema.Assignment) uint { } // AddVanishingConstraint appends a new vanishing constraint. -func (p *Schema) AddVanishingConstraint(handle string, module uint, multiplier uint, domain *int, expr Expr) { - if module >= uint(len(p.modules)) { - panic(fmt.Sprintf("invalid module index (%d)", module)) +func (p *Schema) AddVanishingConstraint(handle string, context trace.Context, domain *int, expr Expr) { + if context.Module() >= uint(len(p.modules)) { + panic(fmt.Sprintf("invalid module index (%d)", context.Module())) } p.constraints = append(p.constraints, - constraint.NewVanishingConstraint(handle, module, multiplier, domain, ZeroArrayTest{expr})) + constraint.NewVanishingConstraint(handle, context, domain, ZeroArrayTest{expr})) } // AddTypeConstraint appends a new range constraint. @@ -167,6 +167,6 @@ func (p *Schema) Declarations() util.Iterator[sc.Declaration] { // Modules returns an iterator over the declared set of modules within this // schema. -func (p *Schema) Modules() util.Iterator[schema.Module] { +func (p *Schema) Modules() util.Iterator[sc.Module] { return util.NewArrayIterator(p.modules) } diff --git a/pkg/hir/util.go b/pkg/hir/util.go index 3f2e2b9..b3c444b 100644 --- a/pkg/hir/util.go +++ b/pkg/hir/util.go @@ -46,7 +46,7 @@ func (p ZeroArrayTest) Bounds() util.Bounds { // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p ZeroArrayTest) Context(schema sc.Schema) (uint, uint, bool) { +func (p ZeroArrayTest) Context(schema sc.Schema) trace.Context { return p.Expr.Context(schema) } @@ -95,6 +95,6 @@ func (e UnitExpr) Bounds() util.Bounds { // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (e UnitExpr) Context(schema sc.Schema) (uint, uint, bool) { +func (e UnitExpr) Context(schema sc.Schema) trace.Context { return e.expr.Context(schema) } diff --git a/pkg/mir/expr.go b/pkg/mir/expr.go index c27647c..6c192ea 100644 --- a/pkg/mir/expr.go +++ b/pkg/mir/expr.go @@ -1,12 +1,10 @@ package mir import ( - "math" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/air" - "github.com/consensys/go-corset/pkg/schema" sc "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/trace" "github.com/consensys/go-corset/pkg/util" ) @@ -17,7 +15,7 @@ import ( // appropriate computed columns and constraints. type Expr interface { util.Boundable - schema.Evaluable + sc.Evaluable // Lower this expression into the Arithmetic Intermediate // Representation. Essentially, this means eliminating // normalising expressions by introducing new columns into the @@ -40,7 +38,7 @@ func (p *Add) Bounds() util.Bounds { return util.BoundsForArray(p.Args) } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Add) Context(schema sc.Schema) (uint, uint, bool) { +func (p *Add) Context(schema sc.Schema) trace.Context { return sc.JoinContexts[Expr](p.Args, schema) } @@ -57,7 +55,7 @@ func (p *Sub) Bounds() util.Bounds { return util.BoundsForArray(p.Args) } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Sub) Context(schema sc.Schema) (uint, uint, bool) { +func (p *Sub) Context(schema sc.Schema) trace.Context { return sc.JoinContexts[Expr](p.Args, schema) } @@ -74,7 +72,7 @@ func (p *Mul) Bounds() util.Bounds { return util.BoundsForArray(p.Args) } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Mul) Context(schema sc.Schema) (uint, uint, bool) { +func (p *Mul) Context(schema sc.Schema) trace.Context { return sc.JoinContexts[Expr](p.Args, schema) } @@ -91,8 +89,8 @@ func (p *Constant) Bounds() util.Bounds { return util.EMPTY_BOUND } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Constant) Context(schema sc.Schema) (uint, uint, bool) { - return math.MaxUint, math.MaxUint, true +func (p *Constant) Context(schema sc.Schema) trace.Context { + return trace.VoidContext() } // ============================================================================ @@ -109,7 +107,7 @@ func (p *Normalise) Bounds() util.Bounds { return p.Arg.Bounds() } // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *Normalise) Context(schema sc.Schema) (uint, uint, bool) { +func (p *Normalise) Context(schema sc.Schema) trace.Context { return p.Arg.Context(schema) } @@ -141,7 +139,7 @@ func (p *ColumnAccess) Bounds() util.Bounds { // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p *ColumnAccess) Context(schema sc.Schema) (uint, uint, bool) { +func (p *ColumnAccess) Context(schema sc.Schema) trace.Context { col := schema.Columns().Nth(p.Column) - return col.Module(), col.LengthMultiplier(), true + return col.Context() } diff --git a/pkg/mir/lower.go b/pkg/mir/lower.go index fc4e860..d669e75 100644 --- a/pkg/mir/lower.go +++ b/pkg/mir/lower.go @@ -19,7 +19,7 @@ func (p *Schema) LowerToAir() *air.Schema { // Add data columns. for _, c := range p.inputs { col := c.(DataColumn) - airSchema.AddColumn(col.Module(), col.Name(), col.Type()) + airSchema.AddColumn(col.Context(), col.Name(), col.Type()) } // Add Assignments. Again this has to be done first for things to work. // Essentially to reflect the fact that these columns have been added above @@ -60,7 +60,7 @@ func lowerConstraintToAir(c sc.Constraint, schema *air.Schema) { lowerLookupConstraintToAir(v, schema) } else if v, ok := c.(VanishingConstraint); ok { air_expr := v.Constraint().Expr.LowerTo(schema) - schema.AddVanishingConstraint(v.Handle(), v.Module(), v.LengthMultiplier(), v.Domain(), air_expr) + schema.AddVanishingConstraint(v.Handle(), v.Context(), v.Domain(), air_expr) } else if v, ok := c.(*constraint.TypeConstraint); ok { if t := v.Type().AsUint(); t != nil { // Yes, a constraint is implied. Now, decide whether to use a range @@ -102,9 +102,7 @@ func lowerLookupConstraintToAir(c LookupConstraint, schema *air.Schema) { sources[i] = air_gadgets.Expand(source, schema) } // finally add the constraint - src_mod, src_mul := c.SourceContext() - dst_mod, dst_mul := c.TargetContext() - schema.AddLookupConstraint(c.Handle(), src_mod, src_mul, dst_mod, dst_mul, sources, targets) + schema.AddLookupConstraint(c.Handle(), c.SourceContext(), c.TargetContext(), sources, targets) } // Lower a permutation to the AIR level. This has quite a few diff --git a/pkg/mir/schema.go b/pkg/mir/schema.go index 385af0b..0a2d635 100644 --- a/pkg/mir/schema.go +++ b/pkg/mir/schema.go @@ -6,6 +6,7 @@ import ( "github.com/consensys/go-corset/pkg/schema" "github.com/consensys/go-corset/pkg/schema/assignment" "github.com/consensys/go-corset/pkg/schema/constraint" + "github.com/consensys/go-corset/pkg/trace" "github.com/consensys/go-corset/pkg/util" ) @@ -69,12 +70,12 @@ func (p *Schema) AddModule(name string) uint { } // AddDataColumn appends a new data column. -func (p *Schema) AddDataColumn(module uint, name string, base schema.Type) { - if module >= uint(len(p.modules)) { - panic(fmt.Sprintf("invalid module index (%d)", module)) +func (p *Schema) AddDataColumn(context trace.Context, name string, base schema.Type) { + if context.Module() >= uint(len(p.modules)) { + panic(fmt.Sprintf("invalid module index (%d)", context.Module())) } - p.inputs = append(p.inputs, assignment.NewDataColumn(module, name, base)) + p.inputs = append(p.inputs, assignment.NewDataColumn(context, name, base)) } // AddAssignment appends a new assignment (i.e. set of computed columns) to be @@ -88,25 +89,25 @@ func (p *Schema) AddAssignment(c schema.Assignment) uint { } // AddLookupConstraint appends a new lookup constraint. -func (p *Schema) AddLookupConstraint(handle string, source uint, source_context uint, target uint, - target_context uint, sources []Expr, targets []Expr) { +func (p *Schema) AddLookupConstraint(handle string, source trace.Context, target trace.Context, + sources []Expr, targets []Expr) { if len(targets) != len(sources) { panic("differeng number of target / source lookup columns") } // TODO: sanity source columns are in the same module, and likewise target // columns (though they don't have to be in the same column together). p.constraints = append(p.constraints, - constraint.NewLookupConstraint(handle, source, source_context, target, target_context, sources, targets)) + constraint.NewLookupConstraint(handle, source, target, sources, targets)) } // AddVanishingConstraint appends a new vanishing constraint. -func (p *Schema) AddVanishingConstraint(handle string, module uint, multiplier uint, domain *int, expr Expr) { - if module >= uint(len(p.modules)) { - panic(fmt.Sprintf("invalid module index (%d)", module)) +func (p *Schema) AddVanishingConstraint(handle string, context trace.Context, domain *int, expr Expr) { + if context.Module() >= uint(len(p.modules)) { + panic(fmt.Sprintf("invalid module index (%d)", context.Module())) } p.constraints = append(p.constraints, - constraint.NewVanishingConstraint(handle, module, multiplier, domain, constraint.ZeroTest[Expr]{Expr: expr})) + constraint.NewVanishingConstraint(handle, context, domain, constraint.ZeroTest[Expr]{Expr: expr})) } // AddTypeConstraint appends a new range constraint. diff --git a/pkg/schema/alignment.go b/pkg/schema/alignment.go index 07df664..e44c0f9 100644 --- a/pkg/schema/alignment.go +++ b/pkg/schema/alignment.go @@ -75,18 +75,18 @@ func alignWith(expand bool, p tr.Trace, schema Schema) error { for j := ith.Columns(); j.HasNext(); { // Extract schema column & module schemaCol := j.Next() - schemaMod := schema.Modules().Nth(schemaCol.Module()) + schemaMod := schema.Modules().Nth(schemaCol.Context().Module()) // Sanity check column exists if colIndex >= ncols { return fmt.Errorf("trace missing column %s.%s (too few columns)", schemaMod.Name(), schemaCol.Name()) } // Extract trace column and module traceCol := columns.Get(colIndex) - traceMod := modules.Get(traceCol.Module()) + traceMod := modules.Get(traceCol.Context().Module()) // Check alignment if traceCol.Name() != schemaCol.Name() || traceMod.Name() != schemaMod.Name() { // Not aligned --- so fix - k, ok := p.Columns().IndexOf(schemaCol.Module(), schemaCol.Name()) + k, ok := p.Columns().IndexOf(schemaCol.Context().Module(), schemaCol.Name()) // check exists if !ok { return fmt.Errorf("trace missing column %s.%s", schemaMod.Name(), schemaCol.Name()) diff --git a/pkg/schema/assignment/byte_decomposition.go b/pkg/schema/assignment/byte_decomposition.go index 811e45d..62ecd3c 100644 --- a/pkg/schema/assignment/byte_decomposition.go +++ b/pkg/schema/assignment/byte_decomposition.go @@ -19,7 +19,7 @@ type ByteDecomposition struct { } // NewByteDecomposition creates a new sorted permutation -func NewByteDecomposition(prefix string, module uint, multiplier uint, source uint, width uint) *ByteDecomposition { +func NewByteDecomposition(prefix string, context trace.Context, source uint, width uint) *ByteDecomposition { if width == 0 { panic("zero byte decomposition encountered") } @@ -30,7 +30,7 @@ func NewByteDecomposition(prefix string, module uint, multiplier uint, source ui for i := uint(0); i < width; i++ { name := fmt.Sprintf("%s:%d", prefix, i) - targets[i] = schema.NewColumn(module, name, multiplier, U8) + targets[i] = schema.NewColumn(context, name, U8) } // Done return &ByteDecomposition{source, targets} @@ -88,7 +88,7 @@ func (p *ByteDecomposition) ExpandTrace(tr trace.Trace) error { // Finally, add byte columns to trace for i := 0; i < n; i++ { ith := p.targets[i] - columns.Add(trace.NewFieldColumn(ith.Module(), ith.Name(), ith.LengthMultiplier(), cols[i], padding[i])) + columns.Add(trace.NewFieldColumn(ith.Context(), ith.Name(), cols[i], padding[i])) } // Done return nil diff --git a/pkg/schema/assignment/computed_column.go b/pkg/schema/assignment/computed_column.go index f173491..7e08cec 100644 --- a/pkg/schema/assignment/computed_column.go +++ b/pkg/schema/assignment/computed_column.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/go-corset/pkg/schema" sc "github.com/consensys/go-corset/pkg/schema" "github.com/consensys/go-corset/pkg/trace" "github.com/consensys/go-corset/pkg/util" @@ -17,7 +16,7 @@ import ( // give rise to "trace expansion". That is where the initial trace provided by // the user is expanded by determining the value of all computed columns. type ComputedColumn[E sc.Evaluable] struct { - target schema.Column + target sc.Column // The computation which accepts a given trace and computes // the value of this column at a given row. expr E @@ -26,9 +25,9 @@ type ComputedColumn[E sc.Evaluable] struct { // NewComputedColumn constructs a new computed column with a given name and // determining expression. More specifically, that expression is used to // compute the values for this column during trace expansion. -func NewComputedColumn[E sc.Evaluable](module uint, name string, multiplier uint, expr E) *ComputedColumn[E] { +func NewComputedColumn[E sc.Evaluable](context trace.Context, name string, expr E) *ComputedColumn[E] { // FIXME: Determine computed columns type? - column := schema.NewColumn(module, name, multiplier, &schema.FieldType{}) + column := sc.NewColumn(context, name, &sc.FieldType{}) return &ComputedColumn[E]{column, expr} } @@ -47,9 +46,9 @@ func (p *ComputedColumn[E]) Name() string { // ============================================================================ // Columns returns the columns declared by this computed column. -func (p *ComputedColumn[E]) Columns() util.Iterator[schema.Column] { +func (p *ComputedColumn[E]) Columns() util.Iterator[sc.Column] { // TODO: figure out appropriate type for computed column - return util.NewUnitIterator[schema.Column](p.target) + return util.NewUnitIterator[sc.Column](p.target) } // IsComputed Determines whether or not this declaration is computed (which it @@ -83,9 +82,9 @@ func (p *ComputedColumn[E]) ExpandTrace(tr trace.Trace) error { return fmt.Errorf("column already exists ({%s})", p.Name()) } // Extract length multipiler - multiplier := p.target.LengthMultiplier() + multiplier := p.target.Context().LengthMultiplier() // Determine multiplied height - height := tr.Modules().Get(p.target.Module()).Height() * multiplier + height := tr.Modules().Get(p.target.Context().Module()).Height() * multiplier // Make space for computed data data := make([]*fr.Element, height) // Expand the trace @@ -103,7 +102,7 @@ func (p *ComputedColumn[E]) ExpandTrace(tr trace.Trace) error { // the padding value for *this* column. padding := p.expr.EvalAt(-1, tr) // Colunm needs to be expanded. - columns.Add(trace.NewFieldColumn(p.target.Module(), p.Name(), multiplier, data, padding)) + columns.Add(trace.NewFieldColumn(p.target.Context(), p.Name(), data, padding)) // Done return nil } diff --git a/pkg/schema/assignment/data_column.go b/pkg/schema/assignment/data_column.go index 36799b6..6081113 100644 --- a/pkg/schema/assignment/data_column.go +++ b/pkg/schema/assignment/data_column.go @@ -4,13 +4,14 @@ import ( "fmt" "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/trace" "github.com/consensys/go-corset/pkg/util" ) // DataColumn represents a column of user-provided values. type DataColumn struct { - // Module where this data column is located. - module uint + // Context where this data column is located. + context trace.Context // Name of this datacolumn name string // Expected type of values held in this column. Observe that this should be @@ -20,13 +21,18 @@ type DataColumn struct { } // NewDataColumn constructs a new data column with a given name. -func NewDataColumn(module uint, name string, base schema.Type) *DataColumn { - return &DataColumn{module, name, base} +func NewDataColumn(context trace.Context, name string, base schema.Type) *DataColumn { + return &DataColumn{context, name, base} +} + +// Context returns the evaluation context for this column. +func (p *DataColumn) Context() trace.Context { + return p.context } // Module identifies the module which encloses this column. func (p *DataColumn) Module() uint { - return p.module + return p.context.Module() } // Name provides access to information about the ith column in a schema. @@ -42,10 +48,10 @@ func (p *DataColumn) Type() schema.Type { //nolint:revive func (c *DataColumn) String() string { if c.datatype.AsField() != nil { - return fmt.Sprintf("(column %s)", c.Name()) + return fmt.Sprintf("(column #%d.%s)", c.Module(), c.Name()) } - return fmt.Sprintf("(column %s :%s)", c.Name(), c.datatype) + return fmt.Sprintf("(column #%d.%s :%s)", c.Module(), c.Name(), c.datatype) } // ============================================================================ @@ -55,7 +61,7 @@ func (c *DataColumn) String() string { // Columns returns the columns declared by this computed column. func (p *DataColumn) Columns() util.Iterator[schema.Column] { // Datacolumns always have a multiplier of 1. - column := schema.NewColumn(p.module, p.name, 1, p.datatype) + column := schema.NewColumn(p.context, p.name, p.datatype) return util.NewUnitIterator[schema.Column](column) } diff --git a/pkg/schema/assignment/interleave.go b/pkg/schema/assignment/interleave.go index 0c20a2a..c622c3f 100644 --- a/pkg/schema/assignment/interleave.go +++ b/pkg/schema/assignment/interleave.go @@ -15,8 +15,6 @@ import ( // a trace X=[1,2], Y=[3,4]. Then, the interleaved column Z has the values // Z=[1,3,2,4]. type Interleaving struct { - // Module where this interleaving is located. - module uint // The new (interleaved) column target schema.Column // The source columns @@ -24,18 +22,18 @@ type Interleaving struct { } // NewInterleaving constructs a new interleaving assignment. -func NewInterleaving(module uint, name string, multiplier uint, sources []uint) *Interleaving { +func NewInterleaving(context tr.Context, name string, sources []uint) *Interleaving { // Update multiplier - multiplier = multiplier * uint(len(sources)) + context = context.Multiply(uint(len(sources))) // Fixme: determine interleaving type - target := schema.NewColumn(module, name, multiplier, &schema.FieldType{}) + target := schema.NewColumn(context, name, &schema.FieldType{}) - return &Interleaving{module, target, sources} + return &Interleaving{target, sources} } // Module returns the module which encloses this interleaving. func (p *Interleaving) Module() uint { - return p.module + return p.target.Context().Module() } // Sources returns the columns used by this interleaving to define the new @@ -74,6 +72,7 @@ func (p *Interleaving) RequiredSpillage() uint { // the interleaved column. func (p *Interleaving) ExpandTrace(tr tr.Trace) error { columns := tr.Columns() + ctx := p.target.Context() // Ensure target column doesn't exist for i := p.Columns(); i.HasNext(); { name := i.Next().Name() @@ -86,10 +85,10 @@ func (p *Interleaving) ExpandTrace(tr tr.Trace) error { width := uint(len(p.sources)) // Following division should always produce whole value because the length // multiplier already includes the width as a factor. - multiplier := p.target.LengthMultiplier() / width + multiplier := ctx.LengthMultiplier() / width // Determine module height (as this can be used to determine the height of // the interleaved column) - height := tr.Modules().Get(p.module).Height() * multiplier + height := tr.Modules().Get(ctx.Module()).Height() * multiplier // Construct empty array data := make([]*fr.Element, height*width) // Offset just gives the column index @@ -109,7 +108,7 @@ func (p *Interleaving) ExpandTrace(tr tr.Trace) error { // column in the interleaving. padding := columns.Get(0).Padding() // Colunm needs to be expanded. - columns.Add(trace.NewFieldColumn(p.module, p.target.Name(), multiplier*width, data, padding)) + columns.Add(trace.NewFieldColumn(ctx, p.target.Name(), data, padding)) // return nil } diff --git a/pkg/schema/assignment/lexicographic_sort.go b/pkg/schema/assignment/lexicographic_sort.go index 7dc7ec4..68bd139 100644 --- a/pkg/schema/assignment/lexicographic_sort.go +++ b/pkg/schema/assignment/lexicographic_sort.go @@ -14,11 +14,9 @@ import ( // columns. Specifically, a delta column is required along with one selector // column (binary) for each source column. type LexicographicSort struct { - // Module in which source and target columns to be located. All target and - // source columns should be contained within this module. - module uint - // Length multiplier for all columns in this gadget - multiplier uint + // Context in which source and target columns to be located. All target and + // source columns should be contained within this. + context trace.Context // The target columns to be filled. The first entry is for the delta // column, and the remaining n entries are for the selector columns. targets []schema.Column @@ -29,19 +27,19 @@ type LexicographicSort struct { } // NewLexicographicSort constructs a new LexicographicSorting assignment. -func NewLexicographicSort(prefix string, module uint, multiplier uint, +func NewLexicographicSort(prefix string, context trace.Context, sources []uint, signs []bool, bitwidth uint) *LexicographicSort { // targets := make([]schema.Column, len(sources)+1) // Create delta column - targets[0] = schema.NewColumn(module, fmt.Sprintf("%s:delta", prefix), multiplier, schema.NewUintType(bitwidth)) + targets[0] = schema.NewColumn(context, fmt.Sprintf("%s:delta", prefix), schema.NewUintType(bitwidth)) // Create selector columns for i := range sources { ithName := fmt.Sprintf("%s:%d", prefix, i) - targets[1+i] = schema.NewColumn(module, ithName, multiplier, schema.NewUintType(1)) + targets[1+i] = schema.NewColumn(context, ithName, schema.NewUintType(1)) } - return &LexicographicSort{module, multiplier, targets, sources, signs, bitwidth} + return &LexicographicSort{context, targets, sources, signs, bitwidth} } // ============================================================================ @@ -78,9 +76,9 @@ func (p *LexicographicSort) ExpandTrace(tr trace.Trace) error { // Exact number of columns involved in the sort ncols := len(p.sources) // - multiplier := p.multiplier + multiplier := p.context.LengthMultiplier() // Determine how many rows to be constrained. - nrows := tr.Modules().Get(p.module).Height() * multiplier + nrows := tr.Modules().Get(p.context.Module()).Height() * multiplier // Initialise new data columns delta := make([]*fr.Element, nrows) bit := make([][]*fr.Element, ncols) @@ -119,11 +117,11 @@ func (p *LexicographicSort) ExpandTrace(tr trace.Trace) error { } // Add delta column data first := p.targets[0] - columns.Add(trace.NewFieldColumn(first.Module(), first.Name(), multiplier, delta, &zero)) + columns.Add(trace.NewFieldColumn(first.Context(), first.Name(), delta, &zero)) // Add bit column data for i := 0; i < ncols; i++ { ith := p.targets[1+i] - columns.Add(trace.NewFieldColumn(ith.Module(), ith.Name(), multiplier, bit[i], &zero)) + columns.Add(trace.NewFieldColumn(ith.Context(), ith.Name(), bit[i], &zero)) } // Done. return nil diff --git a/pkg/schema/assignment/sorted_permutation.go b/pkg/schema/assignment/sorted_permutation.go index a9d18ae..0050ef3 100644 --- a/pkg/schema/assignment/sorted_permutation.go +++ b/pkg/schema/assignment/sorted_permutation.go @@ -13,9 +13,8 @@ import ( // SortedPermutation declares one or more columns as sorted permutations of // existing columns. type SortedPermutation struct { - module uint - // Length multiplier - multiplier uint + // Context where this data column is located. + context trace.Context // The new (sorted) columns targets []schema.Column // The sorting criteria @@ -25,26 +24,24 @@ type SortedPermutation struct { } // NewSortedPermutation creates a new sorted permutation -func NewSortedPermutation(module uint, multiplier uint, targets []schema.Column, +func NewSortedPermutation(context tr.Context, targets []schema.Column, signs []bool, sources []uint) *SortedPermutation { if len(targets) != len(signs) || len(signs) != len(sources) { panic("target and source column widths must match") } // Check modules for _, c := range targets { - if c.Module() != module { - panic("inconsistent target modules") - } else if c.LengthMultiplier() != multiplier { - panic("inconsistent length multipliers for target columns") + if c.Context() != context { + panic("inconsistent evaluation contexts") } } - return &SortedPermutation{module, multiplier, targets, signs, sources} + return &SortedPermutation{context, targets, signs, sources} } // Module returns the module which encloses this sorted permutation. func (p *SortedPermutation) Module() uint { - return p.module + return p.context.Module() } // Sources returns the columns used by this sorted permutation to define the new @@ -154,7 +151,7 @@ func (p *SortedPermutation) ExpandTrace(tr tr.Trace) error { ith := i.Next() dstColName := ith.Name() srcCol := tr.Columns().Get(p.sources[index]) - columns.Add(trace.NewFieldColumn(ith.Module(), dstColName, p.multiplier, cols[index], srcCol.Padding())) + columns.Add(trace.NewFieldColumn(ith.Context(), dstColName, cols[index], srcCol.Padding())) } // return nil diff --git a/pkg/schema/constraint/lookup.go b/pkg/schema/constraint/lookup.go index c3b3499..7f43fbd 100644 --- a/pkg/schema/constraint/lookup.go +++ b/pkg/schema/constraint/lookup.go @@ -26,16 +26,10 @@ import ( // makes sense. type LookupConstraint[E schema.Evaluable] struct { handle string - // Enclosing module for source columns. - source uint - // Length multiplier partly determines the evaluation context for source - // expressions. - source_multiplier uint - // Enclosing module for target columns. - target uint - // Length multiplier partly determines the evaluation context for target - // expressions. - target_multiplier uint + // Evaluation context for source columns. + source trace.Context + // Evaluation context for target columns. + target trace.Context // Source rows represent the subset of rows. sources []E // Target rows represent the set of rows. @@ -43,13 +37,13 @@ type LookupConstraint[E schema.Evaluable] struct { } // NewLookupConstraint creates a new lookup constraint with a given handle. -func NewLookupConstraint[E schema.Evaluable](handle string, source uint, source_multiplier uint, - target uint, target_multiplier uint, sources []E, targets []E) *LookupConstraint[E] { +func NewLookupConstraint[E schema.Evaluable](handle string, source trace.Context, + target trace.Context, sources []E, targets []E) *LookupConstraint[E] { if len(targets) != len(sources) { panic("differeng number of target / source lookup columns") } - return &LookupConstraint[E]{handle, source, source_multiplier, target, target_multiplier, sources, targets} + return &LookupConstraint[E]{handle, source, target, sources, targets} } // Handle returns the handle for this lookup constraint which is simply an @@ -60,14 +54,14 @@ func (p *LookupConstraint[E]) Handle() string { return p.handle } -// SourceContext returns the module in which all source columns are located. -func (p *LookupConstraint[E]) SourceContext() (uint, uint) { - return p.source, p.source_multiplier +// SourceContext returns the contezt in which all target expressions are evaluated. +func (p *LookupConstraint[E]) SourceContext() trace.Context { + return p.source } -// TargetContext returns the module in which all target columns are located. -func (p *LookupConstraint[E]) TargetContext() (uint, uint) { - return p.target, p.target_multiplier +// TargetContext returns the contezt in which all target expressions are evaluated. +func (p *LookupConstraint[E]) TargetContext() trace.Context { + return p.target } // Sources returns the source expressions which are used to lookup into the @@ -88,8 +82,8 @@ func (p *LookupConstraint[E]) Targets() []E { //nolint:revive func (p *LookupConstraint[E]) Accepts(tr trace.Trace) error { // Determine height of enclosing module for source columns - src_height := tr.Modules().Get(p.source).Height() * p.source_multiplier - tgt_height := tr.Modules().Get(p.target).Height() * p.target_multiplier + src_height := tr.Modules().Get(p.source.Module()).Height() * p.source.LengthMultiplier() + tgt_height := tr.Modules().Get(p.target.Module()).Height() * p.target.LengthMultiplier() // Go through every row of the source columns checking they are present in // the target columns. // diff --git a/pkg/schema/constraint/vanishing.go b/pkg/schema/constraint/vanishing.go index ef841df..ddf42b2 100644 --- a/pkg/schema/constraint/vanishing.go +++ b/pkg/schema/constraint/vanishing.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" - "github.com/consensys/go-corset/pkg/schema" sc "github.com/consensys/go-corset/pkg/schema" "github.com/consensys/go-corset/pkg/trace" "github.com/consensys/go-corset/pkg/util" @@ -13,7 +12,7 @@ import ( // ZeroTest is a wrapper which converts an Evaluable expression into a Testable // constraint. Specifically, by checking whether or not the given expression // vanishes (i.e. evaluates to zero). -type ZeroTest[E schema.Evaluable] struct { +type ZeroTest[E sc.Evaluable] struct { Expr E } @@ -31,7 +30,7 @@ func (p ZeroTest[E]) Bounds() util.Bounds { // Context determines the evaluation context (i.e. enclosing module) for this // expression. -func (p ZeroTest[E]) Context(schema sc.Schema) (uint, uint, bool) { +func (p ZeroTest[E]) Context(schema sc.Schema) trace.Context { return p.Expr.Context(schema) } @@ -48,16 +47,13 @@ func (p ZeroTest[E]) String() string { // ignored. This is parameterised by the type of the constraint expression. // Thus, we can reuse this definition across the various intermediate // representations (e.g. Mid-Level IR, Arithmetic IR, etc). -type VanishingConstraint[T schema.Testable] struct { +type VanishingConstraint[T sc.Testable] struct { // A unique identifier for this constraint. This is primarily // useful for debugging. handle string - // Enclosing module for this assertion. This restricts the constraint to - // access only columns from within this module. - module uint - // Length multiplier. This is used to the column's actual height as a - // multipler of the enclosing module's height. - multiplier uint + // Evaluation context for this constraint which must match that of the + // constrained expression itself. + context trace.Context // Indicates (when nil) a global constraint that applies to all rows. // Otherwise, indicates a local constraint which applies to the specific row // given here. @@ -68,9 +64,9 @@ type VanishingConstraint[T schema.Testable] struct { } // NewVanishingConstraint constructs a new vanishing constraint! -func NewVanishingConstraint[T schema.Testable](handle string, module uint, multiplier uint, +func NewVanishingConstraint[T sc.Testable](handle string, context trace.Context, domain *int, constraint T) *VanishingConstraint[T] { - return &VanishingConstraint[T]{handle, module, multiplier, domain, constraint} + return &VanishingConstraint[T]{handle, context, domain, constraint} } // Handle returns the handle associated with this constraint. @@ -92,18 +88,10 @@ func (p *VanishingConstraint[T]) Domain() *int { return p.domain } -// Module returns the enclosing module for this constraint, a.k.a the evaluation -// context. Every constraint must be situated within exactly one module in -// order to be well-formed. -func (p *VanishingConstraint[T]) Module() uint { - return p.module -} - -// LengthMultiplier returns the length multiplier used by this vanishing -// constraint. This should match the evaluation context of the vanishing -// expression. -func (p *VanishingConstraint[T]) LengthMultiplier() uint { - return p.multiplier +// Context returns the evaluation context for this constraint. Every constraint +// must be situated within exactly one module in order to be well-formed. +func (p *VanishingConstraint[T]) Context() trace.Context { + return p.context } // Accepts checks whether a vanishing constraint evaluates to zero on every row @@ -113,14 +101,14 @@ func (p *VanishingConstraint[T]) LengthMultiplier() uint { func (p *VanishingConstraint[T]) Accepts(tr trace.Trace) error { if p.domain == nil { // Global Constraint - return HoldsGlobally(p.handle, p.module, p.multiplier, p.constraint, tr) + return HoldsGlobally(p.handle, p.context, p.constraint, tr) } // Local constraint var start uint // Handle negative domains if *p.domain < 0 { // Determine height of enclosing module - height := tr.Modules().Get(p.module).Height() + height := tr.Modules().Get(p.context.Module()).Height() * p.context.LengthMultiplier() // Negative rows calculated from end of trace. start = height + uint(*p.domain) } else { @@ -132,9 +120,9 @@ func (p *VanishingConstraint[T]) Accepts(tr trace.Trace) error { // HoldsGlobally checks whether a given expression vanishes (i.e. evaluates to // zero) for all rows of a trace. If not, report an appropriate error. -func HoldsGlobally[T schema.Testable](handle string, module uint, multiplier uint, constraint T, tr trace.Trace) error { +func HoldsGlobally[T sc.Testable](handle string, ctx trace.Context, constraint T, tr trace.Trace) error { // Determine height of enclosing module - height := tr.Modules().Get(module).Height() * multiplier + height := tr.Modules().Get(ctx.Module()).Height() * ctx.LengthMultiplier() // Determine well-definedness bounds for this constraint bounds := constraint.Bounds() // Sanity check enough rows @@ -152,7 +140,7 @@ func HoldsGlobally[T schema.Testable](handle string, module uint, multiplier uin // HoldsLocally checks whether a given constraint holds (e.g. vanishes) on a // specific row of a trace. If not, report an appropriate error. -func HoldsLocally[T schema.Testable](k uint, handle string, constraint T, tr trace.Trace) error { +func HoldsLocally[T sc.Testable](k uint, handle string, constraint T, tr trace.Trace) error { // Check whether it holds or not if !constraint.TestAt(int(k), tr) { // Construct useful error message diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go index d1cc96f..d1599bf 100644 --- a/pkg/schema/schema.go +++ b/pkg/schema/schema.go @@ -68,25 +68,6 @@ type Constraint interface { Accepts(tr.Trace) error } -// Contextual captures something which requires an evaluation context (i.e. a -// single enclosing module) in order to make sense. For example, expressions -// require a single context. This interface is separated from Evaluable (and -// Testable) because HIR expressions do not implement Evaluable. -type Contextual interface { - // Context returns the evaluation context (i.e. enclosing module + length - // multiplier) for this constraint. Every expression must have a single - // evaluation context. This function therefore attempts to determine what - // that is, or return false to signal an error. There are several failure - // modes which need to be considered. Firstly, if the expression has no - // enclosing module (e.g. because it is a constant expression) then it will - // return 'math.MaxUint` to signal this. Secondly, if the expression has - // multiple (i.e. conflicting) enclosing modules then it will return false - // to signal this. Likewise, the expression could have a single enclosing - // module but multiple conflicting length multipliers, in which case it also - // returns false. - Context(Schema) (uint, uint, bool) -} - // Evaluable captures something which can be evaluated on a given table row to // produce an evaluation point. For example, expressions in the // Mid-Level or Arithmetic-Level IR can all be evaluated at rows of a @@ -121,6 +102,25 @@ type Testable interface { TestAt(int, tr.Trace) bool } +// Contextual captures something which requires an evaluation context (i.e. a +// single enclosing module) in order to make sense. For example, expressions +// require a single context. This interface is separated from Evaluable (and +// Testable) because HIR expressions do not implement Evaluable. +type Contextual interface { + // Context returns the evaluation context (i.e. enclosing module + length + // multiplier) for this constraint. Every expression must have a single + // evaluation context. This function therefore attempts to determine what + // that is, or return false to signal an error. There are several failure + // modes which need to be considered. Firstly, if the expression has no + // enclosing module (e.g. because it is a constant expression) then it will + // return 'math.MaxUint` to signal this. Secondly, if the expression has + // multiple (i.e. conflicting) enclosing modules then it will return false + // to signal this. Likewise, the expression could have a single enclosing + // module but multiple conflicting length multipliers, in which case it also + // returns false. + Context(Schema) tr.Context +} + // ============================================================================ // Column // ============================================================================ @@ -128,25 +128,23 @@ type Testable interface { // Column represents a specific column in the schema that, ultimately, will // correspond 1:1 with a column in the trace. type Column struct { - // Returns the index of the module which contains this column - module uint + // Evaluation context of this column. + context tr.Context // Returns the name of this column name string - // Length multiplier. This is used to the column'ss actual height as a - // multipler of the enclosing module's height. - multiplier uint // Returns the expected type of data in this column datatype Type } // NewColumn constructs a new column -func NewColumn(module uint, name string, multiplier uint, datatype Type) Column { - return Column{module, name, multiplier, datatype} +func NewColumn(context tr.Context, name string, datatype Type) Column { + return Column{context, name, datatype} } -// Module returns the index of the module which contains this column -func (p Column) Module() uint { - return p.module +// Context returns the evaluation context for this column access, which is +// determined by the column itself. +func (p Column) Context() tr.Context { + return p.context } // Name returns the name of this column @@ -154,12 +152,6 @@ func (p Column) Name() string { return p.name } -// LengthMultiplier is needed to the column's actual height as a -// multipler of the enclosing module's height. -func (p Column) LengthMultiplier() uint { - return p.multiplier -} - // Type returns the expected type of data in this column func (p Column) Type() Type { return p.datatype diff --git a/pkg/schema/schemas.go b/pkg/schema/schemas.go index 00e21ae..4985571 100644 --- a/pkg/schema/schemas.go +++ b/pkg/schema/schemas.go @@ -1,110 +1,40 @@ package schema import ( - "errors" - "math" - tr "github.com/consensys/go-corset/pkg/trace" ) -// JoinContexts combines one or more evaluation contexts together. There are a -// number of scenarios. The simple path is when each expression has the same -// evaluation context (in which case this is returned). Its also possible one -// or more expressions have no evaluation context (signaled by math.MaxUint) and -// this can be ignored. Finally, we might have two expressions with conflicting -// evaluation contexts, and this clearly signals an error. -func JoinContexts[E Contextual](args []E, schema Schema) (uint, uint, bool) { - var mid uint = math.MaxUint - - var multiplier = uint(1) +// JoinContexts combines one or more evaluation contexts together. If all +// expressions have the void context, then this is returned. Likewise, if any +// expression has a conflicting context then this is returned. Finally, if any +// two expressions have conflicting contexts between them, then the conflicting +// context is returned. Otherwise, the common context to all expressions is +// returned. +func JoinContexts[E Contextual](args []E, schema Schema) tr.Context { + ctx := tr.VoidContext() // for _, e := range args { - c, m, b := e.Context(schema) - if !b { - // Indicates conflict detected upstream, therefore propagate this - // down. - return 0, 0, false - } else if mid == math.MaxUint { - // No evaluation context determined yet, therefore can overwrite - // with whatever we got. Observe that this might still actually - mid = c - multiplier = m - } else if c != math.MaxUint && (c != mid || m != multiplier) { - // This indicates a conflict is detected, therefore we must - // propagate this down. - return 0, 0, false - } + ctx = ctx.Join(e.Context(schema)) } // If we get here, then no conflicts were detected. - return mid, multiplier, true -} - -// DetermineEnclosingModuleOfExpression determines (and checks) the enclosing -// module for a given expression. The expectation is that there is a single -// enclosing module, and this function will panic if that does not hold. -func DetermineEnclosingModuleOfExpression[E Contextual](expr E, schema Schema) (uint, uint) { - if mid, multiplier, ok := expr.Context(schema); ok && mid != math.MaxUint { - return mid, multiplier - } - // - panic("expression has no evaluation context") + return ctx } -// DetermineEnclosingModuleOfExpressions determines (and checks) the enclosing -// module for a given set of expressions. The expectation is that there is a single -// enclosing module, and this function will panic if that does not hold. -func DetermineEnclosingModuleOfExpressions[E Contextual](exprs []E, schema Schema) (uint, uint, error) { - // Sanity check input - if len(exprs) == 0 { - panic("cannot determine enclosing module for empty expression array") - } - // Determine first - mid, multiplier, ok := exprs[0].Context(schema) - // Sanity check this made sense - if !ok { - return 0, 0, errors.New("conflicting enclosing modules") - } - // Check rest against this - for i := 1; i < len(exprs); i++ { - m, f, ok := exprs[i].Context(schema) - if !ok { - return uint(i), 0, errors.New("conflicting enclosing modules") - } else if mid == math.MaxUint { - mid = m - } else if m != math.MaxUint && m != mid { - return uint(i), 0, errors.New("conflicting enclosing modules") - } else if m != math.MaxUint && f != multiplier { - return uint(i), 0, errors.New("conflicting length multipliers") - } - } - // success - return mid, multiplier, nil -} - -// DetermineEnclosingModuleOfColumns determines (and checks) the enclosing module for a -// given set of columns. The expectation is that there is a single enclosing -// module, and this function will panic if that does not hold. -func DetermineEnclosingModuleOfColumns(cols []uint, schema Schema) (uint, uint) { - head := schema.Columns().Nth(cols[0]) - // First, determine module of first column. - mid := head.Module() - multiplier := head.LengthMultiplier() - // Second, check other columns in the same module. +// ContextOfColumns determines the enclosing context for a given set of columns. +// If all columns have the void context, then this is returned. Likewise, +// if any column has a conflicting context then this is returned. Finally, +// if any two columns have conflicting contexts between them, then the +// conflicting context is returned. Otherwise, the common context to all +// columns is returned. +func ContextOfColumns(cols []uint, schema Schema) tr.Context { + ctx := tr.VoidContext() // - // NOTE: this could potentially be made more efficient by checking the - // columns of the module for the first column. - for i := 1; i < len(cols); i++ { + for i := 0; i < len(cols); i++ { col := schema.Columns().Nth(cols[i]) - if mid != col.Module() { - // This is an internal failure which should be prevented by upstream - // checking (e.g. in the parser). - panic("columns have different enclosing module") - } else if multiplier != col.LengthMultiplier() { - panic("columns have different length multipliers") - } + ctx = ctx.Join(col.Context()) } // Done - return mid, multiplier + return ctx } // RequiredSpillage returns the minimum amount of spillage required to ensure @@ -167,6 +97,6 @@ func Accepts(schema Schema, trace tr.Trace) error { // returns false if no matching column exists. func ColumnIndexOf(schema Schema, module uint, name string) (uint, bool) { return schema.Columns().Find(func(c Column) bool { - return c.Module() == module && c.Name() == name + return c.Context().Module() == module && c.Name() == name }) } diff --git a/pkg/trace/array_trace.go b/pkg/trace/array_trace.go index b655a4d..b3276b9 100644 --- a/pkg/trace/array_trace.go +++ b/pkg/trace/array_trace.go @@ -62,7 +62,7 @@ func (p *ArrayTrace) String() string { id.WriteString(",") } - modName := p.modules[ith.Module()].Name() + modName := p.modules[ith.Context().Module()].Name() if modName != "" { id.WriteString(modName) id.WriteString(".") @@ -103,11 +103,12 @@ type arrayTraceColumnSet struct { // Add a new column to this column set. func (p arrayTraceColumnSet) Add(column Column) uint { - m := &p.trace.modules[column.Module()] + ctx := column.Context() + m := &p.trace.modules[ctx.Module()] // Sanity check effective height - if column.Height() != (column.LengthMultiplier() * m.Height()) { + if column.Height() != (ctx.LengthMultiplier() * m.Height()) { panic(fmt.Sprintf("invalid column height for %s: %d vs %d*%d", column.Name(), - column.Height(), m.Height(), column.LengthMultiplier())) + column.Height(), m.Height(), ctx.LengthMultiplier())) } // Proceed index := uint(len(p.trace.columns)) @@ -139,7 +140,7 @@ func (p arrayTraceColumnSet) HasColumn(name string) bool { func (p arrayTraceColumnSet) IndexOf(module uint, name string) (uint, bool) { for i := 0; i < len(p.trace.columns); i++ { c := p.trace.columns[i] - if c.Module() == module && c.Name() == name { + if c.Context().Module() == module && c.Name() == name { return uint(i), true } } @@ -167,10 +168,10 @@ func (p arrayTraceColumnSet) Swap(l uint, r uint) { // Update modules notion of which columns they own. Observe that this only // makes sense when the modules for each column differ. Otherwise, this // leads to broken results. - if lth.Module() != rth.Module() { + if lth.Context().Module() != rth.Context().Module() { // Extract modules being swapped - lthmod := &modules[lth.Module()] - rthmod := &modules[rth.Module()] + lthmod := &modules[lth.Context().Module()] + rthmod := &modules[rth.Context().Module()] // Update their columns caches util.ReplaceFirstOrPanic(lthmod.columns, l, r) util.ReplaceFirstOrPanic(rthmod.columns, r, l) diff --git a/pkg/trace/builder.go b/pkg/trace/builder.go index 9e3b053..9add645 100644 --- a/pkg/trace/builder.go +++ b/pkg/trace/builder.go @@ -52,9 +52,12 @@ func (p *Builder) Add(name string, padding *fr.Element, data []*fr.Element) erro return err } } - // Register new column. Observe that user-provided columns always have a - // factor of 1. - return p.registerColumn(NewFieldColumn(mid, colname, 1, data, padding)) + // We assume (for now) that user-provided columns always have a length + // multiplier of 1. In general, this will be true. However, in situations + // where we are importing expanded traces, then this might not be true. + context := NewContext(mid, 1) + // Register new column. + return p.registerColumn(NewFieldColumn(context, colname, data, padding)) } // HasModule checks whether a given module has already been registered with this @@ -98,7 +101,7 @@ func (p *Builder) splitQualifiedColumnName(name string) (string, string) { // if the column's module does not exist, or if the column's height does not // match that of its enclosing module. func (p *Builder) registerColumn(col Column) error { - mid := col.Module() + mid := col.Context().Module() // Sanity check module exists if mid >= uint(len(p.modules)) { return errors.New("column has invalid enclosing module index") diff --git a/pkg/trace/bytes_column.go b/pkg/trace/bytes_column.go index d0526f2..4cb27f3 100644 --- a/pkg/trace/bytes_column.go +++ b/pkg/trace/bytes_column.go @@ -11,16 +11,16 @@ import ( // in this column is potentially slower than for a FieldColumn, as the raw bytes // must be converted into a field element. type BytesColumn struct { - module uint - name string + // Evaluation context of this column + context Context + // Holds the name of this column + name string // Determines how many bytes each field element takes. For the BLS12-377 // curve, this should be 32. In the future, when other curves are // supported, this could be less. width uint8 // The number of data elements in this column. length uint - // Length multiplier (i.e. of length) - multiplier uint // The data stored in this column (as bytes). bytes []byte // Value to be used when padding this column @@ -28,19 +28,19 @@ type BytesColumn struct { } // NewBytesColumn constructs a new BytesColumn from its constituent parts. -func NewBytesColumn(module uint, name string, width uint8, length uint, multiplier uint, +func NewBytesColumn(context Context, name string, width uint8, length uint, bytes []byte, padding *fr.Element) *BytesColumn { // Sanity check data length - if length%multiplier != 0 { + if length%context.LengthMultiplier() != 0 { panic("data length has incorrect multiplier") } - return &BytesColumn{module, name, width, length, multiplier, bytes, padding} + return &BytesColumn{context, name, width, length, bytes, padding} } -// Module returns the enclosing module of this column -func (p *BytesColumn) Module() uint { - return p.module +// Context returns the evaluation context this column provides. +func (p *BytesColumn) Context() Context { + return p.context } // Name returns the name of this column @@ -58,14 +58,6 @@ func (p *BytesColumn) Height() uint { return p.length } -// LengthMultiplier is a multiplier of the enclosing module's height used to -// determine this column's height. For example, if the multiplier is 2 then the -// height of this column must always be a multiple of 2, etc. This affects -// padding also, as we must pad to this multiplier. -func (p *BytesColumn) LengthMultiplier() uint { - return p.multiplier -} - // Padding returns the value which will be used for padding this column. func (p *BytesColumn) Padding() *fr.Element { return p.padding @@ -85,11 +77,10 @@ func (p *BytesColumn) Get(i int) *fr.Element { // Clone an BytesColumn func (p *BytesColumn) Clone() Column { clone := new(BytesColumn) - clone.module = p.module + clone.context = p.context clone.name = p.name clone.length = p.length clone.width = p.width - clone.multiplier = p.multiplier clone.padding = p.padding // NOTE: the following is as we never actually mutate the underlying bytes // array. @@ -126,7 +117,7 @@ func (p *BytesColumn) Data() []*fr.Element { // Pad this column with n copies of the column's padding value. func (p *BytesColumn) Pad(n uint) { // Apply the length multiplier - n = n * p.multiplier + n = n * p.context.LengthMultiplier() // Computing padding length (in bytes) padding_len := n * uint(p.width) // Access bytes to use for padding @@ -156,7 +147,7 @@ func (p *BytesColumn) Pad(n uint) { // Reseat updates the module index of this column (e.g. as a result of a // realignment). func (p *BytesColumn) Reseat(mid uint) { - p.module = mid + p.context = NewContext(mid, p.context.LengthMultiplier()) } // Write the raw bytes of this column to a given writer, returning an error diff --git a/pkg/trace/context.go b/pkg/trace/context.go new file mode 100644 index 0000000..ad94a16 --- /dev/null +++ b/pkg/trace/context.go @@ -0,0 +1,118 @@ +package trace + +import ( + "fmt" + "math" +) + +// Context represents the evaluation context in which an expression can be +// evaluated. Firstly, every expression must have a single enclosing module +// (i.e. all columns accessed by the expression are in that module); secondly, +// the length multiplier for all columns accessed by the expression must be the +// same. Constant expressions are something of an anomily here since they have +// neither an enclosing module, nor a length modifier. Instead, we consider +// that constant expressions are evaluated in the empty --- or void --- context. +// This is something like a bottom type which is contained within all other +// contexts. +// +// NOTE: Whilst the evaluation context provides a general abstraction, there are +// a number of restrictions imposed on it in practice which limit its use. These +// restrictions arise from what is and is not supported by the underlying +// constraint system (i.e. the prover). Example restrictions imposed include: +// multipliers must be powers of 2. Likewise, non-normal expressions (i.e those +// with a multipler > 1) can only be used in a fairly limited number of +// situtions (e.g. lookups). +type Context struct { + // Identifies the module in which this evaluation context exists. The empty + // module is given by the maximum index (math.MaxUint). + module uint + // Identifies the length multiplier required to complete this context. In + // essence, length multiplies divide up a given module into several disjoint + // "subregions", such than every expression exists only in one of them. + multiplier uint +} + +// VoidContext returns the void (or empty) context. This is the bottom type in +// the lattice, and is the context contained in all other contexts. It is +// needed, for example, as the context for constant expressions. +func VoidContext() Context { + return Context{math.MaxUint, 0} +} + +// ConflictingContext represents the case where multiple different contexts have +// been joined together. For example, when determining the context of an +// expression, the conflicting context is returned when no single context can be +// deteremed. This value is generally considered to indicate an error. +func ConflictingContext() Context { + return Context{math.MaxUint - 1, 0} +} + +// NewContext returns a context representing the given module with the given +// length multiplier. +func NewContext(module uint, multiplier uint) Context { + return Context{module, multiplier} +} + +// Module returns the module for this context. Note, however, that this is +// nonsensical in the case of either the void or the conflicted context. In +// this cases, this method will panic. +func (p Context) Module() uint { + if !p.IsVoid() && !p.IsConflicted() { + return p.module + } else if p.IsVoid() { + panic("void context has no module") + } + + panic("conflicted context has no module") +} + +// LengthMultiplier returns the length multiplier for this context. Note, +// however, that this is nonsensical in the case of either the void or the +// conflicted context. In this cases, this method will panic. +func (p Context) LengthMultiplier() uint { + if !p.IsVoid() && !p.IsConflicted() { + return p.multiplier + } else if p.IsVoid() { + panic("void context has no module") + } + + panic("conflicted context has no module") +} + +// IsVoid checks whether this context is the void context (or not). This is the +// bottom element in the lattice. +func (p Context) IsVoid() bool { + return p.module == math.MaxUint +} + +// IsConflicted checks whether this context represents the conflicted context. +// This is the top element in the lattice, and is used to represent the case +// where e.g. an expression has multiple conflicting contexts. +func (p Context) IsConflicted() bool { + return p.module == math.MaxUint-1 +} + +// Multiply updates the length multiplier by multiplying it by a given factor, +// producing the updated context. +func (p Context) Multiply(factor uint) Context { + return NewContext(p.module, p.multiplier*factor) +} + +// Join returns the least upper bound of the two contexts, or false if this does +// not exist (i.e. the two context's are in conflict). +func (p Context) Join(other Context) Context { + if p.IsVoid() { + return other + } else if other.IsVoid() { + return p + } else if p != other || p.IsConflicted() || other.IsConflicted() { + // Conflicting contexts + return ConflictingContext() + } + // Matching contexts + return p +} + +func (p Context) String() string { + return fmt.Sprintf("%d*%d", p.module, p.multiplier) +} diff --git a/pkg/trace/field_column.go b/pkg/trace/field_column.go index 199ed5c..94d6bd5 100644 --- a/pkg/trace/field_column.go +++ b/pkg/trace/field_column.go @@ -13,11 +13,10 @@ import ( // use quite a lot of memory. In particular, when there are many different // field elements which have smallish values then this requires excess data. type FieldColumn struct { - module uint + // Evaluation context of this column + context Context // Holds the name of this column name string - // Length multiplier (i.e. of the data array) - multiplier uint // Holds the raw data making up this column data []*fr.Element // Value to be used when padding this column @@ -25,18 +24,18 @@ type FieldColumn struct { } // NewFieldColumn constructs a FieldColumn with the give name, data and padding. -func NewFieldColumn(module uint, name string, multiplier uint, data []*fr.Element, padding *fr.Element) *FieldColumn { +func NewFieldColumn(context Context, name string, data []*fr.Element, padding *fr.Element) *FieldColumn { // Sanity check data length - if uint(len(data))%multiplier != 0 { + if uint(len(data))%context.LengthMultiplier() != 0 { panic("data length has incorrect multiplier") } // Done - return &FieldColumn{module, name, multiplier, data, padding} + return &FieldColumn{context, name, data, padding} } -// Module returns the enclosing module of this column -func (p *FieldColumn) Module() uint { - return p.module +// Context returns the evaluation context this column provides. +func (p *FieldColumn) Context() Context { + return p.context } // Name returns the name of the given column. @@ -55,13 +54,6 @@ func (p *FieldColumn) Height() uint { return uint(len(p.data)) } -// LengthMultiplier is a multiplier which must be a factor of the height. For -// example, if the factor is 2 then the height must always be a multiple of 2, -// etc. This affects padding also, as we must pad to this factor. -func (p *FieldColumn) LengthMultiplier() uint { - return p.multiplier -} - // Padding returns the value which will be used for padding this column. func (p *FieldColumn) Padding() *fr.Element { return p.padding @@ -87,9 +79,8 @@ func (p *FieldColumn) Get(row int) *fr.Element { // Clone an FieldColumn func (p *FieldColumn) Clone() Column { clone := new(FieldColumn) - clone.module = p.module + clone.context = p.context clone.name = p.name - clone.multiplier = p.multiplier clone.padding = p.padding // NOTE: the following is as we never actually mutate the underlying bytes // array. @@ -101,7 +92,7 @@ func (p *FieldColumn) Clone() Column { // Pad this column with n copies of the column's padding value. func (p *FieldColumn) Pad(n uint) { // Apply the length multiplier - n = n * p.multiplier + n = n * p.context.LengthMultiplier() // Allocate sufficient memory ndata := make([]*fr.Element, uint(len(p.data))+n) // Copy over the data @@ -117,7 +108,7 @@ func (p *FieldColumn) Pad(n uint) { // Reseat updates the module index of this column (e.g. as a result of a // realignment). func (p *FieldColumn) Reseat(mid uint) { - p.module = mid + p.context = NewContext(mid, p.context.LengthMultiplier()) } // Write the raw bytes of this column to a given writer, returning an error diff --git a/pkg/trace/json/writer.go b/pkg/trace/json/writer.go index e071bac..0f4e4c1 100644 --- a/pkg/trace/json/writer.go +++ b/pkg/trace/json/writer.go @@ -17,7 +17,7 @@ func ToJsonString(tr trace.Trace) string { // for i := uint(0); i < columns.Len(); i++ { ith := columns.Get(i) - mod := tr.Modules().Get(ith.Module()) + mod := tr.Modules().Get(ith.Context().Module()) // Determine fully qualified column name name := ith.Name() // Prepend module name (if applicable) diff --git a/pkg/trace/lt/writer.go b/pkg/trace/lt/writer.go index fd34a87..4118244 100644 --- a/pkg/trace/lt/writer.go +++ b/pkg/trace/lt/writer.go @@ -42,7 +42,7 @@ func WriteBytes(tr trace.Trace, buf io.Writer) error { // Write header information for i := uint(0); i < ncols; i++ { col := columns.Get(i) - mod := modules.Get(col.Module()) + mod := modules.Get(col.Context().Module()) name := col.Name() // Prepend module name (if applicable) if mod.Name() != "" { diff --git a/pkg/trace/trace.go b/pkg/trace/trace.go index 180ea5f..231ce57 100644 --- a/pkg/trace/trace.go +++ b/pkg/trace/trace.go @@ -45,13 +45,12 @@ type Column interface { Get(row int) *fr.Element // Return the height (i.e. number of rows) of this column. Height() uint - // Returns the length multiplier (which must be a factor of the height). For - // example, if the multiplier is 2 then the height must always be a multiple - // of 2, etc. This affects padding also, as we must pad to this multiplier, - // etc. - LengthMultiplier() uint - // Get the module index of the enclosing module. - Module() uint + // Returns the evaluation context for this column. That identifies the + // enclosing module, and then length multiplier (which must be a factor of + // the height). For example, if the multiplier is 2 then the height must + // always be a multiple of 2, etc. This affects padding also, as we must + // pad to this multiplier, etc. + Context() Context // Get the name of this column Name() string // Return the value to use for padding this column.