update TransformerDecoderLayer to use mask in self attention
and found bug in MultiheadAttention
This commit is contained in:
parent
67ffd00de8
commit
5422c2fe83
2 changed files with 27 additions and 16 deletions
|
@ -4,12 +4,14 @@ true || include("transformers.jl")
|
||||||
|
|
||||||
using Random
|
using Random
|
||||||
|
|
||||||
const INPUT_DIM = 10
|
const INPUT_DIM = 100
|
||||||
const NUM_HEADS = 1
|
const NUM_HEADS = 20
|
||||||
|
const NUM_BATCH = 10
|
||||||
|
const INPUT_LEN = 5
|
||||||
|
|
||||||
rng = Random.default_rng()
|
rng = Random.default_rng()
|
||||||
|
|
||||||
model_encoder = TransformerEncoderLayer(10, 2)
|
model_encoder = TransformerEncoderLayer(INPUT_DIM, NUM_HEADS)
|
||||||
@info "model" model_encoder
|
@info "model" model_encoder
|
||||||
|
|
||||||
ps, st = LuxCore.setup(rng, model_encoder)
|
ps, st = LuxCore.setup(rng, model_encoder)
|
||||||
|
@ -17,7 +19,7 @@ ps, st = LuxCore.setup(rng, model_encoder)
|
||||||
@info "status" st
|
@info "status" st
|
||||||
|
|
||||||
# Unbatched
|
# Unbatched
|
||||||
x = randn(rng, Float32, (10,))
|
x = randn(rng, Float32, (INPUT_DIM,))
|
||||||
@info "input" size(x)
|
@info "input" size(x)
|
||||||
|
|
||||||
y, st = model_encoder(x, ps, st)
|
y, st = model_encoder(x, ps, st)
|
||||||
|
@ -26,7 +28,7 @@ y, st = model_encoder(x, ps, st)
|
||||||
memory_unbatched = y
|
memory_unbatched = y
|
||||||
|
|
||||||
# Batched
|
# Batched
|
||||||
x = randn(rng, Float32, (10, 10))
|
x = randn(rng, Float32, (INPUT_DIM, INPUT_LEN, NUM_BATCH))
|
||||||
@info "input" size(x)
|
@info "input" size(x)
|
||||||
|
|
||||||
y, st = model_encoder(x, ps, st)
|
y, st = model_encoder(x, ps, st)
|
||||||
|
@ -34,7 +36,7 @@ y, st = model_encoder(x, ps, st)
|
||||||
|
|
||||||
memory_batched = y
|
memory_batched = y
|
||||||
|
|
||||||
model_decoder = TransformerDecoderLayer(10, 2)
|
model_decoder = TransformerDecoderLayer(INPUT_DIM, NUM_HEADS)
|
||||||
@info "decoder model" model_decoder
|
@info "decoder model" model_decoder
|
||||||
|
|
||||||
ps, st = LuxCore.setup(rng, model_decoder)
|
ps, st = LuxCore.setup(rng, model_decoder)
|
||||||
|
@ -42,14 +44,14 @@ ps, st = LuxCore.setup(rng, model_decoder)
|
||||||
@info "status" st
|
@info "status" st
|
||||||
|
|
||||||
# Unbatched
|
# Unbatched
|
||||||
x = randn(rng, Float32, (10,))
|
x = randn(rng, Float32, (INPUT_DIM,))
|
||||||
@info "input" x
|
@info "input" x
|
||||||
|
|
||||||
y, st = model_decoder((; x = x, memory = memory_unbatched), ps, st)
|
y, st = model_decoder((; x = x, memory = memory_unbatched), ps, st)
|
||||||
@info "output" y st
|
@info "output" y st
|
||||||
|
|
||||||
# Batched
|
# Batched
|
||||||
x = randn(rng, Float32, (10, 10))
|
x = randn(rng, Float32, (INPUT_DIM, INPUT_LEN, NUM_BATCH))
|
||||||
@info "input" x
|
@info "input" x
|
||||||
|
|
||||||
y, st = model_decoder((; x, memory = memory_batched), ps, st)
|
y, st = model_decoder((; x, memory = memory_batched), ps, st)
|
||||||
|
|
|
@ -159,6 +159,7 @@ function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
|
||||||
end
|
end
|
||||||
# TODO
|
# TODO
|
||||||
|
|
||||||
|
# FIXME: expand this to multi dimensional matrix multiplication
|
||||||
# qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads
|
# qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads
|
||||||
# [q] = (qk_dim, q_len, batch_size...)
|
# [q] = (qk_dim, q_len, batch_size...)
|
||||||
q = ps.weight_q * x.q
|
q = ps.weight_q * x.q
|
||||||
|
@ -325,12 +326,12 @@ function TransformerDecoderLayer(
|
||||||
layer_norm_epsilon::T2 = 1.0f-5,
|
layer_norm_epsilon::T2 = 1.0f-5,
|
||||||
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
|
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
|
||||||
sublayer_self_attention = let
|
sublayer_self_attention = let
|
||||||
layer_split = Lux.WrappedFunction(x -> (q = x, k = x, v = x))
|
layer_split = Lux.WrappedFunction(x -> (q = x.x, k = x.x, v = x.x, mask = x.mask))
|
||||||
layer_self_attention = MultiheadAttention(model_dim, num_heads)
|
layer_self_attention = MultiheadAttention(model_dim, num_heads)
|
||||||
layer_dropout = Lux.Dropout(dropout)
|
layer_dropout = Lux.Dropout(dropout)
|
||||||
layer_residual_connection = Lux.SkipConnection(
|
layer_residual_connection = Lux.SkipConnection(
|
||||||
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
|
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
|
||||||
+,
|
(x, res) -> x + res.x,
|
||||||
)
|
)
|
||||||
Lux.Chain(;
|
Lux.Chain(;
|
||||||
layer_residual_connection,
|
layer_residual_connection,
|
||||||
|
@ -376,21 +377,29 @@ function TransformerDecoderLayer(
|
||||||
end
|
end
|
||||||
|
|
||||||
function (decoder::TransformerDecoderLayer)(x, ps, st)
|
function (decoder::TransformerDecoderLayer)(x, ps, st)
|
||||||
# hasproperty?
|
hasproperty(x, :x) || throw(
|
||||||
haskey(x, :x) || throw(
|
|
||||||
ArgumentError(
|
ArgumentError(
|
||||||
"input to the decoder must have keyword argument \"x\"(output input)",
|
"input to the decoder must have keyword argument \"x\"(output input) and \"memory\"(output from encoder)",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
haskey(x, :memory) || throw(
|
hasproperty(x, :memory) || throw(
|
||||||
ArgumentError(
|
ArgumentError(
|
||||||
"input to the decoder must have keyword argument \"memory\"(output from encoder)",
|
"input to the decoder must have keyword argument \"x\"(output input) and \"memory\"(output from encoder)",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_heads =
|
||||||
|
decoder.sublayer_self_attention.layers.layer_residual_connection.layers.layers.layer_self_attention.num_heads
|
||||||
|
self_attention_mask = if ndims(x.x) == 1
|
||||||
|
stack(make_causal_mask(x.x) for _ in 1:(num_heads))
|
||||||
|
else
|
||||||
|
# batched
|
||||||
|
stack(stack(make_causal_mask(x.x) for _ in 1:(num_heads)) for _ in 1:size(x.x, 3))
|
||||||
|
end
|
||||||
|
|
||||||
x_sa, st_sublayer_self_attention = Lux.apply(
|
x_sa, st_sublayer_self_attention = Lux.apply(
|
||||||
decoder.sublayer_self_attention,
|
decoder.sublayer_self_attention,
|
||||||
x.x,
|
(; x = x.x, mask = self_attention_mask),
|
||||||
ps.sublayer_self_attention,
|
ps.sublayer_self_attention,
|
||||||
st.sublayer_self_attention,
|
st.sublayer_self_attention,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Reference in a new issue