ml-test/test_transformer.jl

32 lines
493 B
Julia
Raw Normal View History

2025-01-28 22:19:46 +09:00
# test code
true || include("transformers.jl")
using Random
const INPUT_DIM = 10
const NUM_HEADS = 1
2025-01-28 22:19:46 +09:00
rng = Random.default_rng()
2025-01-28 22:19:46 +09:00
model = TransformerEncoderLayer(10, 2)
@info "model" model
2025-01-28 22:19:46 +09:00
ps, st = LuxCore.setup(rng, model)
@info "parameters" ps
@info "status" st
2025-01-28 22:19:46 +09:00
# Unbatched
x = randn(rng, Float32, (10,))
@info "input" size(x)
2025-01-28 22:19:46 +09:00
y, st = model(x, ps, st)
@info "output" y st
# Batched
x = randn(rng, Float32, (10, 10))
@info "input" size(x)
y, st = model(x, ps, st)
@info "output" y st