new(transformer): implement TransformerEncoderLayer and add test
This commit is contained in:
parent
e3765c80d7
commit
5bdcb56725
2 changed files with 164 additions and 18 deletions
|
@ -4,22 +4,28 @@ true || include("transformers.jl")
|
|||
|
||||
using Random
|
||||
|
||||
const EMBED_DIM = 10
|
||||
const NUM_HEADS = 2
|
||||
const KV_DIM = 12
|
||||
const INPUT_DIM = 10
|
||||
const NUM_HEADS = 1
|
||||
|
||||
rng = TaskLocalRNG()
|
||||
rng = Random.default_rng()
|
||||
|
||||
l = MultiheadAttention(EMBED_DIM, NUM_HEADS, kvdim = KV_DIM)
|
||||
@info "layer" l
|
||||
model = TransformerEncoderLayer(10, 2)
|
||||
@info "model" model
|
||||
|
||||
ps = LuxCore.initialparameters(rng, l)
|
||||
st = LuxCore.initialstates(rng, l)
|
||||
@info "parameters and states" ps st
|
||||
ps, st = LuxCore.setup(rng, model)
|
||||
@info "parameters" ps
|
||||
@info "status" st
|
||||
|
||||
q = rand(rng, Float32, (EMBED_DIM,))
|
||||
k = rand(rng, Float32, (KV_DIM,))
|
||||
v = rand(rng, Float32, (KV_DIM,))
|
||||
@info "q k v" summary.((q, k, v))
|
||||
# Unbatched
|
||||
x = randn(rng, Float32, (10,))
|
||||
@info "input" size(x)
|
||||
|
||||
l((; q, k, v), ps, st)
|
||||
y, st = model(x, ps, st)
|
||||
@info "output" y st
|
||||
|
||||
# Batched
|
||||
x = randn(rng, Float32, (10, 10))
|
||||
@info "input" size(x)
|
||||
|
||||
y, st = model(x, ps, st)
|
||||
@info "output" y st
|
||||
|
|
148
transformers.jl
148
transformers.jl
|
@ -3,6 +3,10 @@ using Random: AbstractRNG
|
|||
using Lux
|
||||
|
||||
"""
|
||||
MultiheadAttention{F} <: LuxCore.AbstractLuxLayer
|
||||
|
||||
Multi head attention layer used in Transformer model.
|
||||
|
||||
# Fields
|
||||
- `embed_dim`: Queue vector length. `q_len`
|
||||
- `kvdim::Int`: Key vector and value vector length.
|
||||
|
@ -63,6 +67,7 @@ Constructor.
|
|||
MultiheadAttention has no states.
|
||||
|
||||
# Inputs
|
||||
NamedTuple of these three variables.
|
||||
- `q`: ``Q``
|
||||
- `k`: ``K``
|
||||
- `v`: ``V``
|
||||
|
@ -97,6 +102,18 @@ function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention)
|
|||
NamedTuple()
|
||||
end
|
||||
|
||||
function LuxCore.parameterlength(l::MultiheadAttention)
|
||||
dim_weight_q = l.embed_dim * l.num_heads * l.qdim
|
||||
dim_weight_k = l.embed_dim * l.num_heads * l.kvdim
|
||||
dim_weight_v = l.v_embed_dim * l.num_heads * l.kvdim
|
||||
dim_weight_o = l.embed_dim * l.v_embed_dim * l.num_heads
|
||||
dim_weight_q + dim_weight_k + dim_weight_v + dim_weight_o
|
||||
end
|
||||
|
||||
function LuxCore.statelength(l::MultiheadAttention)
|
||||
0
|
||||
end
|
||||
|
||||
function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
|
||||
if size(x.q, 1) != l.embed_dim
|
||||
ArgumentError(
|
||||
|
@ -125,9 +142,132 @@ function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
|
|||
# [y] = (v_dim, q_len, batch_size...)
|
||||
# [α] = (kv_len, q_len, nheads, batch_size...)
|
||||
y, α = dot_product_attention(q, k, v; nheads = l.num_heads)
|
||||
ps.weight_o * y
|
||||
ps.weight_o * y, _st
|
||||
end
|
||||
|
||||
# struct TransformerEncoder <: LuxCore.AbstractLuxContainerLayer{}
|
||||
# end
|
||||
#
|
||||
"""
|
||||
TransformerEncoderLayer <: LuxCore.AbstractLuxLayer
|
||||
|
||||
One layer of encoder block of Transformer model.
|
||||
|
||||
# Structure
|
||||
|
||||
It consists of two sublayers, `MultiheadAttention` and feed forward network(two Dense layers).
|
||||
Both of them are wrapped with residual connection and followed by [`Lux.Dropout`](@ref) and [`Lux.LayerNorm`](@ref).
|
||||
|
||||
They are combined using [`Chain`](@ref) and [`SkipConnection`](@ref).
|
||||
|
||||
See the constructor of this layer or displayed layer to see the structure.
|
||||
|
||||
## MultiheadAttention sublayer
|
||||
|
||||
```
|
||||
========= SkipConnection ==========
|
||||
-> MultiheadAttention -> Dropout -->
|
||||
-- +(add) -> LayerNorm
|
||||
-> -------------------------------->
|
||||
```
|
||||
|
||||
## FeedForwardNetworkSubLayer
|
||||
|
||||
The core is two chained `Dense` layer, which has internal dimension `feedforward_dim`.
|
||||
Then this is wrapped with `SkipConnection` and followed by `LayerNorm`.
|
||||
|
||||
# Parameters & States
|
||||
|
||||
Since this layer is implemented as [`LuxCore.AbstractLuxContainerLayer`](@ref),
|
||||
both parameters and states have structured as the container layer.
|
||||
So see the "Structure" section for more detail.
|
||||
|
||||
# Inputs
|
||||
|
||||
AbstractArray with `size(x, 1) == model_dim`. (same as Dense)
|
||||
|
||||
"""
|
||||
struct TransformerEncoderLayer{LSA, LFFN} <: LuxCore.AbstractLuxContainerLayer{(
|
||||
:sublayer_self_attention,
|
||||
:sublayer_feed_forward_network,
|
||||
)}
|
||||
sublayer_self_attention::LSA
|
||||
sublayer_feed_forward_network::LFFN
|
||||
end
|
||||
|
||||
"""
|
||||
TransformerEncoderLayer(
|
||||
model_dim::Int,
|
||||
num_heads::Int;
|
||||
feedforward_dim::Int = 2048,
|
||||
dropout::T1 = 0.1,
|
||||
activation::F = relu,
|
||||
layer_norm_eps::T2 = 1e-5,
|
||||
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
|
||||
|
||||
Constructor of [`TransformerEncoderLayer`](@ref).
|
||||
|
||||
# Arguments
|
||||
- `model_dim::Int`: model input size
|
||||
- `num_heads::Int`: number of heads of MultiheadAttention
|
||||
- `feedforward_dim::Int = 2048`: dimension of feed forward network
|
||||
- `dropout::T1 = 0.1`: dropout rate for all Dropout layers
|
||||
- `activation::F = relu`: activation used in feed forward network
|
||||
- `layer_norm_epsilon::T2 = 1e-5`: eps of LayerNorm
|
||||
"""
|
||||
function TransformerEncoderLayer(
|
||||
model_dim::Int,
|
||||
num_heads::Int;
|
||||
feedforward_dim::Int = 2048,
|
||||
dropout::T1 = 0.1,
|
||||
activation::F = relu,
|
||||
layer_norm_epsilon::T2 = 1.0f-5,
|
||||
) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat}
|
||||
sublayer_self_attention = let
|
||||
layer_split = Lux.WrappedFunction(x -> (q = x, k = x, v = x))
|
||||
layer_self_attention = MultiheadAttention(model_dim, num_heads)
|
||||
layer_dropout = Lux.Dropout(dropout)
|
||||
layer_residual_connection = Lux.SkipConnection(
|
||||
Lux.Chain(; layer_split, layer_self_attention, layer_dropout),
|
||||
+,
|
||||
)
|
||||
Lux.Chain(;
|
||||
layer_residual_connection,
|
||||
layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon),
|
||||
name = "SelfAttentionSubLayer",
|
||||
)
|
||||
end
|
||||
sublayer_feed_forward_network = let
|
||||
layer_feed_forward_network = Lux.Chain(
|
||||
Lux.Dense(model_dim => feedforward_dim, activation),
|
||||
Lux.Dropout(dropout),
|
||||
Lux.Dense(feedforward_dim => model_dim, activation),
|
||||
Lux.Dropout(dropout),
|
||||
)
|
||||
layer_residual_connection = Lux.SkipConnection(layer_feed_forward_network, +)
|
||||
Lux.Chain(;
|
||||
layer_residual_connection,
|
||||
layer_layer_norm = Lux.LayerNorm(model_dim; epsilon = layer_norm_epsilon),
|
||||
name = "FeedForwardNetworkSubLayer",
|
||||
)
|
||||
end
|
||||
TransformerEncoderLayer(sublayer_self_attention, sublayer_feed_forward_network)
|
||||
end
|
||||
|
||||
function (encoder::TransformerEncoderLayer)(x, ps, st)
|
||||
x, st_sublayer_self_attention = Lux.apply(
|
||||
encoder.sublayer_self_attention,
|
||||
x,
|
||||
ps.sublayer_self_attention,
|
||||
st.sublayer_self_attention,
|
||||
)
|
||||
x, st_sublayer_feed_forward_network = Lux.apply(
|
||||
encoder.sublayer_feed_forward_network,
|
||||
x,
|
||||
ps.sublayer_feed_forward_network,
|
||||
st.sublayer_feed_forward_network,
|
||||
)
|
||||
|
||||
return x,
|
||||
(
|
||||
sublayer_self_attention = st_sublayer_self_attention,
|
||||
sublayer_feed_forward_network = st_sublayer_feed_forward_network,
|
||||
)
|
||||
end
|
||||
|
|
Loading…
Reference in a new issue