From 38dacdacf1610454137b2c1c93d41d463be63b06 Mon Sep 17 00:00:00 2001 From: qwjyh Date: Sat, 19 Apr 2025 20:12:39 +0900 Subject: [PATCH] (WIP) implement Encoder and Decoder layer with built-in MultiHeadAttention TODO: fix layer norm --- transformers.jl | 125 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 transformers.jl diff --git a/transformers.jl b/transformers.jl new file mode 100644 index 0000000..2daef5f --- /dev/null +++ b/transformers.jl @@ -0,0 +1,125 @@ +using Lux + +struct TransformerEncoderLayer{LSA, LFFN} <: + LuxCore.AbstractLuxContainerLayer{(:self_attention, :feed_forward_network)} + self_attention::LSA + feed_forward_network::LFFN +end + +function TransformerEncoderLayer( + model_dim::Int, + nheads::Int, + feedforward_dim::Int = 2048, + dropout::T1 = 0.1, + activation::F = relu, + layer_norm_epsilon::T2 = 1.0f-5, +) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat} + self_attention = let dropout = dropout + self_attention = MultiHeadAttention(model_dim; nheads) + get_output = WrappedFunction(x -> x[1]) + dropout = Dropout(dropout) + res_conn = SkipConnection(Chain(; self_attention, get_output, dropout), +) + # TODO: fix layer norm arg + Chain(; res_conn, layer_norm = LayerNorm(model_dim; epsilon = layer_norm_epsilon)) + end + feed_forward_network = let dropout = dropout + feed_forward_network = Chain( + Dense(model_dim => feedforward_dim, activation), + Dropout(dropout), + Dense(feedforward_dim => model_dim, activation), + Dropout(dropout), + ) + res_conn = SkipConnection(feed_forward_network, +) + Chain(; res_conn, layer_norm = LayerNorm(model_dim; epsilon = layer_norm_epsilon)) + end + TransformerEncoderLayer(self_attention, feed_forward_network) +end + +function (encoder::TransformerEncoderLayer)(x, ps, st) + x, st_sa = encoder.self_attention(x, ps.self_attention, st.self_attention) + x, st_ffn = + encoder.feed_forward_network(x, ps.feed_forward_network, st.feed_forward_network) + return x, (self_attention = st_sa, feed_forward_network = st_ffn) +end + +struct TransformerDecoderLayer{LSA, LMHA, LFFN} <: LuxCore.AbstractLuxContainerLayer{( + :self_attention, + :multi_head_attention, + :feed_forward_network, +)} + self_attention::LSA + multi_head_attention::LMHA + feed_forward_network::LFFN +end + +function TransformerDecoderLayer( + model_dim::Int, + nheads::Int, + feedforward_dim::Int = 2048, + dropout::T1 = 0.1, + activation::F = relu, + layer_norm_epsilon::T2 = 1.0f-5, +) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat} + self_attention = let + self_attention = MultiHeadAttention(model_dim; nheads) + get_output = WrappedFunction(x -> x[1]) + dropout = Dropout(dropout) + res_conn = SkipConnection(Chain(; self_attention, get_output, dropout), +) + Chain(; res_conn, layer_norm = LayerNorm(model_dim; epsilon = layer_norm_epsilon)) + end + multi_head_attention = let + multi_head_attention = MultiHeadAttention(model_dim; nheads) + get_output = WrappedFunction(x -> x[1]) + dropout = Dropout(dropout) + res_conn = Parallel( + +, + Chain(; multi_head_attention, get_output, dropout), + WrappedFunction(x -> x[1]), + ) + Chain(; res_conn, layer_norm = LayerNorm(model_dim; epsilon = layer_norm_epsilon)) + end + feed_forward_network = let + feed_forward_network = Chain( + Dense(model_dim => feedforward_dim, activation), + Dropout(dropout), + Dense(feedforward_dim => model_dim, activation), + Dropout(dropout), + ) + res_conn = SkipConnection(feed_forward_network, +) + Chain(; res_conn, layer_norm = LayerNorm(model_dim; epsilon = layer_norm_epsilon)) + end + TransformerDecoderLayer(self_attention, multi_head_attention, feed_forward_network) +end + +function (decoder::TransformerDecoderLayer)(x, ps, st) + if (x isa Tuple && length(x) == 2) + throw( + ArgumentError( + "input to TransformerDecoderLayer must be tuple of target and memory (from encoder)", + ), + ) + end + + tgt = x[1] + memory = x[2] + + sa_mask = make_causal_mask(tgt) + y_sa, st_sa = decoder.self_attention( + (tgt, tgt, tgt, sa_mask), + ps.self_attention, + st.self_attention, + ) + y_mha, st_mha = decoder.multi_head_attention( + (y_sa, memory), + ps.multi_head_attention, + st.multi_head_attention, + ) + y_ffn, st_ffn = decoder.feed_forward_network( + y_mha, + ps.feed_forward_network, + st.feed_forward_network, + ) + + return y_ffn, + (self_attention = st_sa, multi_head_attention = st_mha, feed_forward_network = st_ffn) +end