ml-test/test_transformer.jl

25 lines
497 B
Julia

# test code
true || include("transformers.jl")
using Random
const EMBED_DIM = 10
const NUM_HEADS = 2
const KV_DIM = 12
rng = TaskLocalRNG()
l = MultiheadAttention(EMBED_DIM, NUM_HEADS, kvdim = KV_DIM)
@info "layer" l
ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
@info "parameters and states" ps st
q = rand(rng, Float32, (EMBED_DIM,))
k = rand(rng, Float32, (KV_DIM,))
v = rand(rng, Float32, (KV_DIM,))
@info "q k v" summary.((q, k, v))
l((; q, k, v), ps, st)