From 0c15b58094fa1a5cf0e45f86407df10ee496d379 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 16 Sep 2024 16:44:31 -0400 Subject: [PATCH] Define `dict` function --- src/SymbolicUtils.jl | 2 +- src/inspect.jl | 4 ++-- src/polyform.jl | 8 ++++---- src/types.jl | 38 +++++++++++++++++++++----------------- test/basics.jl | 2 +- 5 files changed, 29 insertions(+), 25 deletions(-) diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index 542fb133..8ccd8d60 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -5,7 +5,7 @@ module SymbolicUtils using DocStringExtensions -export @syms, term, showraw, hasmetadata, getmetadata, setmetadata, name, coeff +export @syms, term, showraw, hasmetadata, getmetadata, setmetadata, name, coeff, dict using TermInterface using DataStructures diff --git a/src/inspect.jl b/src/inspect.jl index 961c9f0f..8592da5f 100644 --- a/src/inspect.jl +++ b/src/inspect.jl @@ -12,10 +12,10 @@ function AbstractTrees.nodevalue(x::BasicSymbolic) string(x.impl.val) elseif isadd(x) string(exprtype(x), - (scalar = coeff(x), coeffs = Tuple(k => v for (k, v) in x.impl.dict))) + (scalar = coeff(x), coeffs = Tuple(k => v for (k, v) in dict(x)))) elseif ismul(x) string(exprtype(x), - (scalar = coeff(x), powers = Tuple(k => v for (k, v) in x.impl.dict))) + (scalar = coeff(x), powers = Tuple(k => v for (k, v) in dict(x)))) elseif isdiv(x) || ispow(x) string(exprtype(x)) else diff --git a/src/polyform.jl b/src/polyform.jl index 6969d20b..8455bdd8 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -502,12 +502,12 @@ end # mul, pow case function quick_mulpow(x, y) y.impl.exp isa Number || return (x, y) - if haskey(x.impl.dict, y.impl.base) - d = copy(x.impl.dict) - if x.impl.dict[y.impl.base] > y.impl.exp + if haskey(dict(x), y.impl.base) + d = copy(dict(x)) + if dict(x)[y.impl.base] > y.impl.exp d[y.impl.base] -= y.impl.exp den = 1 - elseif x.impl.dict[y.impl.base] == y.impl.exp + elseif dict(x)[y.impl.base] == y.impl.exp delete!(d, y.impl.base) den = 1 else diff --git a/src/types.jl b/src/types.jl index 0434f090..c099badf 100644 --- a/src/types.jl +++ b/src/types.jl @@ -72,6 +72,10 @@ function coeff(x::BasicSymbolic) x.impl.coeff end +function dict(x::BasicSymbolic) + x.impl.dict +end + # Same but different error messages @noinline error_on_type() = error("Internal error: unreachable reached!") @noinline error_sym() = error("Sym doesn't have a operation or arguments!") @@ -297,7 +301,7 @@ function _isequal(a, b, E) if E === SYM nameof(a) === nameof(b) elseif E === ADD || E === MUL - coeff_isequal(coeff(a), coeff(b)) && isequal(a.impl.dict, b.impl.dict) + coeff_isequal(coeff(a), coeff(b)) && isequal(dict(a), dict(b)) elseif E === DIV isequal(a.impl.num, b.impl.num) && isequal(a.impl.den, b.impl.den) elseif E === POW @@ -341,7 +345,7 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt h = s.hash[] !iszero(h) && return h hashoffset = isadd(s) ? ADD_SALT : SUB_SALT - h′ = hash(hashoffset, hash(coeff(s), hash(s.impl.dict, salt))) + h′ = hash(hashoffset, hash(coeff(s), hash(dict(s), salt))) s.hash[] = h′ return h′ elseif E === DIV @@ -461,7 +465,7 @@ function maybe_intcoeff(x) if ismul(x) coeff = coeff(x) if coeff isa Rational && isone(denominator(coeff)) - _Mul(symtype(x), coeff.num, x.impl.dict; metadata = x.metadata) + _Mul(symtype(x), coeff.num, dict(x); metadata = x.metadata) else x end @@ -542,7 +546,7 @@ function toterm(t::BasicSymbolic{T}) where {T} elseif E === ADD || E === MUL args = BasicSymbolic[] push!(args, coeff(t)) - for (k, coeff) in t.impl.dict + for (k, coeff) in dict(t) push!( args, coeff == 1 ? k : _Term(T, E === MUL ? (^) : (*), [_Const(coeff), k])) end @@ -567,7 +571,7 @@ function makeadd(sign, coeff, xs...) for x in xs if isadd(x) coeff += coeff(x) - _merge!(+, d, x.impl.dict, filter = _iszero) + _merge!(+, d, dict(x), filter = _iszero) continue end if x isa Number @@ -575,7 +579,7 @@ function makeadd(sign, coeff, xs...) continue end if ismul(x) - k = _Mul(symtype(x), 1, x.impl.dict) + k = _Mul(symtype(x), 1, dict(x)) v = sign * coeff(x) + get(d, k, 0) else k = x @@ -598,7 +602,7 @@ function makemul(coeff, xs...; d = Dict{BasicSymbolic, Any}()) coeff *= x elseif ismul(x) coeff *= coeff(x) - _merge!(+, d, x.impl.dict, filter = _iszero) + _merge!(+, d, dict(x), filter = _iszero) else v = 1 + get(d, x, 0) if _iszero(v) @@ -1223,10 +1227,10 @@ function +(a::SN, b::SN) !issafecanon(+, a, b) && return term(+, a, b) # Don't flatten if args have metadata if isadd(a) && isadd(b) return _Add( - add_t(a, b), coeff(a) + coeff(b), _merge(+, a.impl.dict, b.impl.dict, filter = _iszero)) + add_t(a, b), coeff(a) + coeff(b), _merge(+, dict(a), dict(b), filter = _iszero)) elseif isadd(a) coeff, dict = makeadd(1, 0, b) - return _Add(add_t(a, b), coeff(a) + coeff, _merge(+, a.impl.dict, dict, filter = _iszero)) + return _Add(add_t(a, b), coeff(a) + coeff, _merge(+, dict(a), dict, filter = _iszero)) elseif isadd(b) return b + a end @@ -1240,7 +1244,7 @@ function +(a::Number, b::SN) !issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata iszero(a) && return b if isadd(b) - _Add(add_t(a, b), a + coeff(b), b.impl.dict) + _Add(add_t(a, b), a + coeff(b), dict(b)) else _Add(add_t(a, b), makeadd(1, a, b)...) end @@ -1258,7 +1262,7 @@ function -(a::SN) return term(-, a) end if isadd(a) - _Add(sub_t(a), -coeff(a), mapvalues((_, v) -> -v, a.impl.dict)) + _Add(sub_t(a), -coeff(a), mapvalues((_, v) -> -v, dict(a))) else _Add(sub_t(a), makeadd(-1, 0, a)...) end @@ -1266,7 +1270,7 @@ end function -(a::SN, b::SN) (!issafecanon(+, a) || !issafecanon(*, b)) && return term(-, a, b) if isadd(a) && isadd(b) - _Add(sub_t(a, b), coeff(a) - coeff(b), _merge(-, a.impl.dict, b.impl.dict, filter = _iszero)) + _Add(sub_t(a, b), coeff(a) - coeff(b), _merge(-, dict(a), dict(b), filter = _iszero)) else a + (-b) end @@ -1294,16 +1298,16 @@ function *(a::SN, b::SN) _Div(a * b.impl.num, b.impl.den) elseif ismul(a) && ismul(b) _Mul(mul_t(a, b), coeff(a) * coeff(b), - _merge(+, a.impl.dict, b.impl.dict, filter = _iszero)) + _merge(+, dict(a), dict(b), filter = _iszero)) elseif ismul(a) && ispow(b) if b.impl.exp isa Number _Mul(mul_t(a, b), coeff(a), - _merge(+, a.impl.dict, Base.ImmutableDict(b.impl.base => b.impl.exp), + _merge(+, dict(a), Base.ImmutableDict(b.impl.base => b.impl.exp), filter = _iszero)) else _Mul(mul_t(a, b), coeff(a), - _merge(+, a.impl.dict, Base.ImmutableDict(b => 1), filter = _iszero)) + _merge(+, dict(a), Base.ImmutableDict(b => 1), filter = _iszero)) end elseif ispow(a) && ismul(b) b * a @@ -1326,7 +1330,7 @@ function *(a::Number, b::SN) # -1(a+b) -> -a - b T = promote_symtype(+, typeof(a), symtype(b)) _Add(T, coeff(b) * a, - Dict{BasicSymbolic, Any}(k => v * a for (k, v) in b.impl.dict)) + Dict{BasicSymbolic, Any}(k => v * a for (k, v) in dict(b))) else _Mul(mul_t(a, b), makemul(a, b)...) end @@ -1352,7 +1356,7 @@ function ^(a::SN, b) elseif ismul(a) && b isa Number coeff = unstable_pow(coeff(a), b) _Mul(promote_symtype(^, symtype(a), symtype(b)), - coeff, mapvalues((k, v) -> b * v, a.impl.dict)) + coeff, mapvalues((k, v) -> b * v, dict(a))) else _Pow(a, b) end diff --git a/test/basics.jl b/test/basics.jl index a7db995f..0bce95f2 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -233,7 +233,7 @@ end @testset "maketerm" begin @syms a b c - @test isequal(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], nothing).impl.dict, Dict(a=>1,b=>1,c=>1)) + @test isequal(dict(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], nothing)), Dict(a=>1,b=>1,c=>1)) @test isequal(SymbolicUtils.maketerm(typeof(b^2), ^, [b^2, 1//2], nothing), b) # test that maketerm doesn't hard-code BasicSymbolic subtype