Compare commits
2 commits
f507418bc6
...
5422c2fe83
Author | SHA1 | Date | |
---|---|---|---|
5422c2fe83 | |||
67ffd00de8 |
3 changed files with 83 additions and 24 deletions
|
@ -4,9 +4,11 @@ true || include("transformers.jl")
|
|||
|
||||
using Random
|
||||
|
||||
const EMBED_DIM = 10
|
||||
const NUM_HEADS = 2
|
||||
const KV_DIM = 12
|
||||
const EMBED_DIM = 100
|
||||
const NUM_HEADS = 20
|
||||
const KV_DIM = 120
|
||||
const Q_LEN = 50
|
||||
const KV_LEN = 50
|
||||
|
||||
rng = TaskLocalRNG()
|
||||
|
||||
|
@ -22,4 +24,21 @@ k = rand(rng, Float32, (KV_DIM,))
|
|||
v = rand(rng, Float32, (KV_DIM,))
|
||||
@info "q k v" summary.((q, k, v))
|
||||
|
||||
l((; q, k, v), ps, st)
|
||||
@info "" l((; q, k, v), ps, st)
|
||||
|
||||
# with mask
|
||||
|
||||
l = MultiheadAttention(EMBED_DIM, NUM_HEADS)
|
||||
@info "layer" l
|
||||
|
||||
ps = LuxCore.initialparameters(rng, l)
|
||||
st = LuxCore.initialstates(rng, l)
|
||||
@info "parameters and states" ps st
|
||||
|
||||
q = rand(rng, Float32, (EMBED_DIM, Q_LEN))
|
||||
k = rand(rng, Float32, (EMBED_DIM, KV_LEN))
|
||||
v = rand(rng, Float32, (EMBED_DIM, KV_LEN))
|
||||
mask = stack(make_causal_mask(k) for _ in 1:NUM_HEADS)
|
||||
@info "q k v mask" summary.((q, k, v, mask))
|
||||
|
||||
@info "" l((; q, k, v, mask), ps, st)
|
||||
|
|
|
@ -4,12 +4,14 @@ true || include("transformers.jl")
|
|||
|
||||
using Random
|
||||
|
||||
const INPUT_DIM = 10
|
||||
const NUM_HEADS = 1
|
||||
const INPUT_DIM = 100
|
||||
const NUM_HEADS = 20
|
||||
const NUM_BATCH = 10
|
||||
const INPUT_LEN = 5
|
||||
|
||||
rng = Random.default_rng()
|
||||
|
||||
model_encoder = TransformerEncoderLayer(10, 2)
|
||||
model_encoder = TransformerEncoderLayer(INPUT_DIM, NUM_HEADS)
|
||||
@info "model" model_encoder
|
||||
|
||||
ps, st = LuxCore.setup(rng, model_encoder)
|
||||
|
@ -17,7 +19,7 @@ ps, st = LuxCore.setup(rng, model_encoder)
|
|||
@info "status" st
|
||||
|
||||
# Unbatched
|
||||
x = randn(rng, Float32, (10,))
|
||||
x = randn(rng, Float32, (INPUT_DIM,))
|
||||
@info "input" size(x)
|
||||
|
||||
y, st = model_encoder(x, ps, st)
|
||||
|
@ -26,7 +28,7 @@ y, st = model_encoder(x, ps, st)
|
|||
memory_unbatched = y
|
||||
|
||||
# Batched
|
||||
x = randn(rng, Float32, (10, 10))
|
||||
x = randn(rng, Float32, (INPUT_DIM, INPUT_LEN, NUM_BATCH))
|
||||
@info "input" size(x)
|
||||
|
||||
y, st = model_encoder(x, ps, st)
|
||||
|
@ -34,7 +36,7 @@ y, st = model_encoder(x, ps, st)
|
|||
|
||||
memory_batched = y
|
||||
|
||||
model_decoder = TransformerDecoderLayer(10, 2)
|
||||
model_decoder = TransformerDecoderLayer(INPUT_DIM, NUM_HEADS)
|
||||
@info "decoder model" model_decoder
|
||||
|
||||
ps, st = LuxCore.setup(rng, model_decoder)
|
||||
|
@ -42,14 +44,14 @@ ps, st = LuxCore.setup(rng, model_decoder)
|
|||
@info "status" st
|
||||
|
||||
# Unbatched
|
||||
x = randn(rng, Float32, (10,))
|
||||
x = randn(rng, Float32, (INPUT_DIM,))
|
||||
@info "input" x
|
||||
|
||||
y, st = model_decoder((; x = x, memory = memory_unbatched), ps, st)
|
||||
@info "output" y st
|
||||
|
||||
# Batched
|
||||
x = randn(rng, Float32, (10, 10))
|
||||
x = randn(rng, Float32, (INPUT_DIM, INPUT_LEN, NUM_BATCH))
|
||||
@info "input" x
|
||||
|
||||
y, st = model_decoder((; x, memory = memory_batched), ps, st)
|
||||
|
|
|
@ -116,22 +116,50 @@ end
|
|||
|
||||
function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
|
||||
if size(x.q, 1) != l.embed_dim
|
||||
ArgumentError(
|
||||
DimensionMismatch(
|
||||
"Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",
|
||||
) |> throw
|
||||
end
|
||||
if size(x.k, 1) != l.kvdim
|
||||
ArgumentError(
|
||||
DimensionMismatch(
|
||||
"Length of key must match the layer's kvdim: size(k)[1] = $(size(x.k, 1)), kvdim = $(l.kvdim)",
|
||||
) |> throw
|
||||
end
|
||||
if size(x.v, 1) != l.kvdim
|
||||
ArgumentError(
|
||||
DimensionMismatch(
|
||||
"Length of value must match the layer's kvdim: size(v)[1] = $(size(x.v, 1)), kvdim = $(l.kvdim)",
|
||||
) |> throw
|
||||
end
|
||||
if !(ndims(x.q) == ndims(x.k) == ndims(x.v))
|
||||
DimensionMismatch(
|
||||
"Queue and Key and Value must have the same number of dimensions",
|
||||
) |> throw
|
||||
end
|
||||
if ndims(x.k) > 1 && size(x.k)[2:end] != size(x.v)[2:end]
|
||||
DimensionMismatch(
|
||||
"Some length of key and value must match: size(k)[2:end] = $(size(x.k)[2:end]), size(v)[2:end] = $(size(x.v)[2:end])",
|
||||
) |> throw
|
||||
end
|
||||
if haskey(x, :mask) && !isnothing(x.mask)
|
||||
if size(x.mask, 1) != size(x.k, 2)
|
||||
DimensionMismatch(
|
||||
"First dimension of mask must be the same as the second size of key: mask size $(size(x.mask)), key size $(size(x.k))",
|
||||
) |> throw
|
||||
end
|
||||
if size(x.mask, 2) != size(x.q, 2)
|
||||
DimensionMismatch(
|
||||
"Second dimension of mask must be the same as the second size of key: mask size $(size(x.mask)), key size $(size(x.q))",
|
||||
) |> throw
|
||||
end
|
||||
if size(x.mask, 3) != l.num_heads
|
||||
DimensionMismatch(
|
||||
"Third dimension of mask must be the same as the number of heads: mask $(size(x.mask)), num_heads $(l.num_heads)",
|
||||
) |> throw
|
||||
end
|
||||
end
|
||||
# TODO
|
||||
|
||||
# FIXME: expand this to multi dimensional matrix multiplication
|
||||
# qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads
|
||||
# [q] = (qk_dim, q_len, batch_size...)
|
||||
q = ps.weight_q * x.q
|
||||
|
@ -139,9 +167,11 @@ function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
|
|||
k = ps.weight_k * x.k
|
||||
# [v] = (v_dim, kv_len, batch_size...)
|
||||
v = ps.weight_v * x.v
|
||||
# [mask] = (kv_len, q_len, nheads, batch_size)
|
||||
mask = hasproperty(x, :mask) ? x.mask : nothing
|
||||
# [y] = (v_dim, q_len, batch_size...)
|
||||
# [α] = (kv_len, q_len, nheads, batch_size...)
|
||||
y, α = dot_product_attention(q, k, v; nheads = l.num_heads)
|
||||
y, α = dot_product_attention(q, k, v; nheads = l.num_heads, mask)
|
||||
ps.weight_o * y, _st
|
||||
end
|
||||
|
||||
|
@ -296,12 +326,12 @@ function TransformerDecoderLayer(
|
|||
layer_norm_epsilon::T2 = 1.0f-5,
|
||||
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
|
||||
sublayer_self_attention = let
|
||||
layer_split = Lux.WrappedFunction(x -> (q = x, k = x, v = x))
|
||||
layer_split = Lux.WrappedFunction(x -> (q = x.x, k = x.x, v = x.x, mask = x.mask))
|
||||
layer_self_attention = MultiheadAttention(model_dim, num_heads)
|
||||
layer_dropout = Lux.Dropout(dropout)
|
||||
layer_residual_connection = Lux.SkipConnection(
|
||||
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
|
||||
+,
|
||||
(x, res) -> x + res.x,
|
||||
)
|
||||
Lux.Chain(;
|
||||
layer_residual_connection,
|
||||
|
@ -347,21 +377,29 @@ function TransformerDecoderLayer(
|
|||
end
|
||||
|
||||
function (decoder::TransformerDecoderLayer)(x, ps, st)
|
||||
# hasproperty?
|
||||
haskey(x, :x) || throw(
|
||||
hasproperty(x, :x) || throw(
|
||||
ArgumentError(
|
||||
"input to the decoder must have keyword argument \"x\"(output input)",
|
||||
"input to the decoder must have keyword argument \"x\"(output input) and \"memory\"(output from encoder)",
|
||||
),
|
||||
)
|
||||
haskey(x, :memory) || throw(
|
||||
hasproperty(x, :memory) || throw(
|
||||
ArgumentError(
|
||||
"input to the decoder must have keyword argument \"memory\"(output from encoder)",
|
||||
"input to the decoder must have keyword argument \"x\"(output input) and \"memory\"(output from encoder)",
|
||||
),
|
||||
)
|
||||
|
||||
num_heads =
|
||||
decoder.sublayer_self_attention.layers.layer_residual_connection.layers.layers.layer_self_attention.num_heads
|
||||
self_attention_mask = if ndims(x.x) == 1
|
||||
stack(make_causal_mask(x.x) for _ in 1:(num_heads))
|
||||
else
|
||||
# batched
|
||||
stack(stack(make_causal_mask(x.x) for _ in 1:(num_heads)) for _ in 1:size(x.x, 3))
|
||||
end
|
||||
|
||||
x_sa, st_sublayer_self_attention = Lux.apply(
|
||||
decoder.sublayer_self_attention,
|
||||
x.x,
|
||||
(; x = x.x, mask = self_attention_mask),
|
||||
ps.sublayer_self_attention,
|
||||
st.sublayer_self_attention,
|
||||
)
|
||||
|
|
Loading…
Add table
Reference in a new issue