# using LuxCore using Random: AbstractRNG using Lux """ # Fields - `embed_dim`: Queue vector length. `q_len` - `kdim::Int`: Key vector length. - `vdim::Int`: Value vector length - `num_heads`: Number of heads. """ struct MultiheadAttention{F} <: LuxCore.AbstractLuxLayer embed_dim::Int kdim::Int vdim::Int num_heads::Int init_weight::F end """ MultiheadAttention(embed_dim::Int, num_heads::Int; init_weight=glorot_uniform, kw...) Constructor. # Arguments - `embed_dim::Int` - `num_heads::Int` ## Keyword Arguments - `init_weight`: weight initialzer (rng generator) - `kdim::Int`: Default: `embed_dim` - `vdim::Int`: Default: `embed_dim` # Parameters and states ## Parameters """ function MultiheadAttention( embed_dim::Int, num_heads::Int; init_weight = glorot_uniform, kw..., ) MultiheadAttention{typeof(init_weight)}( embed_dim, haskey(kw, :kdim) ? kw[:kdim] : embed_dim, haskey(kw, :vdim) ? kw[:vdim] : embed_dim, num_heads, init_weight, ) end function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention) # see the original paper for weight dimensions (note that q,k,v weights have `num_heads` of matrices) ( weight_q = l.init_weight(rng, l.embed_dim * l.num_heads, l.embed_dim), weight_k = l.init_weight(rng, l.embed_dim * l.num_heads, l.kdim), weight_v = l.init_weight(rng, l.embed_dim * l.num_heads, l.vdim), weight_o = l.init_weight(rng, 10), # TODO ) end function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention) NamedTuple() end function (l::MultiheadAttention)( x::NamedTuple, ps, st::NamedTuple, ) if size(x.q, 1) != l.embed_dim ArgumentError( "Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)", ) |> throw end if size(x.k, 1) != l.kdim ArgumentError( "Length of key must match the layer's kdim: size(k)[1] = $(size(x.k, 1)), kdim = $(l.kdim)", ) |> throw end if size(x.v, 1) != l.vdim ArgumentError( "Length of value must match the layer's vdim: size(v)[1] = $(size(x.v, 1)), vdim = $(l.vdim)", ) |> throw end # TODO # qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads # [q] = (qk_dim, q_len, batch_size...) q = ps.weight_q * x.q # [k] = (qk_dim, kv_len, batch_size...) k = ps.weight_k * x.k # [v] = (v_dim, kv_len, batch_size...) v = ps.weight_v * x.v # TODO: dimension?? 2025-01-28T21:59:56+09:00 # [y] = (v_dim, q_len, batch_size...) y, α = dot_product_attention(q, k, v; nheads = l.num_heads) end # struct TransformerEncoder <: LuxCore.AbstractLuxContainerLayer{} # end #