(WIP) implement Encoder and Decoder layer with built-in MultiHeadAttention
TODO: fix layer norm
This commit is contained in:
parent
5b21d5b034
commit
38dacdacf1
1 changed files with 125 additions and 0 deletions
125
transformers.jl
Normal file
125
transformers.jl
Normal file
|
@ -0,0 +1,125 @@
|
|||
using Lux
|
||||
|
||||
struct TransformerEncoderLayer{LSA, LFFN} <:
|
||||
LuxCore.AbstractLuxContainerLayer{(:self_attention, :feed_forward_network)}
|
||||
self_attention::LSA
|
||||
feed_forward_network::LFFN
|
||||
end
|
||||
|
||||
function TransformerEncoderLayer(
|
||||
model_dim::Int,
|
||||
nheads::Int,
|
||||
feedforward_dim::Int = 2048,
|
||||
dropout::T1 = 0.1,
|
||||
activation::F = relu,
|
||||
layer_norm_epsilon::T2 = 1.0f-5,
|
||||
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
|
||||
self_attention = let dropout = dropout
|
||||
self_attention = MultiHeadAttention(model_dim; nheads)
|
||||
get_output = WrappedFunction(x -> x[1])
|
||||
dropout = Dropout(dropout)
|
||||
res_conn = SkipConnection(Chain(; self_attention, get_output, dropout), +)
|
||||
# TODO: fix layer norm arg
|
||||
Chain(; res_conn, layer_norm = LayerNorm(model_dim; epsilon = layer_norm_epsilon))
|
||||
end
|
||||
feed_forward_network = let dropout = dropout
|
||||
feed_forward_network = Chain(
|
||||
Dense(model_dim => feedforward_dim, activation),
|
||||
Dropout(dropout),
|
||||
Dense(feedforward_dim => model_dim, activation),
|
||||
Dropout(dropout),
|
||||
)
|
||||
res_conn = SkipConnection(feed_forward_network, +)
|
||||
Chain(; res_conn, layer_norm = LayerNorm(model_dim; epsilon = layer_norm_epsilon))
|
||||
end
|
||||
TransformerEncoderLayer(self_attention, feed_forward_network)
|
||||
end
|
||||
|
||||
function (encoder::TransformerEncoderLayer)(x, ps, st)
|
||||
x, st_sa = encoder.self_attention(x, ps.self_attention, st.self_attention)
|
||||
x, st_ffn =
|
||||
encoder.feed_forward_network(x, ps.feed_forward_network, st.feed_forward_network)
|
||||
return x, (self_attention = st_sa, feed_forward_network = st_ffn)
|
||||
end
|
||||
|
||||
struct TransformerDecoderLayer{LSA, LMHA, LFFN} <: LuxCore.AbstractLuxContainerLayer{(
|
||||
:self_attention,
|
||||
:multi_head_attention,
|
||||
:feed_forward_network,
|
||||
)}
|
||||
self_attention::LSA
|
||||
multi_head_attention::LMHA
|
||||
feed_forward_network::LFFN
|
||||
end
|
||||
|
||||
function TransformerDecoderLayer(
|
||||
model_dim::Int,
|
||||
nheads::Int,
|
||||
feedforward_dim::Int = 2048,
|
||||
dropout::T1 = 0.1,
|
||||
activation::F = relu,
|
||||
layer_norm_epsilon::T2 = 1.0f-5,
|
||||
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
|
||||
self_attention = let
|
||||
self_attention = MultiHeadAttention(model_dim; nheads)
|
||||
get_output = WrappedFunction(x -> x[1])
|
||||
dropout = Dropout(dropout)
|
||||
res_conn = SkipConnection(Chain(; self_attention, get_output, dropout), +)
|
||||
Chain(; res_conn, layer_norm = LayerNorm(model_dim; epsilon = layer_norm_epsilon))
|
||||
end
|
||||
multi_head_attention = let
|
||||
multi_head_attention = MultiHeadAttention(model_dim; nheads)
|
||||
get_output = WrappedFunction(x -> x[1])
|
||||
dropout = Dropout(dropout)
|
||||
res_conn = Parallel(
|
||||
+,
|
||||
Chain(; multi_head_attention, get_output, dropout),
|
||||
WrappedFunction(x -> x[1]),
|
||||
)
|
||||
Chain(; res_conn, layer_norm = LayerNorm(model_dim; epsilon = layer_norm_epsilon))
|
||||
end
|
||||
feed_forward_network = let
|
||||
feed_forward_network = Chain(
|
||||
Dense(model_dim => feedforward_dim, activation),
|
||||
Dropout(dropout),
|
||||
Dense(feedforward_dim => model_dim, activation),
|
||||
Dropout(dropout),
|
||||
)
|
||||
res_conn = SkipConnection(feed_forward_network, +)
|
||||
Chain(; res_conn, layer_norm = LayerNorm(model_dim; epsilon = layer_norm_epsilon))
|
||||
end
|
||||
TransformerDecoderLayer(self_attention, multi_head_attention, feed_forward_network)
|
||||
end
|
||||
|
||||
function (decoder::TransformerDecoderLayer)(x, ps, st)
|
||||
if (x isa Tuple && length(x) == 2)
|
||||
throw(
|
||||
ArgumentError(
|
||||
"input to TransformerDecoderLayer must be tuple of target and memory (from encoder)",
|
||||
),
|
||||
)
|
||||
end
|
||||
|
||||
tgt = x[1]
|
||||
memory = x[2]
|
||||
|
||||
sa_mask = make_causal_mask(tgt)
|
||||
y_sa, st_sa = decoder.self_attention(
|
||||
(tgt, tgt, tgt, sa_mask),
|
||||
ps.self_attention,
|
||||
st.self_attention,
|
||||
)
|
||||
y_mha, st_mha = decoder.multi_head_attention(
|
||||
(y_sa, memory),
|
||||
ps.multi_head_attention,
|
||||
st.multi_head_attention,
|
||||
)
|
||||
y_ffn, st_ffn = decoder.feed_forward_network(
|
||||
y_mha,
|
||||
ps.feed_forward_network,
|
||||
st.feed_forward_network,
|
||||
)
|
||||
|
||||
return y_ffn,
|
||||
(self_attention = st_sa, multi_head_attention = st_mha, feed_forward_network = st_ffn)
|
||||
end
|
Loading…
Add table
Reference in a new issue