complete MultiheadAttention

This commit is contained in:
qwjyh 2025-02-03 22:09:23 +09:00
parent c07d9c30cc
commit c2c2872b46
2 changed files with 48 additions and 27 deletions

View file

@ -5,11 +5,12 @@ true || include("transformers.jl")
using Random using Random
const EMBED_DIM = 10 const EMBED_DIM = 10
const NUM_HEADS = 10 const NUM_HEADS = 2
const KV_DIM = 12
rng = TaskLocalRNG() rng = TaskLocalRNG()
l = MultiheadAttention(EMBED_DIM, NUM_HEADS, vdim = 8) l = MultiheadAttention(EMBED_DIM, NUM_HEADS, kvdim = KV_DIM)
@info "layer" l @info "layer" l
ps = LuxCore.initialparameters(rng, l) ps = LuxCore.initialparameters(rng, l)
@ -17,8 +18,8 @@ st = LuxCore.initialstates(rng, l)
@info "parameters and states" ps st @info "parameters and states" ps st
q = rand(rng, Float32, (EMBED_DIM,)) q = rand(rng, Float32, (EMBED_DIM,))
k = rand(rng, Float32, (EMBED_DIM,)) k = rand(rng, Float32, (KV_DIM,))
v = rand(rng, Float32, (EMBED_DIM,)) v = rand(rng, Float32, (KV_DIM,))
@info "q k v" summary.((q, k, v)) @info "q k v" summary.((q, k, v))
l((; q, k, v), ps, st) l((; q, k, v), ps, st)

View file

@ -5,14 +5,33 @@ using Lux
""" """
# Fields # Fields
- `embed_dim`: Queue vector length. `q_len` - `embed_dim`: Queue vector length. `q_len`
- `kdim::Int`: Key vector length. - `kvdim::Int`: Key vector and value vector length.
- `vdim::Int`: Value vector length
- `num_heads`: Number of heads. - `num_heads`: Number of heads.
# Calculation
```math
\\begin{aligned}
Q &\\in \\mathbb{R}^{qdim} \\\\
W^Q &\\in M_{embed \\times qdim}(\\mathbb{R}) \\\\
K &\\in \\mathbb{R}^{kvdim} \\\\
W^K &\\in M_{embed \\times kvdim}(\\mathbb{R}) \\\\
V &\\in \\mathbb{R}^{kvdim} \\\\
W^V &\\in M_{vembed \\times kvdim}(\\mathbb{R}) \\\\
head_i &= \\operatorname{Attention}(W^Q_i Q, W^K_ K, W^V_i V) \\\\
\\operatorname{Attention}(Q, K, V) &= V \\operatorname{softmax}\\left(\\frac{K^T Q}{\\sqrt{kvdim}} \\right) \\\\
\\operatorname{MultiheadAttention}(Q, K, V) &= W^O \\operatorname{Concat}(head_1, \\ldots, head_h)
\\end{aligned}
```
So far, ``Q, K, V`` are inputs `x` for the layer,
``W^Q, W^K, W^V, W^O`` are parameters `ps`,
and the layer has no states `st`.
""" """
struct MultiheadAttention{F} <: LuxCore.AbstractLuxLayer struct MultiheadAttention{F} <: LuxCore.AbstractLuxLayer
embed_dim::Int embed_dim::Int
kdim::Int qdim::Int
vdim::Int kvdim::Int
v_embed_dim::Int
num_heads::Int num_heads::Int
init_weight::F init_weight::F
end end
@ -28,13 +47,15 @@ Constructor.
## Keyword Arguments ## Keyword Arguments
- `init_weight`: weight initialzer (rng generator) - `init_weight`: weight initialzer (rng generator)
- `kdim::Int`: Default: `embed_dim` - `qdim`: Default `embed_dim`
- `vdim::Int`: Default: `embed_dim` - `kvdim::Int`: Default: `embed_dim`
- `v_embed_dim`: Default: `embed_dim`
# Parameters and states # Parameters and states
## Parameters ## Parameters
- `weight_q`
TODO
""" """
function MultiheadAttention( function MultiheadAttention(
embed_dim::Int, embed_dim::Int,
@ -44,8 +65,9 @@ function MultiheadAttention(
) )
MultiheadAttention{typeof(init_weight)}( MultiheadAttention{typeof(init_weight)}(
embed_dim, embed_dim,
haskey(kw, :kdim) ? kw[:kdim] : embed_dim, haskey(kw, :qdim) ? kw[:qdim] : embed_dim,
haskey(kw, :vdim) ? kw[:vdim] : embed_dim, haskey(kw, :kvdim) ? kw[:kvdim] : embed_dim,
haskey(kw, :v_embed_dim) ? kw[:v_embed_dim] : embed_dim,
num_heads, num_heads,
init_weight, init_weight,
) )
@ -54,10 +76,10 @@ end
function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention) function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention)
# see the original paper for weight dimensions (note that q,k,v weights have `num_heads` of matrices) # see the original paper for weight dimensions (note that q,k,v weights have `num_heads` of matrices)
( (
weight_q = l.init_weight(rng, l.embed_dim * l.num_heads, l.embed_dim), weight_q = l.init_weight(rng, l.embed_dim * l.num_heads, l.qdim),
weight_k = l.init_weight(rng, l.embed_dim * l.num_heads, l.kdim), weight_k = l.init_weight(rng, l.embed_dim * l.num_heads, l.kvdim),
weight_v = l.init_weight(rng, l.embed_dim * l.num_heads, l.vdim), weight_v = l.init_weight(rng, l.v_embed_dim * l.num_heads, l.kvdim),
weight_o = l.init_weight(rng, 10), # TODO weight_o = l.init_weight(rng, l.v_embed_dim, l.v_embed_dim * l.num_heads), # TODO: next here maybe finished?
) )
end end
@ -65,24 +87,20 @@ function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention)
NamedTuple() NamedTuple()
end end
function (l::MultiheadAttention)( function (l::MultiheadAttention)(x::NamedTuple, ps, st::NamedTuple)
x::NamedTuple,
ps,
st::NamedTuple,
)
if size(x.q, 1) != l.embed_dim if size(x.q, 1) != l.embed_dim
ArgumentError( ArgumentError(
"Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)", "Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",
) |> throw ) |> throw
end end
if size(x.k, 1) != l.kdim if size(x.k, 1) != l.kvdim
ArgumentError( ArgumentError(
"Length of key must match the layer's kdim: size(k)[1] = $(size(x.k, 1)), kdim = $(l.kdim)", "Length of key must match the layer's kvdim: size(k)[1] = $(size(x.k, 1)), kvdim = $(l.kvdim)",
) |> throw ) |> throw
end end
if size(x.v, 1) != l.vdim if size(x.v, 1) != l.kvdim
ArgumentError( ArgumentError(
"Length of value must match the layer's vdim: size(v)[1] = $(size(x.v, 1)), vdim = $(l.vdim)", "Length of value must match the layer's kvdim: size(v)[1] = $(size(x.v, 1)), kvdim = $(l.kvdim)",
) |> throw ) |> throw
end end
# TODO # TODO
@ -93,9 +111,11 @@ function (l::MultiheadAttention)(
# [k] = (qk_dim, kv_len, batch_size...) # [k] = (qk_dim, kv_len, batch_size...)
k = ps.weight_k * x.k k = ps.weight_k * x.k
# [v] = (v_dim, kv_len, batch_size...) # [v] = (v_dim, kv_len, batch_size...)
v = ps.weight_v * x.v # TODO: dimension?? 2025-01-28T21:59:56+09:00 v = ps.weight_v * x.v
# [y] = (v_dim, q_len, batch_size...) # [y] = (v_dim, q_len, batch_size...)
# [α] = (kv_len, q_len, nheads, batch_size...)
y, α = dot_product_attention(q, k, v; nheads = l.num_heads) y, α = dot_product_attention(q, k, v; nheads = l.num_heads)
ps.weight_o * y
end end
# struct TransformerEncoder <: LuxCore.AbstractLuxContainerLayer{} # struct TransformerEncoder <: LuxCore.AbstractLuxContainerLayer{}