(WIP): implement MultiheadAttention
This commit is contained in:
parent
6a91ef35e1
commit
c07d9c30cc
2 changed files with 127 additions and 0 deletions
24
test_transformer.jl
Normal file
24
test_transformer.jl
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
# test code
|
||||||
|
|
||||||
|
true || include("transformers.jl")
|
||||||
|
|
||||||
|
using Random
|
||||||
|
|
||||||
|
const EMBED_DIM = 10
|
||||||
|
const NUM_HEADS = 10
|
||||||
|
|
||||||
|
rng = TaskLocalRNG()
|
||||||
|
|
||||||
|
l = MultiheadAttention(EMBED_DIM, NUM_HEADS, vdim = 8)
|
||||||
|
@info "layer" l
|
||||||
|
|
||||||
|
ps = LuxCore.initialparameters(rng, l)
|
||||||
|
st = LuxCore.initialstates(rng, l)
|
||||||
|
@info "parameters and states" ps st
|
||||||
|
|
||||||
|
q = rand(rng, Float32, (EMBED_DIM,))
|
||||||
|
k = rand(rng, Float32, (EMBED_DIM,))
|
||||||
|
v = rand(rng, Float32, (EMBED_DIM,))
|
||||||
|
@info "q k v" summary.((q, k, v))
|
||||||
|
|
||||||
|
l((; q, k, v), ps, st)
|
103
transformers.jl
Normal file
103
transformers.jl
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
# using LuxCore
|
||||||
|
using Random: AbstractRNG
|
||||||
|
using Lux
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Fields
|
||||||
|
- `embed_dim`: Queue vector length. `q_len`
|
||||||
|
- `kdim::Int`: Key vector length.
|
||||||
|
- `vdim::Int`: Value vector length
|
||||||
|
- `num_heads`: Number of heads.
|
||||||
|
"""
|
||||||
|
struct MultiheadAttention{F} <: LuxCore.AbstractLuxLayer
|
||||||
|
embed_dim::Int
|
||||||
|
kdim::Int
|
||||||
|
vdim::Int
|
||||||
|
num_heads::Int
|
||||||
|
init_weight::F
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
MultiheadAttention(embed_dim::Int, num_heads::Int; init_weight=glorot_uniform, kw...)
|
||||||
|
|
||||||
|
Constructor.
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
- `embed_dim::Int`
|
||||||
|
- `num_heads::Int`
|
||||||
|
|
||||||
|
## Keyword Arguments
|
||||||
|
- `init_weight`: weight initialzer (rng generator)
|
||||||
|
- `kdim::Int`: Default: `embed_dim`
|
||||||
|
- `vdim::Int`: Default: `embed_dim`
|
||||||
|
|
||||||
|
# Parameters and states
|
||||||
|
|
||||||
|
## Parameters
|
||||||
|
|
||||||
|
"""
|
||||||
|
function MultiheadAttention(
|
||||||
|
embed_dim::Int,
|
||||||
|
num_heads::Int;
|
||||||
|
init_weight = glorot_uniform,
|
||||||
|
kw...,
|
||||||
|
)
|
||||||
|
MultiheadAttention{typeof(init_weight)}(
|
||||||
|
embed_dim,
|
||||||
|
haskey(kw, :kdim) ? kw[:kdim] : embed_dim,
|
||||||
|
haskey(kw, :vdim) ? kw[:vdim] : embed_dim,
|
||||||
|
num_heads,
|
||||||
|
init_weight,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention)
|
||||||
|
# see the original paper for weight dimensions (note that q,k,v weights have `num_heads` of matrices)
|
||||||
|
(
|
||||||
|
weight_q = l.init_weight(rng, l.embed_dim * l.num_heads, l.embed_dim),
|
||||||
|
weight_k = l.init_weight(rng, l.embed_dim * l.num_heads, l.kdim),
|
||||||
|
weight_v = l.init_weight(rng, l.embed_dim * l.num_heads, l.vdim),
|
||||||
|
weight_o = l.init_weight(rng, 10), # TODO
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention)
|
||||||
|
NamedTuple()
|
||||||
|
end
|
||||||
|
|
||||||
|
function (l::MultiheadAttention)(
|
||||||
|
x::NamedTuple,
|
||||||
|
ps,
|
||||||
|
st::NamedTuple,
|
||||||
|
)
|
||||||
|
if size(x.q, 1) != l.embed_dim
|
||||||
|
ArgumentError(
|
||||||
|
"Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",
|
||||||
|
) |> throw
|
||||||
|
end
|
||||||
|
if size(x.k, 1) != l.kdim
|
||||||
|
ArgumentError(
|
||||||
|
"Length of key must match the layer's kdim: size(k)[1] = $(size(x.k, 1)), kdim = $(l.kdim)",
|
||||||
|
) |> throw
|
||||||
|
end
|
||||||
|
if size(x.v, 1) != l.vdim
|
||||||
|
ArgumentError(
|
||||||
|
"Length of value must match the layer's vdim: size(v)[1] = $(size(x.v, 1)), vdim = $(l.vdim)",
|
||||||
|
) |> throw
|
||||||
|
end
|
||||||
|
# TODO
|
||||||
|
|
||||||
|
# qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads
|
||||||
|
# [q] = (qk_dim, q_len, batch_size...)
|
||||||
|
q = ps.weight_q * x.q
|
||||||
|
# [k] = (qk_dim, kv_len, batch_size...)
|
||||||
|
k = ps.weight_k * x.k
|
||||||
|
# [v] = (v_dim, kv_len, batch_size...)
|
||||||
|
v = ps.weight_v * x.v # TODO: dimension?? 2025-01-28T21:59:56+09:00
|
||||||
|
# [y] = (v_dim, q_len, batch_size...)
|
||||||
|
y, α = dot_product_attention(q, k, v; nheads = l.num_heads)
|
||||||
|
end
|
||||||
|
|
||||||
|
# struct TransformerEncoder <: LuxCore.AbstractLuxContainerLayer{}
|
||||||
|
# end
|
||||||
|
#
|
Loading…
Reference in a new issue