Skip to content

Commit

Permalink
Merge pull request #1348 from AayushSabharwal/as/ia-solve-inverses
Browse files Browse the repository at this point in the history
  • Loading branch information
n0rbed authored Nov 6, 2024
2 parents 8998fa8 + 1d71dd4 commit 8c518c2
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 67 deletions.
9 changes: 9 additions & 0 deletions docs/src/manual/solver.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ to `solve_univar`. We can see that essentially, `solve_univar` is the building b
it to `ia_solve`, which attempts solving by attraction and isolation [^2]. This only works when the input is a single expression
and the user wants the answer in terms of a single variable. Say `log(x) - a == 0` gives us `[e^a]`.

```@docs
Symbolics.solve_univar
Symbolics.solve_multivar
Symbolics.ia_solve
Symbolics.ia_conditions!
Symbolics.is_periodic
Symbolics.fundamental_period
```

#### Nice examples

```@example solver
Expand Down
4 changes: 4 additions & 0 deletions src/inverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ inverse(::typeof(NaNMath.log10)) = inverse(log10)
inverse(::typeof(NaNMath.log1p)) = inverse(log1p)
inverse(::typeof(NaNMath.log2)) = inverse(log2)
left_inverse(::typeof(NaNMath.sqrt)) = left_inverse(sqrt)
# inverses of solve helpers
left_inverse(::typeof(ssqrt)) = left_inverse(sqrt)
left_inverse(::typeof(scbrt)) = left_inverse(cbrt)
left_inverse(::typeof(slog)) = left_inverse(log)

function inverse(f::ComposedFunction)
return inverse(f.inner) inverse(f.outer)
Expand Down
82 changes: 82 additions & 0 deletions src/solver/ia_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,85 @@ function find_logandexpon(arg, var, oper, poly_index)
!isequal(oper_term, 0) && !isequal(constant_term, 0) && return true
return false
end

"""
ia_conditions!(f, lhs, rhs::Vector{Any}, conditions::Vector{Tuple})
If `f` is a left-invertible function, `lhs` and `rhs[i]` are univariate functions and
`f(lhs) ~ rhs[i]` for all `i in eachindex(rhss)`, push to `conditions` all the relevant
conditions on `lhs` or `rhs[i]`. Each condition is of the form `(sym, op)` where `sym`
is an expression involving `lhs` and/or `rhs[i]` and `op` is a binary relational operator.
The condition `op(sym, 0)` is then required to be true for the equation `f(lhs) ~ rhs[i]`
to be valid.
For example, if `f = log`, `lhs = x` and `rhss = [y, z]` then the condition `x > 0` must
be true. Thus, `(lhs, >)` is pushed to `conditions`. Similarly, if `f = sqrt`, `rhs[i] >= 0`
must be true for all `i`, and so `(y, >=)` and `(z, >=)` will be appended to `conditions`.
"""
function ia_conditions!(args...; kwargs...) end

for fn in [log, log2, log10, NaNMath.log, NaNMath.log2, NaNMath.log10, slog]
@eval function ia_conditions!(::typeof($fn), lhs, rhs, conditions)
push!(conditions, (lhs, >))
end
end

for fn in [log1p, NaNMath.log1p]
@eval function ia_conditions!(::typeof($fn), lhs, rhs, conditions)
push!(conditions, (lhs - 1, >))
end
end

for fn in [sqrt, NaNMath.sqrt, ssqrt]
@eval function ia_conditions!(::typeof($fn), lhs, rhs, conditions)
for r in rhs
push!(conditions, (r, >=))
end
end
end

"""
is_periodic(f)
Return `true` if `f` is a single-input single-output periodic function. Return `false` by
default. If `is_periodic(f) == true`, then `fundamental_period(f)` must also be defined.
See also: [`fundamental_period`](@ref)
"""
is_periodic(f) = false

for fn in [
sin, cos, tan, csc, sec, cot, NaNMath.sin, NaNMath.cos, NaNMath.tan, sind, cosd, tand,
cscd, secd, cotd, cospi
]
@eval is_periodic(::typeof($fn)) = true
end

"""
fundamental_period(f)
Return the fundamental period of periodic function `f`. Must only be called if
`is_periodic(f) == true`.
see also: [`is_periodic`](@ref)
"""
function fundamental_period end

for fn in [sin, cos, csc, sec, NaNMath.sin, NaNMath.cos]
@eval fundamental_period(::typeof($fn)) = 2pi
end

for fn in [sind, cosd, cscd, secd]
@eval fundamental_period(::typeof($fn)) = 360.0
end

fundamental_period(::typeof(cospi)) = 2.0

for fn in [tand, cotd]
@eval fundamental_period(::typeof($fn)) = 180.0
end

for fn in [tan, cot, NaNMath.tan]
# `1pi isa Float64` whereas `pi isa Irrational{:π}`
@eval fundamental_period(::typeof($fn)) = 1pi
end
126 changes: 61 additions & 65 deletions src/solver/ia_main.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using Symbolics

function isolate(lhs, var; warns=true, conditions=[])
const SAFE_ALTERNATIVES = Dict(log => slog, sqrt => ssqrt, cbrt => scbrt)

function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, periodic_roots = true)
rhs = Vector{Any}([0])
original_lhs = deepcopy(lhs)
lhs = unwrap(lhs)
Expand Down Expand Up @@ -72,12 +74,21 @@ function isolate(lhs, var; warns=true, conditions=[])
power = args[2]
new_roots = []

for i in eachindex(rhs)
for k in 0:(args[2] - 1)
r = wrap(term(^, rhs[i], (1 // power)))
c = wrap(term(*, 2 * (k), pi)) * im / power
root = r * Base.MathConstants.e^c
push!(new_roots, root)
if complex_roots
for i in eachindex(rhs)
for k in 0:(args[2] - 1)
r = term(^, rhs[i], (1 // power))
c = term(*, 2 * (k), pi) * im / power
root = r * Base.MathConstants.e^c
push!(new_roots, root)
end
end
else
for i in eachindex(rhs)
push!(new_roots, term(^, rhs[i], (1 // power)))
if iseven(power)
push!(new_roots, term(-, new_roots[end]))
end
end
end
rhs = []
Expand All @@ -90,57 +101,23 @@ function isolate(lhs, var; warns=true, conditions=[])
lhs = args[2]
rhs = map(sol -> term(/, term(slog, sol), term(slog, args[1])), rhs)
end

elseif oper === (log) || oper === (slog)
lhs = args[1]
rhs = map(sol -> term(^, Base.MathConstants.e, sol), rhs)
push!(conditions, (args[1], >))

elseif oper === (log2)
lhs = args[1]
rhs = map(sol -> term(^, 2, sol), rhs)
push!(conditions, (args[1], >))

elseif oper === (log10)
elseif has_left_inverse(oper)
lhs = args[1]
rhs = map(sol -> term(^, 10, sol), rhs)
push!(conditions, (args[1], >))

elseif oper === (sqrt)
lhs = args[1]
append!(conditions, [(r, >=) for r in rhs])
rhs = map(sol -> term(^, sol, 2), rhs)

elseif oper === (cbrt)
lhs = args[1]
rhs = map(sol -> term(^, sol, 3), rhs)

elseif oper === (sin) || oper === (cos) || oper === (tan)
rev_oper = Dict(sin => asin, cos => acos, tan => atan)
lhs = args[1]
# make this global somehow so the user doesnt need to declare it on his own
new_var = gensym()
new_var = (@variables $new_var)[1]
rhs = map(
sol -> term(rev_oper[oper], sol) +
term(*, Base.MathConstants.pi, new_var),
rhs)
@info string(new_var) * " ϵ" * " Ζ"

elseif oper === (asin)
lhs = args[1]
rhs = map(sol -> term(sin, sol), rhs)

elseif oper === (acos)
lhs = args[1]
rhs = map(sol -> term(cos, sol), rhs)

elseif oper === (atan)
lhs = args[1]
rhs = map(sol -> term(tan, sol), rhs)
elseif oper === (exp)
lhs = args[1]
rhs = map(sol -> term(slog, sol), rhs)
ia_conditions!(oper, lhs, rhs, conditions)
invop = left_inverse(oper)
invop = get(SAFE_ALTERNATIVES, invop, invop)
if is_periodic(oper) && periodic_roots
new_var = gensym()
new_var = (@variables $new_var)[1]
period = fundamental_period(oper)
rhs = map(
sol -> term(invop, sol) +
term(*, period, new_var),
rhs)
@info string(new_var) * " ϵ" * " Ζ"
else
rhs = map(sol -> term(invop, sol), rhs)
end
end

lhs = simplify(lhs)
Expand All @@ -149,7 +126,7 @@ function isolate(lhs, var; warns=true, conditions=[])
return rhs, conditions
end

function attract(lhs, var; warns = true)
function attract(lhs, var; warns = true, complex_roots = true, periodic_roots = true)
if n_func_occ(simplify(lhs), var) <= n_func_occ(lhs, var)
lhs = simplify(lhs)
end
Expand All @@ -164,7 +141,9 @@ function attract(lhs, var; warns = true)
end
lhs = attract_trig(lhs, var)

n_func_occ(lhs, var) == 1 && return isolate(lhs, var, warns = warns, conditions=conditions)
if n_func_occ(lhs, var) == 1
return isolate(lhs, var; warns, conditions, complex_roots, periodic_roots)
end

lhs, sub = turn_to_poly(lhs, var)

Expand All @@ -182,12 +161,12 @@ function attract(lhs, var; warns = true)
new_var = collect(keys(sub))[1]
new_var_val = collect(values(sub))[1]

roots, new_conds = isolate(lhs, new_var, warns = warns)
roots, new_conds = isolate(lhs, new_var; warns = warns, complex_roots, periodic_roots)
append!(conditions, new_conds)
new_roots = []

for root in roots
new_sol, new_conds = isolate(new_var_val - root, var, warns = warns)
new_sol, new_conds = isolate(new_var_val - root, var; warns = warns, complex_roots, periodic_roots)
append!(conditions, new_conds)
push!(new_roots, new_sol)
end
Expand All @@ -197,7 +176,7 @@ function attract(lhs, var; warns = true)
end

"""
ia_solve(lhs, var)
ia_solve(lhs, var; kwargs...)
This function attempts to solve transcendental functions by first checking
the "smart" number of occurrences in the input LHS. By smart here we mean
that polynomials are counted as 1 occurrence. for example `x^2 + 2x` is 1
Expand Down Expand Up @@ -226,6 +205,13 @@ we throw an error to tell the user that this is currently unsolvable by our cove
- lhs: a Num/SymbolicUtils.BasicSymbolic
- var: variable to solve for.
# Keyword arguments
- `warns = true`: Whether to emit warnings for unsolvable expressions.
- `complex_roots = true`: Whether to consider complex roots of `x ^ n ~ y`, where `n` is an integer.
- `periodic_roots = true`: If `true`, isolate `f(x) ~ y` as `x ~ finv(y) + n * period` where
`is_periodic(f) == true`, `finv = left_inverse(f)` and `period = fundamental_period(f)`. `n`
is a new anonymous symbolic variable.
# Examples
```jldoctest
julia> solve(a*x^b + c, x)
Expand Down Expand Up @@ -256,20 +242,30 @@ julia> RootFinding.ia_solve(expr, x)
-2 + π*2var"##230" + asin((1//2)*(-1 + RootFinding.ssqrt(-39)))
-2 + π*2var"##234" + asin((1//2)*(-1 - RootFinding.ssqrt(-39)))
```
All transcendental functions for which `left_inverse` is defined are supported.
To enable `ia_solve` to handle custom transcendental functions, define an inverse or
left inverse. If the function is periodic, `is_periodic` and `fundamental_period` must
be defined. If the function imposes certain conditions on its input or output (for
example, `log` requires that its input be positive) define `ia_conditions!`.
See also: [`left_inverse`](@ref), [`inverse`](@ref), [`is_periodic`](@ref),
[`fundamental_period`](@ref), [`ia_conditions!`](@ref).
# References
[^1]: [R. W. Hamming, Coding and Information Theory, ScienceDirect, 1980](https://www.sciencedirect.com/science/article/pii/S0747717189800070).
"""
function ia_solve(lhs, var; warns = true)
function ia_solve(lhs, var; warns = true, complex_roots = true, periodic_roots = true)
nx = n_func_occ(lhs, var)
sols = []
conditions = []
if nx == 0
warns && @warn("Var not present in given expression")
return []
elseif nx == 1
sols, conditions = isolate(lhs, var, warns = warns)
sols, conditions = isolate(lhs, var; warns = warns, complex_roots, periodic_roots)
elseif nx > 1
sols, conditions = attract(lhs, var, warns = warns)
sols, conditions = attract(lhs, var; warns = warns, complex_roots, periodic_roots)
end

isequal(sols, nothing) && return nothing
Expand Down
29 changes: 27 additions & 2 deletions test/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ end
#@test isequal(lhs, rhs)

lhs = symbolic_solve(log(a*x)-b,x)[1]
@test isequal(Symbolics.arguments(Symbolics.unwrap(Symbolics.ssubs(lhs, Dict(a=>1, b=>1))))[1], E)
@test isequal(Symbolics.unwrap(Symbolics.ssubs(lhs, Dict(a=>1, b=>1))), 1E)

expr = x + 2
lhs = eval.(Symbolics.toexpr.(ia_solve(expr, x)))
Expand All @@ -431,7 +431,7 @@ end
@test isapprox(eval(Symbolics.toexpr(symbolic_solve(expr, x)[1])), sqrt(2), atol=1e-6)

expr = 2^(x+1) + 5^(x+3)
lhs = eval.(Symbolics.toexpr.(ia_solve(expr, x)))
lhs = ComplexF64.(eval.(Symbolics.toexpr.(ia_solve(expr, x))))
lhs_solve = eval.(Symbolics.toexpr.(symbolic_solve(expr, x)))
rhs = [(-im*Base.MathConstants.pi - log(2) + 3log(5))/(log(2) - log(5))]
@test lhs[1] rhs[1]
Expand Down Expand Up @@ -488,6 +488,31 @@ end

@test all(lhs .≈ rhs)
@test all(lhs_solve .≈ rhs)

@testset "Keyword arguments" begin
expr = sec(x ^ 2 + 4x + 4) ^ 3 - 3
roots = ia_solve(expr, x)
@test length(roots) == 6 # 2 quadratic roots * 3 roots from cbrt(3)
@test length(Symbolics.get_variables(roots[1])) == 1
_n = only(Symbolics.get_variables(roots[1]))
vals = substitute.(roots, (Dict(_n => 0),))
@test all(x -> isapprox(norm(sec(x^2 + 4x + 4) ^ 3 - 3), 0.0, atol = 1e-14), vals)

roots = ia_solve(expr, x; complex_roots = false)
@test length(roots) == 2
# the `n` in `θ + n * 2π`
@test length(Symbolics.get_variables(roots[1])) == 1
_n = only(Symbolics.get_variables(roots[1]))
vals = substitute.(roots, (Dict(_n => 0),))
@test all(x -> isapprox(norm(sec(x^2 + 4x + 4) ^ 3 - 3), 0.0, atol = 1e-14), vals)

roots = ia_solve(expr, x; complex_roots = false, periodic_roots = false)
@test length(roots) == 2
@test length(Symbolics.get_variables(roots[1])) == 0
@test length(Symbolics.get_variables(roots[2])) == 0
vals = eval.(Symbolics.toexpr.(roots))
@test all(x -> isapprox(norm(sec(x^2 + 4x + 4) ^ 3 - 3), 0.0, atol = 1e-14), vals)
end
end

@testset "Sqrt case poly" begin
Expand Down

0 comments on commit 8c518c2

Please sign in to comment.