refactor: rename test_transformer to test_multihead_attention
This commit is contained in:
parent
de6634e3f7
commit
8c14f7bb64
1 changed files with 25 additions and 0 deletions
25
test_multihead_attention.jl
Normal file
25
test_multihead_attention.jl
Normal file
|
@ -0,0 +1,25 @@
|
|||
# test code
|
||||
|
||||
true || include("transformers.jl")
|
||||
|
||||
using Random
|
||||
|
||||
const EMBED_DIM = 10
|
||||
const NUM_HEADS = 2
|
||||
const KV_DIM = 12
|
||||
|
||||
rng = TaskLocalRNG()
|
||||
|
||||
l = MultiheadAttention(EMBED_DIM, NUM_HEADS, kvdim = KV_DIM)
|
||||
@info "layer" l
|
||||
|
||||
ps = LuxCore.initialparameters(rng, l)
|
||||
st = LuxCore.initialstates(rng, l)
|
||||
@info "parameters and states" ps st
|
||||
|
||||
q = rand(rng, Float32, (EMBED_DIM,))
|
||||
k = rand(rng, Float32, (KV_DIM,))
|
||||
v = rand(rng, Float32, (KV_DIM,))
|
||||
@info "q k v" summary.((q, k, v))
|
||||
|
||||
l((; q, k, v), ps, st)
|
Loading…
Reference in a new issue