ml-test/transformers.jl

124 lines
3.7 KiB
Julia
Raw Normal View History

2025-01-28 22:19:46 +09:00
# using LuxCore
using Random: AbstractRNG
using Lux
"""
# Fields
- `embed_dim`: Queue vector length. `q_len`
2025-02-03 22:09:23 +09:00
- `kvdim::Int`: Key vector and value vector length.
2025-01-28 22:19:46 +09:00
- `num_heads`: Number of heads.
2025-02-03 22:09:23 +09:00
# Calculation
```math
\\begin{aligned}
Q &\\in \\mathbb{R}^{qdim} \\\\
W^Q &\\in M_{embed \\times qdim}(\\mathbb{R}) \\\\
K &\\in \\mathbb{R}^{kvdim} \\\\
W^K &\\in M_{embed \\times kvdim}(\\mathbb{R}) \\\\
V &\\in \\mathbb{R}^{kvdim} \\\\
W^V &\\in M_{vembed \\times kvdim}(\\mathbb{R}) \\\\
head_i &= \\operatorname{Attention}(W^Q_i Q, W^K_ K, W^V_i V) \\\\
\\operatorname{Attention}(Q, K, V) &= V \\operatorname{softmax}\\left(\\frac{K^T Q}{\\sqrt{kvdim}} \\right) \\\\
\\operatorname{MultiheadAttention}(Q, K, V) &= W^O \\operatorname{Concat}(head_1, \\ldots, head_h)
\\end{aligned}
```
So far, ``Q, K, V`` are inputs `x` for the layer,
``W^Q, W^K, W^V, W^O`` are parameters `ps`,
and the layer has no states `st`.
2025-01-28 22:19:46 +09:00
"""
struct MultiheadAttention{F} <: LuxCore.AbstractLuxLayer
embed_dim::Int
2025-02-03 22:09:23 +09:00
qdim::Int
kvdim::Int
v_embed_dim::Int
2025-01-28 22:19:46 +09:00
num_heads::Int
init_weight::F
end
"""
MultiheadAttention(embed_dim::Int, num_heads::Int; init_weight=glorot_uniform, kw...)
Constructor.
# Arguments
- `embed_dim::Int`
- `num_heads::Int`
## Keyword Arguments
- `init_weight`: weight initialzer (rng generator)
2025-02-03 22:09:23 +09:00
- `qdim`: Default `embed_dim`
- `kvdim::Int`: Default: `embed_dim`
- `v_embed_dim`: Default: `embed_dim`
2025-01-28 22:19:46 +09:00
# Parameters and states
## Parameters
2025-02-03 22:09:23 +09:00
- `weight_q`
TODO
2025-01-28 22:19:46 +09:00
"""
function MultiheadAttention(
embed_dim::Int,
num_heads::Int;
init_weight = glorot_uniform,
kw...,
)
MultiheadAttention{typeof(init_weight)}(
embed_dim,
2025-02-03 22:09:23 +09:00
haskey(kw, :qdim) ? kw[:qdim] : embed_dim,
haskey(kw, :kvdim) ? kw[:kvdim] : embed_dim,
haskey(kw, :v_embed_dim) ? kw[:v_embed_dim] : embed_dim,
2025-01-28 22:19:46 +09:00
num_heads,
init_weight,
)
end
function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention)
# see the original paper for weight dimensions (note that q,k,v weights have `num_heads` of matrices)
(
2025-02-03 22:09:23 +09:00
weight_q = l.init_weight(rng, l.embed_dim * l.num_heads, l.qdim),
weight_k = l.init_weight(rng, l.embed_dim * l.num_heads, l.kvdim),
weight_v = l.init_weight(rng, l.v_embed_dim * l.num_heads, l.kvdim),
weight_o = l.init_weight(rng, l.v_embed_dim, l.v_embed_dim * l.num_heads), # TODO: next here maybe finished?
2025-01-28 22:19:46 +09:00
)
end
function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention)
NamedTuple()
end
2025-02-03 22:09:23 +09:00
function (l::MultiheadAttention)(x::NamedTuple, ps, st::NamedTuple)
2025-01-28 22:19:46 +09:00
if size(x.q, 1) != l.embed_dim
ArgumentError(
"Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",
) |> throw
end
2025-02-03 22:09:23 +09:00
if size(x.k, 1) != l.kvdim
2025-01-28 22:19:46 +09:00
ArgumentError(
2025-02-03 22:09:23 +09:00
"Length of key must match the layer's kvdim: size(k)[1] = $(size(x.k, 1)), kvdim = $(l.kvdim)",
2025-01-28 22:19:46 +09:00
) |> throw
end
2025-02-03 22:09:23 +09:00
if size(x.v, 1) != l.kvdim
2025-01-28 22:19:46 +09:00
ArgumentError(
2025-02-03 22:09:23 +09:00
"Length of value must match the layer's kvdim: size(v)[1] = $(size(x.v, 1)), kvdim = $(l.kvdim)",
2025-01-28 22:19:46 +09:00
) |> throw
end
# TODO
# qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads
# [q] = (qk_dim, q_len, batch_size...)
q = ps.weight_q * x.q
# [k] = (qk_dim, kv_len, batch_size...)
k = ps.weight_k * x.k
# [v] = (v_dim, kv_len, batch_size...)
2025-02-03 22:09:23 +09:00
v = ps.weight_v * x.v
2025-01-28 22:19:46 +09:00
# [y] = (v_dim, q_len, batch_size...)
2025-02-03 22:09:23 +09:00
# [α] = (kv_len, q_len, nheads, batch_size...)
2025-01-28 22:19:46 +09:00
y, α = dot_product_attention(q, k, v; nheads = l.num_heads)
2025-02-03 22:09:23 +09:00
ps.weight_o * y
2025-01-28 22:19:46 +09:00
end
# struct TransformerEncoder <: LuxCore.AbstractLuxContainerLayer{}
# end
#