compile time determined Chain length

This commit is contained in:
Wataru Otsubo 2025-04-26 21:59:46 +09:00
parent c86b8be5a2
commit c7772df7c2

View file

@ -68,8 +68,10 @@ y, st = model_decoder((x, memory_batched), ps, st)
# ====================================================================
# RepeatedLayer
# ====================================================================
model_repeated_encoder =
RepeatedLayer(TransformerEncoderLayer(INPUT_DIM, NUM_HEADS); repeats = Val(6))
model_repeated_encoder = let
layers = fill(:(TransformerEncoderLayer(INPUT_DIM, NUM_HEADS)), 6)
eval(Expr(:call, :Chain, layers...))
end
x = randn(rng, Float32, (INPUT_DIM, INPUT_LEN, NUM_BATCH))
ps, st = LuxCore.setup(rng, model_encoder)
ps, st = LuxCore.setup(rng, model_repeated_encoder)
@info "" model_repeated_encoder(x, ps, st)