From 5422c2fe83e3e6c28cfa0262a5744d684412f856 Mon Sep 17 00:00:00 2001 From: qwjyh Date: Sat, 15 Mar 2025 22:35:59 +0900 Subject: [PATCH] update TransformerDecoderLayer to use mask in self attention and found bug in MultiheadAttention --- test_transformer.jl | 18 ++++++++++-------- transformers.jl | 25 +++++++++++++++++-------- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/test_transformer.jl b/test_transformer.jl index df1c14a..57600fe 100644 --- a/test_transformer.jl +++ b/test_transformer.jl @@ -4,12 +4,14 @@ true || include("transformers.jl") using Random -const INPUT_DIM = 10 -const NUM_HEADS = 1 +const INPUT_DIM = 100 +const NUM_HEADS = 20 +const NUM_BATCH = 10 +const INPUT_LEN = 5 rng = Random.default_rng() -model_encoder = TransformerEncoderLayer(10, 2) +model_encoder = TransformerEncoderLayer(INPUT_DIM, NUM_HEADS) @info "model" model_encoder ps, st = LuxCore.setup(rng, model_encoder) @@ -17,7 +19,7 @@ ps, st = LuxCore.setup(rng, model_encoder) @info "status" st # Unbatched -x = randn(rng, Float32, (10,)) +x = randn(rng, Float32, (INPUT_DIM,)) @info "input" size(x) y, st = model_encoder(x, ps, st) @@ -26,7 +28,7 @@ y, st = model_encoder(x, ps, st) memory_unbatched = y # Batched -x = randn(rng, Float32, (10, 10)) +x = randn(rng, Float32, (INPUT_DIM, INPUT_LEN, NUM_BATCH)) @info "input" size(x) y, st = model_encoder(x, ps, st) @@ -34,7 +36,7 @@ y, st = model_encoder(x, ps, st) memory_batched = y -model_decoder = TransformerDecoderLayer(10, 2) +model_decoder = TransformerDecoderLayer(INPUT_DIM, NUM_HEADS) @info "decoder model" model_decoder ps, st = LuxCore.setup(rng, model_decoder) @@ -42,14 +44,14 @@ ps, st = LuxCore.setup(rng, model_decoder) @info "status" st # Unbatched -x = randn(rng, Float32, (10,)) +x = randn(rng, Float32, (INPUT_DIM,)) @info "input" x y, st = model_decoder((; x = x, memory = memory_unbatched), ps, st) @info "output" y st # Batched -x = randn(rng, Float32, (10, 10)) +x = randn(rng, Float32, (INPUT_DIM, INPUT_LEN, NUM_BATCH)) @info "input" x y, st = model_decoder((; x, memory = memory_batched), ps, st) diff --git a/transformers.jl b/transformers.jl index 7011c46..a731555 100644 --- a/transformers.jl +++ b/transformers.jl @@ -159,6 +159,7 @@ function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple) end # TODO + # FIXME: expand this to multi dimensional matrix multiplication # qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads # [q] = (qk_dim, q_len, batch_size...) q = ps.weight_q * x.q @@ -325,12 +326,12 @@ function TransformerDecoderLayer( layer_norm_epsilon::T2 = 1.0f-5, ) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat} sublayer_self_attention = let - layer_split = Lux.WrappedFunction(x -> (q = x, k = x, v = x)) + layer_split = Lux.WrappedFunction(x -> (q = x.x, k = x.x, v = x.x, mask = x.mask)) layer_self_attention = MultiheadAttention(model_dim, num_heads) layer_dropout = Lux.Dropout(dropout) layer_residual_connection = Lux.SkipConnection( Lux.Chain(; layer_split, layer_self_attention, layer_dropout), - +, + (x, res) -> x + res.x, ) Lux.Chain(; layer_residual_connection, @@ -376,21 +377,29 @@ function TransformerDecoderLayer( end function (decoder::TransformerDecoderLayer)(x, ps, st) - # hasproperty? - haskey(x, :x) || throw( + hasproperty(x, :x) || throw( ArgumentError( - "input to the decoder must have keyword argument \"x\"(output input)", + "input to the decoder must have keyword argument \"x\"(output input) and \"memory\"(output from encoder)", ), ) - haskey(x, :memory) || throw( + hasproperty(x, :memory) || throw( ArgumentError( - "input to the decoder must have keyword argument \"memory\"(output from encoder)", + "input to the decoder must have keyword argument \"x\"(output input) and \"memory\"(output from encoder)", ), ) + num_heads = + decoder.sublayer_self_attention.layers.layer_residual_connection.layers.layers.layer_self_attention.num_heads + self_attention_mask = if ndims(x.x) == 1 + stack(make_causal_mask(x.x) for _ in 1:(num_heads)) + else + # batched + stack(stack(make_causal_mask(x.x) for _ in 1:(num_heads)) for _ in 1:size(x.x, 3)) + end + x_sa, st_sublayer_self_attention = Lux.apply( decoder.sublayer_self_attention, - x.x, + (; x = x.x, mask = self_attention_mask), ps.sublayer_self_attention, st.sublayer_self_attention, )