Skip to content

Commit

Permalink
reverted solve_for merge
Browse files Browse the repository at this point in the history
  • Loading branch information
n0rbed committed Jul 30, 2024
1 parent 3c302d2 commit f3e6cf9
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 30 deletions.
4 changes: 1 addition & 3 deletions src/linear_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ function solve_for(eq, var; simplify=false, check=true) # scalar case
else
x = a \ -b
end
x = length(var) == 1 ? x : Dict(v => ans for (v, ans) in zip(var, x))
simplify || return [x]

simplify || return x
if x isa AbstractArray
SymbolicUtils.simplify.(simplify_fractions.(x))
else
Expand Down
15 changes: 5 additions & 10 deletions src/solver/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,22 +134,17 @@ function solve(expr, x, multiplicities=false)
expr = Vector{Num}(expr)
end

# the islinear outputted here is breaks a lot:
# a, b, islinear = linear_expansion(expr, x)
# islinear && return map(postprocess_root, solve_for(expr, x))

if x_univar

sols = []
if expr_univar
sols = !check_poly_inunivar(expr, x) ? ia_solve(expr, x) :
islinear(expr, [x]) ? solve_for(expr, x) :
solve_univar(expr, x, multiplicities)
sols = check_poly_inunivar(expr, x) ? solve_univar(expr, x, multiplicities) : ia_solve(expr, x)
else
for i in eachindex(expr)
!check_poly_inunivar(expr[i], x) && throw("Solve can not solve this input currently")
end
sols = all(e->islinear(e, [x]), expr) ? solve_for(expr, x) :
solve_multipoly(expr, x, multiplicities)
sols = solve_multipoly(expr, x, multiplicities)
end

sols = map(postprocess_root, sols)
Expand All @@ -162,8 +157,8 @@ function solve(expr, x, multiplicities=false)
@assert check_poly_inunivar(e, var) "This system can not be currently solved by solve."
end
end
sols = all(e->islinear(e, x), expr) ? solve_for(expr, x) : solve_multivar(expr, x, multiplicities)

sols = solve_multivar(expr, x, multiplicities)
for sol in sols
for var in x
sol[var] = postprocess_root(sol[var])
Expand Down
17 changes: 0 additions & 17 deletions src/solver/solve_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,23 +98,6 @@ function check_poly_inunivar(poly, var)
return isequal(constant, 0)
end

function islinear(expr, vars)
for var in vars
subs, filtered_expr = filter_poly(expr, var)
coeffs, constant = polynomial_coeffs(filtered_expr, [var])

!isequal(constant, 0) && return false
sdegree(coeffs, var) > 1 && return false
delete!(coeffs, 1)

vals = collect(values(coeffs))
for val in vals
any(x->n_occurrences(val, x) > 0, vars) && return false
end
end
return true
end

function f_numbers(n)
n = unwrap(n)
if n isa ComplexTerm || n isa Float64 || n isa Irrational
Expand Down

0 comments on commit f3e6cf9

Please sign in to comment.