From c8829c96461e0d2f98deab28767d0f1ef0418b5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niels=20M=C3=BCndler?= Date: Fri, 12 Jan 2024 17:14:57 +0100 Subject: [PATCH] Add a rewrite for not in --- opshin/compiler.py | 2 ++ opshin/rewrite/rewrite_notin.py | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+) create mode 100644 opshin/rewrite/rewrite_notin.py diff --git a/opshin/compiler.py b/opshin/compiler.py index 89756629..97606eca 100644 --- a/opshin/compiler.py +++ b/opshin/compiler.py @@ -18,6 +18,7 @@ from .rewrite.rewrite_import_uplc_builtins import RewriteImportUPLCBuiltins from .rewrite.rewrite_inject_builtins import RewriteInjectBuiltins from .rewrite.rewrite_inject_builtin_constr import RewriteInjectBuiltinsConstr +from .rewrite.rewrite_notin import RewriteNotIn from .rewrite.rewrite_orig_name import RewriteOrigName from .rewrite.rewrite_remove_type_stuff import RewriteRemoveTypeStuff from .rewrite.rewrite_scoping import RewriteScoping @@ -1012,6 +1013,7 @@ def compile( RewriteAugAssign(), RewriteComparisonChaining(), RewriteTupleAssign(), + RewriteNotIn(), RewriteImportIntegrityCheck(), RewriteImportPlutusData(), RewriteImportHashlib(), diff --git a/opshin/rewrite/rewrite_notin.py b/opshin/rewrite/rewrite_notin.py new file mode 100644 index 00000000..a5440a26 --- /dev/null +++ b/opshin/rewrite/rewrite_notin.py @@ -0,0 +1,21 @@ +from ast import * +from copy import copy + +from ..util import CompilingNodeTransformer + +""" +Rewrites all occurences of (a not in b) into (not (a in b)). +""" + + +class RewriteNotIn(CompilingNodeTransformer): + step = "Rewriting (not in)" + + def visit_Compare(self, node: Compare) -> AST: + assert ( + len(node.ops) == 1 + ), "RewriteNotIn only works on single comparisons, need to run RewriteComparisonChaining first" + if not isinstance(node.ops[0], NotIn): + return self.generic_visit(node) + new_node = UnaryOp(Not(), Compare(node.left, [In()], node.comparators)) + return self.generic_visit(new_node)