new: implement transformer decoder layer

This commit is contained in:
qwjyh 2025-02-27 22:01:48 +09:00
parent 47944b830b
commit f507418bc6
2 changed files with 146 additions and 5 deletions

View file

@ -9,10 +9,10 @@ const NUM_HEADS = 1
rng = Random.default_rng()
model = TransformerEncoderLayer(10, 2)
@info "model" model
model_encoder = TransformerEncoderLayer(10, 2)
@info "model" model_encoder
ps, st = LuxCore.setup(rng, model)
ps, st = LuxCore.setup(rng, model_encoder)
@info "parameters" ps
@info "status" st
@ -20,12 +20,37 @@ ps, st = LuxCore.setup(rng, model)
x = randn(rng, Float32, (10,))
@info "input" size(x)
y, st = model(x, ps, st)
y, st = model_encoder(x, ps, st)
@info "output" y st
memory_unbatched = y
# Batched
x = randn(rng, Float32, (10, 10))
@info "input" size(x)
y, st = model(x, ps, st)
y, st = model_encoder(x, ps, st)
@info "output" y st
memory_batched = y
model_decoder = TransformerDecoderLayer(10, 2)
@info "decoder model" model_decoder
ps, st = LuxCore.setup(rng, model_decoder)
@info "parameter" ps
@info "status" st
# Unbatched
x = randn(rng, Float32, (10,))
@info "input" x
y, st = model_decoder((; x = x, memory = memory_unbatched), ps, st)
@info "output" y st
# Batched
x = randn(rng, Float32, (10, 10))
@info "input" x
y, st = model_decoder((; x, memory = memory_batched), ps, st)
@info "output" y st

View file

@ -271,3 +271,119 @@ function (encoder::TransformerEncoderLayer)(x, ps, st)
sublayer_feed_forward_network = st_sublayer_feed_forward_network,
)
end
"""
TransformerDecoderLayer <: LuxCore.AbstractLuxContainerLayer
One layer of decoder block of Transformer model.
"""
struct TransformerDecoderLayer{LSA, LMHA, LFFN} <: LuxCore.AbstractLuxContainerLayer{(
:sublayer_self_attention,
:sublayer_multihead_attention,
:sublayer_feed_forward_network,
)}
sublayer_self_attention::LSA
sublayer_multihead_attention::LMHA
sublayer_feed_forward_network::LFFN
end
function TransformerDecoderLayer(
model_dim::Int,
num_heads::Int,
feedforward_dim::Int = 2048,
dropout::T1 = 0.1,
activation::F = relu,
layer_norm_epsilon::T2 = 1.0f-5,
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
sublayer_self_attention = let
layer_split = Lux.WrappedFunction(x -> (q = x, k = x, v = x))
layer_self_attention = MultiheadAttention(model_dim, num_heads)
layer_dropout = Lux.Dropout(dropout)
layer_residual_connection = Lux.SkipConnection(
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
+,
)
Lux.Chain(;
layer_residual_connection,
layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon),
name = "SelfAttentionSubLayer",
)
end
sublayer_multihead_attention = let
layer_split =
Lux.WrappedFunction(x::NamedTuple -> (q = x.x, k = x.memory, v = x.memory))
layer_self_attention = MultiheadAttention(model_dim, num_heads)
layer_dropout = Lux.Dropout(dropout)
layer_residual_connection = Lux.Parallel(
+,
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
Lux.WrappedFunction(x::NamedTuple -> x.x),
)
Lux.Chain(;
layer_residual_connection,
layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon),
name = "MultiheadAttentionSubLayer",
)
end
sublayer_feed_forward_network = let
layer_feed_forward_network = Lux.Chain(
Lux.Dense(model_dim => feedforward_dim, activation),
Lux.Dropout(dropout),
Lux.Dense(feedforward_dim => model_dim, activation),
Lux.Dropout(dropout),
)
layer_residual_connection = Lux.SkipConnection(layer_feed_forward_network, +)
Lux.Chain(;
layer_residual_connection,
layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon),
name = "FeedForwardNetworkSubLayer",
)
end
TransformerDecoderLayer(
sublayer_self_attention,
sublayer_multihead_attention,
sublayer_feed_forward_network,
)
end
function (decoder::TransformerDecoderLayer)(x, ps, st)
# hasproperty?
haskey(x, :x) || throw(
ArgumentError(
"input to the decoder must have keyword argument \"x\"(output input)",
),
)
haskey(x, :memory) || throw(
ArgumentError(
"input to the decoder must have keyword argument \"memory\"(output from encoder)",
),
)
x_sa, st_sublayer_self_attention = Lux.apply(
decoder.sublayer_self_attention,
x.x,
ps.sublayer_self_attention,
st.sublayer_self_attention,
)
x_mha, st_sublayer_multihead_attention = Lux.apply(
decoder.sublayer_multihead_attention,
(x = x_sa, memory = x.memory),
ps.sublayer_multihead_attention,
st.sublayer_multihead_attention,
)
y, st_sublayer_feed_forward_network = Lux.apply(
decoder.sublayer_feed_forward_network,
x_mha,
ps.sublayer_feed_forward_network,
st.sublayer_feed_forward_network,
)
return (
y,
(
sublayer_self_attention = st_sublayer_self_attention,
sublayer_multihead_attention = st_sublayer_multihead_attention,
sublayer_feed_forward_network = st_sublayer_feed_forward_network,
),
)
end