fix: support batched input in MultiheadAttention using Dense
This commit is contained in:
parent
5422c2fe83
commit
ac73d9b834
5 changed files with 248 additions and 256 deletions
413
Manifest.toml
413
Manifest.toml
File diff suppressed because it is too large
Load diff
|
@ -3,14 +3,15 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
|
|||
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
|
||||
IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
|
||||
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
|
||||
LanguageServer = "2b0e0bc5-e4fd-59b4-8912-456d1b03d8d7"
|
||||
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
|
||||
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
|
||||
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
|
||||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
|
||||
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
|
||||
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
|
||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
|
||||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ const KV_LEN = 50
|
|||
|
||||
rng = TaskLocalRNG()
|
||||
|
||||
l = MultiheadAttention(EMBED_DIM, NUM_HEADS, kvdim = KV_DIM)
|
||||
l = MultiheadAttention2(EMBED_DIM, NUM_HEADS, kvdim = KV_DIM)
|
||||
@info "layer" l
|
||||
|
||||
ps = LuxCore.initialparameters(rng, l)
|
||||
|
@ -28,7 +28,7 @@ v = rand(rng, Float32, (KV_DIM,))
|
|||
|
||||
# with mask
|
||||
|
||||
l = MultiheadAttention(EMBED_DIM, NUM_HEADS)
|
||||
l = MultiheadAttention2(EMBED_DIM, NUM_HEADS)
|
||||
@info "layer" l
|
||||
|
||||
ps = LuxCore.initialparameters(rng, l)
|
||||
|
|
|
@ -19,7 +19,7 @@ ps, st = LuxCore.setup(rng, model_encoder)
|
|||
@info "status" st
|
||||
|
||||
# Unbatched
|
||||
x = randn(rng, Float32, (INPUT_DIM,))
|
||||
x = randn(rng, Float32, (INPUT_DIM, 1, 1))
|
||||
@info "input" size(x)
|
||||
|
||||
y, st = model_encoder(x, ps, st)
|
||||
|
@ -44,7 +44,7 @@ ps, st = LuxCore.setup(rng, model_decoder)
|
|||
@info "status" st
|
||||
|
||||
# Unbatched
|
||||
x = randn(rng, Float32, (INPUT_DIM,))
|
||||
x = randn(rng, Float32, (INPUT_DIM, 1, 1))
|
||||
@info "input" x
|
||||
|
||||
y, st = model_decoder((; x = x, memory = memory_unbatched), ps, st)
|
||||
|
|
|
@ -1,6 +1,45 @@
|
|||
# using LuxCore
|
||||
using Random: AbstractRNG
|
||||
using Lux
|
||||
using Static
|
||||
|
||||
struct MultiheadAttention2{WQ, WK, WV, WO} <:
|
||||
LuxCore.AbstractLuxContainerLayer{(:weight_q, :weight_k, :weight_v, :weight_o)}
|
||||
weight_q::WQ
|
||||
weight_k::WK
|
||||
weight_v::WV
|
||||
weight_o::WO
|
||||
nheads::Int
|
||||
end
|
||||
|
||||
function MultiheadAttention2(embed_dim::Int, num_heads::Int; kw...)
|
||||
qdim = get(kw, :qdim, embed_dim)
|
||||
kvdim = get(kw, :kvdim, embed_dim)
|
||||
v_embed_dim = get(kw, :v_embed_dim, embed_dim)
|
||||
init_weight = get(kw, :init_weight, glorot_uniform)
|
||||
weight_q = Dense(qdim => embed_dim * num_heads; init_weight, use_bias = False())
|
||||
weight_k = Dense(kvdim => embed_dim * num_heads; init_weight, use_bias = False())
|
||||
weight_v = Dense(kvdim => v_embed_dim * num_heads; init_weight, use_bias = False())
|
||||
weight_o =
|
||||
Dense(v_embed_dim * num_heads => v_embed_dim; init_weight, use_bias = False())
|
||||
MultiheadAttention2(weight_q, weight_k, weight_v, weight_o, num_heads)
|
||||
end
|
||||
|
||||
function (mha::MultiheadAttention2)(x::NT, ps, st::NamedTuple) where {NT <: NamedTuple}
|
||||
q, st_weight_q = mha.weight_q(x.q, ps.weight_q, st.weight_q)
|
||||
k, st_weight_k = mha.weight_k(x.k, ps.weight_k, st.weight_k)
|
||||
v, st_weight_v = mha.weight_v(x.v, ps.weight_v, st.weight_v)
|
||||
mask = hasproperty(x, :mask) ? x.mask : nothing
|
||||
y, _α = dot_product_attention(q, k, v; nheads = mha.nheads, mask)
|
||||
o, st_weight_o = mha.weight_o(y, ps.weight_o, st.weight_o)
|
||||
st = (
|
||||
weight_q = st_weight_q,
|
||||
weight_k = st_weight_k,
|
||||
weight_v = st_weight_v,
|
||||
weight_o = st_weight_o,
|
||||
)
|
||||
o, st
|
||||
end
|
||||
|
||||
"""
|
||||
MultiheadAttention{F} <: LuxCore.AbstractLuxLayer
|
||||
|
@ -31,7 +70,7 @@ So far, ``Q, K, V`` are inputs `x` for the layer,
|
|||
``W^Q, W^K, W^V, W^O`` are parameters `ps`,
|
||||
and the layer has no states `st`.
|
||||
"""
|
||||
struct MultiheadAttention{F} <: LuxCore.AbstractLuxLayer
|
||||
struct MultiheadAttention_{F} <: LuxCore.AbstractLuxLayer
|
||||
embed_dim::Int
|
||||
qdim::Int
|
||||
kvdim::Int
|
||||
|
@ -72,13 +111,13 @@ NamedTuple of these three variables.
|
|||
- `k`: ``K``
|
||||
- `v`: ``V``
|
||||
"""
|
||||
function MultiheadAttention(
|
||||
function MultiheadAttention_(
|
||||
embed_dim::Int,
|
||||
num_heads::Int;
|
||||
init_weight = glorot_uniform,
|
||||
kw...,
|
||||
)
|
||||
MultiheadAttention{typeof(init_weight)}(
|
||||
MultiheadAttention_{typeof(init_weight)}(
|
||||
embed_dim,
|
||||
haskey(kw, :qdim) ? kw[:qdim] : embed_dim,
|
||||
haskey(kw, :kvdim) ? kw[:kvdim] : embed_dim,
|
||||
|
@ -88,7 +127,7 @@ function MultiheadAttention(
|
|||
)
|
||||
end
|
||||
|
||||
function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention)
|
||||
function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention_)
|
||||
# see the original paper for weight dimensions (note that q,k,v weights have `num_heads` of matrices)
|
||||
(
|
||||
weight_q = l.init_weight(rng, l.embed_dim * l.num_heads, l.qdim),
|
||||
|
@ -98,11 +137,11 @@ function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention)
|
|||
)
|
||||
end
|
||||
|
||||
function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention)
|
||||
function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention_)
|
||||
NamedTuple()
|
||||
end
|
||||
|
||||
function LuxCore.parameterlength(l::MultiheadAttention)
|
||||
function LuxCore.parameterlength(l::MultiheadAttention_)
|
||||
dim_weight_q = l.embed_dim * l.num_heads * l.qdim
|
||||
dim_weight_k = l.embed_dim * l.num_heads * l.kvdim
|
||||
dim_weight_v = l.v_embed_dim * l.num_heads * l.kvdim
|
||||
|
@ -110,11 +149,11 @@ function LuxCore.parameterlength(l::MultiheadAttention)
|
|||
dim_weight_q + dim_weight_k + dim_weight_v + dim_weight_o
|
||||
end
|
||||
|
||||
function LuxCore.statelength(l::MultiheadAttention)
|
||||
function LuxCore.statelength(l::MultiheadAttention_)
|
||||
0
|
||||
end
|
||||
|
||||
function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
|
||||
function (l::MultiheadAttention_)(x::NamedTuple, ps, _st::NamedTuple)
|
||||
if size(x.q, 1) != l.embed_dim
|
||||
DimensionMismatch(
|
||||
"Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",
|
||||
|
@ -162,17 +201,19 @@ function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
|
|||
# FIXME: expand this to multi dimensional matrix multiplication
|
||||
# qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads
|
||||
# [q] = (qk_dim, q_len, batch_size...)
|
||||
q = ps.weight_q * x.q
|
||||
# need to apply weights for all batches
|
||||
@info "" size(ps.weight_q) size(x.q)
|
||||
q = ps.weight_q ⊠ x.q
|
||||
# [k] = (qk_dim, kv_len, batch_size...)
|
||||
k = ps.weight_k * x.k
|
||||
k = ps.weight_k ⊠ x.k
|
||||
# [v] = (v_dim, kv_len, batch_size...)
|
||||
v = ps.weight_v * x.v
|
||||
v = ps.weight_v ⊠ x.v
|
||||
# [mask] = (kv_len, q_len, nheads, batch_size)
|
||||
mask = hasproperty(x, :mask) ? x.mask : nothing
|
||||
# [y] = (v_dim, q_len, batch_size...)
|
||||
# [α] = (kv_len, q_len, nheads, batch_size...)
|
||||
y, α = dot_product_attention(q, k, v; nheads = l.num_heads, mask)
|
||||
ps.weight_o * y, _st
|
||||
ps.weight_o ⊠ y, _st
|
||||
end
|
||||
|
||||
"""
|
||||
|
@ -252,7 +293,7 @@ function TransformerEncoderLayer(
|
|||
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
|
||||
sublayer_self_attention = let
|
||||
layer_split = Lux.WrappedFunction(x -> (q = x, k = x, v = x))
|
||||
layer_self_attention = MultiheadAttention(model_dim, num_heads)
|
||||
layer_self_attention = MultiheadAttention2(model_dim, num_heads)
|
||||
layer_dropout = Lux.Dropout(dropout)
|
||||
layer_residual_connection = Lux.SkipConnection(
|
||||
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
|
||||
|
@ -282,12 +323,17 @@ function TransformerEncoderLayer(
|
|||
end
|
||||
|
||||
function (encoder::TransformerEncoderLayer)(x, ps, st)
|
||||
x, st_sublayer_self_attention = Lux.apply(
|
||||
encoder.sublayer_self_attention,
|
||||
x, st_sublayer_self_attention = encoder.sublayer_self_attention(
|
||||
x,
|
||||
ps.sublayer_self_attention,
|
||||
st.sublayer_self_attention,
|
||||
)
|
||||
# x, st_sublayer_self_attention = Lux.apply(
|
||||
# encoder.sublayer_self_attention,
|
||||
# x,
|
||||
# ps.sublayer_self_attention,
|
||||
# st.sublayer_self_attention,
|
||||
# )
|
||||
x, st_sublayer_feed_forward_network = Lux.apply(
|
||||
encoder.sublayer_feed_forward_network,
|
||||
x,
|
||||
|
@ -327,7 +373,7 @@ function TransformerDecoderLayer(
|
|||
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
|
||||
sublayer_self_attention = let
|
||||
layer_split = Lux.WrappedFunction(x -> (q = x.x, k = x.x, v = x.x, mask = x.mask))
|
||||
layer_self_attention = MultiheadAttention(model_dim, num_heads)
|
||||
layer_self_attention = MultiheadAttention2(model_dim, num_heads)
|
||||
layer_dropout = Lux.Dropout(dropout)
|
||||
layer_residual_connection = Lux.SkipConnection(
|
||||
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
|
||||
|
@ -342,7 +388,7 @@ function TransformerDecoderLayer(
|
|||
sublayer_multihead_attention = let
|
||||
layer_split =
|
||||
Lux.WrappedFunction(x::NamedTuple -> (q = x.x, k = x.memory, v = x.memory))
|
||||
layer_self_attention = MultiheadAttention(model_dim, num_heads)
|
||||
layer_self_attention = MultiheadAttention2(model_dim, num_heads)
|
||||
layer_dropout = Lux.Dropout(dropout)
|
||||
layer_residual_connection = Lux.Parallel(
|
||||
+,
|
||||
|
|
Loading…
Add table
Reference in a new issue