diff --git a/src/solver/ia_main.jl b/src/solver/ia_main.jl index 174eeb083..5614bb1fa 100644 --- a/src/solver/ia_main.jl +++ b/src/solver/ia_main.jl @@ -146,9 +146,9 @@ function attract(lhs, var) lhs, sub = turn_to_poly(lhs, var) if (isequal(sub, Dict()) || n_func_occ(lhs, collect(keys(sub))[1]) != 1) - tuff_poly = detect_tuffpoly(lhs, var) - if tuff_poly - return attract_tuffpoly(lhs, var) + sqrt_poly = detect_sqrtpoly(lhs, var) + if sqrt_poly + return attract_and_solve_sqrtpoly(lhs, var) else throw("This expression cannot be solved with the methods available to solve. Try \ a numerical method instead.") diff --git a/src/solver/polynomialization.jl b/src/solver/polynomialization.jl index 9e12daf26..b9692e4ac 100644 --- a/src/solver/polynomialization.jl +++ b/src/solver/polynomialization.jl @@ -303,7 +303,8 @@ function check_sqrt(arg, sqrt_term, var) end end -function detect_tuffpoly(lhs, var) +# f(x) + sqrt(g(x)) + c +function detect_sqrtpoly(lhs, var) lhs = unwrap(expand(lhs)) !iscall(lhs) && return false args = arguments(lhs) @@ -351,7 +352,7 @@ end -function attract_tuffpoly(lhs, var) +function attract_and_solve_sqrtpoly(lhs, var) sqrt_term = 0 poly_term = 0 subs, filtered_expr = filter_poly(lhs, var) @@ -386,12 +387,14 @@ function attract_tuffpoly(lhs, var) end end - eq_to_solve = postprocess_root(expand((poly_term)^2 - (sqrt_term)^2)) + lhs = lhs - sqrt_term + ssqrt(arguments(sqrt_term)[1]) + eq_to_solve = expand((poly_term)^2 - arguments(sqrt_term)[1]) eq_to_solve = ssubs(eq_to_solve, subs) roots = solve(eq_to_solve, var) answers = [] + for root in roots - if isapprox(ssubs(lhs, Dict(var=>root)), 0, atol=1e-4) + if isapprox(substitute(lhs, Dict(var=>eval(Symbolics.toexpr(root)))), 0, atol=1e-4) push!(answers, root) end end diff --git a/src/solver/solve_helpers.jl b/src/solver/solve_helpers.jl index f2c664cca..e99a81fba 100644 --- a/src/solver/solve_helpers.jl +++ b/src/solver/solve_helpers.jl @@ -72,6 +72,12 @@ struct RootsOf end Base.show(io::IO, r::RootsOf) = print(io, "roots_of(", r.poly, ", ", x, ")") +Base.show(io::IO, f::typeof(ssqrt)) = print(io, "√") +Base.show(io::IO, r::typeof(scbrt)) = print(io, "∛") + +# not sure if this is a good idea as it can hide from the +# user when it misbehaves +# Base.show(io::IO, r::typeof(slog)) = print(io, "log") function check_expr_validity(expr) diff --git a/test/new_solver.jl b/test/new_solver.jl index 849a99ac3..41cf5568d 100644 --- a/test/new_solver.jl +++ b/test/new_solver.jl @@ -359,6 +359,23 @@ end @test all(lhs_solve .≈ rhs) end +@tesetset "Sqrt case poly" begin + # f(x) + sqrt(g(x)) + c + expr = x + sqrt(x+1) - 5 + lhs_ia = ia_solve(expr, x)[1] + lhs_att = Symbolics.attract_and_solve_sqrtpoly(expr, x)[1] + lhs_solve = solve(expr, x)[1] + @test all(isequal(answer, 3) for answer in [lhs_ia, lhs_att, lhs_solve]) + + expr = x^2 + x + sqrt(x) + 2 + lhs = sort_roots(eval.(Symbolics.toexpr.(ia_solve(expr, x)))) + lhs_solve = sort_roots(eval.(Symbolics.toexpr.(solve(expr, x)))) + rhs = sort_roots([-0.860929555 - 1.604034315im, -0.860929555 + 1.604034315im]) + @test all(isapprox.(lhs, rhs, atol=1e-6)) + @test all(isapprox.(lhs_solve, rhs, atol=1e-6)) + @test all(isequal.(lhs, lhs_solve)) +end + @testset "Turn to poly" begin @variables x # does not sub because these can not be solved as polys