From de6634e3f76eccedfad514a69784ccff5cb82c6e Mon Sep 17 00:00:00 2001 From: qwjyh Date: Mon, 3 Feb 2025 22:12:34 +0900 Subject: [PATCH] update doc on MultiheadAttention --- transformers.jl | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/transformers.jl b/transformers.jl index bb7b17f..ef0a9f4 100644 --- a/transformers.jl +++ b/transformers.jl @@ -54,8 +54,18 @@ Constructor. # Parameters and states ## Parameters -- `weight_q` -TODO +- `weight_q`: ``W^Q`` +- `weight_k`: ``W^K`` +- `weight_v`: ``W^V`` +- `weight_o`: ``W^O`` + +## States +MultiheadAttention has no states. + +# Inputs +- `q`: ``Q`` +- `k`: ``K`` +- `v`: ``V`` """ function MultiheadAttention( embed_dim::Int, @@ -87,7 +97,7 @@ function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention) NamedTuple() end -function (l::MultiheadAttention)(x::NamedTuple, ps, st::NamedTuple) +function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple) if size(x.q, 1) != l.embed_dim ArgumentError( "Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",