diff --git a/test_multihead_attention.jl b/test_multihead_attention.jl index f1874f1..00e27aa 100644 --- a/test_multihead_attention.jl +++ b/test_multihead_attention.jl @@ -4,9 +4,11 @@ true || include("transformers.jl") using Random -const EMBED_DIM = 10 -const NUM_HEADS = 2 -const KV_DIM = 12 +const EMBED_DIM = 100 +const NUM_HEADS = 20 +const KV_DIM = 120 +const Q_LEN = 50 +const KV_LEN = 50 rng = TaskLocalRNG() @@ -22,4 +24,21 @@ k = rand(rng, Float32, (KV_DIM,)) v = rand(rng, Float32, (KV_DIM,)) @info "q k v" summary.((q, k, v)) -l((; q, k, v), ps, st) +@info "" l((; q, k, v), ps, st) + +# with mask + +l = MultiheadAttention(EMBED_DIM, NUM_HEADS) +@info "layer" l + +ps = LuxCore.initialparameters(rng, l) +st = LuxCore.initialstates(rng, l) +@info "parameters and states" ps st + +q = rand(rng, Float32, (EMBED_DIM, Q_LEN)) +k = rand(rng, Float32, (EMBED_DIM, KV_LEN)) +v = rand(rng, Float32, (EMBED_DIM, KV_LEN)) +mask = stack(make_causal_mask(k) for _ in 1:NUM_HEADS) +@info "q k v mask" summary.((q, k, v, mask)) + +@info "" l((; q, k, v, mask), ps, st) diff --git a/transformers.jl b/transformers.jl index eb2e905..7011c46 100644 --- a/transformers.jl +++ b/transformers.jl @@ -116,20 +116,47 @@ end function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple) if size(x.q, 1) != l.embed_dim - ArgumentError( + DimensionMismatch( "Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)", ) |> throw end if size(x.k, 1) != l.kvdim - ArgumentError( + DimensionMismatch( "Length of key must match the layer's kvdim: size(k)[1] = $(size(x.k, 1)), kvdim = $(l.kvdim)", ) |> throw end if size(x.v, 1) != l.kvdim - ArgumentError( + DimensionMismatch( "Length of value must match the layer's kvdim: size(v)[1] = $(size(x.v, 1)), kvdim = $(l.kvdim)", ) |> throw end + if !(ndims(x.q) == ndims(x.k) == ndims(x.v)) + DimensionMismatch( + "Queue and Key and Value must have the same number of dimensions", + ) |> throw + end + if ndims(x.k) > 1 && size(x.k)[2:end] != size(x.v)[2:end] + DimensionMismatch( + "Some length of key and value must match: size(k)[2:end] = $(size(x.k)[2:end]), size(v)[2:end] = $(size(x.v)[2:end])", + ) |> throw + end + if haskey(x, :mask) && !isnothing(x.mask) + if size(x.mask, 1) != size(x.k, 2) + DimensionMismatch( + "First dimension of mask must be the same as the second size of key: mask size $(size(x.mask)), key size $(size(x.k))", + ) |> throw + end + if size(x.mask, 2) != size(x.q, 2) + DimensionMismatch( + "Second dimension of mask must be the same as the second size of key: mask size $(size(x.mask)), key size $(size(x.q))", + ) |> throw + end + if size(x.mask, 3) != l.num_heads + DimensionMismatch( + "Third dimension of mask must be the same as the number of heads: mask $(size(x.mask)), num_heads $(l.num_heads)", + ) |> throw + end + end # TODO # qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads @@ -139,9 +166,11 @@ function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple) k = ps.weight_k * x.k # [v] = (v_dim, kv_len, batch_size...) v = ps.weight_v * x.v + # [mask] = (kv_len, q_len, nheads, batch_size) + mask = hasproperty(x, :mask) ? x.mask : nothing # [y] = (v_dim, q_len, batch_size...) # [α] = (kv_len, q_len, nheads, batch_size...) - y, α = dot_product_attention(q, k, v; nheads = l.num_heads) + y, α = dot_product_attention(q, k, v; nheads = l.num_heads, mask) ps.weight_o * y, _st end