Handle mask in MultiheadAttention
This commit is contained in:
parent
f507418bc6
commit
67ffd00de8
2 changed files with 56 additions and 8 deletions
|
@ -4,9 +4,11 @@ true || include("transformers.jl")
|
|||
|
||||
using Random
|
||||
|
||||
const EMBED_DIM = 10
|
||||
const NUM_HEADS = 2
|
||||
const KV_DIM = 12
|
||||
const EMBED_DIM = 100
|
||||
const NUM_HEADS = 20
|
||||
const KV_DIM = 120
|
||||
const Q_LEN = 50
|
||||
const KV_LEN = 50
|
||||
|
||||
rng = TaskLocalRNG()
|
||||
|
||||
|
@ -22,4 +24,21 @@ k = rand(rng, Float32, (KV_DIM,))
|
|||
v = rand(rng, Float32, (KV_DIM,))
|
||||
@info "q k v" summary.((q, k, v))
|
||||
|
||||
l((; q, k, v), ps, st)
|
||||
@info "" l((; q, k, v), ps, st)
|
||||
|
||||
# with mask
|
||||
|
||||
l = MultiheadAttention(EMBED_DIM, NUM_HEADS)
|
||||
@info "layer" l
|
||||
|
||||
ps = LuxCore.initialparameters(rng, l)
|
||||
st = LuxCore.initialstates(rng, l)
|
||||
@info "parameters and states" ps st
|
||||
|
||||
q = rand(rng, Float32, (EMBED_DIM, Q_LEN))
|
||||
k = rand(rng, Float32, (EMBED_DIM, KV_LEN))
|
||||
v = rand(rng, Float32, (EMBED_DIM, KV_LEN))
|
||||
mask = stack(make_causal_mask(k) for _ in 1:NUM_HEADS)
|
||||
@info "q k v mask" summary.((q, k, v, mask))
|
||||
|
||||
@info "" l((; q, k, v, mask), ps, st)
|
||||
|
|
|
@ -116,20 +116,47 @@ end
|
|||
|
||||
function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
|
||||
if size(x.q, 1) != l.embed_dim
|
||||
ArgumentError(
|
||||
DimensionMismatch(
|
||||
"Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",
|
||||
) |> throw
|
||||
end
|
||||
if size(x.k, 1) != l.kvdim
|
||||
ArgumentError(
|
||||
DimensionMismatch(
|
||||
"Length of key must match the layer's kvdim: size(k)[1] = $(size(x.k, 1)), kvdim = $(l.kvdim)",
|
||||
) |> throw
|
||||
end
|
||||
if size(x.v, 1) != l.kvdim
|
||||
ArgumentError(
|
||||
DimensionMismatch(
|
||||
"Length of value must match the layer's kvdim: size(v)[1] = $(size(x.v, 1)), kvdim = $(l.kvdim)",
|
||||
) |> throw
|
||||
end
|
||||
if !(ndims(x.q) == ndims(x.k) == ndims(x.v))
|
||||
DimensionMismatch(
|
||||
"Queue and Key and Value must have the same number of dimensions",
|
||||
) |> throw
|
||||
end
|
||||
if ndims(x.k) > 1 && size(x.k)[2:end] != size(x.v)[2:end]
|
||||
DimensionMismatch(
|
||||
"Some length of key and value must match: size(k)[2:end] = $(size(x.k)[2:end]), size(v)[2:end] = $(size(x.v)[2:end])",
|
||||
) |> throw
|
||||
end
|
||||
if haskey(x, :mask) && !isnothing(x.mask)
|
||||
if size(x.mask, 1) != size(x.k, 2)
|
||||
DimensionMismatch(
|
||||
"First dimension of mask must be the same as the second size of key: mask size $(size(x.mask)), key size $(size(x.k))",
|
||||
) |> throw
|
||||
end
|
||||
if size(x.mask, 2) != size(x.q, 2)
|
||||
DimensionMismatch(
|
||||
"Second dimension of mask must be the same as the second size of key: mask size $(size(x.mask)), key size $(size(x.q))",
|
||||
) |> throw
|
||||
end
|
||||
if size(x.mask, 3) != l.num_heads
|
||||
DimensionMismatch(
|
||||
"Third dimension of mask must be the same as the number of heads: mask $(size(x.mask)), num_heads $(l.num_heads)",
|
||||
) |> throw
|
||||
end
|
||||
end
|
||||
# TODO
|
||||
|
||||
# qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads
|
||||
|
@ -139,9 +166,11 @@ function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
|
|||
k = ps.weight_k * x.k
|
||||
# [v] = (v_dim, kv_len, batch_size...)
|
||||
v = ps.weight_v * x.v
|
||||
# [mask] = (kv_len, q_len, nheads, batch_size)
|
||||
mask = hasproperty(x, :mask) ? x.mask : nothing
|
||||
# [y] = (v_dim, q_len, batch_size...)
|
||||
# [α] = (kv_len, q_len, nheads, batch_size...)
|
||||
y, α = dot_product_attention(q, k, v; nheads = l.num_heads)
|
||||
y, α = dot_product_attention(q, k, v; nheads = l.num_heads, mask)
|
||||
ps.weight_o * y, _st
|
||||
end
|
||||
|
||||
|
|
Loading…
Reference in a new issue