Skip to content

Commit

Permalink
Starting to tidy up mathics.core.pattern (#1086)
Browse files Browse the repository at this point in the history
In line with the last changes proposed by @rocky, I was doing a pass
over mathics.core.pattern and fixing some issues reported by the linter.

---------

Co-authored-by: R. Bernstein <[email protected]>
  • Loading branch information
mmatera and rocky authored Sep 19, 2024
1 parent 6c02434 commit 77db22f
Show file tree
Hide file tree
Showing 11 changed files with 764 additions and 517 deletions.
4 changes: 2 additions & 2 deletions mathics/builtin/exp_structure/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mathics.core.exceptions import InvalidLevelspecError
from mathics.core.expression import Evaluation, Expression
from mathics.core.list import ListExpression
from mathics.core.rules import Pattern
from mathics.core.rules import BasePattern
from mathics.core.symbols import Atom, SymbolFalse, SymbolTrue
from mathics.core.systemsymbols import SymbolMap
from mathics.eval.parts import python_levelspec, walk_levels
Expand Down Expand Up @@ -114,7 +114,7 @@ class FreeQ(Builtin):
def eval(self, expr, form, evaluation: Evaluation):
"FreeQ[expr_, form_]"

form = Pattern.create(form)
form = BasePattern.create(form)
if expr.is_free(form, evaluation):
return SymbolTrue
else:
Expand Down
4 changes: 2 additions & 2 deletions mathics/builtin/list/constructing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from mathics.builtin.box.layout import RowBox
from mathics.core.atoms import Integer, is_integer_rational_or_real
from mathics.core.attributes import A_HOLD_FIRST, A_LISTABLE, A_LOCKED, A_PROTECTED
from mathics.core.builtin import Builtin, IterationFunction, Pattern
from mathics.core.builtin import BasePattern, Builtin, IterationFunction
from mathics.core.convert.expression import to_expression
from mathics.core.convert.sympy import from_sympy
from mathics.core.element import ElementsProperties
Expand Down Expand Up @@ -431,7 +431,7 @@ def eval(self, expr, patterns, f, evaluation: Evaluation):
"Reap[expr_, {patterns___}, f_]"

patterns = patterns.get_sequence()
sown = [(Pattern.create(pattern), []) for pattern in patterns]
sown = [(BasePattern.create(pattern), []) for pattern in patterns]

def listener(e, tag):
result = False
Expand Down
12 changes: 6 additions & 6 deletions mathics/builtin/numbers/algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
MATHICS3_NEG_INFINITY,
)
from mathics.core.list import ListExpression
from mathics.core.rules import Pattern
from mathics.core.rules import BasePattern
from mathics.core.symbols import (
Atom,
Symbol,
Expand Down Expand Up @@ -549,11 +549,11 @@ def coeff_power_internal(
else:
return [([], expr)]
if len(var_exprs) == 1:
target_pat = Pattern.create(var_exprs[0])
target_pat = BasePattern.create(var_exprs[0])
var_pats = [target_pat]
else:
target_pat = Pattern.create(Expression(SymbolAlternatives, *var_exprs))
var_pats = [Pattern.create(var) for var in var_exprs]
target_pat = BasePattern.create(Expression(SymbolAlternatives, *var_exprs))
var_pats = [BasePattern.create(var) for var in var_exprs]

# ###### Auxiliary functions #########
def key_powers(lst: list) -> Union[int, float]:
Expand Down Expand Up @@ -1172,7 +1172,7 @@ def eval_patt(self, expr, target, evaluation: Evaluation, options: dict):
return

if target:
kwargs["pattern"] = Pattern.create(target)
kwargs["pattern"] = BasePattern.create(target)
kwargs["evaluation"] = evaluation
return expand(expr, True, False, **kwargs)

Expand Down Expand Up @@ -1235,7 +1235,7 @@ def eval_patt(self, expr, target, evaluation: Evaluation, options: dict):
return

if target:
kwargs["pattern"] = Pattern.create(target)
kwargs["pattern"] = BasePattern.create(target)
kwargs["evaluation"] = evaluation
return expand(expr, numer=True, denom=True, deep=True, **kwargs)

Expand Down
8 changes: 4 additions & 4 deletions mathics/builtin/numbers/calculus.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from mathics.core.expression import Expression
from mathics.core.list import ListExpression
from mathics.core.number import MACHINE_EPSILON, dps
from mathics.core.rules import Pattern
from mathics.core.rules import BasePattern
from mathics.core.symbols import (
BaseElement,
Symbol,
Expand Down Expand Up @@ -228,7 +228,7 @@ def eval(self, f, x, evaluation: Evaluation):
if f == x:
return Integer1

x_pattern = Pattern.create(x)
x_pattern = BasePattern.create(x)
if f.is_free(x_pattern, evaluation):
return Integer0

Expand Down Expand Up @@ -1917,7 +1917,7 @@ def eval_times(
nummax.get_int_value(),
den.get_int_value(),
)
x_pattern = Pattern.create(x)
x_pattern = BasePattern.create(x)
incompat_series = []
max_exponent = Integer(int(series[2] / series[3] + 1))
if coeff.get_head() is SymbolSequence:
Expand Down Expand Up @@ -2263,7 +2263,7 @@ def eval(self, eqs, vars, evaluation: Evaluation):
vars = []
vars_sympy = []
for var, var_sympy in zip(all_vars, all_vars_sympy):
pattern = Pattern.create(var)
pattern = BasePattern.create(var)
if not eqs.is_free(pattern, evaluation):
vars.append(var)
vars_sympy.append(var_sympy)
Expand Down
55 changes: 29 additions & 26 deletions mathics/builtin/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@
The attributes 'Flat', 'Orderless', and 'OneIdentity' affect pattern matching.
"""

# This tells documentation how to sort this module
sort_order = "mathics.builtin.rules-and-patterns"

from typing import Callable, List, Optional as OptionalType, Tuple, Union

from mathics.core.atoms import Integer, Number, Rational, Real, String
Expand All @@ -63,12 +60,15 @@
from mathics.core.exceptions import InvalidLevelspecError
from mathics.core.expression import Expression, SymbolVerbatim
from mathics.core.list import ListExpression
from mathics.core.pattern import Pattern, StopGenerator
from mathics.core.pattern import BasePattern, StopGenerator
from mathics.core.rules import Rule
from mathics.core.symbols import Atom, Symbol, SymbolList, SymbolTrue
from mathics.core.systemsymbols import SymbolBlank, SymbolDefault, SymbolDispatch
from mathics.eval.parts import python_levelspec

# This tells documentation how to sort this module
sort_order = "mathics.builtin.rules-and-patterns"


class Rule_(BinaryOperator):
"""
Expand Down Expand Up @@ -159,7 +159,7 @@ def create_rules(
if any_lists:
all_lists = True
for item in rules:
if not item.get_head() is SymbolList:
if item.get_head() is not SymbolList:
all_lists = False
break

Expand Down Expand Up @@ -568,7 +568,7 @@ def init(
"System`NonNegative": self.match_nonnegative,
}

self.pattern = Pattern.create(expr.elements[0], evaluation=evaluation)
self.pattern = BasePattern.create(expr.elements[0], evaluation=evaluation)
self.test = expr.elements[1]
testname = self.test.get_name()
self.test_name = testname
Expand Down Expand Up @@ -765,7 +765,8 @@ def init(
) -> None:
super(Alternatives, self).init(expr, evaluation=evaluation)
self.alternatives = [
Pattern.create(element, evaluation=evaluation) for element in expr.elements
BasePattern.create(element, evaluation=evaluation)
for element in expr.elements
]

def match(self, yield_func, expression, vars, evaluation, **kwargs):
Expand Down Expand Up @@ -826,11 +827,11 @@ def init(
self, expr: Expression, evaluation: OptionalType[Evaluation] = None
) -> None:
super(Except, self).init(expr, evaluation=evaluation)
self.c = Pattern.create(expr.elements[0])
self.c = BasePattern.create(expr.elements[0])
if len(expr.elements) == 2:
self.p = Pattern.create(expr.elements[1], evaluation=evaluation)
self.p = BasePattern.create(expr.elements[1], evaluation=evaluation)
else:
self.p = Pattern.create(Expression(SymbolBlank), evaluation=evaluation)
self.p = BasePattern.create(Expression(SymbolBlank), evaluation=evaluation)

def match(self, yield_func, expression, vars, evaluation, **kwargs):
def except_yield_func(vars, rest):
Expand Down Expand Up @@ -910,7 +911,7 @@ def init(
self, expr: Expression, evaluation: OptionalType[Evaluation] = None
) -> None:
super(HoldPattern, self).init(expr, evaluation=evaluation)
self.pattern = Pattern.create(expr.elements[0], evaluation=evaluation)
self.pattern = BasePattern.create(expr.elements[0], evaluation=evaluation)

def match(self, yield_func, expression, vars, evaluation, **kwargs):
# for new_vars, rest in self.pattern.match(
Expand All @@ -919,7 +920,7 @@ def match(self, yield_func, expression, vars, evaluation, **kwargs):
self.pattern.match(yield_func, expression, vars, evaluation)


class Pattern_(PatternObject):
class Pattern(PatternObject):
"""
<url>:WMA link:https://reference.wolfram.com/language/ref/Pattern.html</url>
Expand Down Expand Up @@ -995,20 +996,20 @@ def init(
varname = expr.elements[0].get_name()
if varname is None or varname == "":
self.error("patvar", expr)
super(Pattern_, self).init(expr, evaluation=evaluation)
super(Pattern, self).init(expr, evaluation=evaluation)
self.varname = varname
self.pattern = Pattern.create(expr.elements[1], evaluation=evaluation)
self.pattern = BasePattern.create(expr.elements[1], evaluation=evaluation)

def __repr__(self):
return "<Pattern: %s>" % repr(self.pattern)

def get_match_count(self, vars={}):
return self.pattern.get_match_count(vars)

def match(self, yield_func, expression, vars, evaluation, **kwargs):
existing = vars.get(self.varname, None)
def match(self, yield_func, expression, vars_dict, evaluation, **kwargs):
existing = vars_dict.get(self.varname, None)
if existing is None:
new_vars = vars.copy()
new_vars = vars_dict.copy()
new_vars[self.varname] = expression
# for vars_2, rest in self.pattern.match(
# expression, new_vars, evaluation):
Expand All @@ -1021,22 +1022,24 @@ def match(self, yield_func, expression, vars, evaluation, **kwargs):
self.pattern.match(yield_func, expression, new_vars, evaluation)
else:
if existing.sameQ(expression):
yield_func(vars, None)
yield_func(vars_dict, None)

def get_match_candidates(
self, elements: tuple, expression, attributes, evaluation, vars={}
self, elements: tuple, expression, attributes, evaluation, vars_dict=None
):
existing = vars.get(self.varname, None)
if vars_dict is None:
vars_dict = {}
existing = vars_dict.get(self.varname, None)
if existing is None:
return self.pattern.get_match_candidates(
elements, expression, attributes, evaluation, vars
elements, expression, attributes, evaluation, vars_dict
)
else:
# Treat existing variable as verbatim
verbatim_expr = Expression(SymbolVerbatim, existing)
verbatim = Verbatim(verbatim_expr)
return verbatim.get_match_candidates(
elements, expression, attributes, evaluation, vars
elements, expression, attributes, evaluation, vars_dict
)


Expand Down Expand Up @@ -1097,7 +1100,7 @@ def init(
self, expr: Expression, evaluation: OptionalType[Evaluation] = None
) -> None:
super(Optional, self).init(expr, evaluation=evaluation)
self.pattern = Pattern.create(expr.elements[0], evaluation=evaluation)
self.pattern = BasePattern.create(expr.elements[0], evaluation=evaluation)
if len(expr.elements) == 2:
self.default = expr.elements[1]
else:
Expand Down Expand Up @@ -1398,7 +1401,7 @@ def init(
min: int = 1,
evaluation: OptionalType[Evaluation] = None,
):
self.pattern = Pattern.create(expr.elements[0], evaluation=evaluation)
self.pattern = BasePattern.create(expr.elements[0], evaluation=evaluation)
self.max = None
self.min = min
if len(expr.elements) == 2:
Expand Down Expand Up @@ -1550,9 +1553,9 @@ def init(
# if (expr.elements[0].get_head_name() == "System`Condition" and
# len(expr.elements[0].elements) == 2):
# self.test = Expression(SymbolAnd, self.test, expr.elements[0].elements[1])
# self.pattern = Pattern.create(expr.elements[0].elements[0])
# self.pattern = BasePattern.create(expr.elements[0].elements[0])
# else:
self.pattern = Pattern.create(expr.elements[0], evaluation=evaluation)
self.pattern = BasePattern.create(expr.elements[0], evaluation=evaluation)

def match(
self,
Expand Down
9 changes: 5 additions & 4 deletions mathics/core/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from mathics.core.list import ListExpression
from mathics.core.number import PrecisionValueError, dps, get_precision, min_prec
from mathics.core.parser.util import PyMathicsDefinitions, SystemDefinitions
from mathics.core.pattern import Pattern
from mathics.core.pattern import BasePattern
from mathics.core.rules import FunctionApplyRule, Rule
from mathics.core.symbols import (
BaseElement,
Expand Down Expand Up @@ -1131,7 +1131,7 @@ def __init__(self, name, count, expected):
super().__init__(name, "argr", count, expected)


class PatternObject(BuiltinElement, Pattern):
class PatternObject(BuiltinElement, BasePattern):
needs_verbatim = True

arg_counts: List[int] = []
Expand All @@ -1142,9 +1142,10 @@ def init(self, expr, evaluation: Optional[Evaluation] = None):
if len(expr.elements) not in self.arg_counts:
self.error_args(len(expr.elements), *self.arg_counts)
self.expr = expr
self.head = Pattern.create(expr.head, evaluation=evaluation)
self.head = BasePattern.create(expr.head, evaluation=evaluation)
self.elements = [
Pattern.create(element, evaluation=evaluation) for element in expr.elements
BasePattern.create(element, evaluation=evaluation)
for element in expr.elements
]

def error(self, tag, *args):
Expand Down
Loading

0 comments on commit 77db22f

Please sign in to comment.