update TransformerDecoderLayer to use mask in self attention
and found bug in MultiheadAttention
This commit is contained in:
parent
67ffd00de8
commit
5422c2fe83
2 changed files with 27 additions and 16 deletions
|
@ -4,12 +4,14 @@ true || include("transformers.jl")
|
|||
|
||||
using Random
|
||||
|
||||
const INPUT_DIM = 10
|
||||
const NUM_HEADS = 1
|
||||
const INPUT_DIM = 100
|
||||
const NUM_HEADS = 20
|
||||
const NUM_BATCH = 10
|
||||
const INPUT_LEN = 5
|
||||
|
||||
rng = Random.default_rng()
|
||||
|
||||
model_encoder = TransformerEncoderLayer(10, 2)
|
||||
model_encoder = TransformerEncoderLayer(INPUT_DIM, NUM_HEADS)
|
||||
@info "model" model_encoder
|
||||
|
||||
ps, st = LuxCore.setup(rng, model_encoder)
|
||||
|
@ -17,7 +19,7 @@ ps, st = LuxCore.setup(rng, model_encoder)
|
|||
@info "status" st
|
||||
|
||||
# Unbatched
|
||||
x = randn(rng, Float32, (10,))
|
||||
x = randn(rng, Float32, (INPUT_DIM,))
|
||||
@info "input" size(x)
|
||||
|
||||
y, st = model_encoder(x, ps, st)
|
||||
|
@ -26,7 +28,7 @@ y, st = model_encoder(x, ps, st)
|
|||
memory_unbatched = y
|
||||
|
||||
# Batched
|
||||
x = randn(rng, Float32, (10, 10))
|
||||
x = randn(rng, Float32, (INPUT_DIM, INPUT_LEN, NUM_BATCH))
|
||||
@info "input" size(x)
|
||||
|
||||
y, st = model_encoder(x, ps, st)
|
||||
|
@ -34,7 +36,7 @@ y, st = model_encoder(x, ps, st)
|
|||
|
||||
memory_batched = y
|
||||
|
||||
model_decoder = TransformerDecoderLayer(10, 2)
|
||||
model_decoder = TransformerDecoderLayer(INPUT_DIM, NUM_HEADS)
|
||||
@info "decoder model" model_decoder
|
||||
|
||||
ps, st = LuxCore.setup(rng, model_decoder)
|
||||
|
@ -42,14 +44,14 @@ ps, st = LuxCore.setup(rng, model_decoder)
|
|||
@info "status" st
|
||||
|
||||
# Unbatched
|
||||
x = randn(rng, Float32, (10,))
|
||||
x = randn(rng, Float32, (INPUT_DIM,))
|
||||
@info "input" x
|
||||
|
||||
y, st = model_decoder((; x = x, memory = memory_unbatched), ps, st)
|
||||
@info "output" y st
|
||||
|
||||
# Batched
|
||||
x = randn(rng, Float32, (10, 10))
|
||||
x = randn(rng, Float32, (INPUT_DIM, INPUT_LEN, NUM_BATCH))
|
||||
@info "input" x
|
||||
|
||||
y, st = model_decoder((; x, memory = memory_batched), ps, st)
|
||||
|
|
|
@ -159,6 +159,7 @@ function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
|
|||
end
|
||||
# TODO
|
||||
|
||||
# FIXME: expand this to multi dimensional matrix multiplication
|
||||
# qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads
|
||||
# [q] = (qk_dim, q_len, batch_size...)
|
||||
q = ps.weight_q * x.q
|
||||
|
@ -325,12 +326,12 @@ function TransformerDecoderLayer(
|
|||
layer_norm_epsilon::T2 = 1.0f-5,
|
||||
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
|
||||
sublayer_self_attention = let
|
||||
layer_split = Lux.WrappedFunction(x -> (q = x, k = x, v = x))
|
||||
layer_split = Lux.WrappedFunction(x -> (q = x.x, k = x.x, v = x.x, mask = x.mask))
|
||||
layer_self_attention = MultiheadAttention(model_dim, num_heads)
|
||||
layer_dropout = Lux.Dropout(dropout)
|
||||
layer_residual_connection = Lux.SkipConnection(
|
||||
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
|
||||
+,
|
||||
(x, res) -> x + res.x,
|
||||
)
|
||||
Lux.Chain(;
|
||||
layer_residual_connection,
|
||||
|
@ -376,21 +377,29 @@ function TransformerDecoderLayer(
|
|||
end
|
||||
|
||||
function (decoder::TransformerDecoderLayer)(x, ps, st)
|
||||
# hasproperty?
|
||||
haskey(x, :x) || throw(
|
||||
hasproperty(x, :x) || throw(
|
||||
ArgumentError(
|
||||
"input to the decoder must have keyword argument \"x\"(output input)",
|
||||
"input to the decoder must have keyword argument \"x\"(output input) and \"memory\"(output from encoder)",
|
||||
),
|
||||
)
|
||||
haskey(x, :memory) || throw(
|
||||
hasproperty(x, :memory) || throw(
|
||||
ArgumentError(
|
||||
"input to the decoder must have keyword argument \"memory\"(output from encoder)",
|
||||
"input to the decoder must have keyword argument \"x\"(output input) and \"memory\"(output from encoder)",
|
||||
),
|
||||
)
|
||||
|
||||
num_heads =
|
||||
decoder.sublayer_self_attention.layers.layer_residual_connection.layers.layers.layer_self_attention.num_heads
|
||||
self_attention_mask = if ndims(x.x) == 1
|
||||
stack(make_causal_mask(x.x) for _ in 1:(num_heads))
|
||||
else
|
||||
# batched
|
||||
stack(stack(make_causal_mask(x.x) for _ in 1:(num_heads)) for _ in 1:size(x.x, 3))
|
||||
end
|
||||
|
||||
x_sa, st_sublayer_self_attention = Lux.apply(
|
||||
decoder.sublayer_self_attention,
|
||||
x.x,
|
||||
(; x = x.x, mask = self_attention_mask),
|
||||
ps.sublayer_self_attention,
|
||||
st.sublayer_self_attention,
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue