2025-01-28 22:19:46 +09:00
|
|
|
|
# using LuxCore
|
|
|
|
|
using Random: AbstractRNG
|
|
|
|
|
using Lux
|
|
|
|
|
|
|
|
|
|
"""
|
2025-02-15 18:12:55 +09:00
|
|
|
|
MultiheadAttention{F} <: LuxCore.AbstractLuxLayer
|
|
|
|
|
|
|
|
|
|
Multi head attention layer used in Transformer model.
|
|
|
|
|
|
2025-01-28 22:19:46 +09:00
|
|
|
|
# Fields
|
|
|
|
|
- `embed_dim`: Queue vector length. `q_len`
|
2025-02-03 22:09:23 +09:00
|
|
|
|
- `kvdim::Int`: Key vector and value vector length.
|
2025-01-28 22:19:46 +09:00
|
|
|
|
- `num_heads`: Number of heads.
|
2025-02-03 22:09:23 +09:00
|
|
|
|
|
|
|
|
|
# Calculation
|
|
|
|
|
```math
|
|
|
|
|
\\begin{aligned}
|
|
|
|
|
Q &\\in \\mathbb{R}^{qdim} \\\\
|
|
|
|
|
W^Q &\\in M_{embed \\times qdim}(\\mathbb{R}) \\\\
|
|
|
|
|
K &\\in \\mathbb{R}^{kvdim} \\\\
|
|
|
|
|
W^K &\\in M_{embed \\times kvdim}(\\mathbb{R}) \\\\
|
|
|
|
|
V &\\in \\mathbb{R}^{kvdim} \\\\
|
|
|
|
|
W^V &\\in M_{vembed \\times kvdim}(\\mathbb{R}) \\\\
|
|
|
|
|
head_i &= \\operatorname{Attention}(W^Q_i Q, W^K_ K, W^V_i V) \\\\
|
|
|
|
|
\\operatorname{Attention}(Q, K, V) &= V \\operatorname{softmax}\\left(\\frac{K^T Q}{\\sqrt{kvdim}} \\right) \\\\
|
|
|
|
|
\\operatorname{MultiheadAttention}(Q, K, V) &= W^O \\operatorname{Concat}(head_1, \\ldots, head_h)
|
|
|
|
|
\\end{aligned}
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
So far, ``Q, K, V`` are inputs `x` for the layer,
|
|
|
|
|
``W^Q, W^K, W^V, W^O`` are parameters `ps`,
|
|
|
|
|
and the layer has no states `st`.
|
2025-01-28 22:19:46 +09:00
|
|
|
|
"""
|
|
|
|
|
struct MultiheadAttention{F} <: LuxCore.AbstractLuxLayer
|
|
|
|
|
embed_dim::Int
|
2025-02-03 22:09:23 +09:00
|
|
|
|
qdim::Int
|
|
|
|
|
kvdim::Int
|
|
|
|
|
v_embed_dim::Int
|
2025-01-28 22:19:46 +09:00
|
|
|
|
num_heads::Int
|
|
|
|
|
init_weight::F
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
MultiheadAttention(embed_dim::Int, num_heads::Int; init_weight=glorot_uniform, kw...)
|
|
|
|
|
|
|
|
|
|
Constructor.
|
|
|
|
|
|
|
|
|
|
# Arguments
|
|
|
|
|
- `embed_dim::Int`
|
|
|
|
|
- `num_heads::Int`
|
|
|
|
|
|
|
|
|
|
## Keyword Arguments
|
|
|
|
|
- `init_weight`: weight initialzer (rng generator)
|
2025-02-03 22:09:23 +09:00
|
|
|
|
- `qdim`: Default `embed_dim`
|
|
|
|
|
- `kvdim::Int`: Default: `embed_dim`
|
|
|
|
|
- `v_embed_dim`: Default: `embed_dim`
|
2025-01-28 22:19:46 +09:00
|
|
|
|
|
|
|
|
|
# Parameters and states
|
|
|
|
|
|
|
|
|
|
## Parameters
|
2025-02-03 22:12:34 +09:00
|
|
|
|
- `weight_q`: ``W^Q``
|
|
|
|
|
- `weight_k`: ``W^K``
|
|
|
|
|
- `weight_v`: ``W^V``
|
|
|
|
|
- `weight_o`: ``W^O``
|
|
|
|
|
|
|
|
|
|
## States
|
|
|
|
|
MultiheadAttention has no states.
|
|
|
|
|
|
|
|
|
|
# Inputs
|
2025-02-15 18:12:55 +09:00
|
|
|
|
NamedTuple of these three variables.
|
2025-02-03 22:12:34 +09:00
|
|
|
|
- `q`: ``Q``
|
|
|
|
|
- `k`: ``K``
|
|
|
|
|
- `v`: ``V``
|
2025-01-28 22:19:46 +09:00
|
|
|
|
"""
|
|
|
|
|
function MultiheadAttention(
|
|
|
|
|
embed_dim::Int,
|
|
|
|
|
num_heads::Int;
|
|
|
|
|
init_weight = glorot_uniform,
|
|
|
|
|
kw...,
|
|
|
|
|
)
|
|
|
|
|
MultiheadAttention{typeof(init_weight)}(
|
|
|
|
|
embed_dim,
|
2025-02-03 22:09:23 +09:00
|
|
|
|
haskey(kw, :qdim) ? kw[:qdim] : embed_dim,
|
|
|
|
|
haskey(kw, :kvdim) ? kw[:kvdim] : embed_dim,
|
|
|
|
|
haskey(kw, :v_embed_dim) ? kw[:v_embed_dim] : embed_dim,
|
2025-01-28 22:19:46 +09:00
|
|
|
|
num_heads,
|
|
|
|
|
init_weight,
|
|
|
|
|
)
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention)
|
|
|
|
|
# see the original paper for weight dimensions (note that q,k,v weights have `num_heads` of matrices)
|
|
|
|
|
(
|
2025-02-03 22:09:23 +09:00
|
|
|
|
weight_q = l.init_weight(rng, l.embed_dim * l.num_heads, l.qdim),
|
|
|
|
|
weight_k = l.init_weight(rng, l.embed_dim * l.num_heads, l.kvdim),
|
|
|
|
|
weight_v = l.init_weight(rng, l.v_embed_dim * l.num_heads, l.kvdim),
|
|
|
|
|
weight_o = l.init_weight(rng, l.v_embed_dim, l.v_embed_dim * l.num_heads), # TODO: next here maybe finished?
|
2025-01-28 22:19:46 +09:00
|
|
|
|
)
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention)
|
|
|
|
|
NamedTuple()
|
|
|
|
|
end
|
|
|
|
|
|
2025-02-15 18:12:55 +09:00
|
|
|
|
function LuxCore.parameterlength(l::MultiheadAttention)
|
|
|
|
|
dim_weight_q = l.embed_dim * l.num_heads * l.qdim
|
|
|
|
|
dim_weight_k = l.embed_dim * l.num_heads * l.kvdim
|
|
|
|
|
dim_weight_v = l.v_embed_dim * l.num_heads * l.kvdim
|
|
|
|
|
dim_weight_o = l.embed_dim * l.v_embed_dim * l.num_heads
|
|
|
|
|
dim_weight_q + dim_weight_k + dim_weight_v + dim_weight_o
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function LuxCore.statelength(l::MultiheadAttention)
|
|
|
|
|
0
|
|
|
|
|
end
|
|
|
|
|
|
2025-02-03 22:12:34 +09:00
|
|
|
|
function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
|
2025-01-28 22:19:46 +09:00
|
|
|
|
if size(x.q, 1) != l.embed_dim
|
|
|
|
|
ArgumentError(
|
|
|
|
|
"Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",
|
|
|
|
|
) |> throw
|
|
|
|
|
end
|
2025-02-03 22:09:23 +09:00
|
|
|
|
if size(x.k, 1) != l.kvdim
|
2025-01-28 22:19:46 +09:00
|
|
|
|
ArgumentError(
|
2025-02-03 22:09:23 +09:00
|
|
|
|
"Length of key must match the layer's kvdim: size(k)[1] = $(size(x.k, 1)), kvdim = $(l.kvdim)",
|
2025-01-28 22:19:46 +09:00
|
|
|
|
) |> throw
|
|
|
|
|
end
|
2025-02-03 22:09:23 +09:00
|
|
|
|
if size(x.v, 1) != l.kvdim
|
2025-01-28 22:19:46 +09:00
|
|
|
|
ArgumentError(
|
2025-02-03 22:09:23 +09:00
|
|
|
|
"Length of value must match the layer's kvdim: size(v)[1] = $(size(x.v, 1)), kvdim = $(l.kvdim)",
|
2025-01-28 22:19:46 +09:00
|
|
|
|
) |> throw
|
|
|
|
|
end
|
|
|
|
|
# TODO
|
|
|
|
|
|
|
|
|
|
# qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads
|
|
|
|
|
# [q] = (qk_dim, q_len, batch_size...)
|
|
|
|
|
q = ps.weight_q * x.q
|
|
|
|
|
# [k] = (qk_dim, kv_len, batch_size...)
|
|
|
|
|
k = ps.weight_k * x.k
|
|
|
|
|
# [v] = (v_dim, kv_len, batch_size...)
|
2025-02-03 22:09:23 +09:00
|
|
|
|
v = ps.weight_v * x.v
|
2025-01-28 22:19:46 +09:00
|
|
|
|
# [y] = (v_dim, q_len, batch_size...)
|
2025-02-03 22:09:23 +09:00
|
|
|
|
# [α] = (kv_len, q_len, nheads, batch_size...)
|
2025-01-28 22:19:46 +09:00
|
|
|
|
y, α = dot_product_attention(q, k, v; nheads = l.num_heads)
|
2025-02-15 18:12:55 +09:00
|
|
|
|
ps.weight_o * y, _st
|
2025-01-28 22:19:46 +09:00
|
|
|
|
end
|
|
|
|
|
|
2025-02-15 18:12:55 +09:00
|
|
|
|
"""
|
|
|
|
|
TransformerEncoderLayer <: LuxCore.AbstractLuxLayer
|
|
|
|
|
|
|
|
|
|
One layer of encoder block of Transformer model.
|
|
|
|
|
|
|
|
|
|
# Structure
|
|
|
|
|
|
|
|
|
|
It consists of two sublayers, `MultiheadAttention` and feed forward network(two Dense layers).
|
|
|
|
|
Both of them are wrapped with residual connection and followed by [`Lux.Dropout`](@ref) and [`Lux.LayerNorm`](@ref).
|
|
|
|
|
|
|
|
|
|
They are combined using [`Chain`](@ref) and [`SkipConnection`](@ref).
|
|
|
|
|
|
|
|
|
|
See the constructor of this layer or displayed layer to see the structure.
|
|
|
|
|
|
|
|
|
|
## MultiheadAttention sublayer
|
|
|
|
|
|
|
|
|
|
```
|
|
|
|
|
========= SkipConnection ==========
|
|
|
|
|
-> MultiheadAttention -> Dropout -->
|
|
|
|
|
-- +(add) -> LayerNorm
|
|
|
|
|
-> -------------------------------->
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
## FeedForwardNetworkSubLayer
|
|
|
|
|
|
|
|
|
|
The core is two chained `Dense` layer, which has internal dimension `feedforward_dim`.
|
|
|
|
|
Then this is wrapped with `SkipConnection` and followed by `LayerNorm`.
|
|
|
|
|
|
|
|
|
|
# Parameters & States
|
|
|
|
|
|
|
|
|
|
Since this layer is implemented as [`LuxCore.AbstractLuxContainerLayer`](@ref),
|
|
|
|
|
both parameters and states have structured as the container layer.
|
|
|
|
|
So see the "Structure" section for more detail.
|
|
|
|
|
|
|
|
|
|
# Inputs
|
|
|
|
|
|
|
|
|
|
AbstractArray with `size(x, 1) == model_dim`. (same as Dense)
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
struct TransformerEncoderLayer{LSA, LFFN} <: LuxCore.AbstractLuxContainerLayer{(
|
|
|
|
|
:sublayer_self_attention,
|
|
|
|
|
:sublayer_feed_forward_network,
|
|
|
|
|
)}
|
|
|
|
|
sublayer_self_attention::LSA
|
|
|
|
|
sublayer_feed_forward_network::LFFN
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
TransformerEncoderLayer(
|
|
|
|
|
model_dim::Int,
|
|
|
|
|
num_heads::Int;
|
|
|
|
|
feedforward_dim::Int = 2048,
|
|
|
|
|
dropout::T1 = 0.1,
|
|
|
|
|
activation::F = relu,
|
|
|
|
|
layer_norm_eps::T2 = 1e-5,
|
|
|
|
|
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
|
|
|
|
|
|
|
|
|
|
Constructor of [`TransformerEncoderLayer`](@ref).
|
|
|
|
|
|
|
|
|
|
# Arguments
|
|
|
|
|
- `model_dim::Int`: model input size
|
|
|
|
|
- `num_heads::Int`: number of heads of MultiheadAttention
|
|
|
|
|
- `feedforward_dim::Int = 2048`: dimension of feed forward network
|
|
|
|
|
- `dropout::T1 = 0.1`: dropout rate for all Dropout layers
|
|
|
|
|
- `activation::F = relu`: activation used in feed forward network
|
|
|
|
|
- `layer_norm_epsilon::T2 = 1e-5`: eps of LayerNorm
|
|
|
|
|
"""
|
|
|
|
|
function TransformerEncoderLayer(
|
|
|
|
|
model_dim::Int,
|
|
|
|
|
num_heads::Int;
|
|
|
|
|
feedforward_dim::Int = 2048,
|
|
|
|
|
dropout::T1 = 0.1,
|
|
|
|
|
activation::F = relu,
|
|
|
|
|
layer_norm_epsilon::T2 = 1.0f-5,
|
|
|
|
|
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
|
|
|
|
|
sublayer_self_attention = let
|
|
|
|
|
layer_split = Lux.WrappedFunction(x -> (q = x, k = x, v = x))
|
|
|
|
|
layer_self_attention = MultiheadAttention(model_dim, num_heads)
|
|
|
|
|
layer_dropout = Lux.Dropout(dropout)
|
|
|
|
|
layer_residual_connection = Lux.SkipConnection(
|
|
|
|
|
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
|
|
|
|
|
+,
|
|
|
|
|
)
|
|
|
|
|
Lux.Chain(;
|
|
|
|
|
layer_residual_connection,
|
|
|
|
|
layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon),
|
|
|
|
|
name = "SelfAttentionSubLayer",
|
|
|
|
|
)
|
|
|
|
|
end
|
|
|
|
|
sublayer_feed_forward_network = let
|
|
|
|
|
layer_feed_forward_network = Lux.Chain(
|
|
|
|
|
Lux.Dense(model_dim => feedforward_dim, activation),
|
|
|
|
|
Lux.Dropout(dropout),
|
|
|
|
|
Lux.Dense(feedforward_dim => model_dim, activation),
|
|
|
|
|
Lux.Dropout(dropout),
|
|
|
|
|
)
|
|
|
|
|
layer_residual_connection = Lux.SkipConnection(layer_feed_forward_network, +)
|
|
|
|
|
Lux.Chain(;
|
|
|
|
|
layer_residual_connection,
|
|
|
|
|
layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon),
|
|
|
|
|
name = "FeedForwardNetworkSubLayer",
|
|
|
|
|
)
|
|
|
|
|
end
|
|
|
|
|
TransformerEncoderLayer(sublayer_self_attention, sublayer_feed_forward_network)
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function (encoder::TransformerEncoderLayer)(x, ps, st)
|
|
|
|
|
x, st_sublayer_self_attention = Lux.apply(
|
|
|
|
|
encoder.sublayer_self_attention,
|
|
|
|
|
x,
|
|
|
|
|
ps.sublayer_self_attention,
|
|
|
|
|
st.sublayer_self_attention,
|
|
|
|
|
)
|
|
|
|
|
x, st_sublayer_feed_forward_network = Lux.apply(
|
|
|
|
|
encoder.sublayer_feed_forward_network,
|
|
|
|
|
x,
|
|
|
|
|
ps.sublayer_feed_forward_network,
|
|
|
|
|
st.sublayer_feed_forward_network,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return x,
|
|
|
|
|
(
|
|
|
|
|
sublayer_self_attention = st_sublayer_self_attention,
|
|
|
|
|
sublayer_feed_forward_network = st_sublayer_feed_forward_network,
|
|
|
|
|
)
|
|
|
|
|
end
|