58 lines
1.1 KiB
Julia
58 lines
1.1 KiB
Julia
# test code
|
|
|
|
true || include("transformers.jl")
|
|
|
|
using Random
|
|
|
|
const INPUT_DIM = 100
|
|
const NUM_HEADS = 20
|
|
const NUM_BATCH = 10
|
|
const INPUT_LEN = 5
|
|
|
|
rng = Random.default_rng()
|
|
|
|
model_encoder = TransformerEncoderLayer(INPUT_DIM, NUM_HEADS)
|
|
@info "model" model_encoder
|
|
|
|
ps, st = LuxCore.setup(rng, model_encoder)
|
|
@info "parameters" ps
|
|
@info "status" st
|
|
|
|
# Unbatched
|
|
x = randn(rng, Float32, (INPUT_DIM,))
|
|
@info "input" size(x)
|
|
|
|
y, st = model_encoder(x, ps, st)
|
|
@info "output" y st
|
|
|
|
memory_unbatched = y
|
|
|
|
# Batched
|
|
x = randn(rng, Float32, (INPUT_DIM, INPUT_LEN, NUM_BATCH))
|
|
@info "input" size(x)
|
|
|
|
y, st = model_encoder(x, ps, st)
|
|
@info "output" y st
|
|
|
|
memory_batched = y
|
|
|
|
model_decoder = TransformerDecoderLayer(INPUT_DIM, NUM_HEADS)
|
|
@info "decoder model" model_decoder
|
|
|
|
ps, st = LuxCore.setup(rng, model_decoder)
|
|
@info "parameter" ps
|
|
@info "status" st
|
|
|
|
# Unbatched
|
|
x = randn(rng, Float32, (INPUT_DIM,))
|
|
@info "input" x
|
|
|
|
y, st = model_decoder((; x = x, memory = memory_unbatched), ps, st)
|
|
@info "output" y st
|
|
|
|
# Batched
|
|
x = randn(rng, Float32, (INPUT_DIM, INPUT_LEN, NUM_BATCH))
|
|
@info "input" x
|
|
|
|
y, st = model_decoder((; x, memory = memory_batched), ps, st)
|
|
@info "output" y st
|