fix: support batched input in MultiheadAttention using Dense

This commit is contained in:
Wataru Otsubo 2025-04-19 19:17:40 +09:00 committed by qwjyh
parent 7bb7a67b0a
commit d2e12a8a00
5 changed files with 248 additions and 256 deletions

File diff suppressed because it is too large Load diff

View file

@ -3,14 +3,15 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LanguageServer = "2b0e0bc5-e4fd-59b4-8912-456d1b03d8d7"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

View file

@ -12,7 +12,7 @@ const KV_LEN = 50
rng = TaskLocalRNG()
l = MultiheadAttention(EMBED_DIM, NUM_HEADS, kvdim = KV_DIM)
l = MultiheadAttention2(EMBED_DIM, NUM_HEADS, kvdim = KV_DIM)
@info "layer" l
ps = LuxCore.initialparameters(rng, l)
@ -28,7 +28,7 @@ v = rand(rng, Float32, (KV_DIM,))
# with mask
l = MultiheadAttention(EMBED_DIM, NUM_HEADS)
l = MultiheadAttention2(EMBED_DIM, NUM_HEADS)
@info "layer" l
ps = LuxCore.initialparameters(rng, l)

View file

@ -19,7 +19,7 @@ ps, st = LuxCore.setup(rng, model_encoder)
@info "status" st
# Unbatched
x = randn(rng, Float32, (INPUT_DIM,))
x = randn(rng, Float32, (INPUT_DIM, 1, 1))
@info "input" size(x)
y, st = model_encoder(x, ps, st)
@ -44,7 +44,7 @@ ps, st = LuxCore.setup(rng, model_decoder)
@info "status" st
# Unbatched
x = randn(rng, Float32, (INPUT_DIM,))
x = randn(rng, Float32, (INPUT_DIM, 1, 1))
@info "input" x
y, st = model_decoder((; x = x, memory = memory_unbatched), ps, st)

View file

@ -1,6 +1,45 @@
# using LuxCore
using Random: AbstractRNG
using Lux
using Static
struct MultiheadAttention2{WQ, WK, WV, WO} <:
LuxCore.AbstractLuxContainerLayer{(:weight_q, :weight_k, :weight_v, :weight_o)}
weight_q::WQ
weight_k::WK
weight_v::WV
weight_o::WO
nheads::Int
end
function MultiheadAttention2(embed_dim::Int, num_heads::Int; kw...)
qdim = get(kw, :qdim, embed_dim)
kvdim = get(kw, :kvdim, embed_dim)
v_embed_dim = get(kw, :v_embed_dim, embed_dim)
init_weight = get(kw, :init_weight, glorot_uniform)
weight_q = Dense(qdim => embed_dim * num_heads; init_weight, use_bias = False())
weight_k = Dense(kvdim => embed_dim * num_heads; init_weight, use_bias = False())
weight_v = Dense(kvdim => v_embed_dim * num_heads; init_weight, use_bias = False())
weight_o =
Dense(v_embed_dim * num_heads => v_embed_dim; init_weight, use_bias = False())
MultiheadAttention2(weight_q, weight_k, weight_v, weight_o, num_heads)
end
function (mha::MultiheadAttention2)(x::NT, ps, st::NamedTuple) where {NT <: NamedTuple}
q, st_weight_q = mha.weight_q(x.q, ps.weight_q, st.weight_q)
k, st_weight_k = mha.weight_k(x.k, ps.weight_k, st.weight_k)
v, st_weight_v = mha.weight_v(x.v, ps.weight_v, st.weight_v)
mask = hasproperty(x, :mask) ? x.mask : nothing
y, _α = dot_product_attention(q, k, v; nheads = mha.nheads, mask)
o, st_weight_o = mha.weight_o(y, ps.weight_o, st.weight_o)
st = (
weight_q = st_weight_q,
weight_k = st_weight_k,
weight_v = st_weight_v,
weight_o = st_weight_o,
)
o, st
end
"""
MultiheadAttention{F} <: LuxCore.AbstractLuxLayer
@ -31,7 +70,7 @@ So far, ``Q, K, V`` are inputs `x` for the layer,
``W^Q, W^K, W^V, W^O`` are parameters `ps`,
and the layer has no states `st`.
"""
struct MultiheadAttention{F} <: LuxCore.AbstractLuxLayer
struct MultiheadAttention_{F} <: LuxCore.AbstractLuxLayer
embed_dim::Int
qdim::Int
kvdim::Int
@ -72,13 +111,13 @@ NamedTuple of these three variables.
- `k`: ``K``
- `v`: ``V``
"""
function MultiheadAttention(
function MultiheadAttention_(
embed_dim::Int,
num_heads::Int;
init_weight = glorot_uniform,
kw...,
)
MultiheadAttention{typeof(init_weight)}(
MultiheadAttention_{typeof(init_weight)}(
embed_dim,
haskey(kw, :qdim) ? kw[:qdim] : embed_dim,
haskey(kw, :kvdim) ? kw[:kvdim] : embed_dim,
@ -88,7 +127,7 @@ function MultiheadAttention(
)
end
function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention)
function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention_)
# see the original paper for weight dimensions (note that q,k,v weights have `num_heads` of matrices)
(
weight_q = l.init_weight(rng, l.embed_dim * l.num_heads, l.qdim),
@ -98,11 +137,11 @@ function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention)
)
end
function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention)
function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention_)
NamedTuple()
end
function LuxCore.parameterlength(l::MultiheadAttention)
function LuxCore.parameterlength(l::MultiheadAttention_)
dim_weight_q = l.embed_dim * l.num_heads * l.qdim
dim_weight_k = l.embed_dim * l.num_heads * l.kvdim
dim_weight_v = l.v_embed_dim * l.num_heads * l.kvdim
@ -110,11 +149,11 @@ function LuxCore.parameterlength(l::MultiheadAttention)
dim_weight_q + dim_weight_k + dim_weight_v + dim_weight_o
end
function LuxCore.statelength(l::MultiheadAttention)
function LuxCore.statelength(l::MultiheadAttention_)
0
end
function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
function (l::MultiheadAttention_)(x::NamedTuple, ps, _st::NamedTuple)
if size(x.q, 1) != l.embed_dim
DimensionMismatch(
"Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",
@ -162,17 +201,19 @@ function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
# FIXME: expand this to multi dimensional matrix multiplication
# qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads
# [q] = (qk_dim, q_len, batch_size...)
q = ps.weight_q * x.q
# need to apply weights for all batches
@info "" size(ps.weight_q) size(x.q)
q = ps.weight_q x.q
# [k] = (qk_dim, kv_len, batch_size...)
k = ps.weight_k * x.k
k = ps.weight_k x.k
# [v] = (v_dim, kv_len, batch_size...)
v = ps.weight_v * x.v
v = ps.weight_v x.v
# [mask] = (kv_len, q_len, nheads, batch_size)
mask = hasproperty(x, :mask) ? x.mask : nothing
# [y] = (v_dim, q_len, batch_size...)
# [α] = (kv_len, q_len, nheads, batch_size...)
y, α = dot_product_attention(q, k, v; nheads = l.num_heads, mask)
ps.weight_o * y, _st
ps.weight_o y, _st
end
"""
@ -252,7 +293,7 @@ function TransformerEncoderLayer(
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
sublayer_self_attention = let
layer_split = Lux.WrappedFunction(x -> (q = x, k = x, v = x))
layer_self_attention = MultiheadAttention(model_dim, num_heads)
layer_self_attention = MultiheadAttention2(model_dim, num_heads)
layer_dropout = Lux.Dropout(dropout)
layer_residual_connection = Lux.SkipConnection(
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
@ -282,12 +323,17 @@ function TransformerEncoderLayer(
end
function (encoder::TransformerEncoderLayer)(x, ps, st)
x, st_sublayer_self_attention = Lux.apply(
encoder.sublayer_self_attention,
x, st_sublayer_self_attention = encoder.sublayer_self_attention(
x,
ps.sublayer_self_attention,
st.sublayer_self_attention,
)
# x, st_sublayer_self_attention = Lux.apply(
# encoder.sublayer_self_attention,
# x,
# ps.sublayer_self_attention,
# st.sublayer_self_attention,
# )
x, st_sublayer_feed_forward_network = Lux.apply(
encoder.sublayer_feed_forward_network,
x,
@ -327,7 +373,7 @@ function TransformerDecoderLayer(
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
sublayer_self_attention = let
layer_split = Lux.WrappedFunction(x -> (q = x.x, k = x.x, v = x.x, mask = x.mask))
layer_self_attention = MultiheadAttention(model_dim, num_heads)
layer_self_attention = MultiheadAttention2(model_dim, num_heads)
layer_dropout = Lux.Dropout(dropout)
layer_residual_connection = Lux.SkipConnection(
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
@ -342,7 +388,7 @@ function TransformerDecoderLayer(
sublayer_multihead_attention = let
layer_split =
Lux.WrappedFunction(x::NamedTuple -> (q = x.x, k = x.memory, v = x.memory))
layer_self_attention = MultiheadAttention(model_dim, num_heads)
layer_self_attention = MultiheadAttention2(model_dim, num_heads)
layer_dropout = Lux.Dropout(dropout)
layer_residual_connection = Lux.Parallel(
+,