diff --git a/test_multihead_attention.jl b/test_multihead_attention.jl new file mode 100644 index 0000000..f1874f1 --- /dev/null +++ b/test_multihead_attention.jl @@ -0,0 +1,25 @@ +# test code + +true || include("transformers.jl") + +using Random + +const EMBED_DIM = 10 +const NUM_HEADS = 2 +const KV_DIM = 12 + +rng = TaskLocalRNG() + +l = MultiheadAttention(EMBED_DIM, NUM_HEADS, kvdim = KV_DIM) +@info "layer" l + +ps = LuxCore.initialparameters(rng, l) +st = LuxCore.initialstates(rng, l) +@info "parameters and states" ps st + +q = rand(rng, Float32, (EMBED_DIM,)) +k = rand(rng, Float32, (KV_DIM,)) +v = rand(rng, Float32, (KV_DIM,)) +@info "q k v" summary.((q, k, v)) + +l((; q, k, v), ps, st)