ml-test/transformers.jl

133 lines
3.8 KiB
Julia
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# using LuxCore
using Random: AbstractRNG
using Lux
"""
# Fields
- `embed_dim`: Queue vector length. `q_len`
- `kvdim::Int`: Key vector and value vector length.
- `num_heads`: Number of heads.
# Calculation
```math
\\begin{aligned}
Q &\\in \\mathbb{R}^{qdim} \\\\
W^Q &\\in M_{embed \\times qdim}(\\mathbb{R}) \\\\
K &\\in \\mathbb{R}^{kvdim} \\\\
W^K &\\in M_{embed \\times kvdim}(\\mathbb{R}) \\\\
V &\\in \\mathbb{R}^{kvdim} \\\\
W^V &\\in M_{vembed \\times kvdim}(\\mathbb{R}) \\\\
head_i &= \\operatorname{Attention}(W^Q_i Q, W^K_ K, W^V_i V) \\\\
\\operatorname{Attention}(Q, K, V) &= V \\operatorname{softmax}\\left(\\frac{K^T Q}{\\sqrt{kvdim}} \\right) \\\\
\\operatorname{MultiheadAttention}(Q, K, V) &= W^O \\operatorname{Concat}(head_1, \\ldots, head_h)
\\end{aligned}
```
So far, ``Q, K, V`` are inputs `x` for the layer,
``W^Q, W^K, W^V, W^O`` are parameters `ps`,
and the layer has no states `st`.
"""
struct MultiheadAttention{F} <: LuxCore.AbstractLuxLayer
embed_dim::Int
qdim::Int
kvdim::Int
v_embed_dim::Int
num_heads::Int
init_weight::F
end
"""
MultiheadAttention(embed_dim::Int, num_heads::Int; init_weight=glorot_uniform, kw...)
Constructor.
# Arguments
- `embed_dim::Int`
- `num_heads::Int`
## Keyword Arguments
- `init_weight`: weight initialzer (rng generator)
- `qdim`: Default `embed_dim`
- `kvdim::Int`: Default: `embed_dim`
- `v_embed_dim`: Default: `embed_dim`
# Parameters and states
## Parameters
- `weight_q`: ``W^Q``
- `weight_k`: ``W^K``
- `weight_v`: ``W^V``
- `weight_o`: ``W^O``
## States
MultiheadAttention has no states.
# Inputs
- `q`: ``Q``
- `k`: ``K``
- `v`: ``V``
"""
function MultiheadAttention(
embed_dim::Int,
num_heads::Int;
init_weight = glorot_uniform,
kw...,
)
MultiheadAttention{typeof(init_weight)}(
embed_dim,
haskey(kw, :qdim) ? kw[:qdim] : embed_dim,
haskey(kw, :kvdim) ? kw[:kvdim] : embed_dim,
haskey(kw, :v_embed_dim) ? kw[:v_embed_dim] : embed_dim,
num_heads,
init_weight,
)
end
function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention)
# see the original paper for weight dimensions (note that q,k,v weights have `num_heads` of matrices)
(
weight_q = l.init_weight(rng, l.embed_dim * l.num_heads, l.qdim),
weight_k = l.init_weight(rng, l.embed_dim * l.num_heads, l.kvdim),
weight_v = l.init_weight(rng, l.v_embed_dim * l.num_heads, l.kvdim),
weight_o = l.init_weight(rng, l.v_embed_dim, l.v_embed_dim * l.num_heads), # TODO: next here maybe finished?
)
end
function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention)
NamedTuple()
end
function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
if size(x.q, 1) != l.embed_dim
ArgumentError(
"Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",
) |> throw
end
if size(x.k, 1) != l.kvdim
ArgumentError(
"Length of key must match the layer's kvdim: size(k)[1] = $(size(x.k, 1)), kvdim = $(l.kvdim)",
) |> throw
end
if size(x.v, 1) != l.kvdim
ArgumentError(
"Length of value must match the layer's kvdim: size(v)[1] = $(size(x.v, 1)), kvdim = $(l.kvdim)",
) |> throw
end
# TODO
# qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads
# [q] = (qk_dim, q_len, batch_size...)
q = ps.weight_q * x.q
# [k] = (qk_dim, kv_len, batch_size...)
k = ps.weight_k * x.k
# [v] = (v_dim, kv_len, batch_size...)
v = ps.weight_v * x.v
# [y] = (v_dim, q_len, batch_size...)
# [α] = (kv_len, q_len, nheads, batch_size...)
y, α = dot_product_attention(q, k, v; nheads = l.num_heads)
ps.weight_o * y
end
# struct TransformerEncoder <: LuxCore.AbstractLuxContainerLayer{}
# end
#