ml-test/transformers.jl

103 lines
2.7 KiB
Julia
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# using LuxCore
using Random: AbstractRNG
using Lux
"""
# Fields
- `embed_dim`: Queue vector length. `q_len`
- `kdim::Int`: Key vector length.
- `vdim::Int`: Value vector length
- `num_heads`: Number of heads.
"""
struct MultiheadAttention{F} <: LuxCore.AbstractLuxLayer
embed_dim::Int
kdim::Int
vdim::Int
num_heads::Int
init_weight::F
end
"""
MultiheadAttention(embed_dim::Int, num_heads::Int; init_weight=glorot_uniform, kw...)
Constructor.
# Arguments
- `embed_dim::Int`
- `num_heads::Int`
## Keyword Arguments
- `init_weight`: weight initialzer (rng generator)
- `kdim::Int`: Default: `embed_dim`
- `vdim::Int`: Default: `embed_dim`
# Parameters and states
## Parameters
"""
function MultiheadAttention(
embed_dim::Int,
num_heads::Int;
init_weight = glorot_uniform,
kw...,
)
MultiheadAttention{typeof(init_weight)}(
embed_dim,
haskey(kw, :kdim) ? kw[:kdim] : embed_dim,
haskey(kw, :vdim) ? kw[:vdim] : embed_dim,
num_heads,
init_weight,
)
end
function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention)
# see the original paper for weight dimensions (note that q,k,v weights have `num_heads` of matrices)
(
weight_q = l.init_weight(rng, l.embed_dim * l.num_heads, l.embed_dim),
weight_k = l.init_weight(rng, l.embed_dim * l.num_heads, l.kdim),
weight_v = l.init_weight(rng, l.embed_dim * l.num_heads, l.vdim),
weight_o = l.init_weight(rng, 10), # TODO
)
end
function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention)
NamedTuple()
end
function (l::MultiheadAttention)(
x::NamedTuple,
ps,
st::NamedTuple,
)
if size(x.q, 1) != l.embed_dim
ArgumentError(
"Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",
) |> throw
end
if size(x.k, 1) != l.kdim
ArgumentError(
"Length of key must match the layer's kdim: size(k)[1] = $(size(x.k, 1)), kdim = $(l.kdim)",
) |> throw
end
if size(x.v, 1) != l.vdim
ArgumentError(
"Length of value must match the layer's vdim: size(v)[1] = $(size(x.v, 1)), vdim = $(l.vdim)",
) |> throw
end
# TODO
# qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads
# [q] = (qk_dim, q_len, batch_size...)
q = ps.weight_q * x.q
# [k] = (qk_dim, kv_len, batch_size...)
k = ps.weight_k * x.k
# [v] = (v_dim, kv_len, batch_size...)
v = ps.weight_v * x.v # TODO: dimension?? 2025-01-28T21:59:56+09:00
# [y] = (v_dim, q_len, batch_size...)
y, α = dot_product_attention(q, k, v; nheads = l.num_heads)
end
# struct TransformerEncoder <: LuxCore.AbstractLuxContainerLayer{}
# end
#