Handle mask in MultiheadAttention
This commit is contained in:
parent
f507418bc6
commit
67ffd00de8
2 changed files with 56 additions and 8 deletions
|
@ -4,9 +4,11 @@ true || include("transformers.jl")
|
||||||
|
|
||||||
using Random
|
using Random
|
||||||
|
|
||||||
const EMBED_DIM = 10
|
const EMBED_DIM = 100
|
||||||
const NUM_HEADS = 2
|
const NUM_HEADS = 20
|
||||||
const KV_DIM = 12
|
const KV_DIM = 120
|
||||||
|
const Q_LEN = 50
|
||||||
|
const KV_LEN = 50
|
||||||
|
|
||||||
rng = TaskLocalRNG()
|
rng = TaskLocalRNG()
|
||||||
|
|
||||||
|
@ -22,4 +24,21 @@ k = rand(rng, Float32, (KV_DIM,))
|
||||||
v = rand(rng, Float32, (KV_DIM,))
|
v = rand(rng, Float32, (KV_DIM,))
|
||||||
@info "q k v" summary.((q, k, v))
|
@info "q k v" summary.((q, k, v))
|
||||||
|
|
||||||
l((; q, k, v), ps, st)
|
@info "" l((; q, k, v), ps, st)
|
||||||
|
|
||||||
|
# with mask
|
||||||
|
|
||||||
|
l = MultiheadAttention(EMBED_DIM, NUM_HEADS)
|
||||||
|
@info "layer" l
|
||||||
|
|
||||||
|
ps = LuxCore.initialparameters(rng, l)
|
||||||
|
st = LuxCore.initialstates(rng, l)
|
||||||
|
@info "parameters and states" ps st
|
||||||
|
|
||||||
|
q = rand(rng, Float32, (EMBED_DIM, Q_LEN))
|
||||||
|
k = rand(rng, Float32, (EMBED_DIM, KV_LEN))
|
||||||
|
v = rand(rng, Float32, (EMBED_DIM, KV_LEN))
|
||||||
|
mask = stack(make_causal_mask(k) for _ in 1:NUM_HEADS)
|
||||||
|
@info "q k v mask" summary.((q, k, v, mask))
|
||||||
|
|
||||||
|
@info "" l((; q, k, v, mask), ps, st)
|
||||||
|
|
|
@ -116,20 +116,47 @@ end
|
||||||
|
|
||||||
function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
|
function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
|
||||||
if size(x.q, 1) != l.embed_dim
|
if size(x.q, 1) != l.embed_dim
|
||||||
ArgumentError(
|
DimensionMismatch(
|
||||||
"Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",
|
"Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",
|
||||||
) |> throw
|
) |> throw
|
||||||
end
|
end
|
||||||
if size(x.k, 1) != l.kvdim
|
if size(x.k, 1) != l.kvdim
|
||||||
ArgumentError(
|
DimensionMismatch(
|
||||||
"Length of key must match the layer's kvdim: size(k)[1] = $(size(x.k, 1)), kvdim = $(l.kvdim)",
|
"Length of key must match the layer's kvdim: size(k)[1] = $(size(x.k, 1)), kvdim = $(l.kvdim)",
|
||||||
) |> throw
|
) |> throw
|
||||||
end
|
end
|
||||||
if size(x.v, 1) != l.kvdim
|
if size(x.v, 1) != l.kvdim
|
||||||
ArgumentError(
|
DimensionMismatch(
|
||||||
"Length of value must match the layer's kvdim: size(v)[1] = $(size(x.v, 1)), kvdim = $(l.kvdim)",
|
"Length of value must match the layer's kvdim: size(v)[1] = $(size(x.v, 1)), kvdim = $(l.kvdim)",
|
||||||
) |> throw
|
) |> throw
|
||||||
end
|
end
|
||||||
|
if !(ndims(x.q) == ndims(x.k) == ndims(x.v))
|
||||||
|
DimensionMismatch(
|
||||||
|
"Queue and Key and Value must have the same number of dimensions",
|
||||||
|
) |> throw
|
||||||
|
end
|
||||||
|
if ndims(x.k) > 1 && size(x.k)[2:end] != size(x.v)[2:end]
|
||||||
|
DimensionMismatch(
|
||||||
|
"Some length of key and value must match: size(k)[2:end] = $(size(x.k)[2:end]), size(v)[2:end] = $(size(x.v)[2:end])",
|
||||||
|
) |> throw
|
||||||
|
end
|
||||||
|
if haskey(x, :mask) && !isnothing(x.mask)
|
||||||
|
if size(x.mask, 1) != size(x.k, 2)
|
||||||
|
DimensionMismatch(
|
||||||
|
"First dimension of mask must be the same as the second size of key: mask size $(size(x.mask)), key size $(size(x.k))",
|
||||||
|
) |> throw
|
||||||
|
end
|
||||||
|
if size(x.mask, 2) != size(x.q, 2)
|
||||||
|
DimensionMismatch(
|
||||||
|
"Second dimension of mask must be the same as the second size of key: mask size $(size(x.mask)), key size $(size(x.q))",
|
||||||
|
) |> throw
|
||||||
|
end
|
||||||
|
if size(x.mask, 3) != l.num_heads
|
||||||
|
DimensionMismatch(
|
||||||
|
"Third dimension of mask must be the same as the number of heads: mask $(size(x.mask)), num_heads $(l.num_heads)",
|
||||||
|
) |> throw
|
||||||
|
end
|
||||||
|
end
|
||||||
# TODO
|
# TODO
|
||||||
|
|
||||||
# qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads
|
# qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads
|
||||||
|
@ -139,9 +166,11 @@ function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
|
||||||
k = ps.weight_k * x.k
|
k = ps.weight_k * x.k
|
||||||
# [v] = (v_dim, kv_len, batch_size...)
|
# [v] = (v_dim, kv_len, batch_size...)
|
||||||
v = ps.weight_v * x.v
|
v = ps.weight_v * x.v
|
||||||
|
# [mask] = (kv_len, q_len, nheads, batch_size)
|
||||||
|
mask = hasproperty(x, :mask) ? x.mask : nothing
|
||||||
# [y] = (v_dim, q_len, batch_size...)
|
# [y] = (v_dim, q_len, batch_size...)
|
||||||
# [α] = (kv_len, q_len, nheads, batch_size...)
|
# [α] = (kv_len, q_len, nheads, batch_size...)
|
||||||
y, α = dot_product_attention(q, k, v; nheads = l.num_heads)
|
y, α = dot_product_attention(q, k, v; nheads = l.num_heads, mask)
|
||||||
ps.weight_o * y, _st
|
ps.weight_o * y, _st
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue