diff --git a/docs/src/manual/solver.md b/docs/src/manual/solver.md index ffe2fd0ef..9857cf9bd 100644 --- a/docs/src/manual/solver.md +++ b/docs/src/manual/solver.md @@ -32,6 +32,15 @@ to `solve_univar`. We can see that essentially, `solve_univar` is the building b it to `ia_solve`, which attempts solving by attraction and isolation [^2]. This only works when the input is a single expression and the user wants the answer in terms of a single variable. Say `log(x) - a == 0` gives us `[e^a]`. +```@docs +Symbolics.solve_univar +Symbolics.solve_multivar +Symbolics.ia_solve +Symbolics.ia_conditions! +Symbolics.is_periodic +Symbolics.fundamental_period +``` + #### Nice examples ```@example solver diff --git a/src/inverse.jl b/src/inverse.jl index 0d541d5c2..3fb87a46a 100644 --- a/src/inverse.jl +++ b/src/inverse.jl @@ -157,6 +157,10 @@ inverse(::typeof(NaNMath.log10)) = inverse(log10) inverse(::typeof(NaNMath.log1p)) = inverse(log1p) inverse(::typeof(NaNMath.log2)) = inverse(log2) left_inverse(::typeof(NaNMath.sqrt)) = left_inverse(sqrt) +# inverses of solve helpers +left_inverse(::typeof(ssqrt)) = left_inverse(sqrt) +left_inverse(::typeof(scbrt)) = left_inverse(cbrt) +left_inverse(::typeof(slog)) = left_inverse(log) function inverse(f::ComposedFunction) return inverse(f.inner) ∘ inverse(f.outer) diff --git a/src/solver/ia_helpers.jl b/src/solver/ia_helpers.jl index 7768106ef..914e68d2d 100644 --- a/src/solver/ia_helpers.jl +++ b/src/solver/ia_helpers.jl @@ -140,3 +140,85 @@ function find_logandexpon(arg, var, oper, poly_index) !isequal(oper_term, 0) && !isequal(constant_term, 0) && return true return false end + +""" + ia_conditions!(f, lhs, rhs::Vector{Any}, conditions::Vector{Tuple}) + +If `f` is a left-invertible function, `lhs` and `rhs[i]` are univariate functions and +`f(lhs) ~ rhs[i]` for all `i in eachindex(rhss)`, push to `conditions` all the relevant +conditions on `lhs` or `rhs[i]`. Each condition is of the form `(sym, op)` where `sym` +is an expression involving `lhs` and/or `rhs[i]` and `op` is a binary relational operator. +The condition `op(sym, 0)` is then required to be true for the equation `f(lhs) ~ rhs[i]` +to be valid. + +For example, if `f = log`, `lhs = x` and `rhss = [y, z]` then the condition `x > 0` must +be true. Thus, `(lhs, >)` is pushed to `conditions`. Similarly, if `f = sqrt`, `rhs[i] >= 0` +must be true for all `i`, and so `(y, >=)` and `(z, >=)` will be appended to `conditions`. +""" +function ia_conditions!(args...; kwargs...) end + +for fn in [log, log2, log10, NaNMath.log, NaNMath.log2, NaNMath.log10, slog] + @eval function ia_conditions!(::typeof($fn), lhs, rhs, conditions) + push!(conditions, (lhs, >)) + end +end + +for fn in [log1p, NaNMath.log1p] + @eval function ia_conditions!(::typeof($fn), lhs, rhs, conditions) + push!(conditions, (lhs - 1, >)) + end +end + +for fn in [sqrt, NaNMath.sqrt, ssqrt] + @eval function ia_conditions!(::typeof($fn), lhs, rhs, conditions) + for r in rhs + push!(conditions, (r, >=)) + end + end +end + +""" + is_periodic(f) + +Return `true` if `f` is a single-input single-output periodic function. Return `false` by +default. If `is_periodic(f) == true`, then `fundamental_period(f)` must also be defined. + +See also: [`fundamental_period`](@ref) +""" +is_periodic(f) = false + +for fn in [ + sin, cos, tan, csc, sec, cot, NaNMath.sin, NaNMath.cos, NaNMath.tan, sind, cosd, tand, + cscd, secd, cotd, cospi +] + @eval is_periodic(::typeof($fn)) = true +end + +""" + fundamental_period(f) + +Return the fundamental period of periodic function `f`. Must only be called if +`is_periodic(f) == true`. + +see also: [`is_periodic`](@ref) +""" +function fundamental_period end + +for fn in [sin, cos, csc, sec, NaNMath.sin, NaNMath.cos] + @eval fundamental_period(::typeof($fn)) = 2pi +end + +for fn in [sind, cosd, cscd, secd] + @eval fundamental_period(::typeof($fn)) = 360.0 +end + +fundamental_period(::typeof(cospi)) = 2.0 + +for fn in [tand, cotd] + @eval fundamental_period(::typeof($fn)) = 180.0 +end + +for fn in [tan, cot, NaNMath.tan] + # `1pi isa Float64` whereas `pi isa Irrational{:π}` + @eval fundamental_period(::typeof($fn)) = 1pi +end diff --git a/src/solver/ia_main.jl b/src/solver/ia_main.jl index 8f83ca02b..ae0e6e717 100644 --- a/src/solver/ia_main.jl +++ b/src/solver/ia_main.jl @@ -1,6 +1,8 @@ using Symbolics -function isolate(lhs, var; warns=true, conditions=[]) +const SAFE_ALTERNATIVES = Dict(log => slog, sqrt => ssqrt, cbrt => scbrt) + +function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, periodic_roots = true) rhs = Vector{Any}([0]) original_lhs = deepcopy(lhs) lhs = unwrap(lhs) @@ -72,12 +74,21 @@ function isolate(lhs, var; warns=true, conditions=[]) power = args[2] new_roots = [] - for i in eachindex(rhs) - for k in 0:(args[2] - 1) - r = wrap(term(^, rhs[i], (1 // power))) - c = wrap(term(*, 2 * (k), pi)) * im / power - root = r * Base.MathConstants.e^c - push!(new_roots, root) + if complex_roots + for i in eachindex(rhs) + for k in 0:(args[2] - 1) + r = term(^, rhs[i], (1 // power)) + c = term(*, 2 * (k), pi) * im / power + root = r * Base.MathConstants.e^c + push!(new_roots, root) + end + end + else + for i in eachindex(rhs) + push!(new_roots, term(^, rhs[i], (1 // power))) + if iseven(power) + push!(new_roots, term(-, new_roots[end])) + end end end rhs = [] @@ -90,57 +101,23 @@ function isolate(lhs, var; warns=true, conditions=[]) lhs = args[2] rhs = map(sol -> term(/, term(slog, sol), term(slog, args[1])), rhs) end - - elseif oper === (log) || oper === (slog) - lhs = args[1] - rhs = map(sol -> term(^, Base.MathConstants.e, sol), rhs) - push!(conditions, (args[1], >)) - - elseif oper === (log2) - lhs = args[1] - rhs = map(sol -> term(^, 2, sol), rhs) - push!(conditions, (args[1], >)) - - elseif oper === (log10) + elseif has_left_inverse(oper) lhs = args[1] - rhs = map(sol -> term(^, 10, sol), rhs) - push!(conditions, (args[1], >)) - - elseif oper === (sqrt) - lhs = args[1] - append!(conditions, [(r, >=) for r in rhs]) - rhs = map(sol -> term(^, sol, 2), rhs) - - elseif oper === (cbrt) - lhs = args[1] - rhs = map(sol -> term(^, sol, 3), rhs) - - elseif oper === (sin) || oper === (cos) || oper === (tan) - rev_oper = Dict(sin => asin, cos => acos, tan => atan) - lhs = args[1] - # make this global somehow so the user doesnt need to declare it on his own - new_var = gensym() - new_var = (@variables $new_var)[1] - rhs = map( - sol -> term(rev_oper[oper], sol) + - term(*, Base.MathConstants.pi, new_var), - rhs) - @info string(new_var) * " ϵ" * " Ζ" - - elseif oper === (asin) - lhs = args[1] - rhs = map(sol -> term(sin, sol), rhs) - - elseif oper === (acos) - lhs = args[1] - rhs = map(sol -> term(cos, sol), rhs) - - elseif oper === (atan) - lhs = args[1] - rhs = map(sol -> term(tan, sol), rhs) - elseif oper === (exp) - lhs = args[1] - rhs = map(sol -> term(slog, sol), rhs) + ia_conditions!(oper, lhs, rhs, conditions) + invop = left_inverse(oper) + invop = get(SAFE_ALTERNATIVES, invop, invop) + if is_periodic(oper) && periodic_roots + new_var = gensym() + new_var = (@variables $new_var)[1] + period = fundamental_period(oper) + rhs = map( + sol -> term(invop, sol) + + term(*, period, new_var), + rhs) + @info string(new_var) * " ϵ" * " Ζ" + else + rhs = map(sol -> term(invop, sol), rhs) + end end lhs = simplify(lhs) @@ -149,7 +126,7 @@ function isolate(lhs, var; warns=true, conditions=[]) return rhs, conditions end -function attract(lhs, var; warns = true) +function attract(lhs, var; warns = true, complex_roots = true, periodic_roots = true) if n_func_occ(simplify(lhs), var) <= n_func_occ(lhs, var) lhs = simplify(lhs) end @@ -164,7 +141,9 @@ function attract(lhs, var; warns = true) end lhs = attract_trig(lhs, var) - n_func_occ(lhs, var) == 1 && return isolate(lhs, var, warns = warns, conditions=conditions) + if n_func_occ(lhs, var) == 1 + return isolate(lhs, var; warns, conditions, complex_roots, periodic_roots) + end lhs, sub = turn_to_poly(lhs, var) @@ -182,12 +161,12 @@ function attract(lhs, var; warns = true) new_var = collect(keys(sub))[1] new_var_val = collect(values(sub))[1] - roots, new_conds = isolate(lhs, new_var, warns = warns) + roots, new_conds = isolate(lhs, new_var; warns = warns, complex_roots, periodic_roots) append!(conditions, new_conds) new_roots = [] for root in roots - new_sol, new_conds = isolate(new_var_val - root, var, warns = warns) + new_sol, new_conds = isolate(new_var_val - root, var; warns = warns, complex_roots, periodic_roots) append!(conditions, new_conds) push!(new_roots, new_sol) end @@ -197,7 +176,7 @@ function attract(lhs, var; warns = true) end """ - ia_solve(lhs, var) + ia_solve(lhs, var; kwargs...) This function attempts to solve transcendental functions by first checking the "smart" number of occurrences in the input LHS. By smart here we mean that polynomials are counted as 1 occurrence. for example `x^2 + 2x` is 1 @@ -226,6 +205,13 @@ we throw an error to tell the user that this is currently unsolvable by our cove - lhs: a Num/SymbolicUtils.BasicSymbolic - var: variable to solve for. +# Keyword arguments +- `warns = true`: Whether to emit warnings for unsolvable expressions. +- `complex_roots = true`: Whether to consider complex roots of `x ^ n ~ y`, where `n` is an integer. +- `periodic_roots = true`: If `true`, isolate `f(x) ~ y` as `x ~ finv(y) + n * period` where + `is_periodic(f) == true`, `finv = left_inverse(f)` and `period = fundamental_period(f)`. `n` + is a new anonymous symbolic variable. + # Examples ```jldoctest julia> solve(a*x^b + c, x) @@ -256,10 +242,20 @@ julia> RootFinding.ia_solve(expr, x) -2 + π*2var"##230" + asin((1//2)*(-1 + RootFinding.ssqrt(-39))) -2 + π*2var"##234" + asin((1//2)*(-1 - RootFinding.ssqrt(-39))) ``` + +All transcendental functions for which `left_inverse` is defined are supported. +To enable `ia_solve` to handle custom transcendental functions, define an inverse or +left inverse. If the function is periodic, `is_periodic` and `fundamental_period` must +be defined. If the function imposes certain conditions on its input or output (for +example, `log` requires that its input be positive) define `ia_conditions!`. + +See also: [`left_inverse`](@ref), [`inverse`](@ref), [`is_periodic`](@ref), +[`fundamental_period`](@ref), [`ia_conditions!`](@ref). + # References [^1]: [R. W. Hamming, Coding and Information Theory, ScienceDirect, 1980](https://www.sciencedirect.com/science/article/pii/S0747717189800070). """ -function ia_solve(lhs, var; warns = true) +function ia_solve(lhs, var; warns = true, complex_roots = true, periodic_roots = true) nx = n_func_occ(lhs, var) sols = [] conditions = [] @@ -267,9 +263,9 @@ function ia_solve(lhs, var; warns = true) warns && @warn("Var not present in given expression") return [] elseif nx == 1 - sols, conditions = isolate(lhs, var, warns = warns) + sols, conditions = isolate(lhs, var; warns = warns, complex_roots, periodic_roots) elseif nx > 1 - sols, conditions = attract(lhs, var, warns = warns) + sols, conditions = attract(lhs, var; warns = warns, complex_roots, periodic_roots) end isequal(sols, nothing) && return nothing diff --git a/test/solver.jl b/test/solver.jl index 8e85bfbce..1c82a899b 100644 --- a/test/solver.jl +++ b/test/solver.jl @@ -411,7 +411,7 @@ end #@test isequal(lhs, rhs) lhs = symbolic_solve(log(a*x)-b,x)[1] - @test isequal(Symbolics.arguments(Symbolics.unwrap(Symbolics.ssubs(lhs, Dict(a=>1, b=>1))))[1], E) + @test isequal(Symbolics.unwrap(Symbolics.ssubs(lhs, Dict(a=>1, b=>1))), 1E) expr = x + 2 lhs = eval.(Symbolics.toexpr.(ia_solve(expr, x))) @@ -431,7 +431,7 @@ end @test isapprox(eval(Symbolics.toexpr(symbolic_solve(expr, x)[1])), sqrt(2), atol=1e-6) expr = 2^(x+1) + 5^(x+3) - lhs = eval.(Symbolics.toexpr.(ia_solve(expr, x))) + lhs = ComplexF64.(eval.(Symbolics.toexpr.(ia_solve(expr, x)))) lhs_solve = eval.(Symbolics.toexpr.(symbolic_solve(expr, x))) rhs = [(-im*Base.MathConstants.pi - log(2) + 3log(5))/(log(2) - log(5))] @test lhs[1] ≈ rhs[1] @@ -488,6 +488,31 @@ end @test all(lhs .≈ rhs) @test all(lhs_solve .≈ rhs) + + @testset "Keyword arguments" begin + expr = sec(x ^ 2 + 4x + 4) ^ 3 - 3 + roots = ia_solve(expr, x) + @test length(roots) == 6 # 2 quadratic roots * 3 roots from cbrt(3) + @test length(Symbolics.get_variables(roots[1])) == 1 + _n = only(Symbolics.get_variables(roots[1])) + vals = substitute.(roots, (Dict(_n => 0),)) + @test all(x -> isapprox(norm(sec(x^2 + 4x + 4) ^ 3 - 3), 0.0, atol = 1e-14), vals) + + roots = ia_solve(expr, x; complex_roots = false) + @test length(roots) == 2 + # the `n` in `θ + n * 2π` + @test length(Symbolics.get_variables(roots[1])) == 1 + _n = only(Symbolics.get_variables(roots[1])) + vals = substitute.(roots, (Dict(_n => 0),)) + @test all(x -> isapprox(norm(sec(x^2 + 4x + 4) ^ 3 - 3), 0.0, atol = 1e-14), vals) + + roots = ia_solve(expr, x; complex_roots = false, periodic_roots = false) + @test length(roots) == 2 + @test length(Symbolics.get_variables(roots[1])) == 0 + @test length(Symbolics.get_variables(roots[2])) == 0 + vals = eval.(Symbolics.toexpr.(roots)) + @test all(x -> isapprox(norm(sec(x^2 + 4x + 4) ^ 3 - 3), 0.0, atol = 1e-14), vals) + end end @testset "Sqrt case poly" begin