2025-01-28 22:19:46 +09:00
|
|
|
# test code
|
|
|
|
|
|
|
|
true || include("transformers.jl")
|
|
|
|
|
|
|
|
using Random
|
|
|
|
|
2025-02-15 18:12:55 +09:00
|
|
|
const INPUT_DIM = 10
|
|
|
|
const NUM_HEADS = 1
|
2025-01-28 22:19:46 +09:00
|
|
|
|
2025-02-15 18:12:55 +09:00
|
|
|
rng = Random.default_rng()
|
2025-01-28 22:19:46 +09:00
|
|
|
|
2025-02-15 18:12:55 +09:00
|
|
|
model = TransformerEncoderLayer(10, 2)
|
|
|
|
@info "model" model
|
2025-01-28 22:19:46 +09:00
|
|
|
|
2025-02-15 18:12:55 +09:00
|
|
|
ps, st = LuxCore.setup(rng, model)
|
|
|
|
@info "parameters" ps
|
|
|
|
@info "status" st
|
2025-01-28 22:19:46 +09:00
|
|
|
|
2025-02-15 18:12:55 +09:00
|
|
|
# Unbatched
|
|
|
|
x = randn(rng, Float32, (10,))
|
|
|
|
@info "input" size(x)
|
2025-01-28 22:19:46 +09:00
|
|
|
|
2025-02-15 18:12:55 +09:00
|
|
|
y, st = model(x, ps, st)
|
|
|
|
@info "output" y st
|
|
|
|
|
|
|
|
# Batched
|
|
|
|
x = randn(rng, Float32, (10, 10))
|
|
|
|
@info "input" size(x)
|
|
|
|
|
|
|
|
y, st = model(x, ps, st)
|
|
|
|
@info "output" y st
|