diff --git a/src/activation.jl b/src/activation.jl index 3e7dd22c4..12dcd691c 100644 --- a/src/activation.jl +++ b/src/activation.jl @@ -1,6 +1,7 @@ export σ, sigmoid, hardσ, hardsigmoid, hardtanh, relu, leakyrelu, relu6, rrelu, elu, gelu, swish, selu, celu, softplus, softsign, logσ, logsigmoid, logcosh, mish, tanhshrink, softshrink, thresholdrelu, trelu, lisht +import Base: tanh import LoopVectorization: vifelse using LoopVectorization.SLEEFPirates: FloatType @@ -118,7 +119,7 @@ elu(x::RealOrFloatType, α = one(x)) = vifelse(x ≥ 0, x / one(x), α * (exp(x) activation function. """ function gelu(x::RealOrFloatType) - p = oftype(x / 1, π) + p = oftype(x / 1, Float64(π)) λ = oftype(x / 1, √(2 / p)) α = oftype(x / 1, 0.044715) h = oftype(x / 1, 0.5) @@ -166,7 +167,7 @@ end Continuously Differentiable Exponential Linear Units See [Continuously Differentiable Exponential Linear Units](https://arxiv.org/pdf/1704.07483.pdf). """ -celu(x::RealOrFloatType, α::Real = one(x)) = vifelse(x ≥ 0, x / one(x), α * (exp(x/α) - one(x))) +celu(x::RealOrFloatType, α::RealOrFloatType = one(x)) = vifelse(x ≥ 0, x / one(x), α * (exp(x/α) - one(x))) """ @@ -227,11 +228,8 @@ See [Softshrink Activation Function](https://www.gabormelli.com/RKB/Softshrink_A softshrink(x::RealOrFloatType, λ = oftype(x/1, 0.5)) = min(max(zero(x), x - λ), x + λ) # Provide an informative error message if activation functions are called with an array -for f in (:σ, :hardσ, :logσ, :hardtanh, :relu, :leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :lisht, :selu, :celu, :trelu, :softsign, :softplus, :logcosh, :mish, :tanhshrink, :softshrink) +for f in (:σ, :hardσ, :logσ, :tanh, :hardtanh, :relu, :leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :lisht, :selu, :celu, :trelu, :softsign, :softplus, :logcosh, :mish, :tanhshrink, :softshrink) @eval $(f)(x::AbstractArray, args...) = error("Use broadcasting (`", $(string(f)), ".(x)`) to apply activation functions to arrays.") -end - -for f in (:σ, :tanh) @eval Base.broadcasted(::typeof($f), x::Array{T, N}) where {T <: Union{Float64, Float32}, N} = vmap($f, x) end diff --git a/test/activation.jl b/test/activation.jl index 70558fc62..78249c145 100644 --- a/test/activation.jl +++ b/test/activation.jl @@ -1,6 +1,6 @@ using NNlib, Test, Zygote -ACTIVATION_FUNCTIONS = [σ, hardσ, logσ, hardtanh, relu, leakyrelu, relu6, rrelu, elu, gelu, celu, swish, lisht, selu, trelu, softplus, softsign, logcosh, mish, tanhshrink, softshrink]; +ACTIVATION_FUNCTIONS = [σ, hardσ, logσ, tanh, hardtanh, relu, leakyrelu, relu6, rrelu, elu, gelu, celu, swish, lisht, selu, trelu, softplus, softsign, logcosh, mish, tanhshrink, softshrink]; function test_value_float_precision_preserving(a) @testset "$(a): " begin @@ -112,6 +112,16 @@ end end end + @testset "Broadcasting" begin + for T in (Float32, Float64) + x = rand(T, 5) + for a in ACTIVATION_FUNCTIONS + @test a.(x) ≈ map(a, x) + @test isapprox(gradient(z -> sum(a.(z)), x)[1], a'.(x)) + end + end + end + @testset "Test Integer64 and Integer32 inputs will force Float64 outputs" begin test_value_int_input_forces_float64.(filter(x -> (x != relu && x != relu6 && x != hardtanh && x != trelu), ACTIVATION_FUNCTIONS))