Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a switch to turn slow paths into errors #612

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/src/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,10 @@ NNlib.glu
NNlib.within_gradient
bias_act!
```

Finally, this switch changes warnings on various fallback paths into errors.
It's a bit like `CUDA.allowscalar(false)`.

```@docs
allowslow
```
9 changes: 9 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ using Statistics: mean

const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number}

"""
allowslow(::Bool)

By default, NNlib will print warnings the first time various slow fallback paths are taken.
Calling `allowslow(false)` will instead make these into errors.
"""
allowslow(flag::Bool) = (SLOWERROR[] = !flag; nothing)
const SLOWERROR = Ref(false)

# Include APIs
include("dim_helpers.jl")
export ConvDims, DenseConvDims, PoolDims, DepthwiseConvDims
Expand Down
5 changes: 4 additions & 1 deletion src/batched/batchedmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,10 @@ for (TA, fA) in _BATCHED_LIST, (TB, fB) in _BATCHED_LIST

size(A, 3) == size(C, 3) || size(A, 3) == 1 || throw(DimensionMismatch("batch size mismatch: A != C"))
size(B, 3) == size(C, 3) || size(B, 3) == 1 || throw(DimensionMismatch("batch size mismatch: B != C"))
@debug "calling fallback method for batched_mul!" typeof(A) size(A) typeof(B) size(B) typeof(C)
@warn "calling fallback method for batched_mul!" typeof(A) size(A) typeof(B) size(B) typeof(C) maxlog=1
if SLOWERROR[]
error("calling fallback method for batched_mul!")
end

Abase, Bbase = _unbatch(A), _unbatch(B)
sA, oA = size(A,3) == 1 ? (0,1) : (1,0)
Expand Down
4 changes: 4 additions & 0 deletions src/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ for (front_name, backend, signature) in (
if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
SLOWERROR[] && error(string("calling slow fallback method for ", $(string(front_name))))
end

x_cs = Iterators.partition(1:size(in1, 4),
Expand Down Expand Up @@ -232,6 +233,7 @@ for (front_name, backend, signature) in (
if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
SLOWERROR[] && error(string("calling slow fallback method for ", $(string(front_name))))
end


Expand Down Expand Up @@ -275,6 +277,7 @@ for (front_name, backend, signature) in (
if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
SLOWERROR[] && error(string("calling slow fallback method for ", $(string(front_name))))
end

dw_cs = Iterators.partition(1:size(out, 5),
Expand Down Expand Up @@ -326,6 +329,7 @@ for (front_name, backend, signature) in (
if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual
@warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ",
"You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
SLOWERROR[] && error(string("calling slow fallback method for ", $(string(front_name))))
end
$(Symbol("$(front_name)_$(backend)!"))(out, in1, in2, cdims; kwargs...)
end
Expand Down
34 changes: 17 additions & 17 deletions src/fold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,35 @@ and a potential inverse of `unfold`.
The below example demonstrates that `unfold` uses the same sliding windows as `conv`.
In general [`batched_mul`](@ref) + `unfold` should not be used to achieve convolution.
```jldoctest
julia> x = reshape([100 2 3 40 5 6 700], 7, 1, 1); # 1D data, 1 channel, batch of 1
julia> x = reshape(Float32[100 2 3 40 5 6 700], 7, 1, 1); # 1D data, 1 channel, batch of 1

julia> w = reshape([1 0 -1], 3, 1, 1); # 1D conv kernel of length 3
julia> w = reshape(Float32[1 0 -1], 3, 1, 1); # 1D conv kernel of length 3

julia> kws = (pad=1, stride=2, flipped=true); # use same args for conv and unfold

julia> z = NNlib.unfold(x, size(w); kws...)
4×3×1 Array{Int64, 3}:
4×3×1 Array{Float32, 3}:
[:, :, 1] =
0 100 2
2 3 40
40 5 6
6 700 0
0.0 100.0 2.0
2.0 3.0 40.0
40.0 5.0 6.0
6.0 700.0 0.0

julia> y1 = conv(x, w; kws...)
4×1×1 Array{Int64, 3}:
4×1×1 Array{Float32, 3}:
[:, :, 1] =
-2
-38
34
6
-2.0
-38.0
34.0
6.0

julia> y2 = z ⊠ w # ⊠ (\\boxtimes) is NNlib.batched_mul
4×1×1 Array{Int64, 3}:
4×1×1 Array{Float32, 3}:
[:, :, 1] =
-2
-38
34
6
-2.0
-38.0
34.0
6.0
```
"""
function unfold(x::AbstractArray{T, N}, kernel_size::NTuple{K}; stride = 1, pad = 0, dilation = 1, flipped = true) where {T, K, N}
Expand Down
11 changes: 11 additions & 0 deletions test/batchedmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -303,3 +303,14 @@ FiniteDifferences.to_vec(x::BatchedTranspose) = FiniteDifferences.to_vec(collect

gradtest(batched_vec, randn(rng, M, P, B), randn(rng, P))
end

@testset "warning / error" begin
prev = NNlib.SLOWERROR[]
NNlib.allowslow(true)
A = rand(1:99, 3,4,7)
B = rand(1:99, 4,5,7)
@test batched_mul(A, B) isa Array # no error!
NNlib.allowslow(false)
@test_throws Exception batched_mul(A, B)
NNlib.SLOWERROR[] = prev
end
Loading