Skip to content

Commit

Permalink
Ec opcodes
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonpaulos committed Dec 22, 2023
1 parent 289b3cb commit 58cc4ea
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 2 deletions.
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ per-file-ignores =
examples/signature/recurring_swap.py: F403, F405
examples/signature/split.py: F403, F405
pyteal/__init__.py: F401, F403
pyteal/ast/ec.py: E222
pyteal/compiler/flatten.py: F821
pyteal/compiler/optimizer/__init__.py: F401
pyteal/ir/ops.py: E221
Expand Down
7 changes: 7 additions & 0 deletions pyteal/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,19 @@ __all__ = [
"Div",
"Divw",
"DynamicScratchVar",
"EcAdd",
"EcMapTo",
"EcMultiScalarMul",
"EcPairingCheck",
"EcScalarMul",
"EcSubgroupCheck",
"EcdsaCurve",
"EcdsaDecompress",
"EcdsaRecover",
"EcdsaVerify",
"Ed25519Verify",
"Ed25519Verify_Bare",
"EllipticCurve",
"EnumInt",
"Eq",
"Err",
Expand Down
16 changes: 16 additions & 0 deletions pyteal/ast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,15 @@
from pyteal.ast.multi import MultiValue
from pyteal.ast.opup import OpUp, OpUpMode, OpUpFeeSource
from pyteal.ast.ecdsa import EcdsaCurve, EcdsaVerify, EcdsaDecompress, EcdsaRecover
from pyteal.ast.ec import (
EllipticCurve,
EcAdd,
EcScalarMul,
EcPairingCheck,
EcMultiScalarMul,
EcSubgroupCheck,
EcMapTo,
)
from pyteal.ast.router import (
BareCallActions,
CallConfig,
Expand Down Expand Up @@ -256,6 +265,13 @@
"EcdsaVerify",
"Ed25519Verify_Bare",
"Ed25519Verify",
"EllipticCurve",
"EcAdd",
"EcScalarMul",
"EcPairingCheck",
"EcMultiScalarMul",
"EcSubgroupCheck",
"EcMapTo",
"EnumInt",
"Eq",
"Err",
Expand Down
2 changes: 1 addition & 1 deletion pyteal/ast/abstractvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def alloc_abstract_var(stack_type: TealType) -> AbstractVar:
stack_type: TealType that represents stack type.
"""

from pyteal.ast import ScratchVar
from pyteal.ast.scratchvar import ScratchVar
from pyteal.ast.subroutine import SubroutineEval
from pyteal.ast.frame import FrameVar, MAX_FRAME_LOCAL_VARS

Expand Down
148 changes: 148 additions & 0 deletions pyteal/ast/ec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from typing import TYPE_CHECKING
from enum import Enum

from pyteal.ast.expr import Expr

from pyteal.ir import Op, TealBlock, TealOp
from pyteal.types import TealType, require_type
from pyteal.errors import verifyFieldVersion, verifyProgramVersion

if TYPE_CHECKING:
from pyteal.compiler import CompileOptions


class EllipticCurve(Enum):
# fmt: off
# id | name | min version
BN254g1 = (0, "BN254g1", 10)
BN254g2 = (1, "BN254g2", 10)
BLS12_381g1 = (2, "BLS12_381g1", 10)
BLS12_381g2 = (3, "BLS12_381g2", 10)
# fmt: on

def __init__(self, id: int, name: str, min_version: int) -> None:
self.id = id
self.arg_name = name
self.min_version = min_version


class EcOperation(Expr):
def __init__(self, op: Op, curve: EllipticCurve, args: list[Expr]) -> None:
super().__init__()
self.op = op
assert curve in EllipticCurve
self.curve = curve
for arg in args:
require_type(arg, TealType.bytes)
self.args = args

def __teal__(self, options: "CompileOptions"):
verifyProgramVersion(
self.op.min_version,
options.version,
f"Program version too low to use op {self.op}",
)

verifyFieldVersion(self.curve.arg_name, self.curve.min_version, options.version)

op = TealOp(self, self.op, self.curve.arg_name)
return TealBlock.FromOp(options, op, *self.args)

def __str__(self):
return f"(EcOperation {self.op} {self.curve} {self.args})"

def type_of(self):
return TealType.bytes

def has_return(self):
return False


def EcAdd(curve: EllipticCurve, a: Expr, b: Expr) -> Expr:
"""Add two points on the given elliptic curve.
Args:
curve: The elliptic curve to use.
a: The first point to add. Must evaluate to bytes.
b: The second point to add. Must evaluate to bytes.
Returns:
An expression which evaluates to the sum of the two points on the given
curve.
"""
return EcOperation(Op.ec_add, curve, [a, b])


def EcScalarMul(curve: EllipticCurve, point: Expr, scalar: Expr) -> Expr:
"""Multiply a point on the given elliptic curve by a scalar.
Args:
curve: The elliptic curve to use.
point: The point to multiply. Must evaluate to bytes.
scalar: The scalar to multiply by, encoded as a big-endian unsigned
integer. Must evaluate to bytes. Fails if this value exceeds 32 bytes.
Returns:
An expression which evaluates to the product of the point and scalar on
the given curve.
"""
return EcOperation(Op.ec_scalar_mul, curve, [point, scalar])


def EcPairingCheck(curve: EllipticCurve, a: Expr, b: Expr) -> Expr:
"""Check the pairing of two points on the given elliptic curve.
Args:
curve: The elliptic curve to use.
a: The first point to check. Must evaluate to bytes.
b: The second point to check. Must evaluate to bytes.
Returns:
An expression which evaluates to 1 if the product of the pairing of each
point in `a` with its respective point in `b` is equal to the identity
element of the target group. Otherwise, evaluates to 0.
"""
return EcOperation(Op.ec_pairing_check, curve, [a, b])


def EcMultiScalarMul(curve: EllipticCurve, a: Expr, b: Expr) -> Expr:
"""Multiply a point on the given elliptic curve by a series of scalars.
Args:
curve: The elliptic curve to use.
a: The point to multiply. Must evaluate to bytes.
b: A list of concatenated, big-endian, 32-byte scalar integers to
multiply by.
Returns:
An expression that evaluates to curve point :code:`b_0a_0 + b_1a_1 + b_2a_2 + ... + b_Na_N`.
"""
return EcOperation(Op.ec_multi_scalar_mul, curve, [a, b])


def EcSubgroupCheck(curve: EllipticCurve, a: Expr) -> Expr:
"""Check if a point is in the main prime-order subgroup of the given elliptic curve.
Args:
curve: The elliptic curve to use.
a: The point to check. Must evaluate to bytes.
Returns:
An expression that evaluates to 1 if the point is in the main prime-order
subgroup of the curve (including the point at infinity) else 0. Program
fails if the point is not in the curve at all.
"""
return EcOperation(Op.ec_subgroup_check, curve, [a])


def EcMapTo(curve: EllipticCurve, a: Expr) -> Expr:
"""Map field element `a` to group `curve`.
Args:
curve: The elliptic curve to use.
a: The field element to map. Must evaluate to bytes.
Returns:
An expression that evaluates to the mapped point.
"""
return EcOperation(Op.ec_map_to, curve, [a])
47 changes: 47 additions & 0 deletions pyteal/ast/ec_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Callable

import pytest

import pyteal as pt

OPERATIONS: list[
tuple[
Callable[[pt.EllipticCurve, pt.Expr], pt.Expr]
| Callable[[pt.EllipticCurve, pt.Expr, pt.Expr], pt.Expr],
pt.Op,
int,
]
] = [
(pt.EcAdd, pt.Op.ec_add, 2),
(pt.EcScalarMul, pt.Op.ec_scalar_mul, 2),
(pt.EcPairingCheck, pt.Op.ec_pairing_check, 2),
(pt.EcMultiScalarMul, pt.Op.ec_multi_scalar_mul, 2),
(pt.EcSubgroupCheck, pt.Op.ec_subgroup_check, 1),
(pt.EcMapTo, pt.Op.ec_map_to, 1),
]


def test_EcOperation():
for operation, expected_op, num_args in OPERATIONS:
for curve in pt.EllipticCurve:
args = [pt.Bytes(f"arg{i}") for i in range(num_args)]
expr = operation(curve, *args)
assert expr.type_of() == pt.TealType.bytes

expected = pt.TealSimpleBlock(
[pt.TealOp(arg, pt.Op.byte, f'"arg{i}"') for i, arg in enumerate(args)]
+ [pt.TealOp(expr, expected_op, curve.arg_name)]
)

actual, _ = expr.__teal__(pt.CompileOptions(version=10))
actual.addIncoming()
actual = pt.TealBlock.NormalizeBlocks(actual)

assert actual == expected

# Test wrong arg types
for i in range(num_args):
bad_args = args.copy()
bad_args[i] = pt.Int(1)
with pytest.raises(pt.TealTypeError):
operation(curve, *bad_args)
3 changes: 2 additions & 1 deletion pyteal/ast/ecdsa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from enum import Enum
from typing import Tuple, TYPE_CHECKING

from pyteal.ast import Expr, MultiValue
from pyteal.ast.expr import Expr
from pyteal.ast.multi import MultiValue
from pyteal.errors import (
TealTypeError,
verifyFieldVersion,
Expand Down

0 comments on commit 58cc4ea

Please sign in to comment.