ml-test/test_transformer.jl
2025-03-15 22:35:59 +09:00

58 lines
1.1 KiB
Julia

# test code
true || include("transformers.jl")
using Random
const INPUT_DIM = 100
const NUM_HEADS = 20
const NUM_BATCH = 10
const INPUT_LEN = 5
rng = Random.default_rng()
model_encoder = TransformerEncoderLayer(INPUT_DIM, NUM_HEADS)
@info "model" model_encoder
ps, st = LuxCore.setup(rng, model_encoder)
@info "parameters" ps
@info "status" st
# Unbatched
x = randn(rng, Float32, (INPUT_DIM,))
@info "input" size(x)
y, st = model_encoder(x, ps, st)
@info "output" y st
memory_unbatched = y
# Batched
x = randn(rng, Float32, (INPUT_DIM, INPUT_LEN, NUM_BATCH))
@info "input" size(x)
y, st = model_encoder(x, ps, st)
@info "output" y st
memory_batched = y
model_decoder = TransformerDecoderLayer(INPUT_DIM, NUM_HEADS)
@info "decoder model" model_decoder
ps, st = LuxCore.setup(rng, model_decoder)
@info "parameter" ps
@info "status" st
# Unbatched
x = randn(rng, Float32, (INPUT_DIM,))
@info "input" x
y, st = model_decoder((; x = x, memory = memory_unbatched), ps, st)
@info "output" y st
# Batched
x = randn(rng, Float32, (INPUT_DIM, INPUT_LEN, NUM_BATCH))
@info "input" x
y, st = model_decoder((; x, memory = memory_batched), ps, st)
@info "output" y st