update doc on MultiheadAttention
This commit is contained in:
parent
c2c2872b46
commit
de6634e3f7
1 changed files with 13 additions and 3 deletions
|
@ -54,8 +54,18 @@ Constructor.
|
|||
# Parameters and states
|
||||
|
||||
## Parameters
|
||||
- `weight_q`
|
||||
TODO
|
||||
- `weight_q`: ``W^Q``
|
||||
- `weight_k`: ``W^K``
|
||||
- `weight_v`: ``W^V``
|
||||
- `weight_o`: ``W^O``
|
||||
|
||||
## States
|
||||
MultiheadAttention has no states.
|
||||
|
||||
# Inputs
|
||||
- `q`: ``Q``
|
||||
- `k`: ``K``
|
||||
- `v`: ``V``
|
||||
"""
|
||||
function MultiheadAttention(
|
||||
embed_dim::Int,
|
||||
|
@ -87,7 +97,7 @@ function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention)
|
|||
NamedTuple()
|
||||
end
|
||||
|
||||
function (l::MultiheadAttention)(x::NamedTuple, ps, st::NamedTuple)
|
||||
function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
|
||||
if size(x.q, 1) != l.embed_dim
|
||||
ArgumentError(
|
||||
"Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",
|
||||
|
|
Loading…
Reference in a new issue