# using LuxCore using Random: AbstractRNG using Lux """ MultiheadAttention{F} <: LuxCore.AbstractLuxLayer Multi head attention layer used in Transformer model. # Fields - `embed_dim`: Queue vector length. `q_len` - `kvdim::Int`: Key vector and value vector length. - `num_heads`: Number of heads. # Calculation ```math \\begin{aligned} Q &\\in \\mathbb{R}^{qdim} \\\\ W^Q &\\in M_{embed \\times qdim}(\\mathbb{R}) \\\\ K &\\in \\mathbb{R}^{kvdim} \\\\ W^K &\\in M_{embed \\times kvdim}(\\mathbb{R}) \\\\ V &\\in \\mathbb{R}^{kvdim} \\\\ W^V &\\in M_{vembed \\times kvdim}(\\mathbb{R}) \\\\ head_i &= \\operatorname{Attention}(W^Q_i Q, W^K_ K, W^V_i V) \\\\ \\operatorname{Attention}(Q, K, V) &= V \\operatorname{softmax}\\left(\\frac{K^T Q}{\\sqrt{kvdim}} \\right) \\\\ \\operatorname{MultiheadAttention}(Q, K, V) &= W^O \\operatorname{Concat}(head_1, \\ldots, head_h) \\end{aligned} ``` So far, ``Q, K, V`` are inputs `x` for the layer, ``W^Q, W^K, W^V, W^O`` are parameters `ps`, and the layer has no states `st`. """ struct MultiheadAttention{F} <: LuxCore.AbstractLuxLayer embed_dim::Int qdim::Int kvdim::Int v_embed_dim::Int num_heads::Int init_weight::F end """ MultiheadAttention(embed_dim::Int, num_heads::Int; init_weight=glorot_uniform, kw...) Constructor. # Arguments - `embed_dim::Int` - `num_heads::Int` ## Keyword Arguments - `init_weight`: weight initialzer (rng generator) - `qdim`: Default `embed_dim` - `kvdim::Int`: Default: `embed_dim` - `v_embed_dim`: Default: `embed_dim` # Parameters and states ## Parameters - `weight_q`: ``W^Q`` - `weight_k`: ``W^K`` - `weight_v`: ``W^V`` - `weight_o`: ``W^O`` ## States MultiheadAttention has no states. # Inputs NamedTuple of these three variables. - `q`: ``Q`` - `k`: ``K`` - `v`: ``V`` """ function MultiheadAttention( embed_dim::Int, num_heads::Int; init_weight = glorot_uniform, kw..., ) MultiheadAttention{typeof(init_weight)}( embed_dim, haskey(kw, :qdim) ? kw[:qdim] : embed_dim, haskey(kw, :kvdim) ? kw[:kvdim] : embed_dim, haskey(kw, :v_embed_dim) ? kw[:v_embed_dim] : embed_dim, num_heads, init_weight, ) end function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention) # see the original paper for weight dimensions (note that q,k,v weights have `num_heads` of matrices) ( weight_q = l.init_weight(rng, l.embed_dim * l.num_heads, l.qdim), weight_k = l.init_weight(rng, l.embed_dim * l.num_heads, l.kvdim), weight_v = l.init_weight(rng, l.v_embed_dim * l.num_heads, l.kvdim), weight_o = l.init_weight(rng, l.v_embed_dim, l.v_embed_dim * l.num_heads), # TODO: next here maybe finished? ) end function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention) NamedTuple() end function LuxCore.parameterlength(l::MultiheadAttention) dim_weight_q = l.embed_dim * l.num_heads * l.qdim dim_weight_k = l.embed_dim * l.num_heads * l.kvdim dim_weight_v = l.v_embed_dim * l.num_heads * l.kvdim dim_weight_o = l.embed_dim * l.v_embed_dim * l.num_heads dim_weight_q + dim_weight_k + dim_weight_v + dim_weight_o end function LuxCore.statelength(l::MultiheadAttention) 0 end function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple) if size(x.q, 1) != l.embed_dim ArgumentError( "Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)", ) |> throw end if size(x.k, 1) != l.kvdim ArgumentError( "Length of key must match the layer's kvdim: size(k)[1] = $(size(x.k, 1)), kvdim = $(l.kvdim)", ) |> throw end if size(x.v, 1) != l.kvdim ArgumentError( "Length of value must match the layer's kvdim: size(v)[1] = $(size(x.v, 1)), kvdim = $(l.kvdim)", ) |> throw end # TODO # qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads # [q] = (qk_dim, q_len, batch_size...) q = ps.weight_q * x.q # [k] = (qk_dim, kv_len, batch_size...) k = ps.weight_k * x.k # [v] = (v_dim, kv_len, batch_size...) v = ps.weight_v * x.v # [y] = (v_dim, q_len, batch_size...) # [α] = (kv_len, q_len, nheads, batch_size...) y, α = dot_product_attention(q, k, v; nheads = l.num_heads) ps.weight_o * y, _st end """ TransformerEncoderLayer <: LuxCore.AbstractLuxLayer One layer of encoder block of Transformer model. # Structure It consists of two sublayers, `MultiheadAttention` and feed forward network(two Dense layers). Both of them are wrapped with residual connection and followed by [`Lux.Dropout`](@ref) and [`Lux.LayerNorm`](@ref). They are combined using [`Chain`](@ref) and [`SkipConnection`](@ref). See the constructor of this layer or displayed layer to see the structure. ## MultiheadAttention sublayer ``` ========= SkipConnection ========== -> MultiheadAttention -> Dropout --> -- +(add) -> LayerNorm -> --------------------------------> ``` ## FeedForwardNetworkSubLayer The core is two chained `Dense` layer, which has internal dimension `feedforward_dim`. Then this is wrapped with `SkipConnection` and followed by `LayerNorm`. # Parameters & States Since this layer is implemented as [`LuxCore.AbstractLuxContainerLayer`](@ref), both parameters and states have structured as the container layer. So see the "Structure" section for more detail. # Inputs AbstractArray with `size(x, 1) == model_dim`. (same as Dense) """ struct TransformerEncoderLayer{LSA, LFFN} <: LuxCore.AbstractLuxContainerLayer{( :sublayer_self_attention, :sublayer_feed_forward_network, )} sublayer_self_attention::LSA sublayer_feed_forward_network::LFFN end """ TransformerEncoderLayer( model_dim::Int, num_heads::Int; feedforward_dim::Int = 2048, dropout::T1 = 0.1, activation::F = relu, layer_norm_eps::T2 = 1e-5, ) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat} Constructor of [`TransformerEncoderLayer`](@ref). # Arguments - `model_dim::Int`: model input size - `num_heads::Int`: number of heads of MultiheadAttention - `feedforward_dim::Int = 2048`: dimension of feed forward network - `dropout::T1 = 0.1`: dropout rate for all Dropout layers - `activation::F = relu`: activation used in feed forward network - `layer_norm_epsilon::T2 = 1e-5`: eps of LayerNorm """ function TransformerEncoderLayer( model_dim::Int, num_heads::Int; feedforward_dim::Int = 2048, dropout::T1 = 0.1, activation::F = relu, layer_norm_epsilon::T2 = 1.0f-5, ) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat} sublayer_self_attention = let layer_split = Lux.WrappedFunction(x -> (q = x, k = x, v = x)) layer_self_attention = MultiheadAttention(model_dim, num_heads) layer_dropout = Lux.Dropout(dropout) layer_residual_connection = Lux.SkipConnection( Lux.Chain(; layer_split, layer_self_attention, layer_dropout), +, ) Lux.Chain(; layer_residual_connection, layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon), name = "SelfAttentionSubLayer", ) end sublayer_feed_forward_network = let layer_feed_forward_network = Lux.Chain( Lux.Dense(model_dim => feedforward_dim, activation), Lux.Dropout(dropout), Lux.Dense(feedforward_dim => model_dim, activation), Lux.Dropout(dropout), ) layer_residual_connection = Lux.SkipConnection(layer_feed_forward_network, +) Lux.Chain(; layer_residual_connection, layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon), name = "FeedForwardNetworkSubLayer", ) end TransformerEncoderLayer(sublayer_self_attention, sublayer_feed_forward_network) end function (encoder::TransformerEncoderLayer)(x, ps, st) x, st_sublayer_self_attention = Lux.apply( encoder.sublayer_self_attention, x, ps.sublayer_self_attention, st.sublayer_self_attention, ) x, st_sublayer_feed_forward_network = Lux.apply( encoder.sublayer_feed_forward_network, x, ps.sublayer_feed_forward_network, st.sublayer_feed_forward_network, ) return x, ( sublayer_self_attention = st_sublayer_self_attention, sublayer_feed_forward_network = st_sublayer_feed_forward_network, ) end