ml-test/transformers.jl

125 lines
4.3 KiB
Julia

using Lux
struct TransformerEncoderLayer{LSA, LFFN} <:
LuxCore.AbstractLuxContainerLayer{(:self_attention, :feed_forward_network)}
self_attention::LSA
feed_forward_network::LFFN
end
function TransformerEncoderLayer(
model_dim::Int,
nheads::Int,
feedforward_dim::Int = 2048,
dropout::T1 = 0.1,
activation::F = relu,
layer_norm_epsilon::T2 = 1.0f-5,
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
self_attention = let dropout = dropout
self_attention = MultiHeadAttention(model_dim; nheads)
get_output = WrappedFunction(x -> x[1])
dropout = Dropout(dropout)
res_conn = SkipConnection(Chain(; self_attention, get_output, dropout), +)
# TODO: fix layer norm arg
Chain(; res_conn, layer_norm = LayerNorm(model_dim; epsilon = layer_norm_epsilon))
end
feed_forward_network = let dropout = dropout
feed_forward_network = Chain(
Dense(model_dim => feedforward_dim, activation),
Dropout(dropout),
Dense(feedforward_dim => model_dim, activation),
Dropout(dropout),
)
res_conn = SkipConnection(feed_forward_network, +)
Chain(; res_conn, layer_norm = LayerNorm(model_dim; epsilon = layer_norm_epsilon))
end
TransformerEncoderLayer(self_attention, feed_forward_network)
end
function (encoder::TransformerEncoderLayer)(x, ps, st)
x, st_sa = encoder.self_attention(x, ps.self_attention, st.self_attention)
x, st_ffn =
encoder.feed_forward_network(x, ps.feed_forward_network, st.feed_forward_network)
return x, (self_attention = st_sa, feed_forward_network = st_ffn)
end
struct TransformerDecoderLayer{LSA, LMHA, LFFN} <: LuxCore.AbstractLuxContainerLayer{(
:self_attention,
:multi_head_attention,
:feed_forward_network,
)}
self_attention::LSA
multi_head_attention::LMHA
feed_forward_network::LFFN
end
function TransformerDecoderLayer(
model_dim::Int,
nheads::Int,
feedforward_dim::Int = 2048,
dropout::T1 = 0.1,
activation::F = relu,
layer_norm_epsilon::T2 = 1.0f-5,
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
self_attention = let
self_attention = MultiHeadAttention(model_dim; nheads)
get_output = WrappedFunction(x -> x[1])
dropout = Dropout(dropout)
res_conn = SkipConnection(Chain(; self_attention, get_output, dropout), +)
Chain(; res_conn, layer_norm = LayerNorm(model_dim; epsilon = layer_norm_epsilon))
end
multi_head_attention = let
multi_head_attention = MultiHeadAttention(model_dim; nheads)
get_output = WrappedFunction(x -> x[1])
dropout = Dropout(dropout)
res_conn = Parallel(
+,
Chain(; multi_head_attention, get_output, dropout),
WrappedFunction(x -> x[1]),
)
Chain(; res_conn, layer_norm = LayerNorm(model_dim; epsilon = layer_norm_epsilon))
end
feed_forward_network = let
feed_forward_network = Chain(
Dense(model_dim => feedforward_dim, activation),
Dropout(dropout),
Dense(feedforward_dim => model_dim, activation),
Dropout(dropout),
)
res_conn = SkipConnection(feed_forward_network, +)
Chain(; res_conn, layer_norm = LayerNorm(model_dim; epsilon = layer_norm_epsilon))
end
TransformerDecoderLayer(self_attention, multi_head_attention, feed_forward_network)
end
function (decoder::TransformerDecoderLayer)(x, ps, st)
if (x isa Tuple && length(x) == 2)
throw(
ArgumentError(
"input to TransformerDecoderLayer must be tuple of target and memory (from encoder)",
),
)
end
tgt = x[1]
memory = x[2]
sa_mask = make_causal_mask(tgt)
y_sa, st_sa = decoder.self_attention(
(tgt, tgt, tgt, sa_mask),
ps.self_attention,
st.self_attention,
)
y_mha, st_mha = decoder.multi_head_attention(
(y_sa, memory),
ps.multi_head_attention,
st.multi_head_attention,
)
y_ffn, st_ffn = decoder.feed_forward_network(
y_mha,
ps.feed_forward_network,
st.feed_forward_network,
)
return y_ffn,
(self_attention = st_sa, multi_head_attention = st_mha, feed_forward_network = st_ffn)
end