update doc on MultiheadAttention

This commit is contained in:
qwjyh 2025-02-03 22:12:34 +09:00
parent c2c2872b46
commit de6634e3f7

View file

@ -54,8 +54,18 @@ Constructor.
# Parameters and states
## Parameters
- `weight_q`
TODO
- `weight_q`: ``W^Q``
- `weight_k`: ``W^K``
- `weight_v`: ``W^V``
- `weight_o`: ``W^O``
## States
MultiheadAttention has no states.
# Inputs
- `q`: ``Q``
- `k`: ``K``
- `v`: ``V``
"""
function MultiheadAttention(
embed_dim::Int,
@ -87,7 +97,7 @@ function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention)
NamedTuple()
end
function (l::MultiheadAttention)(x::NamedTuple, ps, st::NamedTuple)
function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
if size(x.q, 1) != l.embed_dim
ArgumentError(
"Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",