# test code true || include("transformers.jl") using Random const INPUT_DIM = 10 const NUM_HEADS = 1 rng = Random.default_rng() model_encoder = TransformerEncoderLayer(10, 2) @info "model" model_encoder ps, st = LuxCore.setup(rng, model_encoder) @info "parameters" ps @info "status" st # Unbatched x = randn(rng, Float32, (10,)) @info "input" size(x) y, st = model_encoder(x, ps, st) @info "output" y st memory_unbatched = y # Batched x = randn(rng, Float32, (10, 10)) @info "input" size(x) y, st = model_encoder(x, ps, st) @info "output" y st memory_batched = y model_decoder = TransformerDecoderLayer(10, 2) @info "decoder model" model_decoder ps, st = LuxCore.setup(rng, model_decoder) @info "parameter" ps @info "status" st # Unbatched x = randn(rng, Float32, (10,)) @info "input" x y, st = model_decoder((; x = x, memory = memory_unbatched), ps, st) @info "output" y st # Batched x = randn(rng, Float32, (10, 10)) @info "input" x y, st = model_decoder((; x, memory = memory_batched), ps, st) @info "output" y st