Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a rewrite for not in #312

Draft
wants to merge 1 commit into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions opshin/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .rewrite.rewrite_import_uplc_builtins import RewriteImportUPLCBuiltins
from .rewrite.rewrite_inject_builtins import RewriteInjectBuiltins
from .rewrite.rewrite_inject_builtin_constr import RewriteInjectBuiltinsConstr
from .rewrite.rewrite_notin import RewriteNotIn
from .rewrite.rewrite_orig_name import RewriteOrigName
from .rewrite.rewrite_remove_type_stuff import RewriteRemoveTypeStuff
from .rewrite.rewrite_scoping import RewriteScoping
Expand Down Expand Up @@ -1012,6 +1013,7 @@ def compile(
RewriteAugAssign(),
RewriteComparisonChaining(),
RewriteTupleAssign(),
RewriteNotIn(),
RewriteImportIntegrityCheck(),
RewriteImportPlutusData(),
RewriteImportHashlib(),
Expand Down
21 changes: 21 additions & 0 deletions opshin/rewrite/rewrite_notin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from ast import *
from copy import copy

from ..util import CompilingNodeTransformer

"""
Rewrites all occurences of (a not in b) into (not (a in b)).
"""


class RewriteNotIn(CompilingNodeTransformer):
step = "Rewriting (not in)"

def visit_Compare(self, node: Compare) -> AST:
assert (
len(node.ops) == 1
), "RewriteNotIn only works on single comparisons, need to run RewriteComparisonChaining first"
if not isinstance(node.ops[0], NotIn):
return self.generic_visit(node)
new_node = UnaryOp(Not(), Compare(node.left, [In()], node.comparators))
return self.generic_visit(new_node)