44 lines
1,001 B
Julia
44 lines
1,001 B
Julia
# test code
|
|
|
|
true || include("transformers.jl")
|
|
|
|
using Random
|
|
|
|
const EMBED_DIM = 100
|
|
const NUM_HEADS = 20
|
|
const KV_DIM = 120
|
|
const Q_LEN = 50
|
|
const KV_LEN = 50
|
|
|
|
rng = TaskLocalRNG()
|
|
|
|
l = MultiheadAttention(EMBED_DIM, NUM_HEADS, kvdim = KV_DIM)
|
|
@info "layer" l
|
|
|
|
ps = LuxCore.initialparameters(rng, l)
|
|
st = LuxCore.initialstates(rng, l)
|
|
@info "parameters and states" ps st
|
|
|
|
q = rand(rng, Float32, (EMBED_DIM,))
|
|
k = rand(rng, Float32, (KV_DIM,))
|
|
v = rand(rng, Float32, (KV_DIM,))
|
|
@info "q k v" summary.((q, k, v))
|
|
|
|
@info "" l((; q, k, v), ps, st)
|
|
|
|
# with mask
|
|
|
|
l = MultiheadAttention(EMBED_DIM, NUM_HEADS)
|
|
@info "layer" l
|
|
|
|
ps = LuxCore.initialparameters(rng, l)
|
|
st = LuxCore.initialstates(rng, l)
|
|
@info "parameters and states" ps st
|
|
|
|
q = rand(rng, Float32, (EMBED_DIM, Q_LEN))
|
|
k = rand(rng, Float32, (EMBED_DIM, KV_LEN))
|
|
v = rand(rng, Float32, (EMBED_DIM, KV_LEN))
|
|
mask = stack(make_causal_mask(k) for _ in 1:NUM_HEADS)
|
|
@info "q k v mask" summary.((q, k, v, mask))
|
|
|
|
@info "" l((; q, k, v, mask), ps, st)
|