diff --git a/transformers.jl b/transformers.jl index bb7b17f..ef0a9f4 100644 --- a/transformers.jl +++ b/transformers.jl @@ -54,8 +54,18 @@ Constructor. # Parameters and states ## Parameters -- `weight_q` -TODO +- `weight_q`: ``W^Q`` +- `weight_k`: ``W^K`` +- `weight_v`: ``W^V`` +- `weight_o`: ``W^O`` + +## States +MultiheadAttention has no states. + +# Inputs +- `q`: ``Q`` +- `k`: ``K`` +- `v`: ``V`` """ function MultiheadAttention( embed_dim::Int, @@ -87,7 +97,7 @@ function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention) NamedTuple() end -function (l::MultiheadAttention)(x::NamedTuple, ps, st::NamedTuple) +function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple) if size(x.q, 1) != l.embed_dim ArgumentError( "Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",