(WIP) implement Encoder and Decoder layer with built-in MultiHeadAttention
TODO: fix layer norm
This commit is contained in:
parent
5b21d5b034
commit
38dacdacf1
1 changed files with 125 additions and 0 deletions
125
transformers.jl
Normal file
125
transformers.jl
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
using Lux
|
||||||
|
|
||||||
|
struct TransformerEncoderLayer{LSA, LFFN} <:
|
||||||
|
LuxCore.AbstractLuxContainerLayer{(:self_attention, :feed_forward_network)}
|
||||||
|
self_attention::LSA
|
||||||
|
feed_forward_network::LFFN
|
||||||
|
end
|
||||||
|
|
||||||
|
function TransformerEncoderLayer(
|
||||||
|
model_dim::Int,
|
||||||
|
nheads::Int,
|
||||||
|
feedforward_dim::Int = 2048,
|
||||||
|
dropout::T1 = 0.1,
|
||||||
|
activation::F = relu,
|
||||||
|
layer_norm_epsilon::T2 = 1.0f-5,
|
||||||
|
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
|
||||||
|
self_attention = let dropout = dropout
|
||||||
|
self_attention = MultiHeadAttention(model_dim; nheads)
|
||||||
|
get_output = WrappedFunction(x -> x[1])
|
||||||
|
dropout = Dropout(dropout)
|
||||||
|
res_conn = SkipConnection(Chain(; self_attention, get_output, dropout), +)
|
||||||
|
# TODO: fix layer norm arg
|
||||||
|
Chain(; res_conn, layer_norm = LayerNorm(model_dim; epsilon = layer_norm_epsilon))
|
||||||
|
end
|
||||||
|
feed_forward_network = let dropout = dropout
|
||||||
|
feed_forward_network = Chain(
|
||||||
|
Dense(model_dim => feedforward_dim, activation),
|
||||||
|
Dropout(dropout),
|
||||||
|
Dense(feedforward_dim => model_dim, activation),
|
||||||
|
Dropout(dropout),
|
||||||
|
)
|
||||||
|
res_conn = SkipConnection(feed_forward_network, +)
|
||||||
|
Chain(; res_conn, layer_norm = LayerNorm(model_dim; epsilon = layer_norm_epsilon))
|
||||||
|
end
|
||||||
|
TransformerEncoderLayer(self_attention, feed_forward_network)
|
||||||
|
end
|
||||||
|
|
||||||
|
function (encoder::TransformerEncoderLayer)(x, ps, st)
|
||||||
|
x, st_sa = encoder.self_attention(x, ps.self_attention, st.self_attention)
|
||||||
|
x, st_ffn =
|
||||||
|
encoder.feed_forward_network(x, ps.feed_forward_network, st.feed_forward_network)
|
||||||
|
return x, (self_attention = st_sa, feed_forward_network = st_ffn)
|
||||||
|
end
|
||||||
|
|
||||||
|
struct TransformerDecoderLayer{LSA, LMHA, LFFN} <: LuxCore.AbstractLuxContainerLayer{(
|
||||||
|
:self_attention,
|
||||||
|
:multi_head_attention,
|
||||||
|
:feed_forward_network,
|
||||||
|
)}
|
||||||
|
self_attention::LSA
|
||||||
|
multi_head_attention::LMHA
|
||||||
|
feed_forward_network::LFFN
|
||||||
|
end
|
||||||
|
|
||||||
|
function TransformerDecoderLayer(
|
||||||
|
model_dim::Int,
|
||||||
|
nheads::Int,
|
||||||
|
feedforward_dim::Int = 2048,
|
||||||
|
dropout::T1 = 0.1,
|
||||||
|
activation::F = relu,
|
||||||
|
layer_norm_epsilon::T2 = 1.0f-5,
|
||||||
|
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
|
||||||
|
self_attention = let
|
||||||
|
self_attention = MultiHeadAttention(model_dim; nheads)
|
||||||
|
get_output = WrappedFunction(x -> x[1])
|
||||||
|
dropout = Dropout(dropout)
|
||||||
|
res_conn = SkipConnection(Chain(; self_attention, get_output, dropout), +)
|
||||||
|
Chain(; res_conn, layer_norm = LayerNorm(model_dim; epsilon = layer_norm_epsilon))
|
||||||
|
end
|
||||||
|
multi_head_attention = let
|
||||||
|
multi_head_attention = MultiHeadAttention(model_dim; nheads)
|
||||||
|
get_output = WrappedFunction(x -> x[1])
|
||||||
|
dropout = Dropout(dropout)
|
||||||
|
res_conn = Parallel(
|
||||||
|
+,
|
||||||
|
Chain(; multi_head_attention, get_output, dropout),
|
||||||
|
WrappedFunction(x -> x[1]),
|
||||||
|
)
|
||||||
|
Chain(; res_conn, layer_norm = LayerNorm(model_dim; epsilon = layer_norm_epsilon))
|
||||||
|
end
|
||||||
|
feed_forward_network = let
|
||||||
|
feed_forward_network = Chain(
|
||||||
|
Dense(model_dim => feedforward_dim, activation),
|
||||||
|
Dropout(dropout),
|
||||||
|
Dense(feedforward_dim => model_dim, activation),
|
||||||
|
Dropout(dropout),
|
||||||
|
)
|
||||||
|
res_conn = SkipConnection(feed_forward_network, +)
|
||||||
|
Chain(; res_conn, layer_norm = LayerNorm(model_dim; epsilon = layer_norm_epsilon))
|
||||||
|
end
|
||||||
|
TransformerDecoderLayer(self_attention, multi_head_attention, feed_forward_network)
|
||||||
|
end
|
||||||
|
|
||||||
|
function (decoder::TransformerDecoderLayer)(x, ps, st)
|
||||||
|
if (x isa Tuple && length(x) == 2)
|
||||||
|
throw(
|
||||||
|
ArgumentError(
|
||||||
|
"input to TransformerDecoderLayer must be tuple of target and memory (from encoder)",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
tgt = x[1]
|
||||||
|
memory = x[2]
|
||||||
|
|
||||||
|
sa_mask = make_causal_mask(tgt)
|
||||||
|
y_sa, st_sa = decoder.self_attention(
|
||||||
|
(tgt, tgt, tgt, sa_mask),
|
||||||
|
ps.self_attention,
|
||||||
|
st.self_attention,
|
||||||
|
)
|
||||||
|
y_mha, st_mha = decoder.multi_head_attention(
|
||||||
|
(y_sa, memory),
|
||||||
|
ps.multi_head_attention,
|
||||||
|
st.multi_head_attention,
|
||||||
|
)
|
||||||
|
y_ffn, st_ffn = decoder.feed_forward_network(
|
||||||
|
y_mha,
|
||||||
|
ps.feed_forward_network,
|
||||||
|
st.feed_forward_network,
|
||||||
|
)
|
||||||
|
|
||||||
|
return y_ffn,
|
||||||
|
(self_attention = st_sa, multi_head_attention = st_mha, feed_forward_network = st_ffn)
|
||||||
|
end
|
Loading…
Add table
Reference in a new issue