# test code true || include("transformers.jl") using Random const EMBED_DIM = 10 const NUM_HEADS = 2 const KV_DIM = 12 rng = TaskLocalRNG() l = MultiheadAttention(EMBED_DIM, NUM_HEADS, kvdim = KV_DIM) @info "layer" l ps = LuxCore.initialparameters(rng, l) st = LuxCore.initialstates(rng, l) @info "parameters and states" ps st q = rand(rng, Float32, (EMBED_DIM,)) k = rand(rng, Float32, (KV_DIM,)) v = rand(rng, Float32, (KV_DIM,)) @info "q k v" summary.((q, k, v)) l((; q, k, v), ps, st)