From f507418bc668fcb6f65b9a2fac5810dcc77b3337 Mon Sep 17 00:00:00 2001 From: qwjyh Date: Thu, 27 Feb 2025 22:01:48 +0900 Subject: [PATCH] new: implement transformer decoder layer --- test_transformer.jl | 35 +++++++++++-- transformers.jl | 116 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 5 deletions(-) diff --git a/test_transformer.jl b/test_transformer.jl index 3c09b79..df1c14a 100644 --- a/test_transformer.jl +++ b/test_transformer.jl @@ -9,10 +9,10 @@ const NUM_HEADS = 1 rng = Random.default_rng() -model = TransformerEncoderLayer(10, 2) -@info "model" model +model_encoder = TransformerEncoderLayer(10, 2) +@info "model" model_encoder -ps, st = LuxCore.setup(rng, model) +ps, st = LuxCore.setup(rng, model_encoder) @info "parameters" ps @info "status" st @@ -20,12 +20,37 @@ ps, st = LuxCore.setup(rng, model) x = randn(rng, Float32, (10,)) @info "input" size(x) -y, st = model(x, ps, st) +y, st = model_encoder(x, ps, st) @info "output" y st +memory_unbatched = y + # Batched x = randn(rng, Float32, (10, 10)) @info "input" size(x) -y, st = model(x, ps, st) +y, st = model_encoder(x, ps, st) +@info "output" y st + +memory_batched = y + +model_decoder = TransformerDecoderLayer(10, 2) +@info "decoder model" model_decoder + +ps, st = LuxCore.setup(rng, model_decoder) +@info "parameter" ps +@info "status" st + +# Unbatched +x = randn(rng, Float32, (10,)) +@info "input" x + +y, st = model_decoder((; x = x, memory = memory_unbatched), ps, st) +@info "output" y st + +# Batched +x = randn(rng, Float32, (10, 10)) +@info "input" x + +y, st = model_decoder((; x, memory = memory_batched), ps, st) @info "output" y st diff --git a/transformers.jl b/transformers.jl index 3806c11..eb2e905 100644 --- a/transformers.jl +++ b/transformers.jl @@ -271,3 +271,119 @@ function (encoder::TransformerEncoderLayer)(x, ps, st) sublayer_feed_forward_network = st_sublayer_feed_forward_network, ) end + +""" + TransformerDecoderLayer <: LuxCore.AbstractLuxContainerLayer + +One layer of decoder block of Transformer model. +""" +struct TransformerDecoderLayer{LSA, LMHA, LFFN} <: LuxCore.AbstractLuxContainerLayer{( + :sublayer_self_attention, + :sublayer_multihead_attention, + :sublayer_feed_forward_network, +)} + sublayer_self_attention::LSA + sublayer_multihead_attention::LMHA + sublayer_feed_forward_network::LFFN +end + +function TransformerDecoderLayer( + model_dim::Int, + num_heads::Int, + feedforward_dim::Int = 2048, + dropout::T1 = 0.1, + activation::F = relu, + layer_norm_epsilon::T2 = 1.0f-5, +) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat} + sublayer_self_attention = let + layer_split = Lux.WrappedFunction(x -> (q = x, k = x, v = x)) + layer_self_attention = MultiheadAttention(model_dim, num_heads) + layer_dropout = Lux.Dropout(dropout) + layer_residual_connection = Lux.SkipConnection( + Lux.Chain(; layer_split, layer_self_attention, layer_dropout), + +, + ) + Lux.Chain(; + layer_residual_connection, + layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon), + name = "SelfAttentionSubLayer", + ) + end + sublayer_multihead_attention = let + layer_split = + Lux.WrappedFunction(x::NamedTuple -> (q = x.x, k = x.memory, v = x.memory)) + layer_self_attention = MultiheadAttention(model_dim, num_heads) + layer_dropout = Lux.Dropout(dropout) + layer_residual_connection = Lux.Parallel( + +, + Lux.Chain(; layer_split, layer_self_attention, layer_dropout), + Lux.WrappedFunction(x::NamedTuple -> x.x), + ) + Lux.Chain(; + layer_residual_connection, + layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon), + name = "MultiheadAttentionSubLayer", + ) + end + sublayer_feed_forward_network = let + layer_feed_forward_network = Lux.Chain( + Lux.Dense(model_dim => feedforward_dim, activation), + Lux.Dropout(dropout), + Lux.Dense(feedforward_dim => model_dim, activation), + Lux.Dropout(dropout), + ) + layer_residual_connection = Lux.SkipConnection(layer_feed_forward_network, +) + Lux.Chain(; + layer_residual_connection, + layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon), + name = "FeedForwardNetworkSubLayer", + ) + end + TransformerDecoderLayer( + sublayer_self_attention, + sublayer_multihead_attention, + sublayer_feed_forward_network, + ) +end + +function (decoder::TransformerDecoderLayer)(x, ps, st) + # hasproperty? + haskey(x, :x) || throw( + ArgumentError( + "input to the decoder must have keyword argument \"x\"(output input)", + ), + ) + haskey(x, :memory) || throw( + ArgumentError( + "input to the decoder must have keyword argument \"memory\"(output from encoder)", + ), + ) + + x_sa, st_sublayer_self_attention = Lux.apply( + decoder.sublayer_self_attention, + x.x, + ps.sublayer_self_attention, + st.sublayer_self_attention, + ) + x_mha, st_sublayer_multihead_attention = Lux.apply( + decoder.sublayer_multihead_attention, + (x = x_sa, memory = x.memory), + ps.sublayer_multihead_attention, + st.sublayer_multihead_attention, + ) + y, st_sublayer_feed_forward_network = Lux.apply( + decoder.sublayer_feed_forward_network, + x_mha, + ps.sublayer_feed_forward_network, + st.sublayer_feed_forward_network, + ) + + return ( + y, + ( + sublayer_self_attention = st_sublayer_self_attention, + sublayer_multihead_attention = st_sublayer_multihead_attention, + sublayer_feed_forward_network = st_sublayer_feed_forward_network, + ), + ) +end