Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Migrate from deprecated unsorted_arguments to arguments #1179

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ StaticArrays = "1.1"
SymPy = "2"
SymbolicIndexingInterface = "0.3.14"
SymbolicLimits = "0.2.0"
SymbolicUtils = "2.0.2"
SymbolicUtils = "2.1"
TermInterface = "0.4"
julia = "1.10"

Expand Down
2 changes: 1 addition & 1 deletion src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ function numbered_expr(O::Symbolic,varnumbercache,args...;varordering = args[1],
Expr(:ref, toexpr(args[1], states), toexpr.(args[2:end] .+ offset, (states,))...)
else
Expr(:call, Symbol(operation(O)), (numbered_expr(x,varnumbercache,args...;offset=offset,lhsname=lhsname,
rhsnames=rhsnames,varordering=varordering) for x in arguments(O))...)
rhsnames=rhsnames,varordering=varordering) for x in sorted_arguments(O))...)
end
elseif issym(O)
tosymbol(O, escape=false)
Expand Down
10 changes: 5 additions & 5 deletions src/latexify_recipes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function latexify_derivatives(ex)
integrand
)
elseif x.args[1] === :_textbf
ls = latexify(latexify_derivatives(arguments(x)[1])).s
ls = latexify(latexify_derivatives(sorted_arguments(x)[1])).s
return "\\textbf{" * strip(ls, '\$') * "}"
else
return x
Expand Down Expand Up @@ -134,7 +134,7 @@ function _toexpr(O)

# We need to iterate over each term in m, ignoring the numeric coefficient.
# This iteration needs to be stable, so we can't iterate over m.dict.
for term in Iterators.drop(arguments(m), isone(m.coeff) ? 0 : 1)
for term in Iterators.drop(sorted_arguments(m), isone(m.coeff) ? 0 : 1)
if !ispow(term)
push!(numer, _toexpr(term))
continue
Expand Down Expand Up @@ -182,7 +182,7 @@ function _toexpr(O)
!iscall(O) && return O

op = operation(O)
args = arguments(O)
args = sorted_arguments(O)

if (op===(*)) && (args[1] === -1)
arg_mul = Expr(:call, :(*), _toexpr(args[2:end])...)
Expand Down Expand Up @@ -233,8 +233,8 @@ _toexpr(eqs::AbstractArray) = map(eq->_toexpr(eq), eqs)
_toexpr(x::Num) = _toexpr(value(x))

function getindex_to_symbol(t)
@assert iscall(t) && operation(t) === getindex && symtype(arguments(t)[1]) <: AbstractArray
args = arguments(t)
@assert iscall(t) && operation(t) === getindex && symtype(sorted_arguments(t)[1]) <: AbstractArray
args = sorted_arguments(t)
idxs = args[2:end]
try
sub = join(map(map_subscripts, idxs), "ˏ")
Expand Down
14 changes: 7 additions & 7 deletions src/semipoly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function Base.:+(a::SemiMonomial, b::SemiMonomial)
end
function Base.:+(m::SemiMonomial, t)
if iscall(t) && operation(t) == (+)
return Term(+, [unsorted_arguments(t); m])
return Term(+, [arguments(t); m])
end
Term(+, [m, t])
end
Expand All @@ -42,7 +42,7 @@ function Base.:*(m::SemiMonomial, t::Symbolic)
args = collect(all_terms(t))
return Term(+, (m,) .* args)
elseif op == (*)
return Term(*, [unsorted_arguments(t); m])
return Term(*, [arguments(t); m])
end
end
Term(*, [t, m])
Expand Down Expand Up @@ -151,7 +151,7 @@ function mark_and_exponentiate(expr, vars)
@rule (~a::isop(+))^(~b::isreal) => expand(Pow((~a), real(~b)))
@rule *(~~xs::(xs -> all(issemimonomial, xs))) => *(~~xs...)
@rule *(~~xs::(xs -> any(isop(+), xs))) => expand(Term(*, ~~xs))
@rule (~a::isop(+)) / (~b::issemimonomial) => +(map(x->x/~b, unsorted_arguments(~a))...)
@rule (~a::isop(+)) / (~b::issemimonomial) => +(map(x->x/~b, arguments(~a))...)
@rule (~a::issemimonomial) / (~b::issemimonomial) => (~a) / (~b)]
expr′ = Postwalk(RestartedChain(rules), maketerm = simpleterm)(expr′)
end
Expand All @@ -178,7 +178,7 @@ function has_vars(expr, vars)::Bool
if expr in vars
return true
elseif iscall(expr)
for arg in unsorted_arguments(expr)
for arg in arguments(expr)
if has_vars(arg, vars)
return true
end
Expand All @@ -199,7 +199,7 @@ function mark_vars(expr, vars)
@assert length(args) == 2
return Term{symtype(expr)}(op, map(mark_vars(vars), args))
end
args = unsorted_arguments(expr)
args = arguments(expr)
if op === (+) || op === (*)
return Term{symtype(expr)}(op, map(mark_vars(vars), args))
elseif length(args) == 1
Expand Down Expand Up @@ -375,7 +375,7 @@ function semiquadratic_form(exprs, vars)
push!(V2, v)
else
@assert isop(k, *)
a, b = unsorted_arguments(k)
a, b = arguments(k)
p, q = extrema((idxmap[a], idxmap[b]))
j = div(q*(q-1), 2) + p
push!(J2, j)
Expand Down Expand Up @@ -403,7 +403,7 @@ end

## Utilities

all_terms(x) = iscall(x) && operation(x) == (+) ? collect(Iterators.flatten(map(all_terms, unsorted_arguments(x)))) : (x,)
all_terms(x) = iscall(x) && operation(x) == (+) ? collect(Iterators.flatten(map(all_terms, arguments(x)))) : (x,)

function unwrap_sp(m::SemiMonomial)
degree_dict = pdegrees(m.p)
Expand Down
56 changes: 28 additions & 28 deletions src/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ function get_parts_list(a, b, a_list = Vector{Any}(), b_list = Vector{Any}())
push!(a_list, a)
push!(b_list, b)
elseif iscall(a) && iscall(b) && isequal(operation(a), operation(b))
a_args = arguments(a)
b_args = arguments(b)
a_args = sorted_arguments(a)
b_args = sorted_arguments(b)

length(a_args) != length(b_args) && return Nothing

Expand Down Expand Up @@ -163,7 +163,7 @@ function replace_term(expr, dic::Dict)
elseif iscall(expr)
args = Any[]

for arg in arguments(expr)
for arg in sorted_arguments(expr)
push!(args, replace_term(arg, dic))
end

Expand Down Expand Up @@ -205,11 +205,11 @@ function expr_similar(ref_expr, expr, check_matches = true)
SymbolicUtils.issym(expr) && iscall(ref_expr) && return false

if iscall(ref_expr)
ref_args = arguments(ref_expr)
ref_args = sorted_arguments(ref_expr)
ref_len = length(ref_args)
ref_op = operation(ref_expr)

args = arguments(expr)
args = sorted_arguments(expr)
len = length(args)
op = operation(expr)

Expand Down Expand Up @@ -250,12 +250,12 @@ end

function get_base(expr)
(!iscall(expr) || operation(expr) != (^)) && throw("not a power(^) -> $expr")
return arguments(expr)[1]
return sorted_arguments(expr)[1]
end

function get_exp(expr)
(!iscall(expr) || operation(expr) != (^)) && throw("not a power(^) -> $expr")
return arguments(expr)[2]
return sorted_arguments(expr)[2]
end

function solve_single_eq_unchecked(
Expand Down Expand Up @@ -326,10 +326,10 @@ end
function left_prod_right_zero(eq::Equation, var, single_solution)
if SymbolicUtils.ismul(eq.lhs) && isequal(0, eq.rhs)
if (single_solution)
eq = arguments(eq.lhs)[1] ~ 0
eq = sorted_arguments(eq.lhs)[1] ~ 0
else
solutions = Equation[]
for arg in arguments(eq.lhs)
for arg in sorted_arguments(eq.lhs)
temp = solve_single_eq(arg ~ 0, var)
temp = temp isa Equation ? [temp] : temp
push!(solutions, temp...)
Expand Down Expand Up @@ -371,7 +371,7 @@ function move_to_other_side(eq::Equation, var)
op = operation(eq.lhs)

if op in (+, *)
elements = arguments(eq.lhs)
elements = sorted_arguments(eq.lhs)

stays = []#has variable
move = []#does not have variable
Expand Down Expand Up @@ -424,7 +424,7 @@ function special_strategy(eq::Equation, var)
!iscall(eq.lhs) && return eq#make sure left side is tree form

op = operation(eq.lhs)
elements = arguments(eq.lhs)
elements = sorted_arguments(eq.lhs)

if (op == +) &&
length(elements) == 2 &&
Expand All @@ -446,13 +446,13 @@ function special_strategy(eq::Equation, var)
isequal(eq.rhs, 0) &&
length(elements) == 2 &&
sum(iscall.(elements)) == length(elements) &&
length(arguments(elements[1])) == 2 &&
isequal(arguments(elements[1])[1], -1) &&
iscall(arguments(elements[1])[2]) &&
operation(elements[2]) == operation(arguments(elements[1])[2])#-f(y)+f(x)=0 -> x-y=0
length(sorted_arguments(elements[1])) == 2 &&
isequal(sorted_arguments(elements[1])[1], -1) &&
iscall(sorted_arguments(elements[1])[2]) &&
operation(elements[2]) == operation(sorted_arguments(elements[1])[2])#-f(y)+f(x)=0 -> x-y=0

x = arguments(elements[2])[1]
y = arguments(arguments(elements[1])[2])[1]
x = sorted_arguments(elements[2])[1]
y = sorted_arguments(sorted_arguments(elements[1])[2])[1]

eq = x - y ~ 0
end
Expand All @@ -474,20 +474,20 @@ function reduce_root(a)
end

if iscall(a) && (operation(a) == sqrt)
a = SymbolicUtils.Pow(arguments(a)[1], 1 // 2)
a = SymbolicUtils.Pow(sorted_arguments(a)[1], 1 // 2)
elseif iscall(a) &&
(operation(a) == ^) &&
isequal(arguments(a)[2], 1 // 2) &&
!(arguments(a)[1] isa Number)
a = term(sqrt, arguments(a)[1])
isequal(sorted_arguments(a)[2], 1 // 2) &&
!(sorted_arguments(a)[1] isa Number)
a = term(sqrt, sorted_arguments(a)[1])
end

if iscall(a) &&
(operation(a) == ^) &&
arguments(a)[2] isa Rational &&
isequal((arguments(a)[2]).num, 1)
value = demote_rational(arguments(a)[1])
root = (arguments(a)[2]).den
sorted_arguments(a)[2] isa Rational &&
isequal((sorted_arguments(a)[2]).num, 1)
value = demote_rational(sorted_arguments(a)[1])
root = (sorted_arguments(a)[2]).den

if value isa Integer && value > 0
if isinteger(value^(1.0 / root))
Expand Down Expand Up @@ -596,13 +596,13 @@ function inverse_funcs(eq::Equation, var)

if haskey(inverseOps, op)
inverseOp = inverseOps[op]
inner = arguments(eq.lhs)[1]
inner = sorted_arguments(eq.lhs)[1]
eq = inner ~ term(inverseOp, eq.rhs)
elseif (op == sqrt)
inner = arguments(eq.lhs)[1]
inner = sorted_arguments(eq.lhs)[1]
eq = inner ~ (eq.rhs)^2
elseif (op == lambertw)
inner = arguments(eq.lhs)[1]
inner = sorted_arguments(eq.lhs)[1]
eq = inner ~ eq.rhs * term(exp, eq.rhs)
end

Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ function coeff(p, sym=nothing)
sum(coeff(k, sym) * v for (k, v) in p.dict)
end
elseif ismul(p)
args = unsorted_arguments(p)
args = arguments(p)
coeffs = map(a->coeff(a, sym), args)
if all(iszero, coeffs)
return 0
Expand Down
4 changes: 2 additions & 2 deletions src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ function fast_substitute(expr, subs; operator = Nothing)
end
iscall(expr) || return expr
op = fast_substitute(operation(expr), subs; operator)
args = SymbolicUtils.unsorted_arguments(expr)
args = SymbolicUtils.arguments(expr)
if !(op isa operator)
canfold = Ref(!(op isa Symbolic))
args = let canfold = canfold
Expand Down Expand Up @@ -504,7 +504,7 @@ function fast_substitute(expr, pair::Pair; operator = Nothing)
end
iscall(expr) || return expr
op = fast_substitute(operation(expr), pair; operator)
args = SymbolicUtils.unsorted_arguments(expr)
args = SymbolicUtils.arguments(expr)
if !(op isa operator)
canfold = Ref(!(op isa Symbolic))
args = let canfold = canfold
Expand Down
Loading