Skip to content

Commit

Permalink
Merge pull request #285 from OpShin/fix/determinism_union_types
Browse files Browse the repository at this point in the history
Extend assert sum example and ensure determinism among union types
  • Loading branch information
nielstron authored Nov 1, 2023
2 parents c065e23 + f0402bb commit 6738356
Show file tree
Hide file tree
Showing 10 changed files with 165 additions and 120 deletions.
3 changes: 3 additions & 0 deletions examples/smart_contracts/assert_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@


def validator(datum: int, redeemer: int, context: ScriptContext) -> None:
purpose = context.purpose
if not isinstance(purpose, Spending):
print(f"Wrong script purpose: {purpose}")
assert (
datum + redeemer == 42
), f"Expected datum and redeemer to sum to 42, but they sum to {datum + redeemer}"
11 changes: 6 additions & 5 deletions opshin/optimize/optimize_const_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging

from ast import *
from ordered_set import OrderedSet

from pycardano import PlutusData

Expand Down Expand Up @@ -98,7 +99,7 @@ class ShallowNameDefCollector(CompilingNodeVisitor):
step = "Collecting occuring variable names"

def __init__(self):
self.vars = set()
self.vars = OrderedSet()

def visit_Name(self, node: Name) -> None:
if isinstance(node.ctx, Store):
Expand Down Expand Up @@ -172,13 +173,13 @@ class OptimizeConstantFolding(CompilingNodeTransformer):

def __init__(self):
self.scopes_visible = [
set(INITIAL_SCOPE.keys()).difference(SAFE_GLOBALS.keys())
OrderedSet(INITIAL_SCOPE.keys()).difference(SAFE_GLOBALS.keys())
]
self.scopes_constants = [dict()]
self.constants = set()
self.constants = OrderedSet()

def enter_scope(self):
self.scopes_visible.append(set())
self.scopes_visible.append(OrderedSet())
self.scopes_constants.append(dict())

def add_var_visible(self, var: str):
Expand All @@ -191,7 +192,7 @@ def add_constant(self, var: str, value: typing.Any):
self.scopes_constants[-1][var] = value

def visible_vars(self):
res_set = set()
res_set = OrderedSet()
for s in self.scopes_visible:
res_set.update(s)
return res_set
Expand Down
6 changes: 4 additions & 2 deletions opshin/optimize/optimize_remove_deadvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from copy import copy
from collections import defaultdict

from ordered_set import OrderedSet

from ..util import CompilingNodeVisitor, CompilingNodeTransformer
from ..type_inference import INITIAL_SCOPE
from ..typed_ast import TypedAnnAssign
Expand Down Expand Up @@ -93,7 +95,7 @@ def visit_Module(self, node: Module) -> Module:
# collect all variable names
collector = NameLoadCollector()
collector.visit(node_cp)
loaded_vars = set(collector.loaded.keys()) | {"validator_0"}
loaded_vars = OrderedSet(collector.loaded.keys()) | {"validator_0"}
# break if the set of loaded vars did not change -> set of vars to remove does also not change
if loaded_vars == self.loaded_vars:
break
Expand All @@ -115,7 +117,7 @@ def visit_If(self, node: If):
scope_orelse_cp = self.guaranteed_avail_names[-1].copy()
self.exit_scope()
# what remains after this in the scope is the intersection of both
for var in set(scope_body_cp).intersection(scope_orelse_cp):
for var in OrderedSet(scope_body_cp).intersection(scope_orelse_cp):
self.set_guaranteed(var)
return node_cp

Expand Down
3 changes: 2 additions & 1 deletion opshin/rewrite/rewrite_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import typing
import sys
from ast import *
from ordered_set import OrderedSet

from ..util import CompilingNodeTransformer

Expand Down Expand Up @@ -57,7 +58,7 @@ class RewriteImport(CompilingNodeTransformer):
def __init__(self, filename=None, package=None, resolved_imports=None):
self.filename = filename
self.package = package
self.resolved_imports = resolved_imports or set()
self.resolved_imports = resolved_imports or OrderedSet()

def visit_ImportFrom(
self, node: ImportFrom
Expand Down
8 changes: 5 additions & 3 deletions opshin/rewrite/rewrite_scoping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from copy import copy
from collections import defaultdict

from ordered_set import OrderedSet

from ..type_inference import INITIAL_SCOPE, PolymorphicFunctionInstanceType
from ..util import CompilingNodeTransformer, CompilingNodeVisitor

Expand All @@ -14,7 +16,7 @@ class ShallowNameDefCollector(CompilingNodeVisitor):
step = "Collecting occuring variable names"

def __init__(self):
self.vars = set()
self.vars = OrderedSet()

def visit_Name(self, node: Name) -> None:
if isinstance(node.ctx, Store) or isinstance(
Expand All @@ -36,7 +38,7 @@ class RewriteScoping(CompilingNodeTransformer):

def __init__(self):
self.latest_scope_id = 0
self.scopes = [(set(INITIAL_SCOPE.keys()), -1)]
self.scopes = [(OrderedSet(INITIAL_SCOPE.keys()), -1)]

def variable_scope_id(self, name: str) -> int:
"""find the id of the scope in which this variable is defined (closest to its usage)"""
Expand All @@ -49,7 +51,7 @@ def variable_scope_id(self, name: str) -> int:
)

def enter_scope(self):
self.scopes.append((set(), self.latest_scope_id))
self.scopes.append((OrderedSet(), self.latest_scope_id))
self.latest_scope_id += 1

def exit_scope(self):
Expand Down
32 changes: 25 additions & 7 deletions opshin/tests/test_misc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import sys

import subprocess
Expand Down Expand Up @@ -30,6 +32,13 @@
from pycardano import RawPlutusData
from cbor2 import CBORTag

ALL_EXAMPLES = [
os.path.join(root, f)
for root, dirs, files in os.walk("examples")
for f in files
if f.endswith(".py") and not f.startswith("broken") and not f.startswith("extract")
]


def fib(n):
a, b = 0, 1
Expand Down Expand Up @@ -60,7 +69,16 @@ def test_assert_sum_contract_succeed(self):
input_file = "examples/smart_contracts/assert_sum.py"
with open(input_file) as fp:
source_code = fp.read()
ret = eval_uplc(source_code, 20, 22, Unit())
ret = eval_uplc(
source_code,
20,
22,
uplc.data_from_cbor(
bytes.fromhex(
"d8799fd8799f9fd8799fd8799fd8799f582055d353acacaab6460b37ed0f0e3a1a0aabf056df4a7fa1e265d21149ccacc527ff01ffd8799fd8799fd87a9f581cdbe769758f26efb21f008dc097bb194cffc622acc37fcefc5372eee3ffd87a80ffa140a1401a00989680d87a9f5820dfab81872ce2bbe6ee5af9bbfee4047f91c1f57db5e30da727d5fef1e7f02f4dffd87a80ffffff809fd8799fd8799fd8799f581cdc315c289fee4484eda07038393f21dc4e572aff292d7926018725c2ffd87a80ffa140a14000d87980d87a80ffffa140a14000a140a1400080a0d8799fd8799fd87980d87a80ffd8799fd87b80d87a80ffff80a1d87a9fd8799fd8799f582055d353acacaab6460b37ed0f0e3a1a0aabf056df4a7fa1e265d21149ccacc527ff01ffffd87980a15820dfab81872ce2bbe6ee5af9bbfee4047f91c1f57db5e30da727d5fef1e7f02f4dd8799f581cdc315c289fee4484eda07038393f21dc4e572aff292d7926018725c2ffd8799f5820746957f0eb57f2b11119684e611a98f373afea93473fefbb7632d579af2f6259ffffd87a9fd8799fd8799f582055d353acacaab6460b37ed0f0e3a1a0aabf056df4a7fa1e265d21149ccacc527ff01ffffff"
)
),
)
self.assertEqual(ret, uplc.PlutusConstr(0, []))

@unittest.expectedFailure
Expand Down Expand Up @@ -2176,24 +2194,24 @@ def validator(
"""
builder._compile(source_code)

def test_compilation_deterministic_local(self):
input_file = "examples/smart_contracts/assert_sum.py"
@parameterized.expand(ALL_EXAMPLES)
def test_compilation_deterministic_local(self, input_file):
with open(input_file) as fp:
source_code = fp.read()
code = builder._compile(source_code)
for i in range(10):
code_2 = builder._compile(source_code)
self.assertEqual(code.dumps(), code_2.dumps())

def test_compilation_deterministic_external(self):
input_file = "examples/smart_contracts/assert_sum.py"
@parameterized.expand(ALL_EXAMPLES)
def test_compilation_deterministic_external(self, input_file):
code = subprocess.run(
[
sys.executable,
"-m",
"opshin",
"compile",
"spending",
"any",
input_file,
],
capture_output=True,
Expand All @@ -2205,7 +2223,7 @@ def test_compilation_deterministic_external(self):
"-m",
"opshin",
"compile",
"spending",
"any",
input_file,
],
capture_output=True,
Expand Down
11 changes: 6 additions & 5 deletions opshin/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""
import typing
from collections import defaultdict
from ordered_set import OrderedSet

from copy import copy
from pycardano import PlutusData
Expand Down Expand Up @@ -142,7 +143,7 @@ def constant_type(c):


def union_types(*ts: Type):
ts = list(set(ts))
ts = OrderedSet(ts)
if len(ts) == 1:
return ts[0]
assert ts, "Union must combine multiple classes"
Expand All @@ -151,7 +152,7 @@ def union_types(*ts: Type):
isinstance(e, UnionType) and all(isinstance(e2, RecordType) for e2 in e.typs)
for e in ts
), "Union must combine multiple PlutusData classes"
union_set = set()
union_set = OrderedSet()
for t in ts:
union_set.update(t.typs)
assert distinct(
Expand All @@ -161,12 +162,12 @@ def union_types(*ts: Type):


def intersection_types(*ts: Type):
ts = list(set(ts))
ts = OrderedSet(ts)
if len(ts) == 1:
return ts[0]
ts = [t if isinstance(t, UnionType) else UnionType(frozenlist([t])) for t in ts]
assert ts, "Must have at least one type to intersect"
intersection_set = set(ts[0].typs)
intersection_set = OrderedSet(ts[0].typs)
for t in ts[1:]:
intersection_set.intersection_update(t.typs)
return UnionType(frozenlist(intersection_set))
Expand Down Expand Up @@ -261,7 +262,7 @@ def visit_UnaryOp(self, node: UnaryOp) -> PairType:


def merge_scope(s1: typing.Dict[str, Type], s2: typing.Dict[str, Type]):
keys = set(s1.keys()).union(s2.keys())
keys = OrderedSet(s1.keys()).union(s2.keys())
merged = {}
for k in keys:
if k not in s1.keys():
Expand Down
3 changes: 2 additions & 1 deletion opshin/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from ast import *

import itertools
from ordered_set import OrderedSet

import uplc.ast

Expand Down Expand Up @@ -535,7 +536,7 @@ def attribute_type(self, attr) -> "Type":
return IntegerInstanceType
# need to have a common field with the same name
if all(attr in (n for n, t in x.record.fields) for x in self.typs):
attr_types = set(
attr_types = OrderedSet(
t for x in self.typs for n, t in x.record.fields if n == attr
)
for at in attr_types:
Expand Down
Loading

0 comments on commit 6738356

Please sign in to comment.