From c07d9c30ccc524fd05d42f8feaa56b96d5bb50f2 Mon Sep 17 00:00:00 2001 From: qwjyh Date: Tue, 28 Jan 2025 22:19:46 +0900 Subject: [PATCH] (WIP): implement MultiheadAttention --- test_transformer.jl | 24 +++++++++++ transformers.jl | 103 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+) create mode 100644 test_transformer.jl create mode 100644 transformers.jl diff --git a/test_transformer.jl b/test_transformer.jl new file mode 100644 index 0000000..072c952 --- /dev/null +++ b/test_transformer.jl @@ -0,0 +1,24 @@ +# test code + +true || include("transformers.jl") + +using Random + +const EMBED_DIM = 10 +const NUM_HEADS = 10 + +rng = TaskLocalRNG() + +l = MultiheadAttention(EMBED_DIM, NUM_HEADS, vdim = 8) +@info "layer" l + +ps = LuxCore.initialparameters(rng, l) +st = LuxCore.initialstates(rng, l) +@info "parameters and states" ps st + +q = rand(rng, Float32, (EMBED_DIM,)) +k = rand(rng, Float32, (EMBED_DIM,)) +v = rand(rng, Float32, (EMBED_DIM,)) +@info "q k v" summary.((q, k, v)) + +l((; q, k, v), ps, st) diff --git a/transformers.jl b/transformers.jl new file mode 100644 index 0000000..65f3bb7 --- /dev/null +++ b/transformers.jl @@ -0,0 +1,103 @@ +# using LuxCore +using Random: AbstractRNG +using Lux + +""" +# Fields +- `embed_dim`: Queue vector length. `q_len` +- `kdim::Int`: Key vector length. +- `vdim::Int`: Value vector length +- `num_heads`: Number of heads. +""" +struct MultiheadAttention{F} <: LuxCore.AbstractLuxLayer + embed_dim::Int + kdim::Int + vdim::Int + num_heads::Int + init_weight::F +end + +""" + MultiheadAttention(embed_dim::Int, num_heads::Int; init_weight=glorot_uniform, kw...) + +Constructor. + +# Arguments + - `embed_dim::Int` + - `num_heads::Int` + +## Keyword Arguments +- `init_weight`: weight initialzer (rng generator) +- `kdim::Int`: Default: `embed_dim` +- `vdim::Int`: Default: `embed_dim` + +# Parameters and states + +## Parameters + +""" +function MultiheadAttention( + embed_dim::Int, + num_heads::Int; + init_weight = glorot_uniform, + kw..., +) + MultiheadAttention{typeof(init_weight)}( + embed_dim, + haskey(kw, :kdim) ? kw[:kdim] : embed_dim, + haskey(kw, :vdim) ? kw[:vdim] : embed_dim, + num_heads, + init_weight, + ) +end + +function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention) + # see the original paper for weight dimensions (note that q,k,v weights have `num_heads` of matrices) + ( + weight_q = l.init_weight(rng, l.embed_dim * l.num_heads, l.embed_dim), + weight_k = l.init_weight(rng, l.embed_dim * l.num_heads, l.kdim), + weight_v = l.init_weight(rng, l.embed_dim * l.num_heads, l.vdim), + weight_o = l.init_weight(rng, 10), # TODO + ) +end + +function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention) + NamedTuple() +end + +function (l::MultiheadAttention)( + x::NamedTuple, + ps, + st::NamedTuple, +) + if size(x.q, 1) != l.embed_dim + ArgumentError( + "Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)", + ) |> throw + end + if size(x.k, 1) != l.kdim + ArgumentError( + "Length of key must match the layer's kdim: size(k)[1] = $(size(x.k, 1)), kdim = $(l.kdim)", + ) |> throw + end + if size(x.v, 1) != l.vdim + ArgumentError( + "Length of value must match the layer's vdim: size(v)[1] = $(size(x.v, 1)), vdim = $(l.vdim)", + ) |> throw + end + # TODO + + # qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads + # [q] = (qk_dim, q_len, batch_size...) + q = ps.weight_q * x.q + # [k] = (qk_dim, kv_len, batch_size...) + k = ps.weight_k * x.k + # [v] = (v_dim, kv_len, batch_size...) + v = ps.weight_v * x.v # TODO: dimension?? 2025-01-28T21:59:56+09:00 + # [y] = (v_dim, q_len, batch_size...) + y, α = dot_product_attention(q, k, v; nheads = l.num_heads) +end + +# struct TransformerEncoder <: LuxCore.AbstractLuxContainerLayer{} +# end +#