From c2c2872b4647329b4d1a8719e528d5d4cd957960 Mon Sep 17 00:00:00 2001 From: qwjyh Date: Mon, 3 Feb 2025 22:09:23 +0900 Subject: [PATCH] complete MultiheadAttention --- test_transformer.jl | 9 ++++--- transformers.jl | 66 +++++++++++++++++++++++++++++---------------- 2 files changed, 48 insertions(+), 27 deletions(-) diff --git a/test_transformer.jl b/test_transformer.jl index 072c952..f1874f1 100644 --- a/test_transformer.jl +++ b/test_transformer.jl @@ -5,11 +5,12 @@ true || include("transformers.jl") using Random const EMBED_DIM = 10 -const NUM_HEADS = 10 +const NUM_HEADS = 2 +const KV_DIM = 12 rng = TaskLocalRNG() -l = MultiheadAttention(EMBED_DIM, NUM_HEADS, vdim = 8) +l = MultiheadAttention(EMBED_DIM, NUM_HEADS, kvdim = KV_DIM) @info "layer" l ps = LuxCore.initialparameters(rng, l) @@ -17,8 +18,8 @@ st = LuxCore.initialstates(rng, l) @info "parameters and states" ps st q = rand(rng, Float32, (EMBED_DIM,)) -k = rand(rng, Float32, (EMBED_DIM,)) -v = rand(rng, Float32, (EMBED_DIM,)) +k = rand(rng, Float32, (KV_DIM,)) +v = rand(rng, Float32, (KV_DIM,)) @info "q k v" summary.((q, k, v)) l((; q, k, v), ps, st) diff --git a/transformers.jl b/transformers.jl index 65f3bb7..bb7b17f 100644 --- a/transformers.jl +++ b/transformers.jl @@ -5,14 +5,33 @@ using Lux """ # Fields - `embed_dim`: Queue vector length. `q_len` -- `kdim::Int`: Key vector length. -- `vdim::Int`: Value vector length +- `kvdim::Int`: Key vector and value vector length. - `num_heads`: Number of heads. + +# Calculation +```math +\\begin{aligned} + Q &\\in \\mathbb{R}^{qdim} \\\\ + W^Q &\\in M_{embed \\times qdim}(\\mathbb{R}) \\\\ + K &\\in \\mathbb{R}^{kvdim} \\\\ + W^K &\\in M_{embed \\times kvdim}(\\mathbb{R}) \\\\ + V &\\in \\mathbb{R}^{kvdim} \\\\ + W^V &\\in M_{vembed \\times kvdim}(\\mathbb{R}) \\\\ + head_i &= \\operatorname{Attention}(W^Q_i Q, W^K_ K, W^V_i V) \\\\ + \\operatorname{Attention}(Q, K, V) &= V \\operatorname{softmax}\\left(\\frac{K^T Q}{\\sqrt{kvdim}} \\right) \\\\ + \\operatorname{MultiheadAttention}(Q, K, V) &= W^O \\operatorname{Concat}(head_1, \\ldots, head_h) +\\end{aligned} +``` + +So far, ``Q, K, V`` are inputs `x` for the layer, +``W^Q, W^K, W^V, W^O`` are parameters `ps`, +and the layer has no states `st`. """ struct MultiheadAttention{F} <: LuxCore.AbstractLuxLayer embed_dim::Int - kdim::Int - vdim::Int + qdim::Int + kvdim::Int + v_embed_dim::Int num_heads::Int init_weight::F end @@ -28,13 +47,15 @@ Constructor. ## Keyword Arguments - `init_weight`: weight initialzer (rng generator) -- `kdim::Int`: Default: `embed_dim` -- `vdim::Int`: Default: `embed_dim` +- `qdim`: Default `embed_dim` +- `kvdim::Int`: Default: `embed_dim` +- `v_embed_dim`: Default: `embed_dim` # Parameters and states ## Parameters - +- `weight_q` +TODO """ function MultiheadAttention( embed_dim::Int, @@ -44,8 +65,9 @@ function MultiheadAttention( ) MultiheadAttention{typeof(init_weight)}( embed_dim, - haskey(kw, :kdim) ? kw[:kdim] : embed_dim, - haskey(kw, :vdim) ? kw[:vdim] : embed_dim, + haskey(kw, :qdim) ? kw[:qdim] : embed_dim, + haskey(kw, :kvdim) ? kw[:kvdim] : embed_dim, + haskey(kw, :v_embed_dim) ? kw[:v_embed_dim] : embed_dim, num_heads, init_weight, ) @@ -54,10 +76,10 @@ end function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention) # see the original paper for weight dimensions (note that q,k,v weights have `num_heads` of matrices) ( - weight_q = l.init_weight(rng, l.embed_dim * l.num_heads, l.embed_dim), - weight_k = l.init_weight(rng, l.embed_dim * l.num_heads, l.kdim), - weight_v = l.init_weight(rng, l.embed_dim * l.num_heads, l.vdim), - weight_o = l.init_weight(rng, 10), # TODO + weight_q = l.init_weight(rng, l.embed_dim * l.num_heads, l.qdim), + weight_k = l.init_weight(rng, l.embed_dim * l.num_heads, l.kvdim), + weight_v = l.init_weight(rng, l.v_embed_dim * l.num_heads, l.kvdim), + weight_o = l.init_weight(rng, l.v_embed_dim, l.v_embed_dim * l.num_heads), # TODO: next here maybe finished? ) end @@ -65,24 +87,20 @@ function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention) NamedTuple() end -function (l::MultiheadAttention)( - x::NamedTuple, - ps, - st::NamedTuple, -) +function (l::MultiheadAttention)(x::NamedTuple, ps, st::NamedTuple) if size(x.q, 1) != l.embed_dim ArgumentError( "Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)", ) |> throw end - if size(x.k, 1) != l.kdim + if size(x.k, 1) != l.kvdim ArgumentError( - "Length of key must match the layer's kdim: size(k)[1] = $(size(x.k, 1)), kdim = $(l.kdim)", + "Length of key must match the layer's kvdim: size(k)[1] = $(size(x.k, 1)), kvdim = $(l.kvdim)", ) |> throw end - if size(x.v, 1) != l.vdim + if size(x.v, 1) != l.kvdim ArgumentError( - "Length of value must match the layer's vdim: size(v)[1] = $(size(x.v, 1)), vdim = $(l.vdim)", + "Length of value must match the layer's kvdim: size(v)[1] = $(size(x.v, 1)), kvdim = $(l.kvdim)", ) |> throw end # TODO @@ -93,9 +111,11 @@ function (l::MultiheadAttention)( # [k] = (qk_dim, kv_len, batch_size...) k = ps.weight_k * x.k # [v] = (v_dim, kv_len, batch_size...) - v = ps.weight_v * x.v # TODO: dimension?? 2025-01-28T21:59:56+09:00 + v = ps.weight_v * x.v # [y] = (v_dim, q_len, batch_size...) + # [α] = (kv_len, q_len, nheads, batch_size...) y, α = dot_product_attention(q, k, v; nheads = l.num_heads) + ps.weight_o * y end # struct TransformerEncoder <: LuxCore.AbstractLuxContainerLayer{}