fix: support batched input in MultiheadAttention using Dense
This commit is contained in:
parent
5422c2fe83
commit
ac73d9b834
5 changed files with 248 additions and 256 deletions
413
Manifest.toml
413
Manifest.toml
File diff suppressed because it is too large
Load diff
|
@ -3,14 +3,15 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
|
||||||
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
|
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
|
||||||
IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
|
IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
|
||||||
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
|
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
|
||||||
LanguageServer = "2b0e0bc5-e4fd-59b4-8912-456d1b03d8d7"
|
|
||||||
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
|
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
|
||||||
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
|
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
|
||||||
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
|
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
|
||||||
|
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||||
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
|
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
|
||||||
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
|
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
|
||||||
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
|
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
|
||||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||||
|
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
|
||||||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ const KV_LEN = 50
|
||||||
|
|
||||||
rng = TaskLocalRNG()
|
rng = TaskLocalRNG()
|
||||||
|
|
||||||
l = MultiheadAttention(EMBED_DIM, NUM_HEADS, kvdim = KV_DIM)
|
l = MultiheadAttention2(EMBED_DIM, NUM_HEADS, kvdim = KV_DIM)
|
||||||
@info "layer" l
|
@info "layer" l
|
||||||
|
|
||||||
ps = LuxCore.initialparameters(rng, l)
|
ps = LuxCore.initialparameters(rng, l)
|
||||||
|
@ -28,7 +28,7 @@ v = rand(rng, Float32, (KV_DIM,))
|
||||||
|
|
||||||
# with mask
|
# with mask
|
||||||
|
|
||||||
l = MultiheadAttention(EMBED_DIM, NUM_HEADS)
|
l = MultiheadAttention2(EMBED_DIM, NUM_HEADS)
|
||||||
@info "layer" l
|
@info "layer" l
|
||||||
|
|
||||||
ps = LuxCore.initialparameters(rng, l)
|
ps = LuxCore.initialparameters(rng, l)
|
||||||
|
|
|
@ -19,7 +19,7 @@ ps, st = LuxCore.setup(rng, model_encoder)
|
||||||
@info "status" st
|
@info "status" st
|
||||||
|
|
||||||
# Unbatched
|
# Unbatched
|
||||||
x = randn(rng, Float32, (INPUT_DIM,))
|
x = randn(rng, Float32, (INPUT_DIM, 1, 1))
|
||||||
@info "input" size(x)
|
@info "input" size(x)
|
||||||
|
|
||||||
y, st = model_encoder(x, ps, st)
|
y, st = model_encoder(x, ps, st)
|
||||||
|
@ -44,7 +44,7 @@ ps, st = LuxCore.setup(rng, model_decoder)
|
||||||
@info "status" st
|
@info "status" st
|
||||||
|
|
||||||
# Unbatched
|
# Unbatched
|
||||||
x = randn(rng, Float32, (INPUT_DIM,))
|
x = randn(rng, Float32, (INPUT_DIM, 1, 1))
|
||||||
@info "input" x
|
@info "input" x
|
||||||
|
|
||||||
y, st = model_decoder((; x = x, memory = memory_unbatched), ps, st)
|
y, st = model_decoder((; x = x, memory = memory_unbatched), ps, st)
|
||||||
|
|
|
@ -1,6 +1,45 @@
|
||||||
# using LuxCore
|
# using LuxCore
|
||||||
using Random: AbstractRNG
|
using Random: AbstractRNG
|
||||||
using Lux
|
using Lux
|
||||||
|
using Static
|
||||||
|
|
||||||
|
struct MultiheadAttention2{WQ, WK, WV, WO} <:
|
||||||
|
LuxCore.AbstractLuxContainerLayer{(:weight_q, :weight_k, :weight_v, :weight_o)}
|
||||||
|
weight_q::WQ
|
||||||
|
weight_k::WK
|
||||||
|
weight_v::WV
|
||||||
|
weight_o::WO
|
||||||
|
nheads::Int
|
||||||
|
end
|
||||||
|
|
||||||
|
function MultiheadAttention2(embed_dim::Int, num_heads::Int; kw...)
|
||||||
|
qdim = get(kw, :qdim, embed_dim)
|
||||||
|
kvdim = get(kw, :kvdim, embed_dim)
|
||||||
|
v_embed_dim = get(kw, :v_embed_dim, embed_dim)
|
||||||
|
init_weight = get(kw, :init_weight, glorot_uniform)
|
||||||
|
weight_q = Dense(qdim => embed_dim * num_heads; init_weight, use_bias = False())
|
||||||
|
weight_k = Dense(kvdim => embed_dim * num_heads; init_weight, use_bias = False())
|
||||||
|
weight_v = Dense(kvdim => v_embed_dim * num_heads; init_weight, use_bias = False())
|
||||||
|
weight_o =
|
||||||
|
Dense(v_embed_dim * num_heads => v_embed_dim; init_weight, use_bias = False())
|
||||||
|
MultiheadAttention2(weight_q, weight_k, weight_v, weight_o, num_heads)
|
||||||
|
end
|
||||||
|
|
||||||
|
function (mha::MultiheadAttention2)(x::NT, ps, st::NamedTuple) where {NT <: NamedTuple}
|
||||||
|
q, st_weight_q = mha.weight_q(x.q, ps.weight_q, st.weight_q)
|
||||||
|
k, st_weight_k = mha.weight_k(x.k, ps.weight_k, st.weight_k)
|
||||||
|
v, st_weight_v = mha.weight_v(x.v, ps.weight_v, st.weight_v)
|
||||||
|
mask = hasproperty(x, :mask) ? x.mask : nothing
|
||||||
|
y, _α = dot_product_attention(q, k, v; nheads = mha.nheads, mask)
|
||||||
|
o, st_weight_o = mha.weight_o(y, ps.weight_o, st.weight_o)
|
||||||
|
st = (
|
||||||
|
weight_q = st_weight_q,
|
||||||
|
weight_k = st_weight_k,
|
||||||
|
weight_v = st_weight_v,
|
||||||
|
weight_o = st_weight_o,
|
||||||
|
)
|
||||||
|
o, st
|
||||||
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
MultiheadAttention{F} <: LuxCore.AbstractLuxLayer
|
MultiheadAttention{F} <: LuxCore.AbstractLuxLayer
|
||||||
|
@ -31,7 +70,7 @@ So far, ``Q, K, V`` are inputs `x` for the layer,
|
||||||
``W^Q, W^K, W^V, W^O`` are parameters `ps`,
|
``W^Q, W^K, W^V, W^O`` are parameters `ps`,
|
||||||
and the layer has no states `st`.
|
and the layer has no states `st`.
|
||||||
"""
|
"""
|
||||||
struct MultiheadAttention{F} <: LuxCore.AbstractLuxLayer
|
struct MultiheadAttention_{F} <: LuxCore.AbstractLuxLayer
|
||||||
embed_dim::Int
|
embed_dim::Int
|
||||||
qdim::Int
|
qdim::Int
|
||||||
kvdim::Int
|
kvdim::Int
|
||||||
|
@ -72,13 +111,13 @@ NamedTuple of these three variables.
|
||||||
- `k`: ``K``
|
- `k`: ``K``
|
||||||
- `v`: ``V``
|
- `v`: ``V``
|
||||||
"""
|
"""
|
||||||
function MultiheadAttention(
|
function MultiheadAttention_(
|
||||||
embed_dim::Int,
|
embed_dim::Int,
|
||||||
num_heads::Int;
|
num_heads::Int;
|
||||||
init_weight = glorot_uniform,
|
init_weight = glorot_uniform,
|
||||||
kw...,
|
kw...,
|
||||||
)
|
)
|
||||||
MultiheadAttention{typeof(init_weight)}(
|
MultiheadAttention_{typeof(init_weight)}(
|
||||||
embed_dim,
|
embed_dim,
|
||||||
haskey(kw, :qdim) ? kw[:qdim] : embed_dim,
|
haskey(kw, :qdim) ? kw[:qdim] : embed_dim,
|
||||||
haskey(kw, :kvdim) ? kw[:kvdim] : embed_dim,
|
haskey(kw, :kvdim) ? kw[:kvdim] : embed_dim,
|
||||||
|
@ -88,7 +127,7 @@ function MultiheadAttention(
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention)
|
function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention_)
|
||||||
# see the original paper for weight dimensions (note that q,k,v weights have `num_heads` of matrices)
|
# see the original paper for weight dimensions (note that q,k,v weights have `num_heads` of matrices)
|
||||||
(
|
(
|
||||||
weight_q = l.init_weight(rng, l.embed_dim * l.num_heads, l.qdim),
|
weight_q = l.init_weight(rng, l.embed_dim * l.num_heads, l.qdim),
|
||||||
|
@ -98,11 +137,11 @@ function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention)
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention)
|
function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention_)
|
||||||
NamedTuple()
|
NamedTuple()
|
||||||
end
|
end
|
||||||
|
|
||||||
function LuxCore.parameterlength(l::MultiheadAttention)
|
function LuxCore.parameterlength(l::MultiheadAttention_)
|
||||||
dim_weight_q = l.embed_dim * l.num_heads * l.qdim
|
dim_weight_q = l.embed_dim * l.num_heads * l.qdim
|
||||||
dim_weight_k = l.embed_dim * l.num_heads * l.kvdim
|
dim_weight_k = l.embed_dim * l.num_heads * l.kvdim
|
||||||
dim_weight_v = l.v_embed_dim * l.num_heads * l.kvdim
|
dim_weight_v = l.v_embed_dim * l.num_heads * l.kvdim
|
||||||
|
@ -110,11 +149,11 @@ function LuxCore.parameterlength(l::MultiheadAttention)
|
||||||
dim_weight_q + dim_weight_k + dim_weight_v + dim_weight_o
|
dim_weight_q + dim_weight_k + dim_weight_v + dim_weight_o
|
||||||
end
|
end
|
||||||
|
|
||||||
function LuxCore.statelength(l::MultiheadAttention)
|
function LuxCore.statelength(l::MultiheadAttention_)
|
||||||
0
|
0
|
||||||
end
|
end
|
||||||
|
|
||||||
function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
|
function (l::MultiheadAttention_)(x::NamedTuple, ps, _st::NamedTuple)
|
||||||
if size(x.q, 1) != l.embed_dim
|
if size(x.q, 1) != l.embed_dim
|
||||||
DimensionMismatch(
|
DimensionMismatch(
|
||||||
"Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",
|
"Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",
|
||||||
|
@ -162,17 +201,19 @@ function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
|
||||||
# FIXME: expand this to multi dimensional matrix multiplication
|
# FIXME: expand this to multi dimensional matrix multiplication
|
||||||
# qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads
|
# qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads
|
||||||
# [q] = (qk_dim, q_len, batch_size...)
|
# [q] = (qk_dim, q_len, batch_size...)
|
||||||
q = ps.weight_q * x.q
|
# need to apply weights for all batches
|
||||||
|
@info "" size(ps.weight_q) size(x.q)
|
||||||
|
q = ps.weight_q ⊠ x.q
|
||||||
# [k] = (qk_dim, kv_len, batch_size...)
|
# [k] = (qk_dim, kv_len, batch_size...)
|
||||||
k = ps.weight_k * x.k
|
k = ps.weight_k ⊠ x.k
|
||||||
# [v] = (v_dim, kv_len, batch_size...)
|
# [v] = (v_dim, kv_len, batch_size...)
|
||||||
v = ps.weight_v * x.v
|
v = ps.weight_v ⊠ x.v
|
||||||
# [mask] = (kv_len, q_len, nheads, batch_size)
|
# [mask] = (kv_len, q_len, nheads, batch_size)
|
||||||
mask = hasproperty(x, :mask) ? x.mask : nothing
|
mask = hasproperty(x, :mask) ? x.mask : nothing
|
||||||
# [y] = (v_dim, q_len, batch_size...)
|
# [y] = (v_dim, q_len, batch_size...)
|
||||||
# [α] = (kv_len, q_len, nheads, batch_size...)
|
# [α] = (kv_len, q_len, nheads, batch_size...)
|
||||||
y, α = dot_product_attention(q, k, v; nheads = l.num_heads, mask)
|
y, α = dot_product_attention(q, k, v; nheads = l.num_heads, mask)
|
||||||
ps.weight_o * y, _st
|
ps.weight_o ⊠ y, _st
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -252,7 +293,7 @@ function TransformerEncoderLayer(
|
||||||
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
|
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
|
||||||
sublayer_self_attention = let
|
sublayer_self_attention = let
|
||||||
layer_split = Lux.WrappedFunction(x -> (q = x, k = x, v = x))
|
layer_split = Lux.WrappedFunction(x -> (q = x, k = x, v = x))
|
||||||
layer_self_attention = MultiheadAttention(model_dim, num_heads)
|
layer_self_attention = MultiheadAttention2(model_dim, num_heads)
|
||||||
layer_dropout = Lux.Dropout(dropout)
|
layer_dropout = Lux.Dropout(dropout)
|
||||||
layer_residual_connection = Lux.SkipConnection(
|
layer_residual_connection = Lux.SkipConnection(
|
||||||
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
|
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
|
||||||
|
@ -282,12 +323,17 @@ function TransformerEncoderLayer(
|
||||||
end
|
end
|
||||||
|
|
||||||
function (encoder::TransformerEncoderLayer)(x, ps, st)
|
function (encoder::TransformerEncoderLayer)(x, ps, st)
|
||||||
x, st_sublayer_self_attention = Lux.apply(
|
x, st_sublayer_self_attention = encoder.sublayer_self_attention(
|
||||||
encoder.sublayer_self_attention,
|
|
||||||
x,
|
x,
|
||||||
ps.sublayer_self_attention,
|
ps.sublayer_self_attention,
|
||||||
st.sublayer_self_attention,
|
st.sublayer_self_attention,
|
||||||
)
|
)
|
||||||
|
# x, st_sublayer_self_attention = Lux.apply(
|
||||||
|
# encoder.sublayer_self_attention,
|
||||||
|
# x,
|
||||||
|
# ps.sublayer_self_attention,
|
||||||
|
# st.sublayer_self_attention,
|
||||||
|
# )
|
||||||
x, st_sublayer_feed_forward_network = Lux.apply(
|
x, st_sublayer_feed_forward_network = Lux.apply(
|
||||||
encoder.sublayer_feed_forward_network,
|
encoder.sublayer_feed_forward_network,
|
||||||
x,
|
x,
|
||||||
|
@ -327,7 +373,7 @@ function TransformerDecoderLayer(
|
||||||
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
|
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
|
||||||
sublayer_self_attention = let
|
sublayer_self_attention = let
|
||||||
layer_split = Lux.WrappedFunction(x -> (q = x.x, k = x.x, v = x.x, mask = x.mask))
|
layer_split = Lux.WrappedFunction(x -> (q = x.x, k = x.x, v = x.x, mask = x.mask))
|
||||||
layer_self_attention = MultiheadAttention(model_dim, num_heads)
|
layer_self_attention = MultiheadAttention2(model_dim, num_heads)
|
||||||
layer_dropout = Lux.Dropout(dropout)
|
layer_dropout = Lux.Dropout(dropout)
|
||||||
layer_residual_connection = Lux.SkipConnection(
|
layer_residual_connection = Lux.SkipConnection(
|
||||||
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
|
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
|
||||||
|
@ -342,7 +388,7 @@ function TransformerDecoderLayer(
|
||||||
sublayer_multihead_attention = let
|
sublayer_multihead_attention = let
|
||||||
layer_split =
|
layer_split =
|
||||||
Lux.WrappedFunction(x::NamedTuple -> (q = x.x, k = x.memory, v = x.memory))
|
Lux.WrappedFunction(x::NamedTuple -> (q = x.x, k = x.memory, v = x.memory))
|
||||||
layer_self_attention = MultiheadAttention(model_dim, num_heads)
|
layer_self_attention = MultiheadAttention2(model_dim, num_heads)
|
||||||
layer_dropout = Lux.Dropout(dropout)
|
layer_dropout = Lux.Dropout(dropout)
|
||||||
layer_residual_connection = Lux.Parallel(
|
layer_residual_connection = Lux.Parallel(
|
||||||
+,
|
+,
|
||||||
|
|
Loading…
Add table
Reference in a new issue