Handle mask in MultiheadAttention

This commit is contained in:
qwjyh 2025-03-15 21:56:23 +09:00
parent f507418bc6
commit 67ffd00de8
2 changed files with 56 additions and 8 deletions

View file

@ -4,9 +4,11 @@ true || include("transformers.jl")
using Random
const EMBED_DIM = 10
const NUM_HEADS = 2
const KV_DIM = 12
const EMBED_DIM = 100
const NUM_HEADS = 20
const KV_DIM = 120
const Q_LEN = 50
const KV_LEN = 50
rng = TaskLocalRNG()
@ -22,4 +24,21 @@ k = rand(rng, Float32, (KV_DIM,))
v = rand(rng, Float32, (KV_DIM,))
@info "q k v" summary.((q, k, v))
l((; q, k, v), ps, st)
@info "" l((; q, k, v), ps, st)
# with mask
l = MultiheadAttention(EMBED_DIM, NUM_HEADS)
@info "layer" l
ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
@info "parameters and states" ps st
q = rand(rng, Float32, (EMBED_DIM, Q_LEN))
k = rand(rng, Float32, (EMBED_DIM, KV_LEN))
v = rand(rng, Float32, (EMBED_DIM, KV_LEN))
mask = stack(make_causal_mask(k) for _ in 1:NUM_HEADS)
@info "q k v mask" summary.((q, k, v, mask))
@info "" l((; q, k, v, mask), ps, st)

View file

@ -116,20 +116,47 @@ end
function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
if size(x.q, 1) != l.embed_dim
ArgumentError(
DimensionMismatch(
"Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",
) |> throw
end
if size(x.k, 1) != l.kvdim
ArgumentError(
DimensionMismatch(
"Length of key must match the layer's kvdim: size(k)[1] = $(size(x.k, 1)), kvdim = $(l.kvdim)",
) |> throw
end
if size(x.v, 1) != l.kvdim
ArgumentError(
DimensionMismatch(
"Length of value must match the layer's kvdim: size(v)[1] = $(size(x.v, 1)), kvdim = $(l.kvdim)",
) |> throw
end
if !(ndims(x.q) == ndims(x.k) == ndims(x.v))
DimensionMismatch(
"Queue and Key and Value must have the same number of dimensions",
) |> throw
end
if ndims(x.k) > 1 && size(x.k)[2:end] != size(x.v)[2:end]
DimensionMismatch(
"Some length of key and value must match: size(k)[2:end] = $(size(x.k)[2:end]), size(v)[2:end] = $(size(x.v)[2:end])",
) |> throw
end
if haskey(x, :mask) && !isnothing(x.mask)
if size(x.mask, 1) != size(x.k, 2)
DimensionMismatch(
"First dimension of mask must be the same as the second size of key: mask size $(size(x.mask)), key size $(size(x.k))",
) |> throw
end
if size(x.mask, 2) != size(x.q, 2)
DimensionMismatch(
"Second dimension of mask must be the same as the second size of key: mask size $(size(x.mask)), key size $(size(x.q))",
) |> throw
end
if size(x.mask, 3) != l.num_heads
DimensionMismatch(
"Third dimension of mask must be the same as the number of heads: mask $(size(x.mask)), num_heads $(l.num_heads)",
) |> throw
end
end
# TODO
# qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads
@ -139,9 +166,11 @@ function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
k = ps.weight_k * x.k
# [v] = (v_dim, kv_len, batch_size...)
v = ps.weight_v * x.v
# [mask] = (kv_len, q_len, nheads, batch_size)
mask = hasproperty(x, :mask) ? x.mask : nothing
# [y] = (v_dim, q_len, batch_size...)
# [α] = (kv_len, q_len, nheads, batch_size...)
y, α = dot_product_attention(q, k, v; nheads = l.num_heads)
y, α = dot_product_attention(q, k, v; nheads = l.num_heads, mask)
ps.weight_o * y, _st
end