compile time determined Chain length
This commit is contained in:
parent
c86b8be5a2
commit
c7772df7c2
1 changed files with 5 additions and 3 deletions
|
@ -68,8 +68,10 @@ y, st = model_decoder((x, memory_batched), ps, st)
|
|||
# ====================================================================
|
||||
# RepeatedLayer
|
||||
# ====================================================================
|
||||
model_repeated_encoder =
|
||||
RepeatedLayer(TransformerEncoderLayer(INPUT_DIM, NUM_HEADS); repeats = Val(6))
|
||||
model_repeated_encoder = let
|
||||
layers = fill(:(TransformerEncoderLayer(INPUT_DIM, NUM_HEADS)), 6)
|
||||
eval(Expr(:call, :Chain, layers...))
|
||||
end
|
||||
x = randn(rng, Float32, (INPUT_DIM, INPUT_LEN, NUM_BATCH))
|
||||
ps, st = LuxCore.setup(rng, model_encoder)
|
||||
ps, st = LuxCore.setup(rng, model_repeated_encoder)
|
||||
@info "" model_repeated_encoder(x, ps, st)
|
||||
|
|
Loading…
Add table
Reference in a new issue