56 lines
1 KiB
Julia
56 lines
1 KiB
Julia
# test code
|
|
|
|
true || include("transformers.jl")
|
|
|
|
using Random
|
|
|
|
const INPUT_DIM = 10
|
|
const NUM_HEADS = 1
|
|
|
|
rng = Random.default_rng()
|
|
|
|
model_encoder = TransformerEncoderLayer(10, 2)
|
|
@info "model" model_encoder
|
|
|
|
ps, st = LuxCore.setup(rng, model_encoder)
|
|
@info "parameters" ps
|
|
@info "status" st
|
|
|
|
# Unbatched
|
|
x = randn(rng, Float32, (10,))
|
|
@info "input" size(x)
|
|
|
|
y, st = model_encoder(x, ps, st)
|
|
@info "output" y st
|
|
|
|
memory_unbatched = y
|
|
|
|
# Batched
|
|
x = randn(rng, Float32, (10, 10))
|
|
@info "input" size(x)
|
|
|
|
y, st = model_encoder(x, ps, st)
|
|
@info "output" y st
|
|
|
|
memory_batched = y
|
|
|
|
model_decoder = TransformerDecoderLayer(10, 2)
|
|
@info "decoder model" model_decoder
|
|
|
|
ps, st = LuxCore.setup(rng, model_decoder)
|
|
@info "parameter" ps
|
|
@info "status" st
|
|
|
|
# Unbatched
|
|
x = randn(rng, Float32, (10,))
|
|
@info "input" x
|
|
|
|
y, st = model_decoder((; x = x, memory = memory_unbatched), ps, st)
|
|
@info "output" y st
|
|
|
|
# Batched
|
|
x = randn(rng, Float32, (10, 10))
|
|
@info "input" x
|
|
|
|
y, st = model_decoder((; x, memory = memory_batched), ps, st)
|
|
@info "output" y st
|