From a94771b7b83241f564ad613d899da923bfa4e1c9 Mon Sep 17 00:00:00 2001 From: n0rbed Date: Mon, 4 Nov 2024 20:45:45 +0200 Subject: [PATCH 1/2] Tests and bug fix --- src/solver/ia_rules.jl | 2 +- src/solver/main.jl | 8 +++----- test/solver.jl | 19 ++++++++++++++++++- 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src/solver/ia_rules.jl b/src/solver/ia_rules.jl index 58ab28d3e..33351ae64 100644 --- a/src/solver/ia_rules.jl +++ b/src/solver/ia_rules.jl @@ -50,7 +50,7 @@ function solve_interms_ofvar(eq, s; dropmultiplicity=true, warns=true) coeffs, constant = polynomial_coeffs(eq, [s]) eqs = wrap.(collect(values(coeffs))) - solve_multivar(eqs, vars, dropmultiplicity=dropmultiplicity, warns=warns) + symbolic_solve(eqs, vars, dropmultiplicity=dropmultiplicity, warns=warns) end # an attempt at using ia_solve recursively. diff --git a/src/solver/main.jl b/src/solver/main.jl index 7254ea152..a451b93ce 100644 --- a/src/solver/main.jl +++ b/src/solver/main.jl @@ -286,7 +286,7 @@ function solve_univar(expression, x; dropmultiplicity=true) factors_subbed = map(factor -> ssubs(factor, subs), factors) arr_roots = [] - if degree < 5 && length(factors) == 1 + if degree < 5 && isequal(expression, factors_subbed[1]) arr_roots = get_roots(expression, x) # multiplicities (repeated roots) @@ -296,10 +296,8 @@ function solve_univar(expression, x; dropmultiplicity=true) append!(arr_roots, og_arr_roots) end end - end - - if length(factors) != 1 - for i in eachindex(factors_subbed) + elseif length(factors) > 1 || (length(factors) == 1 && !isequal(factors_subbed[1], expression)) + for i in eachindex(factors_subbed) if !any(isequal(x, var) for var in get_variables(factors[i])) continue end diff --git a/test/solver.jl b/test/solver.jl index 58e5feb26..8e85bfbce 100644 --- a/test/solver.jl +++ b/test/solver.jl @@ -54,7 +54,24 @@ function check_approx(arr1, arr2) return true end -@variables x y z a b c d e +@variables x y z a b c d e s + +@testset "Solving in terms of a constant var" begin + eq = ((s^2 + 1)/(s^2 + 2*s + 1)) - ((s^2 + a)/(b*c*s^2 + (b+c)*s + d)) + calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b,c,d]) + known_roots = sort_arr([Dict(a=>1, b=>1, c=>1, d=>1)], [a,b,c,d]) + @test check_approx(calcd_roots, known_roots) + + eq = (a+b)*s^2 - 2s^2 + 2*b*s - 3*s + calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b]) + known_roots = sort_arr([Dict(a=>1/2, b=>3/2)], [a,b]) + @test check_approx(calcd_roots, known_roots) + + eq = (a*x^2+b)*s^2 - 2s^2 + 2*b*s - 3*s + 2(x^2)*(s^3) + 10*s^3 + calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b]) + known_roots = sort_arr([Dict(a=>-1/10, b=>3/2, x=>-im*sqrt(5)), Dict(a=>-1/10, b=>3/2, x=>im*sqrt(5))], [a,b,x]) + @test check_approx(calcd_roots, known_roots) +end @testset "Invalid input" begin @test_throws AssertionError symbolic_solve(x, x^2) From 62d8d0ecf1f8f280ac51516933ce877aaf1c86dc Mon Sep 17 00:00:00 2001 From: n0rbed Date: Mon, 4 Nov 2024 21:07:20 +0200 Subject: [PATCH 2/2] small fix --- src/solver/main.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/solver/main.jl b/src/solver/main.jl index a451b93ce..62c766b09 100644 --- a/src/solver/main.jl +++ b/src/solver/main.jl @@ -286,7 +286,7 @@ function solve_univar(expression, x; dropmultiplicity=true) factors_subbed = map(factor -> ssubs(factor, subs), factors) arr_roots = [] - if degree < 5 && isequal(expression, factors_subbed[1]) + if degree < 5 && isequal(factors_subbed[1], wrap(expression)) arr_roots = get_roots(expression, x) # multiplicities (repeated roots) @@ -296,7 +296,7 @@ function solve_univar(expression, x; dropmultiplicity=true) append!(arr_roots, og_arr_roots) end end - elseif length(factors) > 1 || (length(factors) == 1 && !isequal(factors_subbed[1], expression)) + elseif length(factors) > 1 || (length(factors) == 1 && !isequal(factors_subbed[1], wrap(expression))) for i in eachindex(factors_subbed) if !any(isequal(x, var) for var in get_variables(factors[i])) continue