Skip to content

Commit

Permalink
Fix mypy errors (#1145)
Browse files Browse the repository at this point in the history
I've been going through fixing mypy errors, and have gotten through
about half of them. In the process I've added some extra error handling
that fixes some exceptions I've seen due to methods being called on the
wrong types.

Current master:

> Found 1718 errors in 183 files (checked 443 source files)

This branch:

> Found 872 errors in 155 files (checked 443 source files)

Feel free to squash all these commits, the only reason I've left them
separate is that I made sure the test suite was passing after each one.
  • Loading branch information
davidar authored Oct 26, 2024
1 parent 2f709d7 commit 6212284
Show file tree
Hide file tree
Showing 29 changed files with 517 additions and 395 deletions.
10 changes: 5 additions & 5 deletions mathics/builtin/atomic/numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ class MaxPrecision(Predefined):
= 3.141592654
"""

is_numeric = False
_is_numeric = False
messages = {
"precset": "Cannot set `1` to `2`; value must be a positive number or Infinity.",
"preccon": "Cannot set `1` such that $MaxPrecision < $MinPrecision.",
Expand Down Expand Up @@ -702,7 +702,7 @@ class MachineEpsilon_(Predefined):
= {0., 0., 2.22045×10^-16}
"""

is_numeric = True
_is_numeric = True
name = "$MachineEpsilon"

summary_text = "the difference between 1.0 and the next-nearest number representable as a machine-precision number"
Expand All @@ -729,7 +729,7 @@ class MachinePrecision_(Predefined):
summary_text = (
"the number of decimal digits of precision for machine-precision numbers"
)
is_numeric = True
_is_numeric = True
rules = {
"$MachinePrecision": "N[MachinePrecision]",
}
Expand All @@ -749,7 +749,7 @@ class MachinePrecision(Predefined):
= 15.9545897701910033463281614204
"""

is_numeric = True
_is_numeric = True
rules = {
"N[MachinePrecision, prec_]": (
"N[Log[10, 2] * %i, prec]" % FP_MANTISA_BINARY_DIGITS
Expand Down Expand Up @@ -786,7 +786,7 @@ class MinPrecision(Builtin):
}

name = "$MinPrecision"
is_numeric = True
_is_numeric = True
rules = {
"$MinPrecision": "0",
}
Expand Down
24 changes: 16 additions & 8 deletions mathics/builtin/box/expression.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# This is never intended to go in Mathics3 docs
no_doc = True

from typing import Optional, Sequence, Union

from mathics.core.attributes import A_PROTECTED, A_READ_PROTECTED
from mathics.core.builtin import BuiltinElement
from mathics.core.element import BoxElementMixin
Expand Down Expand Up @@ -109,8 +111,8 @@ def get_head_name(self):
def get_lookup_name(self):
return self.get_name()

def get_sort_key(self) -> tuple:
return self.to_expression().get_sort_key()
def get_sort_key(self, pattern_sort=False) -> tuple:
return self.to_expression().get_sort_key(pattern_sort)

def get_string_value(self):
return "-@" + self.get_head_name() + "@-"
Expand All @@ -119,7 +121,13 @@ def get_string_value(self):
def head(self):
return self.get_head()

def has_form(self, heads, *element_counts):
@head.setter
def head(self, value):
raise ValueError("BoxExpression.head is write protected.")

def has_form(
self, heads: Union[Sequence[str], str], *element_counts: Optional[int]
) -> bool:
"""
element_counts:
(,): no elements allowed
Expand All @@ -133,9 +141,13 @@ def has_form(self, heads, *element_counts):
if isinstance(heads, (tuple, list, set)):
if head_name not in [ensure_context(h) for h in heads]:
return False
else:
elif isinstance(heads, str):
if head_name != ensure_context(heads):
return False
else:
raise TypeError(
f"Heads must be a string or a sequence of strings, not {type(heads)}"
)
if not element_counts:
return False
if element_counts and element_counts[0] is not None:
Expand All @@ -151,10 +163,6 @@ def has_form(self, heads, *element_counts):
return False
return True

@head.setter
def head(self, value):
raise ValueError("BoxExpression.head is write protected.")

@property
def is_literal(self) -> bool:
"""
Expand Down
16 changes: 9 additions & 7 deletions mathics/builtin/list/constructing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
See also Constructing Vectors.
"""

import typing
from itertools import permutations
from typing import Optional, Tuple

from mathics.builtin.box.layout import RowBox
from mathics.core.atoms import Integer, is_integer_rational_or_real
Expand All @@ -20,7 +22,7 @@
from mathics.core.evaluation import Evaluation
from mathics.core.expression import Expression, structure
from mathics.core.list import ListExpression
from mathics.core.symbols import Atom
from mathics.core.symbols import Atom, Symbol
from mathics.core.systemsymbols import SymbolNormal
from mathics.eval.lists import get_tuples, list_boxes

Expand Down Expand Up @@ -268,9 +270,9 @@ def eval(self, imin, imax, di, evaluation: Evaluation):
and isinstance(di, Integer)
):
pm = 1 if di.value >= 0 else -1
result = [Integer(i) for i in range(imin.value, imax.value + pm, di.value)]
return ListExpression(
*result, elements_properties=range_list_elements_properties
*[Integer(i) for i in range(imin.value, imax.value + pm, di.value)],
elements_properties=range_list_elements_properties,
)

imin = imin.to_sympy()
Expand Down Expand Up @@ -344,7 +346,7 @@ def eval(self, li, evaluation: Evaluation):
def eval_n(self, li, n, evaluation: Evaluation):
"Permutations[li_List, n_]"

rs = None
rs: Optional[Tuple[int, ...]] = None
if isinstance(n, Integer):
py_n = min(n.get_int_value(), len(li.elements))
elif n.has_form("List", 1) and isinstance(n.elements[0], Integer):
Expand All @@ -359,12 +361,12 @@ def eval_n(self, li, n, evaluation: Evaluation):

if py_n is None or py_n < 0:
evaluation.message(
self.get_name(), "nninfseq", Expression(self.get_name(), li, n)
self.get_name(), "nninfseq", Expression(Symbol(self.get_name()), li, n)
)
return

if rs is None:
rs = range(py_n + 1)
rs = tuple(range(py_n + 1))

inner = structure("List", li, evaluation)
outer = structure("List", inner, evaluation)
Expand Down Expand Up @@ -431,7 +433,7 @@ def eval(self, expr, patterns, f, evaluation: Evaluation):
"Reap[expr_, {patterns___}, f_]"

patterns = patterns.get_sequence()
sown = [
sown: typing.List[typing.Tuple[BasePattern, list]] = [
(BasePattern.create(pattern, evaluation=evaluation), [])
for pattern in patterns
]
Expand Down
14 changes: 8 additions & 6 deletions mathics/builtin/list/rearrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def eval_zero(self, element, n, evaluation: Evaluation):
Integer0,
Integer0,
evaluation,
lambda: Expression(self.get_name(), element, n),
lambda: Expression(Symbol(self.get_name()), element, n),
)

def eval(self, element, n, x, evaluation: Evaluation):
Expand All @@ -243,7 +243,7 @@ def eval(self, element, n, x, evaluation: Evaluation):
x,
Integer0,
evaluation,
lambda: Expression(self.get_name(), element, n, x),
lambda: Expression(Symbol(self.get_name()), element, n, x),
)

def eval_margin(self, element, n, x, m, evaluation: Evaluation):
Expand All @@ -254,7 +254,7 @@ def eval_margin(self, element, n, x, m, evaluation: Evaluation):
x,
m,
evaluation,
lambda: Expression(self.get_name(), element, n, x, m),
lambda: Expression(Symbol(self.get_name()), element, n, x, m),
)


Expand Down Expand Up @@ -354,6 +354,8 @@ def _gather(self, keys, values, equivalence):
class _Rotate(Builtin):
messages = {"rspec": "`` should be an integer or a list of integers."}

_sign: int

def _rotate(self, expr, n, evaluation: Evaluation):
if not isinstance(expr, Expression):
return expr
Expand All @@ -363,12 +365,12 @@ def _rotate(self, expr, n, evaluation: Evaluation):
return expr

index = (self._sign * n[0]) % len(elements) # with Python's modulo: index >= 1
new_elements = chain(elements[index:], elements[:index])
new_elements = tuple(chain(elements[index:], elements[:index]))

if len(n) > 1:
new_elements = [
new_elements = tuple(
self._rotate(item, n[1:], evaluation) for item in new_elements
]
)

return expr.restructure(expr.head, new_elements, evaluation)

Expand Down
30 changes: 13 additions & 17 deletions mathics/builtin/numbers/algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,9 @@ def eval(self, expr, evaluation):


# Get a coefficient of form in an expression
def _coefficient(name, expr, form, n, evaluation):
def _coefficient(
name: str, expr: Expression, form: Expression, n: Integer, evaluation: Evaluation
) -> BaseElement:
if expr is SymbolNull or form is SymbolNull or n is SymbolNull:
return Integer0

Expand All @@ -450,18 +452,11 @@ def _coefficient(name, expr, form, n, evaluation):
sympy_var = form.to_sympy()
sympy_n = n.to_sympy()

def combine_exprs(exprs):
result = 0
for e in exprs:
result += e
return result

# expand sub expressions if they contain variables
sympy_exprs = [
sympy_expr: sympy.Expr = sum(
sympy.expand(e) if sympy_var.free_symbols.issubset(e.free_symbols) else e
for e in sympy_exprs
]
sympy_expr = combine_exprs(sympy_exprs)
)
sympy_result = sympy_expr.coeff(sympy_var, sympy_n)
return from_sympy(sympy_result)

Expand Down Expand Up @@ -519,16 +514,18 @@ class Coefficient(Builtin):

summary_text = "coefficient of a monomial in a polynomial expression"

def eval_noform(self, expr, evaluation):
def eval_noform(self, expr: Expression, evaluation: Evaluation):
"Coefficient[expr_]"
evaluation.message("Coefficient", "argtu")

def eval(self, expr, form, evaluation):
def eval(self, expr: Expression, form: Expression, evaluation: Evaluation):
"Coefficient[expr_, form_]"
return _coefficient(self.__class__.__name__, expr, form, Integer1, evaluation)

def eval_n(self, expr, form, n, evaluation):
"Coefficient[expr_, form_, n_]"
def eval_n(
self, expr: Expression, form: Expression, n: Integer, evaluation: Evaluation
):
"Coefficient[expr_, form_, n_Integer]"
return _coefficient(self.__class__.__name__, expr, form, n, evaluation)


Expand Down Expand Up @@ -559,9 +556,8 @@ def coeff_power_internal(

# ###### Auxiliary functions #########
def key_powers(lst: list) -> Union[int, float]:
key = Expression(SymbolPlus, *lst)
key = key.evaluate(evaluation)
if key.is_numeric(evaluation):
key = Expression(SymbolPlus, *lst).evaluate(evaluation)
if key is not None and key.is_numeric(evaluation):
return key.to_python()
return 0

Expand Down
Loading

0 comments on commit 6212284

Please sign in to comment.