new: implement transformer decoder layer
This commit is contained in:
parent
47944b830b
commit
f507418bc6
2 changed files with 146 additions and 5 deletions
|
@ -9,10 +9,10 @@ const NUM_HEADS = 1
|
|||
|
||||
rng = Random.default_rng()
|
||||
|
||||
model = TransformerEncoderLayer(10, 2)
|
||||
@info "model" model
|
||||
model_encoder = TransformerEncoderLayer(10, 2)
|
||||
@info "model" model_encoder
|
||||
|
||||
ps, st = LuxCore.setup(rng, model)
|
||||
ps, st = LuxCore.setup(rng, model_encoder)
|
||||
@info "parameters" ps
|
||||
@info "status" st
|
||||
|
||||
|
@ -20,12 +20,37 @@ ps, st = LuxCore.setup(rng, model)
|
|||
x = randn(rng, Float32, (10,))
|
||||
@info "input" size(x)
|
||||
|
||||
y, st = model(x, ps, st)
|
||||
y, st = model_encoder(x, ps, st)
|
||||
@info "output" y st
|
||||
|
||||
memory_unbatched = y
|
||||
|
||||
# Batched
|
||||
x = randn(rng, Float32, (10, 10))
|
||||
@info "input" size(x)
|
||||
|
||||
y, st = model(x, ps, st)
|
||||
y, st = model_encoder(x, ps, st)
|
||||
@info "output" y st
|
||||
|
||||
memory_batched = y
|
||||
|
||||
model_decoder = TransformerDecoderLayer(10, 2)
|
||||
@info "decoder model" model_decoder
|
||||
|
||||
ps, st = LuxCore.setup(rng, model_decoder)
|
||||
@info "parameter" ps
|
||||
@info "status" st
|
||||
|
||||
# Unbatched
|
||||
x = randn(rng, Float32, (10,))
|
||||
@info "input" x
|
||||
|
||||
y, st = model_decoder((; x = x, memory = memory_unbatched), ps, st)
|
||||
@info "output" y st
|
||||
|
||||
# Batched
|
||||
x = randn(rng, Float32, (10, 10))
|
||||
@info "input" x
|
||||
|
||||
y, st = model_decoder((; x, memory = memory_batched), ps, st)
|
||||
@info "output" y st
|
||||
|
|
116
transformers.jl
116
transformers.jl
|
@ -271,3 +271,119 @@ function (encoder::TransformerEncoderLayer)(x, ps, st)
|
|||
sublayer_feed_forward_network = st_sublayer_feed_forward_network,
|
||||
)
|
||||
end
|
||||
|
||||
"""
|
||||
TransformerDecoderLayer <: LuxCore.AbstractLuxContainerLayer
|
||||
|
||||
One layer of decoder block of Transformer model.
|
||||
"""
|
||||
struct TransformerDecoderLayer{LSA, LMHA, LFFN} <: LuxCore.AbstractLuxContainerLayer{(
|
||||
:sublayer_self_attention,
|
||||
:sublayer_multihead_attention,
|
||||
:sublayer_feed_forward_network,
|
||||
)}
|
||||
sublayer_self_attention::LSA
|
||||
sublayer_multihead_attention::LMHA
|
||||
sublayer_feed_forward_network::LFFN
|
||||
end
|
||||
|
||||
function TransformerDecoderLayer(
|
||||
model_dim::Int,
|
||||
num_heads::Int,
|
||||
feedforward_dim::Int = 2048,
|
||||
dropout::T1 = 0.1,
|
||||
activation::F = relu,
|
||||
layer_norm_epsilon::T2 = 1.0f-5,
|
||||
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
|
||||
sublayer_self_attention = let
|
||||
layer_split = Lux.WrappedFunction(x -> (q = x, k = x, v = x))
|
||||
layer_self_attention = MultiheadAttention(model_dim, num_heads)
|
||||
layer_dropout = Lux.Dropout(dropout)
|
||||
layer_residual_connection = Lux.SkipConnection(
|
||||
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
|
||||
+,
|
||||
)
|
||||
Lux.Chain(;
|
||||
layer_residual_connection,
|
||||
layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon),
|
||||
name = "SelfAttentionSubLayer",
|
||||
)
|
||||
end
|
||||
sublayer_multihead_attention = let
|
||||
layer_split =
|
||||
Lux.WrappedFunction(x::NamedTuple -> (q = x.x, k = x.memory, v = x.memory))
|
||||
layer_self_attention = MultiheadAttention(model_dim, num_heads)
|
||||
layer_dropout = Lux.Dropout(dropout)
|
||||
layer_residual_connection = Lux.Parallel(
|
||||
+,
|
||||
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
|
||||
Lux.WrappedFunction(x::NamedTuple -> x.x),
|
||||
)
|
||||
Lux.Chain(;
|
||||
layer_residual_connection,
|
||||
layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon),
|
||||
name = "MultiheadAttentionSubLayer",
|
||||
)
|
||||
end
|
||||
sublayer_feed_forward_network = let
|
||||
layer_feed_forward_network = Lux.Chain(
|
||||
Lux.Dense(model_dim => feedforward_dim, activation),
|
||||
Lux.Dropout(dropout),
|
||||
Lux.Dense(feedforward_dim => model_dim, activation),
|
||||
Lux.Dropout(dropout),
|
||||
)
|
||||
layer_residual_connection = Lux.SkipConnection(layer_feed_forward_network, +)
|
||||
Lux.Chain(;
|
||||
layer_residual_connection,
|
||||
layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon),
|
||||
name = "FeedForwardNetworkSubLayer",
|
||||
)
|
||||
end
|
||||
TransformerDecoderLayer(
|
||||
sublayer_self_attention,
|
||||
sublayer_multihead_attention,
|
||||
sublayer_feed_forward_network,
|
||||
)
|
||||
end
|
||||
|
||||
function (decoder::TransformerDecoderLayer)(x, ps, st)
|
||||
# hasproperty?
|
||||
haskey(x, :x) || throw(
|
||||
ArgumentError(
|
||||
"input to the decoder must have keyword argument \"x\"(output input)",
|
||||
),
|
||||
)
|
||||
haskey(x, :memory) || throw(
|
||||
ArgumentError(
|
||||
"input to the decoder must have keyword argument \"memory\"(output from encoder)",
|
||||
),
|
||||
)
|
||||
|
||||
x_sa, st_sublayer_self_attention = Lux.apply(
|
||||
decoder.sublayer_self_attention,
|
||||
x.x,
|
||||
ps.sublayer_self_attention,
|
||||
st.sublayer_self_attention,
|
||||
)
|
||||
x_mha, st_sublayer_multihead_attention = Lux.apply(
|
||||
decoder.sublayer_multihead_attention,
|
||||
(x = x_sa, memory = x.memory),
|
||||
ps.sublayer_multihead_attention,
|
||||
st.sublayer_multihead_attention,
|
||||
)
|
||||
y, st_sublayer_feed_forward_network = Lux.apply(
|
||||
decoder.sublayer_feed_forward_network,
|
||||
x_mha,
|
||||
ps.sublayer_feed_forward_network,
|
||||
st.sublayer_feed_forward_network,
|
||||
)
|
||||
|
||||
return (
|
||||
y,
|
||||
(
|
||||
sublayer_self_attention = st_sublayer_self_attention,
|
||||
sublayer_multihead_attention = st_sublayer_multihead_attention,
|
||||
sublayer_feed_forward_network = st_sublayer_feed_forward_network,
|
||||
),
|
||||
)
|
||||
end
|
||||
|
|
Loading…
Add table
Reference in a new issue