diff --git a/ext/SymbolicsGroebnerExt.jl b/ext/SymbolicsGroebnerExt.jl index ad437889f..90aa6d3d0 100644 --- a/ext/SymbolicsGroebnerExt.jl +++ b/ext/SymbolicsGroebnerExt.jl @@ -70,7 +70,7 @@ function Symbolics.is_groebner_basis(polynomials::Vector{Num}; kwargs...) Groebner.isgroebner(polynoms; kwargs...) end -function Symbolics.solve_multivar(eqs::Vector{Num}, vars::Vector{Num}, mult=false) +function Symbolics.solve_multivar(eqs::Vector, vars::Vector{Num}, mult=false) # Reference: Rouillier, F. Solving Zero-Dimensional Systems # Through the Rational Univariate Representation. diff --git a/src/solver/main.jl b/src/solver/main.jl index 2d73feec2..4f5722224 100644 --- a/src/solver/main.jl +++ b/src/solver/main.jl @@ -122,11 +122,16 @@ function solve(expr, x, multiplicities=false) if !(expr isa Vector) expr_univar = true + expr = expr isa Equation ? expr.lhs - expr.rhs : expr check_expr_validity(expr) else - for e in expr - check_expr_validity(e) + expr = Vector{Any}(expr) + for i in eachindex(expr) + expr[i] = expr[i] isa Equation ? expr[i].lhs - expr[i].rhs : expr[i] + check_expr_validity(expr[i]) + !check_poly_inunivar(expr[i], x) && throw("Solve can not solve this input currently") end + expr = Vector{Num}(expr) end @@ -136,12 +141,10 @@ function solve(expr, x, multiplicities=false) if expr_univar sols = check_poly_inunivar(expr, x) ? solve_univar(expr, x, multiplicities) : ia_solve(expr, x) else - exprs_ispoly = [] - for e in expr - push!(e, check_poly_inunivar(e, x)) + for i in eachindex(expr) + !check_poly_inunivar(expr[i], x) && throw("Solve can not solve this input currently") end - - sols = all(exprs_ispoly) ? solve_multipoly(expr, x, multiplicities) : throw("Solve can not solve this input currently") + sols = solve_multipoly(expr, x, multiplicities) end sols = map(postprocess_root, sols) @@ -284,6 +287,6 @@ function solve_multipoly(polys::Vector, x::Num, mult=false) end -function solve_multivar(eqs::Vector{Num}, vars::Vector{Num}, mult=false) +function solve_multivar(eqs::Vector, vars::Vector{Num}, mult=false) throw("Groebner bases engine is required. Execute `using Groebner` to enable this functionality.") end