From 5bdcb56725b2487fb2783048aee3213eb0c61873 Mon Sep 17 00:00:00 2001 From: qwjyh Date: Sat, 15 Feb 2025 18:12:55 +0900 Subject: [PATCH] new(transformer): implement TransformerEncoderLayer and add test --- test_transformer.jl | 34 +++++----- transformers.jl | 148 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 164 insertions(+), 18 deletions(-) diff --git a/test_transformer.jl b/test_transformer.jl index f1874f1..3c09b79 100644 --- a/test_transformer.jl +++ b/test_transformer.jl @@ -4,22 +4,28 @@ true || include("transformers.jl") using Random -const EMBED_DIM = 10 -const NUM_HEADS = 2 -const KV_DIM = 12 +const INPUT_DIM = 10 +const NUM_HEADS = 1 -rng = TaskLocalRNG() +rng = Random.default_rng() -l = MultiheadAttention(EMBED_DIM, NUM_HEADS, kvdim = KV_DIM) -@info "layer" l +model = TransformerEncoderLayer(10, 2) +@info "model" model -ps = LuxCore.initialparameters(rng, l) -st = LuxCore.initialstates(rng, l) -@info "parameters and states" ps st +ps, st = LuxCore.setup(rng, model) +@info "parameters" ps +@info "status" st -q = rand(rng, Float32, (EMBED_DIM,)) -k = rand(rng, Float32, (KV_DIM,)) -v = rand(rng, Float32, (KV_DIM,)) -@info "q k v" summary.((q, k, v)) +# Unbatched +x = randn(rng, Float32, (10,)) +@info "input" size(x) -l((; q, k, v), ps, st) +y, st = model(x, ps, st) +@info "output" y st + +# Batched +x = randn(rng, Float32, (10, 10)) +@info "input" size(x) + +y, st = model(x, ps, st) +@info "output" y st diff --git a/transformers.jl b/transformers.jl index ef0a9f4..3806c11 100644 --- a/transformers.jl +++ b/transformers.jl @@ -3,6 +3,10 @@ using Random: AbstractRNG using Lux """ + MultiheadAttention{F} <: LuxCore.AbstractLuxLayer + +Multi head attention layer used in Transformer model. + # Fields - `embed_dim`: Queue vector length. `q_len` - `kvdim::Int`: Key vector and value vector length. @@ -63,6 +67,7 @@ Constructor. MultiheadAttention has no states. # Inputs +NamedTuple of these three variables. - `q`: ``Q`` - `k`: ``K`` - `v`: ``V`` @@ -97,6 +102,18 @@ function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention) NamedTuple() end +function LuxCore.parameterlength(l::MultiheadAttention) + dim_weight_q = l.embed_dim * l.num_heads * l.qdim + dim_weight_k = l.embed_dim * l.num_heads * l.kvdim + dim_weight_v = l.v_embed_dim * l.num_heads * l.kvdim + dim_weight_o = l.embed_dim * l.v_embed_dim * l.num_heads + dim_weight_q + dim_weight_k + dim_weight_v + dim_weight_o +end + +function LuxCore.statelength(l::MultiheadAttention) + 0 +end + function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple) if size(x.q, 1) != l.embed_dim ArgumentError( @@ -125,9 +142,132 @@ function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple) # [y] = (v_dim, q_len, batch_size...) # [α] = (kv_len, q_len, nheads, batch_size...) y, α = dot_product_attention(q, k, v; nheads = l.num_heads) - ps.weight_o * y + ps.weight_o * y, _st end -# struct TransformerEncoder <: LuxCore.AbstractLuxContainerLayer{} -# end -# +""" + TransformerEncoderLayer <: LuxCore.AbstractLuxLayer + +One layer of encoder block of Transformer model. + +# Structure + +It consists of two sublayers, `MultiheadAttention` and feed forward network(two Dense layers). +Both of them are wrapped with residual connection and followed by [`Lux.Dropout`](@ref) and [`Lux.LayerNorm`](@ref). + +They are combined using [`Chain`](@ref) and [`SkipConnection`](@ref). + +See the constructor of this layer or displayed layer to see the structure. + +## MultiheadAttention sublayer + +``` + ========= SkipConnection ========== + -> MultiheadAttention -> Dropout --> +-- +(add) -> LayerNorm + -> --------------------------------> +``` + +## FeedForwardNetworkSubLayer + +The core is two chained `Dense` layer, which has internal dimension `feedforward_dim`. +Then this is wrapped with `SkipConnection` and followed by `LayerNorm`. + +# Parameters & States + +Since this layer is implemented as [`LuxCore.AbstractLuxContainerLayer`](@ref), +both parameters and states have structured as the container layer. +So see the "Structure" section for more detail. + +# Inputs + +AbstractArray with `size(x, 1) == model_dim`. (same as Dense) + +""" +struct TransformerEncoderLayer{LSA, LFFN} <: LuxCore.AbstractLuxContainerLayer{( + :sublayer_self_attention, + :sublayer_feed_forward_network, +)} + sublayer_self_attention::LSA + sublayer_feed_forward_network::LFFN +end + +""" + TransformerEncoderLayer( + model_dim::Int, + num_heads::Int; + feedforward_dim::Int = 2048, + dropout::T1 = 0.1, + activation::F = relu, + layer_norm_eps::T2 = 1e-5, + ) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat} + +Constructor of [`TransformerEncoderLayer`](@ref). + +# Arguments +- `model_dim::Int`: model input size +- `num_heads::Int`: number of heads of MultiheadAttention +- `feedforward_dim::Int = 2048`: dimension of feed forward network +- `dropout::T1 = 0.1`: dropout rate for all Dropout layers +- `activation::F = relu`: activation used in feed forward network +- `layer_norm_epsilon::T2 = 1e-5`: eps of LayerNorm +""" +function TransformerEncoderLayer( + model_dim::Int, + num_heads::Int; + feedforward_dim::Int = 2048, + dropout::T1 = 0.1, + activation::F = relu, + layer_norm_epsilon::T2 = 1.0f-5, +) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat} + sublayer_self_attention = let + layer_split = Lux.WrappedFunction(x -> (q = x, k = x, v = x)) + layer_self_attention = MultiheadAttention(model_dim, num_heads) + layer_dropout = Lux.Dropout(dropout) + layer_residual_connection = Lux.SkipConnection( + Lux.Chain(; layer_split, layer_self_attention, layer_dropout), + +, + ) + Lux.Chain(; + layer_residual_connection, + layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon), + name = "SelfAttentionSubLayer", + ) + end + sublayer_feed_forward_network = let + layer_feed_forward_network = Lux.Chain( + Lux.Dense(model_dim => feedforward_dim, activation), + Lux.Dropout(dropout), + Lux.Dense(feedforward_dim => model_dim, activation), + Lux.Dropout(dropout), + ) + layer_residual_connection = Lux.SkipConnection(layer_feed_forward_network, +) + Lux.Chain(; + layer_residual_connection, + layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon), + name = "FeedForwardNetworkSubLayer", + ) + end + TransformerEncoderLayer(sublayer_self_attention, sublayer_feed_forward_network) +end + +function (encoder::TransformerEncoderLayer)(x, ps, st) + x, st_sublayer_self_attention = Lux.apply( + encoder.sublayer_self_attention, + x, + ps.sublayer_self_attention, + st.sublayer_self_attention, + ) + x, st_sublayer_feed_forward_network = Lux.apply( + encoder.sublayer_feed_forward_network, + x, + ps.sublayer_feed_forward_network, + st.sublayer_feed_forward_network, + ) + + return x, + ( + sublayer_self_attention = st_sublayer_self_attention, + sublayer_feed_forward_network = st_sublayer_feed_forward_network, + ) +end