ml-test/transformers.jl

418 lines
14 KiB
Julia
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# using LuxCore
using Random: AbstractRNG
using Lux
"""
MultiheadAttention{F} <: LuxCore.AbstractLuxLayer
Multi head attention layer used in Transformer model.
# Fields
- `embed_dim`: Queue vector length. `q_len`
- `kvdim::Int`: Key vector and value vector length.
- `num_heads`: Number of heads.
# Calculation
```math
\\begin{aligned}
Q &\\in \\mathbb{R}^{qdim} \\\\
W^Q &\\in M_{embed \\times qdim}(\\mathbb{R}) \\\\
K &\\in \\mathbb{R}^{kvdim} \\\\
W^K &\\in M_{embed \\times kvdim}(\\mathbb{R}) \\\\
V &\\in \\mathbb{R}^{kvdim} \\\\
W^V &\\in M_{vembed \\times kvdim}(\\mathbb{R}) \\\\
head_i &= \\operatorname{Attention}(W^Q_i Q, W^K_ K, W^V_i V) \\\\
\\operatorname{Attention}(Q, K, V) &= V \\operatorname{softmax}\\left(\\frac{K^T Q}{\\sqrt{kvdim}} \\right) \\\\
\\operatorname{MultiheadAttention}(Q, K, V) &= W^O \\operatorname{Concat}(head_1, \\ldots, head_h)
\\end{aligned}
```
So far, ``Q, K, V`` are inputs `x` for the layer,
``W^Q, W^K, W^V, W^O`` are parameters `ps`,
and the layer has no states `st`.
"""
struct MultiheadAttention{F} <: LuxCore.AbstractLuxLayer
embed_dim::Int
qdim::Int
kvdim::Int
v_embed_dim::Int
num_heads::Int
init_weight::F
end
"""
MultiheadAttention(embed_dim::Int, num_heads::Int; init_weight=glorot_uniform, kw...)
Constructor.
# Arguments
- `embed_dim::Int`
- `num_heads::Int`
## Keyword Arguments
- `init_weight`: weight initialzer (rng generator)
- `qdim`: Default `embed_dim`
- `kvdim::Int`: Default: `embed_dim`
- `v_embed_dim`: Default: `embed_dim`
# Parameters and states
## Parameters
- `weight_q`: ``W^Q``
- `weight_k`: ``W^K``
- `weight_v`: ``W^V``
- `weight_o`: ``W^O``
## States
MultiheadAttention has no states.
# Inputs
NamedTuple of these three variables.
- `q`: ``Q``
- `k`: ``K``
- `v`: ``V``
"""
function MultiheadAttention(
embed_dim::Int,
num_heads::Int;
init_weight = glorot_uniform,
kw...,
)
MultiheadAttention{typeof(init_weight)}(
embed_dim,
haskey(kw, :qdim) ? kw[:qdim] : embed_dim,
haskey(kw, :kvdim) ? kw[:kvdim] : embed_dim,
haskey(kw, :v_embed_dim) ? kw[:v_embed_dim] : embed_dim,
num_heads,
init_weight,
)
end
function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention)
# see the original paper for weight dimensions (note that q,k,v weights have `num_heads` of matrices)
(
weight_q = l.init_weight(rng, l.embed_dim * l.num_heads, l.qdim),
weight_k = l.init_weight(rng, l.embed_dim * l.num_heads, l.kvdim),
weight_v = l.init_weight(rng, l.v_embed_dim * l.num_heads, l.kvdim),
weight_o = l.init_weight(rng, l.v_embed_dim, l.v_embed_dim * l.num_heads), # TODO: next here maybe finished?
)
end
function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention)
NamedTuple()
end
function LuxCore.parameterlength(l::MultiheadAttention)
dim_weight_q = l.embed_dim * l.num_heads * l.qdim
dim_weight_k = l.embed_dim * l.num_heads * l.kvdim
dim_weight_v = l.v_embed_dim * l.num_heads * l.kvdim
dim_weight_o = l.embed_dim * l.v_embed_dim * l.num_heads
dim_weight_q + dim_weight_k + dim_weight_v + dim_weight_o
end
function LuxCore.statelength(l::MultiheadAttention)
0
end
function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
if size(x.q, 1) != l.embed_dim
DimensionMismatch(
"Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",
) |> throw
end
if size(x.k, 1) != l.kvdim
DimensionMismatch(
"Length of key must match the layer's kvdim: size(k)[1] = $(size(x.k, 1)), kvdim = $(l.kvdim)",
) |> throw
end
if size(x.v, 1) != l.kvdim
DimensionMismatch(
"Length of value must match the layer's kvdim: size(v)[1] = $(size(x.v, 1)), kvdim = $(l.kvdim)",
) |> throw
end
if !(ndims(x.q) == ndims(x.k) == ndims(x.v))
DimensionMismatch(
"Queue and Key and Value must have the same number of dimensions",
) |> throw
end
if ndims(x.k) > 1 && size(x.k)[2:end] != size(x.v)[2:end]
DimensionMismatch(
"Some length of key and value must match: size(k)[2:end] = $(size(x.k)[2:end]), size(v)[2:end] = $(size(x.v)[2:end])",
) |> throw
end
if haskey(x, :mask) && !isnothing(x.mask)
if size(x.mask, 1) != size(x.k, 2)
DimensionMismatch(
"First dimension of mask must be the same as the second size of key: mask size $(size(x.mask)), key size $(size(x.k))",
) |> throw
end
if size(x.mask, 2) != size(x.q, 2)
DimensionMismatch(
"Second dimension of mask must be the same as the second size of key: mask size $(size(x.mask)), key size $(size(x.q))",
) |> throw
end
if size(x.mask, 3) != l.num_heads
DimensionMismatch(
"Third dimension of mask must be the same as the number of heads: mask $(size(x.mask)), num_heads $(l.num_heads)",
) |> throw
end
end
# TODO
# qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads
# [q] = (qk_dim, q_len, batch_size...)
q = ps.weight_q * x.q
# [k] = (qk_dim, kv_len, batch_size...)
k = ps.weight_k * x.k
# [v] = (v_dim, kv_len, batch_size...)
v = ps.weight_v * x.v
# [mask] = (kv_len, q_len, nheads, batch_size)
mask = hasproperty(x, :mask) ? x.mask : nothing
# [y] = (v_dim, q_len, batch_size...)
# [α] = (kv_len, q_len, nheads, batch_size...)
y, α = dot_product_attention(q, k, v; nheads = l.num_heads, mask)
ps.weight_o * y, _st
end
"""
TransformerEncoderLayer <: LuxCore.AbstractLuxLayer
One layer of encoder block of Transformer model.
# Structure
It consists of two sublayers, `MultiheadAttention` and feed forward network(two Dense layers).
Both of them are wrapped with residual connection and followed by [`Lux.Dropout`](@ref) and [`Lux.LayerNorm`](@ref).
They are combined using [`Chain`](@ref) and [`SkipConnection`](@ref).
See the constructor of this layer or displayed layer to see the structure.
## MultiheadAttention sublayer
```
========= SkipConnection ==========
-> MultiheadAttention -> Dropout -->
-- +(add) -> LayerNorm
-> -------------------------------->
```
## FeedForwardNetworkSubLayer
The core is two chained `Dense` layer, which has internal dimension `feedforward_dim`.
Then this is wrapped with `SkipConnection` and followed by `LayerNorm`.
# Parameters & States
Since this layer is implemented as [`LuxCore.AbstractLuxContainerLayer`](@ref),
both parameters and states have structured as the container layer.
So see the "Structure" section for more detail.
# Inputs
AbstractArray with `size(x, 1) == model_dim`. (same as Dense)
"""
struct TransformerEncoderLayer{LSA, LFFN} <: LuxCore.AbstractLuxContainerLayer{(
:sublayer_self_attention,
:sublayer_feed_forward_network,
)}
sublayer_self_attention::LSA
sublayer_feed_forward_network::LFFN
end
"""
TransformerEncoderLayer(
model_dim::Int,
num_heads::Int;
feedforward_dim::Int = 2048,
dropout::T1 = 0.1,
activation::F = relu,
layer_norm_eps::T2 = 1e-5,
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
Constructor of [`TransformerEncoderLayer`](@ref).
# Arguments
- `model_dim::Int`: model input size
- `num_heads::Int`: number of heads of MultiheadAttention
- `feedforward_dim::Int = 2048`: dimension of feed forward network
- `dropout::T1 = 0.1`: dropout rate for all Dropout layers
- `activation::F = relu`: activation used in feed forward network
- `layer_norm_epsilon::T2 = 1e-5`: eps of LayerNorm
"""
function TransformerEncoderLayer(
model_dim::Int,
num_heads::Int;
feedforward_dim::Int = 2048,
dropout::T1 = 0.1,
activation::F = relu,
layer_norm_epsilon::T2 = 1.0f-5,
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
sublayer_self_attention = let
layer_split = Lux.WrappedFunction(x -> (q = x, k = x, v = x))
layer_self_attention = MultiheadAttention(model_dim, num_heads)
layer_dropout = Lux.Dropout(dropout)
layer_residual_connection = Lux.SkipConnection(
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
+,
)
Lux.Chain(;
layer_residual_connection,
layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon),
name = "SelfAttentionSubLayer",
)
end
sublayer_feed_forward_network = let
layer_feed_forward_network = Lux.Chain(
Lux.Dense(model_dim => feedforward_dim, activation),
Lux.Dropout(dropout),
Lux.Dense(feedforward_dim => model_dim, activation),
Lux.Dropout(dropout),
)
layer_residual_connection = Lux.SkipConnection(layer_feed_forward_network, +)
Lux.Chain(;
layer_residual_connection,
layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon),
name = "FeedForwardNetworkSubLayer",
)
end
TransformerEncoderLayer(sublayer_self_attention, sublayer_feed_forward_network)
end
function (encoder::TransformerEncoderLayer)(x, ps, st)
x, st_sublayer_self_attention = Lux.apply(
encoder.sublayer_self_attention,
x,
ps.sublayer_self_attention,
st.sublayer_self_attention,
)
x, st_sublayer_feed_forward_network = Lux.apply(
encoder.sublayer_feed_forward_network,
x,
ps.sublayer_feed_forward_network,
st.sublayer_feed_forward_network,
)
return x,
(
sublayer_self_attention = st_sublayer_self_attention,
sublayer_feed_forward_network = st_sublayer_feed_forward_network,
)
end
"""
TransformerDecoderLayer <: LuxCore.AbstractLuxContainerLayer
One layer of decoder block of Transformer model.
"""
struct TransformerDecoderLayer{LSA, LMHA, LFFN} <: LuxCore.AbstractLuxContainerLayer{(
:sublayer_self_attention,
:sublayer_multihead_attention,
:sublayer_feed_forward_network,
)}
sublayer_self_attention::LSA
sublayer_multihead_attention::LMHA
sublayer_feed_forward_network::LFFN
end
function TransformerDecoderLayer(
model_dim::Int,
num_heads::Int,
feedforward_dim::Int = 2048,
dropout::T1 = 0.1,
activation::F = relu,
layer_norm_epsilon::T2 = 1.0f-5,
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
sublayer_self_attention = let
layer_split = Lux.WrappedFunction(x -> (q = x, k = x, v = x))
layer_self_attention = MultiheadAttention(model_dim, num_heads)
layer_dropout = Lux.Dropout(dropout)
layer_residual_connection = Lux.SkipConnection(
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
+,
)
Lux.Chain(;
layer_residual_connection,
layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon),
name = "SelfAttentionSubLayer",
)
end
sublayer_multihead_attention = let
layer_split =
Lux.WrappedFunction(x::NamedTuple -> (q = x.x, k = x.memory, v = x.memory))
layer_self_attention = MultiheadAttention(model_dim, num_heads)
layer_dropout = Lux.Dropout(dropout)
layer_residual_connection = Lux.Parallel(
+,
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
Lux.WrappedFunction(x::NamedTuple -> x.x),
)
Lux.Chain(;
layer_residual_connection,
layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon),
name = "MultiheadAttentionSubLayer",
)
end
sublayer_feed_forward_network = let
layer_feed_forward_network = Lux.Chain(
Lux.Dense(model_dim => feedforward_dim, activation),
Lux.Dropout(dropout),
Lux.Dense(feedforward_dim => model_dim, activation),
Lux.Dropout(dropout),
)
layer_residual_connection = Lux.SkipConnection(layer_feed_forward_network, +)
Lux.Chain(;
layer_residual_connection,
layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon),
name = "FeedForwardNetworkSubLayer",
)
end
TransformerDecoderLayer(
sublayer_self_attention,
sublayer_multihead_attention,
sublayer_feed_forward_network,
)
end
function (decoder::TransformerDecoderLayer)(x, ps, st)
# hasproperty?
haskey(x, :x) || throw(
ArgumentError(
"input to the decoder must have keyword argument \"x\"(output input)",
),
)
haskey(x, :memory) || throw(
ArgumentError(
"input to the decoder must have keyword argument \"memory\"(output from encoder)",
),
)
x_sa, st_sublayer_self_attention = Lux.apply(
decoder.sublayer_self_attention,
x.x,
ps.sublayer_self_attention,
st.sublayer_self_attention,
)
x_mha, st_sublayer_multihead_attention = Lux.apply(
decoder.sublayer_multihead_attention,
(x = x_sa, memory = x.memory),
ps.sublayer_multihead_attention,
st.sublayer_multihead_attention,
)
y, st_sublayer_feed_forward_network = Lux.apply(
decoder.sublayer_feed_forward_network,
x_mha,
ps.sublayer_feed_forward_network,
st.sublayer_feed_forward_network,
)
return (
y,
(
sublayer_self_attention = st_sublayer_self_attention,
sublayer_multihead_attention = st_sublayer_multihead_attention,
sublayer_feed_forward_network = st_sublayer_feed_forward_network,
),
)
end