Skip to content

Commit

Permalink
Earlier initialization of patterns (#1103)
Browse files Browse the repository at this point in the history
When a ExpressionPattern is created, the way in which the matching with
expressions is determined in part by the attributes of its head. For
example, if `S` is `Orderless`, the match method of the pattern
`S[a_, v__Integer]` should check for different orders of the arguments.

Also, the match against patterns like `DirectedInfinity[1]` can be
determined using the faster `sameQ` method, instead of using all the
machinery for looking for named patterns inside.

This PR introduces some modifications in the way Pattern objects are
created, to have earlier access to the attributes, and to determine if
the pattern is a "Literal"
  • Loading branch information
mmatera authored Sep 30, 2024
1 parent 9d0147c commit 6aadef6
Show file tree
Hide file tree
Showing 13 changed files with 99 additions and 37 deletions.
2 changes: 1 addition & 1 deletion mathics/builtin/exp_structure/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class FreeQ(Builtin):
def eval(self, expr, form, evaluation: Evaluation):
"FreeQ[expr_, form_]"

form = BasePattern.create(form)
form = BasePattern.create(form, evaluation=evaluation)
if expr.is_free(form, evaluation):
return SymbolTrue
else:
Expand Down
5 changes: 4 additions & 1 deletion mathics/builtin/list/constructing.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,10 @@ def eval(self, expr, patterns, f, evaluation: Evaluation):
"Reap[expr_, {patterns___}, f_]"

patterns = patterns.get_sequence()
sown = [(BasePattern.create(pattern), []) for pattern in patterns]
sown = [
(BasePattern.create(pattern, evaluation=evaluation), [])
for pattern in patterns
]

def listener(e, tag):
result = False
Expand Down
10 changes: 5 additions & 5 deletions mathics/builtin/list/eol.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def eval(self, items, pattern, ls, evaluation, options):
results = []

if pattern.has_form("Rule", 2) or pattern.has_form("RuleDelayed", 2):
match = Matcher(pattern.elements[0]).match
match = Matcher(pattern.elements[0], evaluation).match
rule = Rule(pattern.elements[0], pattern.elements[1])

def callback(level):
Expand All @@ -224,7 +224,7 @@ def callback(level):
return level

else:
match = Matcher(pattern).match
match = Matcher(pattern, evaluation).match

def callback(level):
if match(level, evaluation):
Expand Down Expand Up @@ -467,7 +467,7 @@ def eval_ls_n(self, items, pattern, levelspec, n, evaluation):
return deletecases_with_levelspec(items, pattern, evaluation, levelspec, n)
# A more efficient way to proceed if levelspec == 1

match = Matcher(pattern).match
match = Matcher(pattern, evaluation).match
if n == -1:

def cond(element):
Expand Down Expand Up @@ -1187,7 +1187,7 @@ def eval(self, items, sel, evaluation):
def eval_pattern(self, items, sel, pattern, evaluation):
"Pick[items_, sel_, pattern_]"

match = Matcher(pattern).match
match = Matcher(pattern, evaluation).match
return self._do(items, sel, lambda s: match(s, evaluation), evaluation)


Expand Down Expand Up @@ -1245,7 +1245,7 @@ def eval_level(self, expr, patt, ls, evaluation, options={}):
evaluation.message("Position", "level", ls)
return

match = Matcher(patt).match
match = Matcher(patt, evaluation).match
result = []

def callback(level, pos):
Expand Down
6 changes: 3 additions & 3 deletions mathics/builtin/numbers/calculus.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def eval(self, f, x, evaluation: Evaluation):
if f == x:
return Integer1

x_pattern = BasePattern.create(x)
x_pattern = BasePattern.create(x, evaluation=evaluation)
if f.is_free(x_pattern, evaluation):
return Integer0

Expand Down Expand Up @@ -1919,7 +1919,7 @@ def eval_times(
nummax.get_int_value(),
den.get_int_value(),
)
x_pattern = BasePattern.create(x)
x_pattern = BasePattern.create(x, evaluation=evaluation)
incompat_series = []
max_exponent = Integer(int(series[2] / series[3] + 1))
if coeff.get_head() is SymbolSequence:
Expand Down Expand Up @@ -2265,7 +2265,7 @@ def eval(self, eqs, vars, evaluation: Evaluation):
vars = []
vars_sympy = []
for var, var_sympy in zip(all_vars, all_vars_sympy):
pattern = BasePattern.create(var)
pattern = BasePattern.create(var, evaluation=evaluation)
if not eqs.is_free(pattern, evaluation):
vars.append(var)
vars_sympy.append(var_sympy)
Expand Down
2 changes: 1 addition & 1 deletion mathics/builtin/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class FilterRules(Builtin):
def eval(self, rules, pattern, evaluation):
"FilterRules[rules_List, pattern_]"

match = Matcher(pattern).match
match = Matcher(pattern, evaluation).match

def matched():
for rule in rules.elements:
Expand Down
2 changes: 1 addition & 1 deletion mathics/builtin/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ def init(
self, expr: Expression, evaluation: OptionalType[Evaluation] = None
) -> None:
super(Except, self).init(expr, evaluation=evaluation)
self.c = BasePattern.create(expr.elements[0])
self.c = BasePattern.create(expr.elements[0], evaluation=evaluation)
if len(expr.elements) == 2:
self.p = BasePattern.create(expr.elements[1], evaluation=evaluation)
else:
Expand Down
36 changes: 31 additions & 5 deletions mathics/core/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def contribute(self, definitions, is_pymodule=False):
if not self.context:
self.context = "Pymathics`" if is_pymodule else "System`"
name = self.get_name()
attributes = self.attributes
options = {}

# - 'Strict': warn and fail with unsupported options
Expand Down Expand Up @@ -268,19 +269,41 @@ def contribute(self, definitions, is_pymodule=False):
for pattern, function in self.get_functions(
prefix="eval", is_pymodule=is_pymodule
):
pat_attr = attributes if pattern.get_head_name() == name else None
rules.append(
FunctionApplyRule(name, pattern, function, check_options, system=True)
FunctionApplyRule(
name,
pattern,
function,
check_options,
attributes=pat_attr,
system=True,
)
)
for pattern, function in self.get_functions(is_pymodule=is_pymodule):
pat_attr = attributes if pattern.get_head_name() == name else None
rules.append(
FunctionApplyRule(name, pattern, function, check_options, system=True)
FunctionApplyRule(
name,
pattern,
function,
check_options,
attributes=pat_attr,
system=True,
)
)
for pattern_str, replace_str in self.rules.items():
pattern_str = pattern_str % {"name": name}
pattern = parse_builtin_rule(pattern_str, definition_class)
replace_str = replace_str % {"name": name}
pat_attr = attributes if pattern.get_head_name() == name else None
rules.append(
Rule(pattern, parse_builtin_rule(replace_str), system=not is_pymodule)
Rule(
pattern,
parse_builtin_rule(replace_str),
attributes=pat_attr,
system=not is_pymodule,
)
)

box_rules = []
Expand Down Expand Up @@ -321,11 +344,14 @@ def contextify_form_name(f):
formatvalues = {"": []}
for pattern, function in self.get_functions("format_"):
forms, pattern = extract_forms(pattern)
pat_attr = attributes if pattern.get_head_name() == name else None
for form in forms:
if form not in formatvalues:
formatvalues[form] = []
formatvalues[form].append(
FunctionApplyRule(name, pattern, function, None, system=True)
FunctionApplyRule(
name, pattern, function, None, attributes=pat_attr, system=True
)
)
for pattern, replace in self.formats.items():
forms, pattern = extract_forms(pattern)
Expand Down Expand Up @@ -377,7 +403,7 @@ def contextify_form_name(f):
rules=rules,
formatvalues=formatvalues,
messages=messages,
attributes=self.attributes,
attributes=attributes,
options=options,
defaultvalues=defaults,
builtin=self,
Expand Down
43 changes: 35 additions & 8 deletions mathics/core/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ class BasePattern(ABC):

expr: BaseElement

# this attribute allows for a faster match algorithm based on sameq.
# Probably we should split ExpressionPattern into two different classes,
# one for literal patterns and the other for "Regular" ExpressionPatterns.
isliteral: bool = False

# TODO: In WMA, when a BasePattern is created, the attributes
# from the head are read from the evaluation context and
# stored as a part of a rule.
Expand Down Expand Up @@ -168,7 +173,9 @@ class BasePattern(ABC):
#
@staticmethod
def create(
expr: BaseElement, evaluation: Optional[Evaluation] = None
expr: BaseElement,
attributes: Optional[int] = None,
evaluation: Optional[Evaluation] = None,
) -> "BasePattern":
"""
If ``expr`` is listed in ``pattern_object`` return the pattern found there.
Expand All @@ -181,7 +188,7 @@ def create(
return pattern_object(expr, evaluation=evaluation)
if isinstance(expr, Atom):
return AtomPattern(expr, evaluation)
return ExpressionPattern(expr, evaluation)
return ExpressionPattern(expr, attributes, evaluation)

def get_attributes(self, definitions):
"""The attributes of the expression"""
Expand Down Expand Up @@ -320,6 +327,9 @@ class AtomPattern(BasePattern):
A pattern that matches with an atom.
"""

# Atoms are always literals
isliteral: bool = True

def __init__(self, expr: Atom, evaluation: Optional[Evaluation] = None) -> None:
self.expr = expr
self.atom = expr
Expand Down Expand Up @@ -405,15 +415,22 @@ class ExpressionPattern(BasePattern):

attributes: Optional[int] = None

def __init__(self, expr: Expression, evaluation: Optional[Evaluation] = None):
def __init__(
self,
expr: Expression,
attributes: Optional[int] = None,
evaluation: Optional[Evaluation] = None,
):
self.expr = expr
head = expr.head
attributes = (
None if evaluation is None else head.get_attributes(evaluation.definition)
)
if attributes is None and evaluation:
attributes = head.get_attributes(evaluation.definitions)
self.head = BasePattern.create(head, evaluation=evaluation)
self.elements = [
BasePattern.create(element, evaluation=evaluation)
for element in expr.elements
]
self.__set_pattern_attributes__(attributes)
self.head = BasePattern.create(head)
self.elements = [BasePattern.create(element) for element in expr.elements]

def __set_pattern_attributes__(self, attributes):
if attributes is None or self.attributes is not None:
Expand All @@ -425,6 +442,10 @@ def __set_pattern_attributes__(self, attributes):
self.get_pre_choices = get_pre_choices_orderless
else:
self.get_pre_choices = get_pre_choices_with_order
if not (A_ONE_IDENTITY + A_FLAT) & attributes:
self.isliteral = self.head.isliteral and all(
element.isliteral for element in self.elements
)

def match(
self,
Expand All @@ -439,6 +460,12 @@ def match(
):
"""Try to match the pattern against an Expression"""
evaluation.check_stopped()
if self.isliteral:
if expression.sameQ(self.expr):
# yield vars, None
yield_func(vars_dict, None)
return

if self.attributes is None:
self.__set_pattern_attributes__(
self.head.get_attributes(evaluation.definitions)
Expand Down
13 changes: 10 additions & 3 deletions mathics/core/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,11 @@ def __init__(
pattern: Expression,
system: bool = False,
evaluation: Optional[Evaluation] = None,
attributes: Optional[int] = None,
) -> None:
self.pattern = BasePattern.create(pattern, evaluation=evaluation)
self.pattern = BasePattern.create(
pattern, attributes=attributes, evaluation=evaluation
)
self.system = system

def apply(
Expand Down Expand Up @@ -222,8 +225,11 @@ def __init__(
replace: Expression,
system=False,
evaluation: Optional[Evaluation] = None,
attributes: Optional[int] = None,
) -> None:
super(Rule, self).__init__(pattern, system=system, evaluation=evaluation)
super(Rule, self).__init__(
pattern, system=system, evaluation=evaluation, attributes=attributes
)
self.replace = replace

def apply_rule(
Expand Down Expand Up @@ -310,9 +316,10 @@ def __init__(
check_options: Optional[Callable],
system: bool = False,
evaluation: Optional[Evaluation] = None,
attributes: Optional[int] = None,
) -> None:
super(FunctionApplyRule, self).__init__(
pattern, system=system, evaluation=evaluation
pattern, system=system, attributes=attributes, evaluation=evaluation
)
self.name = name
self.function = function
Expand Down
2 changes: 1 addition & 1 deletion mathics/eval/numbers/calculus/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def build_series(f, x, x0, n, evaluation):
vars = {
x_name: x0,
}
x_pattern = BasePattern.create(x)
x_pattern = BasePattern.create(x, evaluation=evaluation)

if f.is_free(x_pattern, evaluation):
print(x, " not in ", f)
Expand Down
5 changes: 2 additions & 3 deletions mathics/eval/parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def deletecases_with_levelspec(expr, pattern, evaluation, levelspec=1, n=-1):
"""
nothing = SymbolNothing

match = Matcher(pattern)
match = Matcher(pattern, evaluation)
match = match.match
if type(levelspec) is int:
lsmin = 1
Expand Down Expand Up @@ -631,9 +631,8 @@ def find_matching_indices_with_levelspec(expr, pattern, evaluation, levelspec=1,
n indicates the number of occurrences to return. By default, it
returns all the occurrences.
"""
from mathics.builtin.patterns import Matcher

match = Matcher(pattern)
match = Matcher(pattern, evaluation)
match = match.match
if type(levelspec) is int:
lsmin = 0
Expand Down
6 changes: 3 additions & 3 deletions mathics/eval/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ class _StopGeneratorMatchQ(StopGenerator):


class Matcher:
def __init__(self, form):
def __init__(self, form, evaluation):
if isinstance(form, BasePattern):
self.form = form
else:
self.form = BasePattern.create(form)
self.form = BasePattern.create(form, evaluation=evaluation)

def match(self, expr, evaluation: Evaluation):
def yield_func(vars, rest):
Expand All @@ -25,4 +25,4 @@ def yield_func(vars, rest):


def match(expr, form, evaluation: Evaluation):
return Matcher(form).match(expr, evaluation)
return Matcher(form, evaluation).match(expr, evaluation)
4 changes: 2 additions & 2 deletions mathics/eval/testing_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def is_number(sympy_value) -> bool:
def check_ArrayQ(expr, pattern, test, evaluation: Evaluation):
"Check if expr is an Array which test yields true for each of its elements."

pattern = BasePattern.create(pattern)
pattern = BasePattern.create(pattern, evaluation=evaluation)

dims = [len(expr.get_elements())] # to ensure an atom is not an array

Expand Down Expand Up @@ -152,7 +152,7 @@ def check_SparseArrayQ(expr, pattern, test, evaluation: Evaluation):
if not expr.head.sameQ(SymbolSparseArray):
return SymbolFalse

pattern = BasePattern.create(pattern)
pattern = BasePattern.create(pattern, evaluation=evaluation)
dims, default_value, rules = expr.elements[1:]
if not pattern.does_match(Integer(len(dims.elements)), evaluation):
return SymbolFalse
Expand Down

0 comments on commit 6aadef6

Please sign in to comment.