From 8c14f7bb6400c95cd7c63f2666ca1138d8b7016b Mon Sep 17 00:00:00 2001 From: qwjyh Date: Sat, 15 Feb 2025 18:12:03 +0900 Subject: [PATCH] refactor: rename test_transformer to test_multihead_attention --- test_multihead_attention.jl | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 test_multihead_attention.jl diff --git a/test_multihead_attention.jl b/test_multihead_attention.jl new file mode 100644 index 0000000..f1874f1 --- /dev/null +++ b/test_multihead_attention.jl @@ -0,0 +1,25 @@ +# test code + +true || include("transformers.jl") + +using Random + +const EMBED_DIM = 10 +const NUM_HEADS = 2 +const KV_DIM = 12 + +rng = TaskLocalRNG() + +l = MultiheadAttention(EMBED_DIM, NUM_HEADS, kvdim = KV_DIM) +@info "layer" l + +ps = LuxCore.initialparameters(rng, l) +st = LuxCore.initialstates(rng, l) +@info "parameters and states" ps st + +q = rand(rng, Float32, (EMBED_DIM,)) +k = rand(rng, Float32, (KV_DIM,)) +v = rand(rng, Float32, (KV_DIM,)) +@info "q k v" summary.((q, k, v)) + +l((; q, k, v), ps, st)