ml-test/test_transformer.jl

31 lines
493 B
Julia

# test code
true || include("transformers.jl")
using Random
const INPUT_DIM = 10
const NUM_HEADS = 1
rng = Random.default_rng()
model = TransformerEncoderLayer(10, 2)
@info "model" model
ps, st = LuxCore.setup(rng, model)
@info "parameters" ps
@info "status" st
# Unbatched
x = randn(rng, Float32, (10,))
@info "input" size(x)
y, st = model(x, ps, st)
@info "output" y st
# Batched
x = randn(rng, Float32, (10, 10))
@info "input" size(x)
y, st = model(x, ps, st)
@info "output" y st