From 6aa675969de14e53866f544c1019e2c498212086 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 14 Sep 2024 17:24:20 +0530 Subject: [PATCH] fix: define derivatives of ssqrt, scbrt, slog --- src/solver/solve_helpers.jl | 6 ++++++ test/diff.jl | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/src/solver/solve_helpers.jl b/src/solver/solve_helpers.jl index f2420969d..8f8ccdff0 100644 --- a/src/solver/solve_helpers.jl +++ b/src/solver/solve_helpers.jl @@ -37,6 +37,8 @@ function ssqrt(n) end end +derivative(::typeof(ssqrt), args...) = substitute(derivative(sqrt, args...), sqrt => ssqrt) + function scbrt(n) n = unwrap(n) @@ -53,6 +55,8 @@ function scbrt(n) end end +derivative(::typeof(scbrt), args...) = substitute(derivative(cbrt, args...), cbrt => scbrt) + function slog(n) n = unwrap(n) @@ -67,6 +71,8 @@ function slog(n) return term(slog, n) end +derivative(::typeof(slog), args...) = substitute(derivative(log, args...), log => slog) + const RootsOf = (SymbolicUtils.@syms roots_of(poly,var))[1] Base.show(io::IO, f::typeof(ssqrt)) = print(io, "√") diff --git a/test/diff.jl b/test/diff.jl index eb500cb9d..56b145665 100644 --- a/test/diff.jl +++ b/test/diff.jl @@ -398,3 +398,12 @@ let @test_throws TypeError Symbolics.derivative(f, Val(rand(Int))) end end + +# Check ssqrt, scbrt, slog +let + @variables x + D = Differential(x) + @test isequal(expand_derivatives(D(Symbolics.ssqrt(1 + x ^ 2))), simplify((2x) / (2Symbolics.ssqrt(1 + x^2)))) + @test isequal(expand_derivatives(D(Symbolics.scbrt(1 + x ^ 2))), simplify((2x) / (3Symbolics.scbrt(1 + x^2)^2))) + @test isequal(expand_derivatives(D(Symbolics.slog(1 + x ^ 2))), simplify((2x) / (1 + x ^ 2))) +end