Skip to content

Commit

Permalink
changed ssqrt and scbrt prints, and added sqrt case poly tests
Browse files Browse the repository at this point in the history
  • Loading branch information
n0rbed committed Jul 28, 2024
1 parent de86f9a commit b41f8e5
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 7 deletions.
6 changes: 3 additions & 3 deletions src/solver/ia_main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ function attract(lhs, var)
lhs, sub = turn_to_poly(lhs, var)

if (isequal(sub, Dict()) || n_func_occ(lhs, collect(keys(sub))[1]) != 1)
tuff_poly = detect_tuffpoly(lhs, var)
if tuff_poly
return attract_tuffpoly(lhs, var)
sqrt_poly = detect_sqrtpoly(lhs, var)
if sqrt_poly
return attract_and_solve_sqrtpoly(lhs, var)
else
throw("This expression cannot be solved with the methods available to solve. Try \
a numerical method instead.")
Expand Down
11 changes: 7 additions & 4 deletions src/solver/polynomialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,8 @@ function check_sqrt(arg, sqrt_term, var)
end
end

function detect_tuffpoly(lhs, var)
# f(x) + sqrt(g(x)) + c
function detect_sqrtpoly(lhs, var)
lhs = unwrap(expand(lhs))
!iscall(lhs) && return false
args = arguments(lhs)
Expand Down Expand Up @@ -351,7 +352,7 @@ end



function attract_tuffpoly(lhs, var)
function attract_and_solve_sqrtpoly(lhs, var)
sqrt_term = 0
poly_term = 0
subs, filtered_expr = filter_poly(lhs, var)
Expand Down Expand Up @@ -386,12 +387,14 @@ function attract_tuffpoly(lhs, var)
end

end
eq_to_solve = postprocess_root(expand((poly_term)^2 - (sqrt_term)^2))
lhs = lhs - sqrt_term + ssqrt(arguments(sqrt_term)[1])
eq_to_solve = expand((poly_term)^2 - arguments(sqrt_term)[1])
eq_to_solve = ssubs(eq_to_solve, subs)
roots = solve(eq_to_solve, var)
answers = []

for root in roots
if isapprox(ssubs(lhs, Dict(var=>root)), 0, atol=1e-4)
if isapprox(substitute(lhs, Dict(var=>eval(Symbolics.toexpr(root)))), 0, atol=1e-4)
push!(answers, root)
end
end
Expand Down
6 changes: 6 additions & 0 deletions src/solver/solve_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ struct RootsOf
end

Base.show(io::IO, r::RootsOf) = print(io, "roots_of(", r.poly, ", ", x, ")")
Base.show(io::IO, f::typeof(ssqrt)) = print(io, "")
Base.show(io::IO, r::typeof(scbrt)) = print(io, "")

# not sure if this is a good idea as it can hide from the
# user when it misbehaves
# Base.show(io::IO, r::typeof(slog)) = print(io, "log")


function check_expr_validity(expr)
Expand Down
17 changes: 17 additions & 0 deletions test/new_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,23 @@ end
@test all(lhs_solve .≈ rhs)
end

@tesetset "Sqrt case poly" begin
# f(x) + sqrt(g(x)) + c
expr = x + sqrt(x+1) - 5
lhs_ia = ia_solve(expr, x)[1]
lhs_att = Symbolics.attract_and_solve_sqrtpoly(expr, x)[1]
lhs_solve = solve(expr, x)[1]
@test all(isequal(answer, 3) for answer in [lhs_ia, lhs_att, lhs_solve])

expr = x^2 + x + sqrt(x) + 2
lhs = sort_roots(eval.(Symbolics.toexpr.(ia_solve(expr, x))))
lhs_solve = sort_roots(eval.(Symbolics.toexpr.(solve(expr, x))))
rhs = sort_roots([-0.860929555 - 1.604034315im, -0.860929555 + 1.604034315im])
@test all(isapprox.(lhs, rhs, atol=1e-6))
@test all(isapprox.(lhs_solve, rhs, atol=1e-6))
@test all(isequal.(lhs, lhs_solve))
end

@testset "Turn to poly" begin
@variables x
# does not sub because these can not be solved as polys
Expand Down

0 comments on commit b41f8e5

Please sign in to comment.