update TransformerDecoderLayer to use mask in self attention

and found bug in MultiheadAttention
This commit is contained in:
qwjyh 2025-03-15 22:35:59 +09:00
parent 67ffd00de8
commit 5422c2fe83
2 changed files with 27 additions and 16 deletions

View file

@ -4,12 +4,14 @@ true || include("transformers.jl")
using Random
const INPUT_DIM = 10
const NUM_HEADS = 1
const INPUT_DIM = 100
const NUM_HEADS = 20
const NUM_BATCH = 10
const INPUT_LEN = 5
rng = Random.default_rng()
model_encoder = TransformerEncoderLayer(10, 2)
model_encoder = TransformerEncoderLayer(INPUT_DIM, NUM_HEADS)
@info "model" model_encoder
ps, st = LuxCore.setup(rng, model_encoder)
@ -17,7 +19,7 @@ ps, st = LuxCore.setup(rng, model_encoder)
@info "status" st
# Unbatched
x = randn(rng, Float32, (10,))
x = randn(rng, Float32, (INPUT_DIM,))
@info "input" size(x)
y, st = model_encoder(x, ps, st)
@ -26,7 +28,7 @@ y, st = model_encoder(x, ps, st)
memory_unbatched = y
# Batched
x = randn(rng, Float32, (10, 10))
x = randn(rng, Float32, (INPUT_DIM, INPUT_LEN, NUM_BATCH))
@info "input" size(x)
y, st = model_encoder(x, ps, st)
@ -34,7 +36,7 @@ y, st = model_encoder(x, ps, st)
memory_batched = y
model_decoder = TransformerDecoderLayer(10, 2)
model_decoder = TransformerDecoderLayer(INPUT_DIM, NUM_HEADS)
@info "decoder model" model_decoder
ps, st = LuxCore.setup(rng, model_decoder)
@ -42,14 +44,14 @@ ps, st = LuxCore.setup(rng, model_decoder)
@info "status" st
# Unbatched
x = randn(rng, Float32, (10,))
x = randn(rng, Float32, (INPUT_DIM,))
@info "input" x
y, st = model_decoder((; x = x, memory = memory_unbatched), ps, st)
@info "output" y st
# Batched
x = randn(rng, Float32, (10, 10))
x = randn(rng, Float32, (INPUT_DIM, INPUT_LEN, NUM_BATCH))
@info "input" x
y, st = model_decoder((; x, memory = memory_batched), ps, st)

View file

@ -159,6 +159,7 @@ function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
end
# TODO
# FIXME: expand this to multi dimensional matrix multiplication
# qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads
# [q] = (qk_dim, q_len, batch_size...)
q = ps.weight_q * x.q
@ -325,12 +326,12 @@ function TransformerDecoderLayer(
layer_norm_epsilon::T2 = 1.0f-5,
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
sublayer_self_attention = let
layer_split = Lux.WrappedFunction(x -> (q = x, k = x, v = x))
layer_split = Lux.WrappedFunction(x -> (q = x.x, k = x.x, v = x.x, mask = x.mask))
layer_self_attention = MultiheadAttention(model_dim, num_heads)
layer_dropout = Lux.Dropout(dropout)
layer_residual_connection = Lux.SkipConnection(
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
+,
(x, res) -> x + res.x,
)
Lux.Chain(;
layer_residual_connection,
@ -376,21 +377,29 @@ function TransformerDecoderLayer(
end
function (decoder::TransformerDecoderLayer)(x, ps, st)
# hasproperty?
haskey(x, :x) || throw(
hasproperty(x, :x) || throw(
ArgumentError(
"input to the decoder must have keyword argument \"x\"(output input)",
"input to the decoder must have keyword argument \"x\"(output input) and \"memory\"(output from encoder)",
),
)
haskey(x, :memory) || throw(
hasproperty(x, :memory) || throw(
ArgumentError(
"input to the decoder must have keyword argument \"memory\"(output from encoder)",
"input to the decoder must have keyword argument \"x\"(output input) and \"memory\"(output from encoder)",
),
)
num_heads =
decoder.sublayer_self_attention.layers.layer_residual_connection.layers.layers.layer_self_attention.num_heads
self_attention_mask = if ndims(x.x) == 1
stack(make_causal_mask(x.x) for _ in 1:(num_heads))
else
# batched
stack(stack(make_causal_mask(x.x) for _ in 1:(num_heads)) for _ in 1:size(x.x, 3))
end
x_sa, st_sublayer_self_attention = Lux.apply(
decoder.sublayer_self_attention,
x.x,
(; x = x.x, mask = self_attention_mask),
ps.sublayer_self_attention,
st.sublayer_self_attention,
)