update doc on MultiheadAttention
This commit is contained in:
parent
c2c2872b46
commit
de6634e3f7
1 changed files with 13 additions and 3 deletions
|
@ -54,8 +54,18 @@ Constructor.
|
||||||
# Parameters and states
|
# Parameters and states
|
||||||
|
|
||||||
## Parameters
|
## Parameters
|
||||||
- `weight_q`
|
- `weight_q`: ``W^Q``
|
||||||
TODO
|
- `weight_k`: ``W^K``
|
||||||
|
- `weight_v`: ``W^V``
|
||||||
|
- `weight_o`: ``W^O``
|
||||||
|
|
||||||
|
## States
|
||||||
|
MultiheadAttention has no states.
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
- `q`: ``Q``
|
||||||
|
- `k`: ``K``
|
||||||
|
- `v`: ``V``
|
||||||
"""
|
"""
|
||||||
function MultiheadAttention(
|
function MultiheadAttention(
|
||||||
embed_dim::Int,
|
embed_dim::Int,
|
||||||
|
@ -87,7 +97,7 @@ function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention)
|
||||||
NamedTuple()
|
NamedTuple()
|
||||||
end
|
end
|
||||||
|
|
||||||
function (l::MultiheadAttention)(x::NamedTuple, ps, st::NamedTuple)
|
function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple)
|
||||||
if size(x.q, 1) != l.embed_dim
|
if size(x.q, 1) != l.embed_dim
|
||||||
ArgumentError(
|
ArgumentError(
|
||||||
"Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",
|
"Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)",
|
||||||
|
|
Loading…
Reference in a new issue