Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't use gemm! methods with Metal #423

Closed
andreyz4k opened this issue Sep 20, 2024 · 3 comments
Closed

Can't use gemm! methods with Metal #423

andreyz4k opened this issue Sep 20, 2024 · 3 comments
Labels
needs info Further information is requested

Comments

@andreyz4k
Copy link

I'm trying to run some small transformer models on my Mac and I'm getting an error that it's not possible to convert ::Metal.MtlPtr{Float32} to ::Ptr{Float32} which is happening in gemm! operation. Here is the full log from the minimal example.

julia> using Transformers

julia> using Transformers.HuggingFace

julia> using Transformers.TextEncoders

julia> tokenizer, enc_model = hgf"avsolatorio/GIST-small-Embedding-v0"
(TrfTextEncoder(
├─ TextTokenizer(MatchTokenization(WordPieceTokenization(bert_uncased_tokenizer, WordPiece(vocab_size = 30522, unk = [UNK], max_char = 100)), 5 patterns)),
├─ vocab = Vocab{String, SizedArray}(size = 30522, unk = [UNK], unki = 101),
├─ config = @NamedTuple{startsym::String, endsym::String, padsym::String, trunc::Union{Nothing, Int64}}(("[CLS]", "[SEP]", "[PAD]", 512)),
├─ annotate = annotate_strings,
├─ onehot = lookup_first,
├─ decode = nestedcall(remove_conti_prefix),
├─ textprocess = Pipelines(target[token] := join_text(source); target[token] := nestedcall(cleanup  remove_prefix_space, target.token); target := (target.token)),
└─ process = Pipelines:
  ╰─ target[token] := TextEncodeBase.nestedcall(string_getvalue, source)
  ╰─ target[token] := Transformers.TextEncoders.grouping_sentence(target.token)
  ╰─ target[(token, segment)] := SequenceTemplate{String}([CLS]:<type=1> Input[1]:<type=1> [SEP]:<type=1> (Input[2]:<type=2> [SEP]:<type=2>)...)(target.token)
  ╰─ target[attention_mask] := (NeuralAttentionlib.LengthMask  Transformers.TextEncoders.getlengths(512))(target.token)
  ╰─ target[token] := TextEncodeBase.trunc_and_pad(512, [PAD], tail, tail)(target.token)
  ╰─ target[token] := TextEncodeBase.nested2batch(target.token)
  ╰─ target[segment] := TextEncodeBase.trunc_and_pad(512, 1, tail, tail)(target.segment)
  ╰─ target[segment] := TextEncodeBase.nested2batch(target.segment)
  ╰─ target[sequence_mask] := identity(target.attention_mask)
  ╰─ target := (target.token, target.segment, target.attention_mask, target.sequence_mask)
), HGFBertModel(Chain(CompositeEmbedding(token = Embed(384, 30522), position = ApplyEmbed(.+, FixedLenPositionEmbed(384, 512)), segment = ApplyEmbed(.+, Embed(384, 2), Transformers.HuggingFace.bert_ones_like)), DropoutLayer<nothing>(LayerNorm(384, ϵ = 1.0e-12))), Transformer<12>(PostNormTransformerBlock(DropoutLayer<nothing>(SelfAttention(MultiheadQKVAttenOp(head = 12, p = nothing), Fork<3>(Dense(W = (384, 384), b = true)), Dense(W = (384, 384), b = true))), LayerNorm(384, ϵ = 1.0e-12), DropoutLayer<nothing>(Chain(Dense= NNlib.gelu, W = (384, 1536), b = true), Dense(W = (1536, 384), b = true))), LayerNorm(384, ϵ = 1.0e-12))), Branch{(:pooled,) = (:hidden_state,)}(BertPooler(Dense= NNlib.tanh_fast, W = (384, 384), b = true)))))

julia> data = "some test string"
"some test string"

julia> encoded = encode(tokenizer, data)
(token = Bool[0 0  0 0; 0 0  0 0;  ; 0 0  0 0; 0 0  0 0], segment = [1, 1, 1, 1, 1], attention_mask = NeuralAttentionlib.LengthMask{1, Vector{Int32}}(Int32[5]), sequence_mask = NeuralAttentionlib.LengthMask{1, Vector{Int32}}(Int32[5]))

julia> enc_model(encoded)               # works on CPU
(hidden_state = Float32[-0.62841725 -0.6865131  -0.6662981 -0.6284234; 0.1254244 0.27522323  -0.060979918 0.1254203;  ; 0.7561999 0.7735646  0.9089587 0.75621134; 0.42806438 0.92203164  0.123192824 0.42805216;;;], attention_mask = NeuralAttentionlib.LengthMask{1, Vector{Int32}}(Int32[5]), sequence_mask = NeuralAttentionlib.LengthMask{1, Vector{Int32}}(Int32[5]), pooled = Float32[0.030063635; -0.058565706;  ; -0.14084515; -0.038214456;;])

julia> enable_gpu()
todevice (generic function with 1 method)

julia> encoded2 = todevice(encoded)
(token = Bool[0 0  0 0; 0 0  0 0;  ; 0 0  0 0; 0 0  0 0], segment = [1, 1, 1, 1, 1], attention_mask = NeuralAttentionlib.LengthMask{1, MtlVector{Int32, Metal.PrivateStorage}}(Int32[5]), sequence_mask = NeuralAttentionlib.LengthMask{1, MtlVector{Int32, Metal.PrivateStorage}}(Int32[5]))

julia> enc_model2 = todevice(enc_model)
HGFBertModel(
  Chain(
    CompositeEmbedding(
      token = Embed(384, 30522),        # 11_720_448 parameters
      position = ApplyEmbed(.+, FixedLenPositionEmbed(384, 512)),  # 196_608 parameters
      segment = ApplyEmbed(.+, Embed(384, 2), Transformers.HuggingFace.bert_ones_like),  # 768 parameters
    ),
    DropoutLayer<nothing>(
      LayerNorm(384, ϵ = 1.0e-12),      # 768 parameters
    ),
  ),
  Transformer<12>(
    PostNormTransformerBlock(
      DropoutLayer<nothing>(
        SelfAttention(
          MultiheadQKVAttenOp(head = 12, p = nothing),
          Fork<3>(Dense(W = (384, 384), b = true)),  # 443_520 parameters
          Dense(W = (384, 384), b = true),  # 147_840 parameters
        ),
      ),
      LayerNorm(384, ϵ = 1.0e-12),      # 768 parameters
      DropoutLayer<nothing>(
        Chain(
          Dense= NNlib.gelu, W = (384, 1536), b = true),  # 591_360 parameters
          Dense(W = (1536, 384), b = true),  # 590_208 parameters
        ),
      ),
      LayerNorm(384, ϵ = 1.0e-12),      # 768 parameters
    ),
  ),                  # Total: 192 arrays, 21_293_568 parameters, 28.312 KiB.
  Branch{(:pooled,) = (:hidden_state,)}(
    BertPooler(Dense= NNlib.tanh_fast, W = (384, 384), b = true)),  # 147_840 parameters
  ),
)                   # Total: 199 arrays, 33_360_000 parameters, 31.031 KiB.

julia> enc_model2(encoded2)
ERROR: MethodError: no method matching unsafe_convert(::Type{Ptr{Float32}}, ::Metal.MtlPtr{Float32})

Closest candidates are:
  unsafe_convert(::Type{Ptr{T}}, ::Array{T}) where T
   @ Base pointer.jl:65
  unsafe_convert(::Type{Ptr{T}}, ::Base.Threads.Atomic{T}) where T
   @ Base atomics.jl:328
  unsafe_convert(::Type{Ptr{T}}, ::NeuralAttentionlib.CollapsedDimsArray{T, A} where A<:(AbstractArray{T})) where T
   @ NeuralAttentionlib ~/.julia/packages/NeuralAttentionlib/2ao7i/src/matmul/collapseddims.jl:49
  ...

Stacktrace:
  [1] unsafe_gemm!
    @ ~/.julia/packages/NeuralAttentionlib/2ao7i/src/matmul/gemm.jl:48 [inlined]
  [2] unsafe_gemm_strided_batched!
    @ ~/.julia/packages/NeuralAttentionlib/2ao7i/src/matmul/gemm.jl:100 [inlined]
  [3] gemm_strided_batched_impl!
    @ ~/.julia/packages/NeuralAttentionlib/2ao7i/src/matmul/gemm.jl:27 [inlined]
  [4] gemm_strided_batched!(transA::Char, transB::Char, alpha::Float32, A::MtlArray{…}, B::MtlArray{…}, beta::Float32, C::MtlArray{…}, Ai::Static.StaticInt{…}, Aj::Static.StaticInt{…}, Bi::Static.StaticInt{…}, Bj::Static.StaticInt{…}, Ci::Static.StaticInt{…}, Cj::Static.StaticInt{…})
    @ NeuralAttentionlib ~/.julia/packages/NeuralAttentionlib/2ao7i/src/matmul/gemm.jl:177
  [5] gemm_strided_batched_wrapper
    @ ~/.julia/packages/NeuralAttentionlib/2ao7i/src/matmul/matmul.jl:62 [inlined]
  [6] matmul_wrapper
    @ ~/.julia/packages/NeuralAttentionlib/2ao7i/src/matmul/matmul.jl:124 [inlined]
  [7] matmul
    @ ~/.julia/packages/NeuralAttentionlib/2ao7i/src/matmul/matmul.jl:24 [inlined]
  [8] scaled_dot_product_score
    @ ~/.julia/packages/NeuralAttentionlib/2ao7i/src/functional/score.jl:7 [inlined]
  [9] scaled_dot_product_score(q::NeuralAttentionlib.CollapsedDimsArray{…}, k::NeuralAttentionlib.CollapsedDimsArray{…})
    @ NeuralAttentionlib ~/.julia/packages/NeuralAttentionlib/2ao7i/src/functional/score.jl:7
 [10] masked_score(::NeuralAttentionlib.GenericMaskOp{…}, ::NeuralAttentionlib.BatchedMask{…}, ::typeof(NeuralAttentionlib.scaled_dot_product_score), ::NeuralAttentionlib.CollapsedDimsArray{…}, ::Vararg{…})
    @ NeuralAttentionlib ~/.julia/packages/NeuralAttentionlib/2ao7i/src/functional/optimized.jl:55
 [11] normalized_score
    @ ~/.julia/packages/NeuralAttentionlib/2ao7i/src/functional/optimized.jl:31 [inlined]
 [12] dropout_score
    @ ~/.julia/packages/NeuralAttentionlib/2ao7i/src/functional/score.jl:33 [inlined]
 [13] attention_score
    @ ~/.julia/packages/NeuralAttentionlib/2ao7i/src/functional/score.jl:43 [inlined]
 [14] attention_score
    @ ~/.julia/packages/NeuralAttentionlib/2ao7i/src/functional/score.jl:44 [inlined]
 [15] mixing
    @ ~/.julia/packages/NeuralAttentionlib/2ao7i/src/functional/mixing.jl:3 [inlined]
 [16] generic_qkv_attention
    @ ~/.julia/packages/NeuralAttentionlib/2ao7i/src/functional/attention.jl:2 [inlined]
 [17] generic_multihead_qkv_attention
    @ ~/.julia/packages/NeuralAttentionlib/2ao7i/src/functional/attention.jl:26 [inlined]
 [18] generic_multihead_qkv_attention
    @ ~/.julia/packages/NeuralAttentionlib/2ao7i/src/functional/attention.jl:11 [inlined]
 [19] multihead_qkv_attention(::Int64, ::MtlMatrix{…}, ::MtlMatrix{…}, ::MtlMatrix{…}, ::NeuralAttentionlib.BatchedMask{…}, ::Nothing)
    @ NeuralAttentionlib ~/.julia/packages/NeuralAttentionlib/2ao7i/src/functional/attention.jl:40
 [20] AbstractAttenOp
    @ ~/.julia/packages/NeuralAttentionlib/2ao7i/src/types.jl:9 [inlined]
 [21] apply_attention_op
    @ ~/.julia/packages/Transformers/qH1VW/src/layers/attention_op.jl:33 [inlined]
 [22] apply_on_namedtuple
    @ ~/.julia/packages/Transformers/qH1VW/src/layers/attention_op.jl:269 [inlined]
 [23] SelfAttention
    @ ~/.julia/packages/Transformers/qH1VW/src/layers/layer.jl:148 [inlined]
 [24] apply_on_namedtuple
    @ ~/.julia/packages/Transformers/qH1VW/src/layers/architecture.jl:80 [inlined]
 [25] DropoutLayer
    @ ~/.julia/packages/Transformers/qH1VW/src/layers/layer.jl:16 [inlined]
 [26] apply_on_namedtuple
    @ ~/.julia/packages/Transformers/qH1VW/src/layers/architecture.jl:80 [inlined]
 [27] (::Transformers.Layers.PostNormResidual{…})(nt::@NamedTuple{})
    @ Transformers.Layers ~/.julia/packages/Transformers/qH1VW/src/layers/layer.jl:92
 [28] apply_on_namedtuple
    @ ~/.julia/packages/Transformers/qH1VW/src/layers/architecture.jl:80 [inlined]
 [29] TransformerBlock
    @ ~/.julia/packages/Transformers/qH1VW/src/layers/layer.jl:44 [inlined]
 [30] macro expansion
    @ ~/.julia/packages/Transformers/qH1VW/src/layers/layer.jl:0 [inlined]
 [31] applyblocks(blocks::NTuple{…}, f::Nothing, x::@NamedTuple{})
    @ Transformers.Layers ~/.julia/packages/Transformers/qH1VW/src/layers/layer.jl:219
 [32] Transformer
    @ ~/.julia/packages/Transformers/qH1VW/src/layers/layer.jl:208 [inlined]
 [33] hgf_model_forward
    @ ~/.julia/packages/Transformers/qH1VW/src/huggingface/implementation/bert/load.jl:44 [inlined]
 [34] (::Transformers.HuggingFace.HGFBertModel{…})(nt::@NamedTuple{})
    @ Transformers.HuggingFace ~/.julia/packages/Transformers/qH1VW/src/huggingface/models/load.jl:35
 [35] top-level scope
    @ REPL[11]:1
 [36] top-level scope
    @ ~/.julia/packages/Metal/UcSBS/src/initialization.jl:58
Some type information was truncated. Use `show(err)` to see complete types.
@maleadt
Copy link
Member

maleadt commented Sep 20, 2024

Metal.jl only implements well-known interfaces like LinearAlgebra.mul!. That does not include the NeuralAttentionlib.jl-specific gemm! (or rather gemm_strided_batched!) you're calling here. If anything, it's up to NeuralAttentionlib.jl to provide a Metal.jl-accelerated version of this function.

For batched matmul, Metal.jl only provides the low-level MPS.matmul!. If you're interested in a portable way of using this method, maybe you can look at adding a Metal.jl extension to NNlib.jl, which IIRC has a batched matmul interface.

@maleadt maleadt added the needs info Further information is requested label Sep 20, 2024
@andreyz4k
Copy link
Author

Oh, I see. I was expecting that all Metal-related specialized methods would be grouped in this package.
I checked NNlib extensions and I see that for CUDA and AMD backends just convert gemm calls to cuBLAS and rocBLAS. In any case, I've managed to disable gemm usage for MtlArrays and added a _batched_mul! implementation that uses MPS.matmul! locally, and it seems to work.
On the other hand, I don't see any performance improvements. I suspect that's because it still uses a fallback to a simple loop somewhere, but I don't think I'll be able to track it in the near future.

What's worse, I've encountered some inconsistent behavior, my small transformer network sometimes produces NaNs in outputs and sometimes runs just fine. I've tracked the problem up to the NeuralAttentionlib.matmul function that produced some (not all) NaNs in the outputs about 10% of the time. I didn't go any further for the lack of time. I should note that this never happens with the CPU backend, so I assume that this problem is Metal-specific. It's probably worth it to open a separate issue for that, but the only example I can provide will use these higher-level packages.

I'm not sure that I'll continue trying to make Metal backend work for me, I'll probably stick to CUDA for now, but let me know if I can help you with some information

@maleadt
Copy link
Member

maleadt commented Sep 21, 2024

What's worse, I've encountered some inconsistent behavior, my small transformer network sometimes produces NaNs in outputs and sometimes runs just fine. I've tracked the problem up to the NeuralAttentionlib.matmul function that produced some (not all) NaNs in the outputs about 10% of the time. I didn't go any further for the lack of time. I should note that this never happens with the CPU backend, so I assume that this problem is Metal-specific.

That's probably #381

I think we can close this issue then? It's probably worth opening an issue on NeuralAttentionlib.jl for Metal.jl support, either in the form of a package extension there that calls MPS.matmul! as the implementation of gemm_strided_batched!, or by using another package that already provides portable batched matrix-multiplication abstractions (like NNlib.jl).

@maleadt maleadt closed this as completed Sep 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs info Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants