Skip to content

Commit

Permalink
Convert RootSum to SymPy (#1136)
Browse files Browse the repository at this point in the history
I split this out into a separate PR as I think it's going to require
some more work. We already had a SymPy -> Mathics conversion for
RootSum, this adds one going the other way too. However, it results in
some tests failing as it means that Simplify automatically expands
RootSums now, not sure if we want to add some hints to prevent that from
happening.
  • Loading branch information
davidar authored Oct 27, 2024
1 parent 6212284 commit f88d3c2
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 13 deletions.
33 changes: 26 additions & 7 deletions mathics/builtin/functional/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@

from itertools import chain

import sympy

from mathics.core.atoms import Integer, Integer1
from mathics.core.attributes import A_HOLD_ALL, A_N_HOLD_ALL, A_PROTECTED
from mathics.core.builtin import Builtin, PostfixOperator
from mathics.core.builtin import Builtin, PostfixOperator, SympyFunction
from mathics.core.convert.sympy import SymbolFunction
from mathics.core.evaluation import Evaluation
from mathics.core.expression import Expression
from mathics.core.symbols import Symbol
from mathics.core.symbols import Symbol, sympy_slot_prefix
from mathics.core.systemsymbols import SymbolSlot


class Function(PostfixOperator):
class Function(PostfixOperator, SympyFunction):
"""
<dl>
<dt>'Function[$body$]'
Expand Down Expand Up @@ -119,9 +123,11 @@ def eval_named(self, vars, body, args, evaluation: Evaluation):
# this is not included in WL, and here does not have any impact, but it is needed for
# translating the function to a compiled version.
var_names = (
var.get_name()
if isinstance(var, Symbol)
else var.elements[0].get_name()
(
var.get_name()
if isinstance(var, Symbol)
else var.elements[0].get_name()
)
for var in vars
)
vars = dict(list(zip(var_names, args[: len(vars)])))
Expand All @@ -148,8 +154,17 @@ def eval_named_attr(self, vars, body, attr, args, evaluation: Evaluation):
except Exception:
return

def to_sympy(self, expr: Expression, **kwargs):
if len(expr.elements) == 1:
body = expr.elements[0]
slot = Expression(SymbolSlot, Integer1)
return sympy.Lambda(slot.to_sympy(), body.to_sympy())
else:
# TODO: Handle multiple and/or named arguments
raise NotImplementedError


class Slot(Builtin):
class Slot(SympyFunction):
"""
<dl>
<dt>'#$n$'
Expand Down Expand Up @@ -184,6 +199,10 @@ class Slot(Builtin):
}
summary_text = "one argument of a pure function"

def to_sympy(self, expr: Expression, **kwargs):
index: Integer = expr.elements[0]
return sympy.Symbol(f"{sympy_slot_prefix}{index.get_int_value()}")


class SlotSequence(Builtin):
"""
Expand Down
4 changes: 3 additions & 1 deletion mathics/builtin/list/constructing.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,12 @@ class Normal(Builtin):

summary_text = "convert objects to normal expressions"

def eval_general(self, expr, evaluation: Evaluation):
def eval_general(self, expr: Expression, evaluation: Evaluation):
"Normal[expr_]"
if isinstance(expr, Atom):
return
if expr.has_form("RootSum", 2):
return from_sympy(expr.to_sympy().doit(roots=True))
return Expression(
expr.get_head(),
*[Expression(SymbolNormal, element) for element in expr.elements],
Expand Down
57 changes: 55 additions & 2 deletions mathics/builtin/numbers/calculus.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@
from mathics.core.convert.expression import to_expression, to_mathics_list
from mathics.core.convert.function import expression_to_callable_and_args
from mathics.core.convert.python import from_python
from mathics.core.convert.sympy import SympyExpression, from_sympy, sympy_symbol_prefix
from mathics.core.convert.sympy import (
SymbolRootSum,
SympyExpression,
from_sympy,
sympy_symbol_prefix,
)
from mathics.core.evaluation import Evaluation
from mathics.core.expression import Expression
from mathics.core.list import ListExpression
Expand All @@ -63,6 +68,7 @@
SymbolConditionalExpression,
SymbolD,
SymbolDerivative,
SymbolFunction,
SymbolIndeterminate,
SymbolInfinity,
SymbolInfix,
Expand All @@ -76,6 +82,7 @@
SymbolSeries,
SymbolSeriesData,
SymbolSimplify,
SymbolSlot,
SymbolUndefined,
)
from mathics.eval.makeboxes import format_element
Expand Down Expand Up @@ -1627,7 +1634,7 @@ class Root(SympyFunction):
Roots that can't be represented by radicals:
>> Root[#1 ^ 5 + 2 #1 + 1&, 2]
= Root[#1 ^ 5 + 2 #1 + 1&, 2]
= Root[1 + #1 ^ 5 + 2 #1&, 2]
"""

messages = {
Expand Down Expand Up @@ -1691,6 +1698,52 @@ def to_sympy(self, expr, **kwargs):
return None


class RootSum(SympyFunction):
"""
<url>:WMA link: https://reference.wolfram.com/language/ref/RootSum.html</url>
<dl>
<dt>'RootSum[$f$, $form$]'
<dd>sums $form[x]$ for all roots of the polynomial $f[x]$.
</dl>
>> Integrate[1/(x^5 + 11 x + 1), {x, 1, 3}]
= RootSum[-1 - 212960 #1 ^ 3 - 9680 #1 ^ 2 - 165 #1 + 41232181 #1 ^ 5&, (Log[3749971 - 3512322106304 #1 ^ 4 + 453522741 #1 + 16326568676 #1 ^ 2 + 79825502416 #1 ^ 3] - 4 Log[5]) #1&] - RootSum[-1 - 212960 #1 ^ 3 - 9680 #1 ^ 2 - 165 #1 + 41232181 #1 ^ 5&, (Log[3748721 - 3512322106304 #1 ^ 4 + 453522741 #1 + 16326568676 #1 ^ 2 + 79825502416 #1 ^ 3] - 4 Log[5]) #1&]
>> N[%, 50]
= 0.051278805184286949884270940103072421286139857550894
>> RootSum[#^5 - 11 # + 1 &, (#^2 - 1)/(#^3 - 2 # + c) &]
= (538 - 88 c + 396 c ^ 2 + 5 c ^ 3 - 5 c ^ 4) / (97 - 529 c - 53 c ^ 2 + 88 c ^ 3 + c ^ 5)
>> RootSum[#^5 - 3 # - 7 &, Sin] //N//Chop
= 0.292188
Use Normal to expand RootSum:
>> RootSum[1+#+#^2+#^3+#^4 &, Log[x + #] &]
= RootSum[1 + #1 ^ 2 + #1 ^ 3 + #1 ^ 4 + #1&, Log[x + #1]&]
>> %//Normal
= Log[-1 / 4 - Sqrt[5] / 4 - I Sqrt[5 / 8 - Sqrt[5] / 8] + x] + Log[-1 / 4 - Sqrt[5] / 4 + I Sqrt[5 / 8 - Sqrt[5] / 8] + x] + Log[-1 / 4 - I Sqrt[5 / 8 + Sqrt[5] / 8] + Sqrt[5] / 4 + x] + Log[-1 / 4 + I Sqrt[5 / 8 + Sqrt[5] / 8] + Sqrt[5] / 4 + x]
"""

summary_text = "sum polynomial roots"

def eval(self, f, form, evaluation: Evaluation): # type: ignore[override]
"RootSum[f_, form_]"
return from_sympy(Expression(SymbolRootSum, f, form).to_sympy())

def to_sympy(self, expr: Expression, **kwargs):
func = expr.elements[1]
if not isinstance(func.to_sympy(), sympy.Lambda):
# eta conversion
func = Expression(
SymbolFunction, Expression(func, Expression(SymbolSlot, Integer1))
)

poly = expr.elements[0].to_sympy()
poly_x = sympy.Symbol("poly_x")
return sympy.RootSum(poly(poly_x), func.to_sympy(), x=poly_x)


class Series(Builtin):
"""
<url>:WMA link:https://reference.wolfram.com/language/ref/Series.html</url>
Expand Down
3 changes: 2 additions & 1 deletion mathics/core/convert/sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def __new__(cls, *exprs):
if all(isinstance(expr, BasicSympy) for expr in exprs):
# called with SymPy arguments
obj = super().__new__(cls, *exprs)
obj.expr = None
elif len(exprs) == 1 and isinstance(exprs[0], Expression):
# called with Mathics argument
expr = exprs[0]
Expand Down Expand Up @@ -460,7 +461,7 @@ def old_from_sympy(expr) -> BaseElement:
result.append(Expression(SymbolTimes, *factors))
else:
result.append(Integer1)
return Expression(SymbolFunction, Expression(SymbolPlus, *result))
return Expression(SymbolFunction, Expression(SymbolPlus, *sorted(result)))
if isinstance(expr, sympy.CRootOf):
try:
e_root, indx = expr.args
Expand Down
2 changes: 1 addition & 1 deletion mathics/eval/nevaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def eval_NValues(

# Special case for the Root builtin
# This should be implemented as an NValue
if expr.has_form("Root", 2):
if expr.has_form("Root", 2) or expr.has_form("RootSum", 2):
return from_sympy(sympy.N(expr.to_sympy(), d))

# Here we look for the NValues associated to the
Expand Down
3 changes: 2 additions & 1 deletion mathics/eval/numbers/algebra/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def _default_complexity_function(x):

# At this point, ``complexity_function`` is a function that takes a
# sympy expression and returns an integer.
sympy_result = simplify(sympy_expr, measure=complexity_function)
sympy_result = simplify(sympy_expr, measure=complexity_function, doit=False)
sympy_result = sympy_result.doit(roots=False) # Don't expand RootSum

# and bring it back
result = from_sympy(sympy_result).evaluate(evaluation)
Expand Down

0 comments on commit f88d3c2

Please sign in to comment.