ml-test/transformers.jl

273 lines
8.4 KiB
Julia
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# using LuxCore
using Random: AbstractRNG
using Lux
"""
MultiheadAttention{F} <: LuxCore.AbstractLuxLayer
Multi head attention layer used in Transformer model.
# Fields
- `embed_dim`: Queue vector length. `q_len`
- `kvdim::Int`: Key vector and value vector length.
- `num_heads`: Number of heads.
# Calculation
```math
\\begin{aligned}
Q &\\in \\mathbb{R}^{qdim} \\\\
W^Q &\\in M_{embed \\times qdim}(\\mathbb{R}) \\\\
K &\\in \\mathbb{R}^{kvdim} \\\\
W^K &\\in M_{embed \\times kvdim}(\\mathbb{R}) \\\\
V &\\in \\mathbb{R}^{kvdim} \\\\
W^V &\\in M_{vembed \\times kvdim}(\\mathbb{R}) \\\\
head_i &= \\operatorname{Attention}(W^Q_i Q, W^K_ K, W^V_i V) \\\\
\\operatorname{Attention}(Q, K, V) &= V \\operatorname{softmax}\\left(\\frac{K^T Q}{\\sqrt{kvdim}} \\right) \\\\
\\operatorname{MultiheadAttention}(Q, K, V) &= W^O \\operatorname{Concat}(head_1, \\ldots, head_h)
\\end{aligned}
```
So far, ``Q, K, V`` are inputs `x` for the layer,
``W^Q, W^K, W^V, W^O`` are parameters `ps`,
and the layer has no states `st`.
"""
struct MultiheadAttention{F} <: LuxCore.AbstractLuxLayer
embed_dim::Int
qdim::Int
kvdim::Int
v_embed_dim::Int
num_heads::Int
init_weight::F
end
"""
MultiheadAttention(embed_dim::Int, num_heads::Int; init_weight=glorot_uniform, kw...)
Constructor.
# Arguments
- `embed_dim::Int`
- `num_heads::Int`
## Keyword Arguments
- `init_weight`: weight initialzer (rng generator)
- `qdim`: Default `embed_dim`
- `kvdim::Int`: Default: `embed_dim`
- `v_embed_dim`: Default: `embed_dim`
# Parameters and states
## Parameters
- `weight_q`: ``W^Q``
- `weight_k`: ``W^K``
- `weight_v`: ``W^V``
- `weight_o`: ``W^O``
## States
MultiheadAttention has no states.
# Inputs
NamedTuple of these three variables.
- `q`: ``Q``
- `k`: ``K``
- `v`: ``V``
"""
function MultiheadAttention(
embed_dim::Int,
num_heads::Int;
init_weight = glorot_uniform,
kw...,
)
MultiheadAttention{typeof(init_weight)}(
embed_dim,
haskey(kw, :qdim) ? kw[:qdim] : embed_dim,
haskey(kw, :kvdim) ? kw[:kvdim] : embed_dim,
haskey(kw, :v_embed_dim) ? kw[:v_embed_dim] : embed_dim,
num_heads,
init_weight,
)
end
function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention)
# see the original paper for weight dimensions (note that q,k,v weights have `num_heads` of matrices)
(
weight_q = l.init_weight(rng, l.embed_dim * l.num_heads, l.qdim),
weight_k = l.init_weight(rng, l.embed_dim * l.num_heads, l.kvdim),
weight_v = l.init_weight(rng, l.v_embed_dim * l.num_heads, l.kvdim),
weight_o = l.init_weight(rng, l.v_embed_dim, l.v_embed_dim * l.num_heads), # TODO: next here maybe finished?
)
end
function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention)
NamedTuple()
end
function LuxCore.parameterlength(l::MultiheadAttention)
dim_weight_q = l.embed_dim * l.num_heads * l.qdim
dim_weight_k = l.embed_dim * l.num_heads * l.kvdim
dim_weight_v = l.v_embed_dim * l.num_heads * l.kvdim
dim_weight_o = l.embed_dim * l.v_embed_dim * l.num_heads
dim_weight_q + dim_weight_k + dim_weight_v + dim_weight_o
end
function LuxCore.statelength(l::MultiheadAttention)
0
end
function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
if size(x.q, 1) != l.embed_dim
ArgumentError(
"Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",
) |> throw
end
if size(x.k, 1) != l.kvdim
ArgumentError(
"Length of key must match the layer's kvdim: size(k)[1] = $(size(x.k, 1)), kvdim = $(l.kvdim)",
) |> throw
end
if size(x.v, 1) != l.kvdim
ArgumentError(
"Length of value must match the layer's kvdim: size(v)[1] = $(size(x.v, 1)), kvdim = $(l.kvdim)",
) |> throw
end
# TODO
# qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads
# [q] = (qk_dim, q_len, batch_size...)
q = ps.weight_q * x.q
# [k] = (qk_dim, kv_len, batch_size...)
k = ps.weight_k * x.k
# [v] = (v_dim, kv_len, batch_size...)
v = ps.weight_v * x.v
# [y] = (v_dim, q_len, batch_size...)
# [α] = (kv_len, q_len, nheads, batch_size...)
y, α = dot_product_attention(q, k, v; nheads = l.num_heads)
ps.weight_o * y, _st
end
"""
TransformerEncoderLayer <: LuxCore.AbstractLuxLayer
One layer of encoder block of Transformer model.
# Structure
It consists of two sublayers, `MultiheadAttention` and feed forward network(two Dense layers).
Both of them are wrapped with residual connection and followed by [`Lux.Dropout`](@ref) and [`Lux.LayerNorm`](@ref).
They are combined using [`Chain`](@ref) and [`SkipConnection`](@ref).
See the constructor of this layer or displayed layer to see the structure.
## MultiheadAttention sublayer
```
========= SkipConnection ==========
-> MultiheadAttention -> Dropout -->
-- +(add) -> LayerNorm
-> -------------------------------->
```
## FeedForwardNetworkSubLayer
The core is two chained `Dense` layer, which has internal dimension `feedforward_dim`.
Then this is wrapped with `SkipConnection` and followed by `LayerNorm`.
# Parameters & States
Since this layer is implemented as [`LuxCore.AbstractLuxContainerLayer`](@ref),
both parameters and states have structured as the container layer.
So see the "Structure" section for more detail.
# Inputs
AbstractArray with `size(x, 1) == model_dim`. (same as Dense)
"""
struct TransformerEncoderLayer{LSA, LFFN} <: LuxCore.AbstractLuxContainerLayer{(
:sublayer_self_attention,
:sublayer_feed_forward_network,
)}
sublayer_self_attention::LSA
sublayer_feed_forward_network::LFFN
end
"""
TransformerEncoderLayer(
model_dim::Int,
num_heads::Int;
feedforward_dim::Int = 2048,
dropout::T1 = 0.1,
activation::F = relu,
layer_norm_eps::T2 = 1e-5,
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
Constructor of [`TransformerEncoderLayer`](@ref).
# Arguments
- `model_dim::Int`: model input size
- `num_heads::Int`: number of heads of MultiheadAttention
- `feedforward_dim::Int = 2048`: dimension of feed forward network
- `dropout::T1 = 0.1`: dropout rate for all Dropout layers
- `activation::F = relu`: activation used in feed forward network
- `layer_norm_epsilon::T2 = 1e-5`: eps of LayerNorm
"""
function TransformerEncoderLayer(
model_dim::Int,
num_heads::Int;
feedforward_dim::Int = 2048,
dropout::T1 = 0.1,
activation::F = relu,
layer_norm_epsilon::T2 = 1.0f-5,
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
sublayer_self_attention = let
layer_split = Lux.WrappedFunction(x -> (q = x, k = x, v = x))
layer_self_attention = MultiheadAttention(model_dim, num_heads)
layer_dropout = Lux.Dropout(dropout)
layer_residual_connection = Lux.SkipConnection(
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
+,
)
Lux.Chain(;
layer_residual_connection,
layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon),
name = "SelfAttentionSubLayer",
)
end
sublayer_feed_forward_network = let
layer_feed_forward_network = Lux.Chain(
Lux.Dense(model_dim => feedforward_dim, activation),
Lux.Dropout(dropout),
Lux.Dense(feedforward_dim => model_dim, activation),
Lux.Dropout(dropout),
)
layer_residual_connection = Lux.SkipConnection(layer_feed_forward_network, +)
Lux.Chain(;
layer_residual_connection,
layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon),
name = "FeedForwardNetworkSubLayer",
)
end
TransformerEncoderLayer(sublayer_self_attention, sublayer_feed_forward_network)
end
function (encoder::TransformerEncoderLayer)(x, ps, st)
x, st_sublayer_self_attention = Lux.apply(
encoder.sublayer_self_attention,
x,
ps.sublayer_self_attention,
st.sublayer_self_attention,
)
x, st_sublayer_feed_forward_network = Lux.apply(
encoder.sublayer_feed_forward_network,
x,
ps.sublayer_feed_forward_network,
st.sublayer_feed_forward_network,
)
return x,
(
sublayer_self_attention = st_sublayer_self_attention,
sublayer_feed_forward_network = st_sublayer_feed_forward_network,
)
end