# test code true || include("transformers.jl") using Random const EMBED_DIM = 100 const NUM_HEADS = 20 const KV_DIM = 120 const Q_LEN = 50 const KV_LEN = 50 rng = TaskLocalRNG() l = MultiheadAttention(EMBED_DIM, NUM_HEADS, kvdim = KV_DIM) @info "layer" l ps = LuxCore.initialparameters(rng, l) st = LuxCore.initialstates(rng, l) @info "parameters and states" ps st q = rand(rng, Float32, (EMBED_DIM,)) k = rand(rng, Float32, (KV_DIM,)) v = rand(rng, Float32, (KV_DIM,)) @info "q k v" summary.((q, k, v)) @info "" l((; q, k, v), ps, st) # with mask l = MultiheadAttention(EMBED_DIM, NUM_HEADS) @info "layer" l ps = LuxCore.initialparameters(rng, l) st = LuxCore.initialstates(rng, l) @info "parameters and states" ps st q = rand(rng, Float32, (EMBED_DIM, Q_LEN)) k = rand(rng, Float32, (EMBED_DIM, KV_LEN)) v = rand(rng, Float32, (EMBED_DIM, KV_LEN)) mask = stack(make_causal_mask(k) for _ in 1:NUM_HEADS) @info "q k v mask" summary.((q, k, v, mask)) @info "" l((; q, k, v, mask), ps, st)