Compare commits

...

2 commits

Author SHA1 Message Date
5422c2fe83 update TransformerDecoderLayer to use mask in self attention
and found bug in MultiheadAttention
2025-03-15 22:35:59 +09:00
67ffd00de8 Handle mask in MultiheadAttention 2025-03-15 21:56:23 +09:00
3 changed files with 83 additions and 24 deletions

View file

@ -4,9 +4,11 @@ true || include("transformers.jl")
using Random
const EMBED_DIM = 10
const NUM_HEADS = 2
const KV_DIM = 12
const EMBED_DIM = 100
const NUM_HEADS = 20
const KV_DIM = 120
const Q_LEN = 50
const KV_LEN = 50
rng = TaskLocalRNG()
@ -22,4 +24,21 @@ k = rand(rng, Float32, (KV_DIM,))
v = rand(rng, Float32, (KV_DIM,))
@info "q k v" summary.((q, k, v))
l((; q, k, v), ps, st)
@info "" l((; q, k, v), ps, st)
# with mask
l = MultiheadAttention(EMBED_DIM, NUM_HEADS)
@info "layer" l
ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
@info "parameters and states" ps st
q = rand(rng, Float32, (EMBED_DIM, Q_LEN))
k = rand(rng, Float32, (EMBED_DIM, KV_LEN))
v = rand(rng, Float32, (EMBED_DIM, KV_LEN))
mask = stack(make_causal_mask(k) for _ in 1:NUM_HEADS)
@info "q k v mask" summary.((q, k, v, mask))
@info "" l((; q, k, v, mask), ps, st)

View file

@ -4,12 +4,14 @@ true || include("transformers.jl")
using Random
const INPUT_DIM = 10
const NUM_HEADS = 1
const INPUT_DIM = 100
const NUM_HEADS = 20
const NUM_BATCH = 10
const INPUT_LEN = 5
rng = Random.default_rng()
model_encoder = TransformerEncoderLayer(10, 2)
model_encoder = TransformerEncoderLayer(INPUT_DIM, NUM_HEADS)
@info "model" model_encoder
ps, st = LuxCore.setup(rng, model_encoder)
@ -17,7 +19,7 @@ ps, st = LuxCore.setup(rng, model_encoder)
@info "status" st
# Unbatched
x = randn(rng, Float32, (10,))
x = randn(rng, Float32, (INPUT_DIM,))
@info "input" size(x)
y, st = model_encoder(x, ps, st)
@ -26,7 +28,7 @@ y, st = model_encoder(x, ps, st)
memory_unbatched = y
# Batched
x = randn(rng, Float32, (10, 10))
x = randn(rng, Float32, (INPUT_DIM, INPUT_LEN, NUM_BATCH))
@info "input" size(x)
y, st = model_encoder(x, ps, st)
@ -34,7 +36,7 @@ y, st = model_encoder(x, ps, st)
memory_batched = y
model_decoder = TransformerDecoderLayer(10, 2)
model_decoder = TransformerDecoderLayer(INPUT_DIM, NUM_HEADS)
@info "decoder model" model_decoder
ps, st = LuxCore.setup(rng, model_decoder)
@ -42,14 +44,14 @@ ps, st = LuxCore.setup(rng, model_decoder)
@info "status" st
# Unbatched
x = randn(rng, Float32, (10,))
x = randn(rng, Float32, (INPUT_DIM,))
@info "input" x
y, st = model_decoder((; x = x, memory = memory_unbatched), ps, st)
@info "output" y st
# Batched
x = randn(rng, Float32, (10, 10))
x = randn(rng, Float32, (INPUT_DIM, INPUT_LEN, NUM_BATCH))
@info "input" x
y, st = model_decoder((; x, memory = memory_batched), ps, st)

View file

@ -116,22 +116,50 @@ end
function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
if size(x.q, 1) != l.embed_dim
ArgumentError(
DimensionMismatch(
"Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",
) |> throw
end
if size(x.k, 1) != l.kvdim
ArgumentError(
DimensionMismatch(
"Length of key must match the layer's kvdim: size(k)[1] = $(size(x.k, 1)), kvdim = $(l.kvdim)",
) |> throw
end
if size(x.v, 1) != l.kvdim
ArgumentError(
DimensionMismatch(
"Length of value must match the layer's kvdim: size(v)[1] = $(size(x.v, 1)), kvdim = $(l.kvdim)",
) |> throw
end
if !(ndims(x.q) == ndims(x.k) == ndims(x.v))
DimensionMismatch(
"Queue and Key and Value must have the same number of dimensions",
) |> throw
end
if ndims(x.k) > 1 && size(x.k)[2:end] != size(x.v)[2:end]
DimensionMismatch(
"Some length of key and value must match: size(k)[2:end] = $(size(x.k)[2:end]), size(v)[2:end] = $(size(x.v)[2:end])",
) |> throw
end
if haskey(x, :mask) && !isnothing(x.mask)
if size(x.mask, 1) != size(x.k, 2)
DimensionMismatch(
"First dimension of mask must be the same as the second size of key: mask size $(size(x.mask)), key size $(size(x.k))",
) |> throw
end
if size(x.mask, 2) != size(x.q, 2)
DimensionMismatch(
"Second dimension of mask must be the same as the second size of key: mask size $(size(x.mask)), key size $(size(x.q))",
) |> throw
end
if size(x.mask, 3) != l.num_heads
DimensionMismatch(
"Third dimension of mask must be the same as the number of heads: mask $(size(x.mask)), num_heads $(l.num_heads)",
) |> throw
end
end
# TODO
# FIXME: expand this to multi dimensional matrix multiplication
# qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads
# [q] = (qk_dim, q_len, batch_size...)
q = ps.weight_q * x.q
@ -139,9 +167,11 @@ function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
k = ps.weight_k * x.k
# [v] = (v_dim, kv_len, batch_size...)
v = ps.weight_v * x.v
# [mask] = (kv_len, q_len, nheads, batch_size)
mask = hasproperty(x, :mask) ? x.mask : nothing
# [y] = (v_dim, q_len, batch_size...)
# [α] = (kv_len, q_len, nheads, batch_size...)
y, α = dot_product_attention(q, k, v; nheads = l.num_heads)
y, α = dot_product_attention(q, k, v; nheads = l.num_heads, mask)
ps.weight_o * y, _st
end
@ -296,12 +326,12 @@ function TransformerDecoderLayer(
layer_norm_epsilon::T2 = 1.0f-5,
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
sublayer_self_attention = let
layer_split = Lux.WrappedFunction(x -> (q = x, k = x, v = x))
layer_split = Lux.WrappedFunction(x -> (q = x.x, k = x.x, v = x.x, mask = x.mask))
layer_self_attention = MultiheadAttention(model_dim, num_heads)
layer_dropout = Lux.Dropout(dropout)
layer_residual_connection = Lux.SkipConnection(
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
+,
(x, res) -> x + res.x,
)
Lux.Chain(;
layer_residual_connection,
@ -347,21 +377,29 @@ function TransformerDecoderLayer(
end
function (decoder::TransformerDecoderLayer)(x, ps, st)
# hasproperty?
haskey(x, :x) || throw(
hasproperty(x, :x) || throw(
ArgumentError(
"input to the decoder must have keyword argument \"x\"(output input)",
"input to the decoder must have keyword argument \"x\"(output input) and \"memory\"(output from encoder)",
),
)
haskey(x, :memory) || throw(
hasproperty(x, :memory) || throw(
ArgumentError(
"input to the decoder must have keyword argument \"memory\"(output from encoder)",
"input to the decoder must have keyword argument \"x\"(output input) and \"memory\"(output from encoder)",
),
)
num_heads =
decoder.sublayer_self_attention.layers.layer_residual_connection.layers.layers.layer_self_attention.num_heads
self_attention_mask = if ndims(x.x) == 1
stack(make_causal_mask(x.x) for _ in 1:(num_heads))
else
# batched
stack(stack(make_causal_mask(x.x) for _ in 1:(num_heads)) for _ in 1:size(x.x, 3))
end
x_sa, st_sublayer_self_attention = Lux.apply(
decoder.sublayer_self_attention,
x.x,
(; x = x.x, mask = self_attention_mask),
ps.sublayer_self_attention,
st.sublayer_self_attention,
)