From 2f2c941712f9e2cd11f476666a63dc462ed6440a Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Tue, 17 Sep 2024 18:15:22 -0400 Subject: [PATCH] Add type promotion rules for `NoTangent` and `ZeroTangent`, and add `eltype` for `NoTangent` (#682) * Add promotion rules for ZeroTangent and NoTangent * Make NoTangent have an eltype of itself. * bump version --- Project.toml | 2 +- src/tangent_types/abstract_zero.jl | 3 +++ test/tangent_types/abstract_zero.jl | 21 +++++++++++++++++++++ 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index bad5f567f..e60124086 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.24.0" +version = "1.25.0" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index d526abffe..c5260489e 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -32,6 +32,7 @@ Base.:/(z::AbstractZero, ::Any) = z Base.convert(::Type{T}, x::AbstractZero) where {T<:Number} = zero(T) # (::Type{T})(::AbstractZero, ::AbstractZero...) where {T<:Number} = zero(T) +Base.promote_rule(T::Type{<:Number}, S::Type{<:AbstractZero}) = T (::Type{Complex})(x::AbstractZero, y::Real) = Complex(false, y) (::Type{Complex})(x::Real, y::AbstractZero) = Complex(x, false) @@ -92,6 +93,8 @@ end """ struct NoTangent <: AbstractZero end +Base.eltype(::Type{NoTangent}) = NoTangent + """ zero_tangent(primal) diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 7da7cfadc..85a9e1d07 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -82,6 +82,15 @@ @test convert(Float32, ZeroTangent()) === 0.0f0 @test convert(ComplexF64, ZeroTangent()) === 0.0 + 0.0im + @test promote_type(ZeroTangent, Bool) == Bool + @test promote_type(Bool, ZeroTangent) == Bool + @test promote_type(ZeroTangent, Int64) == Int64 + @test promote_type(Int64, ZeroTangent) == Int64 + @test promote_type(ZeroTangent, Float32) == Float32 + @test promote_type(Float32, ZeroTangent) == Float32 + @test promote_type(ZeroTangent, ComplexF64) == ComplexF64 + @test promote_type(ComplexF64, ZeroTangent) == ComplexF64 + @test z[1] === z @test z[1:3] === z @test z[1, 2] === z @@ -110,6 +119,18 @@ @test dot(dne, 17.2) == dne @test dot(11.9, dne) == dne + @test eltype(dne) === NoTangent + @test eltype(NoTangent) === NoTangent + + @test promote_type(NoTangent, Bool) == Bool + @test promote_type(Bool, NoTangent) == Bool + @test promote_type(NoTangent, Int64) == Int64 + @test promote_type(Int64, NoTangent) == Int64 + @test promote_type(NoTangent, Float32) == Float32 + @test promote_type(Float32, NoTangent) == Float32 + @test promote_type(NoTangent, ComplexF64) == ComplexF64 + @test promote_type(ComplexF64, NoTangent) == ComplexF64 + @test ZeroTangent() + dne == dne @test dne + ZeroTangent() == dne @test ZeroTangent() - dne == dne