Skip to content

Commit

Permalink
Use NNlib.bias_act! (#2327)
Browse files Browse the repository at this point in the history
* use NNlib.bias_act

rm comments

* mend

* add to news

* Update src/layers/basic.jl

Co-authored-by: Carlo Lucibello <[email protected]>

---------

Co-authored-by: Carlo Lucibello <[email protected]>
  • Loading branch information
mcabbott and CarloLucibello authored Nov 8, 2024
1 parent c86580b commit af1e5fc
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 10 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl
* The `Flux.Optimise` module has been deprecated in favor of the Optimisers.jl package.
Now Flux re-exports the optimisers from Optimisers.jl. Most users will be uneffected by this change.
The module is still available for now, but will be removed in a future release.
* Most Flux layers will [re-use memory via `NNlib.bias_act!`](https://github.com/FluxML/Flux.jl/pull/2327), when possible.

## v0.14.22
* Data movement between devices is now provided by [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl).
Expand Down
5 changes: 2 additions & 3 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,8 @@ end

function (a::Dense)(x::AbstractVecOrMat)
_size_check(a, x, 1 => size(a.weight, 2))
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
xT = _match_eltype(a, x) # fixes Float64 input, etc.
return σ.(a.weight * xT .+ a.bias)
return NNlib.bias_act!(a.σ, a.weight * xT, a.bias) # does σ.(W*x .+ b), with fast paths
end

function (a::Dense)(x::AbstractArray)
Expand Down Expand Up @@ -466,7 +465,7 @@ function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix)
Z = reshape(Wyx, (d_z, :))

# @einsum out[o,s] := σ(Z[o,i] + b[o])
σ.(Z .+ b)
NNlib.bias_act!(σ, Z, b) # σ.(Z .+ b)
end

(a::Bilinear)(x::AbstractVecOrMat) = a(x, x)
Expand Down
9 changes: 3 additions & 6 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,9 @@ ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any)

function (c::Conv)(x::AbstractArray)
_conv_size_check(c, x)
σ = NNlib.fast_act(c.σ, x)
cdims = conv_dims(c, x)
xT = _match_eltype(c, x)
σ.(conv(xT, c.weight, cdims) .+ conv_reshape_bias(c))
NNlib.bias_act!(c.σ, conv(xT, c.weight, cdims), conv_reshape_bias(c))
end

_channels_in(l::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups
Expand Down Expand Up @@ -350,10 +349,9 @@ ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any)

function (c::ConvTranspose)(x::AbstractArray)
_conv_size_check(c, x)
σ = NNlib.fast_act(c.σ, x)
cdims = conv_transpose_dims(c, x)
xT = _match_eltype(c, x)
σ.(∇conv_data(xT, c.weight, cdims) .+ conv_reshape_bias(c))
NNlib.bias_act!(c.σ, ∇conv_data(xT, c.weight, cdims), conv_reshape_bias(c))
end

function Base.show(io::IO, l::ConvTranspose)
Expand Down Expand Up @@ -493,10 +491,9 @@ ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any)

function (c::CrossCor)(x::AbstractArray)
_conv_size_check(c, x)
σ = NNlib.fast_act(c.σ, x)
cdims = crosscor_dims(c, x)
xT = _match_eltype(c, x)
σ.(crosscor(xT, c.weight, cdims) .+ conv_reshape_bias(c))
NNlib.bias_act!(c.σ, crosscor(xT, c.weight, cdims), conv_reshape_bias(c))
end

function Base.show(io::IO, l::CrossCor)
Expand Down
2 changes: 1 addition & 1 deletion src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ function _norm_layer_forward(
β = reshape(l.β, affine_shape)

scale = γ ./ sqrt.(σ² .+ eps)
bias = -scale .* μ .+ β
bias = .-scale .* μ .+ β
l.λ.(scale .* x .+ bias)
end

Expand Down

0 comments on commit af1e5fc

Please sign in to comment.