From ac73d9b8340983ea00121f6955eaa1577b9d5874 Mon Sep 17 00:00:00 2001 From: qwjyh Date: Sat, 19 Apr 2025 19:17:40 +0900 Subject: [PATCH] fix: support batched input in MultiheadAttention using Dense --- Manifest.toml | 413 ++++++++++++++++-------------------- Project.toml | 3 +- test_multihead_attention.jl | 4 +- test_transformer.jl | 4 +- transformers.jl | 80 +++++-- 5 files changed, 248 insertions(+), 256 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 6a8f594..aec47ba 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,13 +1,13 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.11.3" +julia_version = "1.11.5" manifest_format = "2.0" -project_hash = "ab2197e930896350f4ba6508edd2db23590d6902" +project_hash = "156b0ed82f8777cb358f12321b5a6ff6c6c0aa80" [[deps.ADTypes]] -git-tree-sha1 = "fb97701c117c8162e84dfcf80215caa904aef44f" +git-tree-sha1 = "e2478490447631aedba0823d4d7a80b2cc8cdb32" uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" -version = "1.13.0" +version = "1.14.0" weakdeps = ["ChainRulesCore", "ConstructionBase", "EnzymeCore"] [deps.ADTypes.extensions] @@ -57,9 +57,9 @@ version = "0.1.42" [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "cd8b948862abee8f3d3e9b73a102a9ca924debb0" +git-tree-sha1 = "f7817e2e585aa6d924fd714df1e2a84be7896c60" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "4.2.0" +version = "4.3.0" weakdeps = ["SparseArrays", "StaticArrays"] [deps.Adapt.extensions] @@ -84,9 +84,9 @@ uuid = "27a7e980-b3e6-11e9-2bcd-0b925532e340" version = "0.4.2" [[deps.ArgCheck]] -git-tree-sha1 = "680b3b8759bd4c54052ada14e52355ab69e07876" +git-tree-sha1 = "f9e9a66c9b7be1ad7372bbd9b062d9230c30c5ce" uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" -version = "2.4.0" +version = "2.5.0" [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" @@ -165,10 +165,10 @@ uuid = "39de3d68-74b9-583c-8d2d-e117c070f3a9" version = "0.4.7" [[deps.BFloat16s]] -deps = ["LinearAlgebra", "Printf", "Random", "Test"] -git-tree-sha1 = "2c7cc21e8678eff479978a0a2ef5ce2f51b63dff" +deps = ["LinearAlgebra", "Printf", "Random"] +git-tree-sha1 = "3b642331600250f592719140c60cf12372b82d66" uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" -version = "0.5.0" +version = "0.5.1" [[deps.BangBang]] deps = ["Accessors", "ConstructionBase", "InitialValues", "LinearAlgebra"] @@ -234,17 +234,11 @@ git-tree-sha1 = "e329286945d0cfc04456972ea732551869af1cfc" uuid = "4e9b3aee-d8a1-5a3d-ad8b-7d824db253f0" version = "1.0.1+0" -[[deps.CSTParser]] -deps = ["Tokenize"] -git-tree-sha1 = "0157e592151e39fa570645e2b2debcdfb8a0f112" -uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" -version = "3.4.3" - [[deps.CUDA]] -deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "Crayons", "DataFrames", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LLVMLoopInfo", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "NVTX", "Preferences", "PrettyTables", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "StaticArrays", "Statistics", "demumble_jll"] -git-tree-sha1 = "7be665c420b5d16059b1ba00b1dbb4e85012fa65" +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "Crayons", "DataFrames", "ExprTools", "GPUArrays", "GPUCompiler", "GPUToolbox", "KernelAbstractions", "LLVM", "LLVMLoopInfo", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "NVTX", "Preferences", "PrettyTables", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "StaticArrays", "Statistics", "demumble_jll"] +git-tree-sha1 = "049d804a037ed39300722bcad4b7a538eabe1e47" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "5.6.1" +version = "5.7.1" weakdeps = ["ChainRulesCore", "EnzymeCore", "SpecialFunctions"] [deps.CUDA.extensions] @@ -254,9 +248,9 @@ weakdeps = ["ChainRulesCore", "EnzymeCore", "SpecialFunctions"] [[deps.CUDA_Driver_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "14996d716a2eaaeccfc8d7bc854dd87fde720ac1" +git-tree-sha1 = "f69205592dbd3721a156245b6dd837206786a848" uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" -version = "0.10.4+0" +version = "0.12.1+1" [[deps.CUDA_Runtime_Discovery]] deps = ["Libdl"] @@ -266,9 +260,9 @@ version = "0.3.5" [[deps.CUDA_Runtime_jll]] deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "17f1536c600133f7c4113bae0a2d98dbf27c7ebc" +git-tree-sha1 = "99f1c6b659c14bbb3492246791bb4928a40ceb84" uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" -version = "0.15.5+0" +version = "0.16.1+0" [[deps.CUDNN_jll]] deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] @@ -290,9 +284,9 @@ version = "0.13.2" [[deps.Cairo_jll]] deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "009060c9a6168704143100f36ab08f06c2af4642" +git-tree-sha1 = "2ac646d71d0d24b44f3f8c84da8c9f4d70fb67df" uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a" -version = "1.18.2+1" +version = "1.18.4+0" [[deps.ChainRules]] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] @@ -330,9 +324,9 @@ version = "3.29.0" [[deps.ColorTypes]] deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "c7acce7a7e1078a20a285211dd73cd3941a871d6" +git-tree-sha1 = "67e11ee83a43eb71ddc950302c53bf33f0690dfe" uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.12.0" +version = "0.12.1" weakdeps = ["StyledStrings"] [deps.ColorTypes.extensions] @@ -354,12 +348,6 @@ git-tree-sha1 = "64e15186f0aa277e174aa81798f7eb8598e0157e" uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" version = "0.13.0" -[[deps.CommonMark]] -deps = ["Crayons", "PrecompileTools"] -git-tree-sha1 = "3faae67b8899797592335832fccf4b3c80bb04fa" -uuid = "a80b9123-70ca-4bc0-993e-6e3bcb318db6" -version = "0.8.15" - [[deps.CommonSolve]] git-tree-sha1 = "0eee5eb66b1cf62cd6ad1b460238e60e4b09400c" uuid = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" @@ -463,9 +451,9 @@ version = "1.7.0" [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" +git-tree-sha1 = "4e1fe97fdaed23e9dc21d4d664bea76b65fc50a0" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.20" +version = "0.18.22" [[deps.DataValueInterfaces]] git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" @@ -508,9 +496,9 @@ version = "1.15.1" [[deps.DifferentiationInterface]] deps = ["ADTypes", "LinearAlgebra"] -git-tree-sha1 = "258fa016b2d03f19e4d0d1cd8e30c84907af1528" +git-tree-sha1 = "70e500f6d5d50091d87859251de7b8cd060c1cce" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" -version = "0.6.42" +version = "0.6.50" [deps.DifferentiationInterface.extensions] DifferentiationInterfaceChainRulesCoreExt = "ChainRulesCore" @@ -522,9 +510,10 @@ version = "0.6.42" DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"] DifferentiationInterfaceGTPSAExt = "GTPSA" DifferentiationInterfaceMooncakeExt = "Mooncake" - DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff" + DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"] DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"] DifferentiationInterfaceSparseArraysExt = "SparseArrays" + DifferentiationInterfaceSparseConnectivityTracerExt = "SparseConnectivityTracer" DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings" DifferentiationInterfaceStaticArraysExt = "StaticArrays" DifferentiationInterfaceSymbolicsExt = "Symbolics" @@ -546,6 +535,7 @@ version = "0.6.42" PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" @@ -570,9 +560,9 @@ version = "1.11.0" [[deps.Distributions]] deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] -git-tree-sha1 = "03aa5d44647eaec98e1920635cdfed5d5560a8b9" +git-tree-sha1 = "0b4190661e8a4e51a842070e7dd4fae440ddb7f4" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.117" +version = "0.25.118" [deps.Distributions.extensions] DistributionsChainRulesCoreExt = "ChainRulesCore" @@ -585,10 +575,9 @@ version = "0.25.117" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[deps.DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +git-tree-sha1 = "e7b7e6f178525d17c720ab9c081e4ef04429f860" uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.9.3" +version = "0.9.4" [[deps.Downloads]] deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] @@ -602,9 +591,9 @@ uuid = "5ae413db-bbd1-5e63-b57d-d24a61df00f5" version = "2.2.4+0" [[deps.EnumX]] -git-tree-sha1 = "bdb1942cd4c45e3c678fd11569d5cccd80976237" +git-tree-sha1 = "bddad79635af6aec424f53ed8aad5d7555dc6f00" uuid = "4e289a0a-7415-4d19-859d-a7e5c4648b56" -version = "1.0.4" +version = "1.0.5" [[deps.EnzymeCore]] git-tree-sha1 = "0cdb7af5c39e92d78a0ee8d0a447d32f7593137e" @@ -655,10 +644,10 @@ uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" version = "1.8.1" [[deps.FFTW_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "4d81ed14783ec49ce9f2e168208a12ce1815aa25" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "6d6219a004b8cf1e0b4dbe27a2860b8e04eba0be" uuid = "f5851436-0d7a-5f13-b9de-f02708fd171a" -version = "3.3.10+3" +version = "3.3.11+0" [[deps.FLoops]] deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] @@ -679,9 +668,9 @@ version = "0.3.2" [[deps.FileIO]] deps = ["Pkg", "Requires", "UUIDs"] -git-tree-sha1 = "2dd20384bf8c6d411b5c7370865b1e9b26cb2ea3" +git-tree-sha1 = "b66970a70db13f45b7e57fbda1736e1cf72174ea" uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -version = "1.16.6" +version = "1.17.0" [deps.FileIO.extensions] HTTPExt = "HTTP" @@ -697,9 +686,9 @@ version = "0.8.3" [[deps.FilePathsBase]] deps = ["Compat", "Dates"] -git-tree-sha1 = "2ec417fc319faa2d768621085cc1feebbdee686b" +git-tree-sha1 = "3bab2c5aa25e7840a4b065805c0cdfc01f3068d2" uuid = "48062228-2e41-5def-b9a4-89aafe57970f" -version = "0.9.23" +version = "0.9.24" weakdeps = ["Mmap", "Test"] [deps.FilePathsBase.extensions] @@ -730,9 +719,9 @@ version = "0.8.5" [[deps.Fontconfig_jll]] deps = ["Artifacts", "Bzip2_jll", "Expat_jll", "FreeType2_jll", "JLLWrappers", "Libdl", "Libuuid_jll", "Zlib_jll"] -git-tree-sha1 = "21fac3c77d7b5a9fc03b0ec503aa1a6392c34d2b" +git-tree-sha1 = "301b5d5d731a0654825f1f2e906990f7141a106b" uuid = "a3f928ae-7b40-5064-980b-68af3947d34b" -version = "2.15.0+0" +version = "2.16.0+0" [[deps.Format]] git-tree-sha1 = "9c68794ef81b08086aeb32eeaf33531668d5f5fc" @@ -757,9 +746,9 @@ version = "4.1.1" [[deps.FreeType2_jll]] deps = ["Artifacts", "Bzip2_jll", "JLLWrappers", "Libdl", "Zlib_jll"] -git-tree-sha1 = "786e968a8d2fb167f2e4880baba62e0e26bd8e4e" +git-tree-sha1 = "2c5512e11c791d1baed2049c5652441b28fc6a31" uuid = "d7e528f0-a631-5988-bf34-fe36492bcfd7" -version = "2.13.3+1" +version = "2.13.4+0" [[deps.FreeTypeAbstraction]] deps = ["ColorVectorSpace", "Colors", "FreeType", "GeometryBasics"] @@ -769,9 +758,9 @@ version = "0.10.6" [[deps.FriBidi_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "846f7026a9decf3679419122b49f8a1fdb48d2d5" +git-tree-sha1 = "7a214fdac5ed5f59a22c2d9a885a16da1c74bbc7" uuid = "559328eb-81f9-559d-9380-de523a88c83c" -version = "1.0.16+0" +version = "1.0.17+0" [[deps.FunctionWrappers]] git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e" @@ -809,9 +798,14 @@ version = "0.2.0" [[deps.GPUCompiler]] deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "PrecompileTools", "Preferences", "Scratch", "Serialization", "TOML", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "199f213e40a7982e9138bc9edc3299419d510390" +git-tree-sha1 = "b08c164134dd0dbc76ff54e45e016cf7f30e16a4" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "1.2.0" +version = "1.3.2" + +[[deps.GPUToolbox]] +git-tree-sha1 = "15d8b0f5a6dca9bf8c02eeaf6687660dafa638d0" +uuid = "096a3bc2-3ced-46d0-87f4-dd12716f4bfc" +version = "0.2.0" [[deps.GeoFormatTypes]] git-tree-sha1 = "8e233d5167e63d708d41f87597433f59a0f213fe" @@ -826,9 +820,9 @@ version = "1.4.1" [[deps.GeometryBasics]] deps = ["EarCut_jll", "Extents", "GeoInterface", "IterTools", "LinearAlgebra", "PrecompileTools", "Random", "StaticArrays"] -git-tree-sha1 = "3ba0e2818cc2ff79a5989d4dca4bc63120a98bd9" +git-tree-sha1 = "65e3f5c519c3ec6a4c59f4c3ba21b6ff3add95b0" uuid = "5c1252a2-5f33-56bf-86c9-59e7332b4326" -version = "0.5.5" +version = "0.5.7" [[deps.Gettext_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "XML2_jll"] @@ -848,11 +842,6 @@ git-tree-sha1 = "b0036b392358c80d2d2124746c2bf3d48d457938" uuid = "7746bdde-850d-59dc-9ae8-88ece973131d" version = "2.82.4+0" -[[deps.Glob]] -git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" -uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" -version = "1.3.1" - [[deps.Graphics]] deps = ["Colors", "LinearAlgebra", "NaNMath"] git-tree-sha1 = "a641238db938fff9b2f60d08ed9030387daf428c" @@ -860,10 +849,10 @@ uuid = "a2bd30eb-e257-5431-a919-1863eab51364" version = "1.1.3" [[deps.Graphite2_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "01979f9b37367603e2848ea225918a3b3861b606" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "8a6dbda1fd736d60cc477d99f2e7a042acfa46e8" uuid = "3b182d85-2403-5c21-9c21-1e1f0cc25472" -version = "1.3.14+1" +version = "1.3.15+0" [[deps.GridLayoutBase]] deps = ["GeometryBasics", "InteractiveUtils", "Observables"] @@ -887,33 +876,17 @@ git-tree-sha1 = "2eaa69a7cab70a52b9687c8bf950a5a93ec895ae" uuid = "076d061b-32b6-4027-95e0-9a2c6f6d7e74" version = "0.2.0" -[[deps.Hwloc]] -deps = ["CEnum", "Hwloc_jll", "Printf"] -git-tree-sha1 = "6a3d80f31ff87bc94ab22a7b8ec2f263f9a6a583" -uuid = "0e44f5e4-bd66-52a0-8798-143a42290a1d" -version = "3.3.0" -weakdeps = ["AbstractTrees"] - - [deps.Hwloc.extensions] - HwlocTrees = "AbstractTrees" - -[[deps.Hwloc_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "f93a9ce66cd89c9ba7a4695a47fd93b4c6bc59fa" -uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8" -version = "2.12.0+0" - [[deps.HypergeometricFunctions]] deps = ["LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] -git-tree-sha1 = "2bd56245074fab4015b9174f24ceba8293209053" +git-tree-sha1 = "68c173f4f449de5b438ee67ed0c9c748dc31a2ec" uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" -version = "0.3.27" +version = "0.3.28" [[deps.IJulia]] -deps = ["Base64", "Conda", "Dates", "InteractiveUtils", "JSON", "Libdl", "Logging", "Markdown", "MbedTLS", "Pkg", "Printf", "REPL", "Random", "SoftGlobalScope", "Test", "UUIDs", "ZMQ"] -git-tree-sha1 = "1b1299f7d6617291f3d260e9f5b0250afdaac8c0" +deps = ["Base64", "Conda", "Dates", "InteractiveUtils", "JSON", "Logging", "Markdown", "MbedTLS", "Pkg", "Printf", "REPL", "Random", "SoftGlobalScope", "UUIDs", "ZMQ"] +git-tree-sha1 = "be30be76e25b0aff2c9a85930ed3ac34c5f10c83" uuid = "7073ff75-c697-5162-941a-fcdaad2a7d2a" -version = "1.26.0" +version = "1.27.0" [[deps.IRTools]] deps = ["InteractiveUtils", "MacroTools"] @@ -1012,10 +985,10 @@ weakdeps = ["Unitful"] InterpolationsUnitfulExt = "Unitful" [[deps.IntervalArithmetic]] -deps = ["CRlibm_jll", "LinearAlgebra", "MacroTools", "RoundingEmulator"] -git-tree-sha1 = "0fcf2079f918f68c6412cab5f2679822cbd7357f" +deps = ["CRlibm_jll", "LinearAlgebra", "MacroTools", "OpenBLASConsistentFPCSR_jll", "RoundingEmulator"] +git-tree-sha1 = "5aad168b75fc3b6b25e99feb1e6e3168d41e4c08" uuid = "d1acc4aa-44c8-5952-acd4-ba5d80a2a253" -version = "0.22.23" +version = "0.22.28" weakdeps = ["DiffRules", "ForwardDiff", "IntervalSets", "RecipesBase"] [deps.IntervalArithmetic.extensions] @@ -1073,9 +1046,9 @@ version = "1.0.0" [[deps.JLD2]] deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"] -git-tree-sha1 = "91d501cb908df6f134352ad73cde5efc50138279" +git-tree-sha1 = "1059c071429b4753c0c869b75c859c44ba09a526" uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -version = "0.5.11" +version = "0.5.12" [[deps.JLLWrappers]] deps = ["Artifacts", "Preferences"] @@ -1089,12 +1062,6 @@ git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" version = "0.21.4" -[[deps.JSONRPC]] -deps = ["JSON", "UUIDs"] -git-tree-sha1 = "3928eaef5261194e95e8e99b1405e069e82b981e" -uuid = "b9b8584e-8fd3-41f9-ad0c-7255d428e418" -version = "1.4.2" - [[deps.Jieko]] deps = ["ExproniconLite"] git-tree-sha1 = "2f05ed29618da60c06a87e9c033982d4f71d0b6c" @@ -1103,9 +1070,9 @@ version = "0.2.1" [[deps.JpegTurbo]] deps = ["CEnum", "FileIO", "ImageCore", "JpegTurbo_jll", "TOML"] -git-tree-sha1 = "fa6d0bcff8583bac20f1ffa708c3913ca605c611" +git-tree-sha1 = "9496de8fb52c224a2e3f9ff403947674517317d9" uuid = "b835a17e-a41a-41e7-81f0-2f016b05efe0" -version = "0.1.5" +version = "0.1.6" [[deps.JpegTurbo_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -1113,12 +1080,6 @@ git-tree-sha1 = "eac1206917768cb54957c65a615460d87b455fc1" uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" version = "3.1.1+0" -[[deps.JuliaFormatter]] -deps = ["CSTParser", "CommonMark", "DataStructures", "Glob", "PrecompileTools", "TOML", "Tokenize"] -git-tree-sha1 = "59cf7ad64f1b0708a4fa4369879d33bad3239b56" -uuid = "98e50ef6-434e-11e9-1051-2b60c6c9e899" -version = "1.0.62" - [[deps.JuliaNVTXCallbacks_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "af433a10f3942e882d3c671aacb203e006a5808f" @@ -1211,12 +1172,6 @@ git-tree-sha1 = "dda21b8cbd6a6c40d9d02a73230f9d70fed6918c" uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" version = "1.4.0" -[[deps.LanguageServer]] -deps = ["CSTParser", "JSON", "JSONRPC", "JuliaFormatter", "Logging", "Markdown", "Pkg", "PrecompileTools", "REPL", "StaticLint", "SymbolServer", "TestItemDetection", "Tokenize", "URIs", "UUIDs"] -git-tree-sha1 = "6d9ff7a24c4334e25f80b61f1188cac1e58d2c28" -uuid = "2b0e0bc5-e4fd-59b4-8912-456d1b03d8d7" -version = "4.5.1" - [[deps.LayoutPointers]] deps = ["ArrayInterface", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static", "StaticArrayInterface"] git-tree-sha1 = "a9eaadb366f5493a5654e843864c13d8b107548c" @@ -1276,9 +1231,9 @@ version = "3.2.2+2" [[deps.Libgcrypt_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgpg_error_jll"] -git-tree-sha1 = "8be878062e0ffa2c3f67bb58a595375eda5de80b" +git-tree-sha1 = "d77592fa54ad343c5043b6f38a03f1a3c3959ffe" uuid = "d4300ac3-e22c-5743-9152-c294e39db1e4" -version = "1.11.0+0" +version = "1.11.1+0" [[deps.Libglvnd_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libX11_jll", "Xorg_libXext_jll"] @@ -1300,9 +1255,9 @@ version = "1.18.0+0" [[deps.Libmount_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "89211ea35d9df5831fca5d33552c02bd33878419" +git-tree-sha1 = "a31572773ac1b745e0343fe5e2c8ddda7a37e997" uuid = "4b2f31a3-9ecc-558c-b454-b3730dcb73e9" -version = "2.40.3+0" +version = "2.41.0+0" [[deps.Libtiff_jll]] deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "LERC_jll", "Libdl", "XZ_jll", "Zlib_jll", "Zstd_jll"] @@ -1312,9 +1267,9 @@ version = "4.7.1+0" [[deps.Libuuid_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "e888ad02ce716b319e6bdb985d2ef300e7089889" +git-tree-sha1 = "321ccef73a96ba828cd51f2ab5b9f917fa73945a" uuid = "38a345b3-de98-5d2b-a5d3-14cd9215e700" -version = "2.40.3+0" +version = "2.41.0+0" [[deps.LinearAlgebra]] deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] @@ -1348,10 +1303,10 @@ uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" version = "1.1.0" [[deps.Lux]] -deps = ["ADTypes", "Adapt", "ArgCheck", "ArrayInterface", "ChainRulesCore", "Compat", "ConcreteStructs", "DispatchDoctor", "EnzymeCore", "FastClosures", "ForwardDiff", "Functors", "GPUArraysCore", "LinearAlgebra", "LuxCore", "LuxLib", "MLDataDevices", "MacroTools", "Markdown", "NNlib", "Optimisers", "Preferences", "Random", "Reexport", "SIMDTypes", "Setfield", "Static", "StaticArraysCore", "Statistics", "WeightInitializers"] -git-tree-sha1 = "cdd363655fbde007c36558206d8a7e12e8a0aae1" +deps = ["ADTypes", "Adapt", "ArgCheck", "ArrayInterface", "ChainRulesCore", "Compat", "ConcreteStructs", "DiffResults", "DispatchDoctor", "EnzymeCore", "FastClosures", "ForwardDiff", "Functors", "GPUArraysCore", "LinearAlgebra", "LuxCore", "LuxLib", "MLDataDevices", "MacroTools", "Markdown", "NNlib", "Optimisers", "Preferences", "Random", "Reexport", "SIMDTypes", "Setfield", "Static", "StaticArraysCore", "Statistics", "WeightInitializers"] +git-tree-sha1 = "ae83c84e8f6cae3b51d7847b75077c5102ef246d" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" -version = "1.8.0" +version = "1.12.1" [deps.Lux.extensions] LuxComponentArraysExt = "ComponentArrays" @@ -1361,7 +1316,7 @@ version = "1.8.0" LuxMLUtilsExt = "MLUtils" LuxMPIExt = "MPI" LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"] - LuxReactantExt = ["Enzyme", "Reactant"] + LuxReactantExt = ["Enzyme", "Reactant", "ReactantCore"] LuxReverseDiffExt = ["FunctionWrappers", "ReverseDiff"] LuxSimpleChainsExt = "SimpleChains" LuxTrackerExt = "Tracker" @@ -1378,6 +1333,7 @@ version = "1.8.0" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" + ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" @@ -1391,9 +1347,9 @@ version = "0.3.3" [[deps.LuxCore]] deps = ["Compat", "DispatchDoctor", "Random"] -git-tree-sha1 = "32fb4c311f024e5f9cab95e12b8ed5e82d094a8b" +git-tree-sha1 = "58425c34dd82a5ee5d8fe9cf996cd5f43a26b85a" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" -version = "1.2.2" +version = "1.2.4" [deps.LuxCore.extensions] LuxCoreArrayInterfaceReverseDiffExt = ["ArrayInterface", "ReverseDiff"] @@ -1418,10 +1374,10 @@ version = "1.2.2" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [[deps.LuxLib]] -deps = ["ArrayInterface", "ChainRulesCore", "Compat", "CpuId", "DispatchDoctor", "EnzymeCore", "FastClosures", "ForwardDiff", "Functors", "Hwloc", "KernelAbstractions", "LinearAlgebra", "LuxCore", "MLDataDevices", "Markdown", "NNlib", "Polyester", "Preferences", "Random", "Reexport", "Static", "StaticArraysCore", "Statistics"] -git-tree-sha1 = "a95eb6684f2eb3bec17ac0c12e6260273363b977" +deps = ["ArrayInterface", "CPUSummary", "ChainRulesCore", "Compat", "DispatchDoctor", "EnzymeCore", "FastClosures", "ForwardDiff", "Functors", "KernelAbstractions", "LinearAlgebra", "LuxCore", "MLDataDevices", "Markdown", "NNlib", "Polyester", "Preferences", "Random", "Reexport", "Static", "StaticArraysCore", "Statistics"] +git-tree-sha1 = "37dc4c73f361dbaae35c677c5a74189a24576e53" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" -version = "1.6.1" +version = "1.7.2" [deps.LuxLib.extensions] LuxLibAppleAccelerateExt = "AppleAccelerate" @@ -1431,7 +1387,6 @@ version = "1.6.1" LuxLibLoopVectorizationExt = "LoopVectorization" LuxLibMKLExt = "MKL" LuxLibOctavianExt = ["Octavian", "LoopVectorization"] - LuxLibReactantExt = "Reactant" LuxLibReverseDiffExt = "ReverseDiff" LuxLibSLEEFPiratesExt = "SLEEFPirates" LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"] @@ -1447,7 +1402,6 @@ version = "1.6.1" LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2" Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" - Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" @@ -1467,9 +1421,9 @@ version = "1.0.0" [[deps.MLDataDevices]] deps = ["Adapt", "Compat", "Functors", "Preferences", "Random"] -git-tree-sha1 = "cd12511c75cac31bc6257b302db2b83d983fe598" +git-tree-sha1 = "1326836c4c845cfabc542b658c8686f0c31a9911" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" -version = "1.6.11" +version = "1.9.1" [deps.MLDataDevices.extensions] MLDataDevicesAMDGPUExt = "AMDGPU" @@ -1518,9 +1472,9 @@ version = "0.4.17" [[deps.MLUtils]] deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "MLCore", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] -git-tree-sha1 = "6963295133aaa789f5fb18a6dd276c420793cf43" +git-tree-sha1 = "a772d8d1987433538a5c226f79393324b55f7846" uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" -version = "0.4.7" +version = "0.4.8" [[deps.MacroTools]] git-tree-sha1 = "72aebe0b5051e5143a079a4685a46da330a40472" @@ -1604,10 +1558,10 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159" version = "2023.12.12" [[deps.NNlib]] -deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "bdc9d30f151590aca0af22690f5ab7dc18a551cb" +deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Random", "ScopedValues", "Statistics"] +git-tree-sha1 = "4abc63cdd8dd9dd925d8e879cda280bedc8013ca" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.27" +version = "0.9.30" [deps.NNlib.extensions] NNlibAMDGPUExt = "AMDGPU" @@ -1616,6 +1570,7 @@ version = "0.9.27" NNlibEnzymeCoreExt = "EnzymeCore" NNlibFFTWExt = "FFTW" NNlibForwardDiffExt = "ForwardDiff" + NNlibSpecialFunctionsExt = "SpecialFunctions" [deps.NNlib.weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" @@ -1623,13 +1578,14 @@ version = "0.9.27" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [[deps.NVTX]] deps = ["Colors", "JuliaNVTXCallbacks_jll", "Libdl", "NVTX_jll"] -git-tree-sha1 = "6a6f8bfaa91bb2e40ff562ab9f30dc827741daef" +git-tree-sha1 = "1a24c3430fa2ef3317c4c97fa7e431ef45793bd2" uuid = "5da4648a-3479-48b8-97b9-01cb529c0a1f" -version = "0.3.5" +version = "1.0.0" [[deps.NVTX_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -1639,9 +1595,9 @@ version = "3.1.1+0" [[deps.NaNMath]] deps = ["OpenLibm_jll"] -git-tree-sha1 = "cc0a5deefdb12ab3a096f00a6d42133af4560d71" +git-tree-sha1 = "9b8215b1ee9e78a293f99797cd31375471b2bcae" uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "1.1.2" +version = "1.1.3" [[deps.NameResolution]] deps = ["PrettyPrint"] @@ -1665,9 +1621,9 @@ uuid = "510215fc-4207-5dde-b226-833fc4488ee2" version = "0.5.5" [[deps.OffsetArrays]] -git-tree-sha1 = "5e1897147d1ff8d98883cda2be2187dcf57d8f0c" +git-tree-sha1 = "a414039192a155fb38c4599a60110f0018c6ec82" uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" -version = "1.15.0" +version = "1.16.0" weakdeps = ["Adapt"] [deps.OffsetArrays.extensions] @@ -1679,6 +1635,12 @@ git-tree-sha1 = "887579a3eb005446d514ab7aeac5d1d027658b8f" uuid = "e7412a2a-1a6e-54c0-be00-318e2571c051" version = "1.3.5+1" +[[deps.OpenBLASConsistentFPCSR_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl"] +git-tree-sha1 = "567515ca155d0020a45b05175449b499c63e7015" +uuid = "6cdc7f73-28fd-5e50-80fb-958a8875b1af" +version = "0.3.29+0" + [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" @@ -1699,7 +1661,7 @@ version = "3.2.4+0" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+2" +version = "0.8.5+0" [[deps.OpenSSL_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -1715,26 +1677,31 @@ version = "0.5.6+0" [[deps.Optimisers]] deps = ["ChainRulesCore", "ConstructionBase", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "c57a1a58e29a017a2b07e78d075385b981942430" +git-tree-sha1 = "131dc319e7c58317e8c6d5170440f6bdaee0a959" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.4.5" -weakdeps = ["Adapt", "EnzymeCore"] +version = "0.4.6" [deps.Optimisers.extensions] OptimisersAdaptExt = ["Adapt"] OptimisersEnzymeCoreExt = "EnzymeCore" + OptimisersReactantExt = "Reactant" + + [deps.Optimisers.weakdeps] + Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" [[deps.Optimization]] deps = ["ADTypes", "ArrayInterface", "ConsoleProgressMonitor", "DocStringExtensions", "LBFGSB", "LinearAlgebra", "Logging", "LoggingExtras", "OptimizationBase", "Printf", "ProgressLogging", "Reexport", "SciMLBase", "SparseArrays", "TerminalLoggers"] -git-tree-sha1 = "df361b5dc1f91ffb601700a2bc4bfdcd4cc584ef" +git-tree-sha1 = "e72af10f1c6ffe2f295455c4c35534d713be62bb" uuid = "7f7a1694-90dd-40f0-9382-eb1efda571ba" -version = "4.1.1" +version = "4.1.2" [[deps.OptimizationBase]] deps = ["ADTypes", "ArrayInterface", "DifferentiationInterface", "DocStringExtensions", "FastClosures", "LinearAlgebra", "PDMats", "Reexport", "Requires", "SciMLBase", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings"] -git-tree-sha1 = "9e8569bc1c511c425fdc63f7ee41f2da057f8662" +git-tree-sha1 = "070d2c33da5f0b33d57b61f7f601c4ea6185af15" uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb" -version = "2.4.0" +version = "2.5.0" [deps.OptimizationBase.extensions] OptimizationEnzymeExt = "Enzyme" @@ -1776,9 +1743,9 @@ version = "10.42.0+1" [[deps.PDMats]] deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "966b85253e959ea89c53a9abebbf2e964fbf593b" +git-tree-sha1 = "48566789a6d5f6492688279e22445002d171cf76" uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.11.32" +version = "0.11.33" [[deps.PNGFiles]] deps = ["Base64", "CEnum", "ImageCore", "IndirectArrays", "OffsetArrays", "libpng_jll"] @@ -1812,9 +1779,9 @@ version = "2.8.1" [[deps.Pixman_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "Libdl"] -git-tree-sha1 = "35621f10a7531bc8fa58f74610b1bfb70a3cfc6b" +git-tree-sha1 = "db76b1ecd5e9715f3d043cec13b2ec93ce015d53" uuid = "30392449-352a-5448-841d-b1acce4e97dc" -version = "0.43.4+0" +version = "0.44.2+0" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "Random", "SHA", "TOML", "Tar", "UUIDs", "p7zip_jll"] @@ -1896,9 +1863,9 @@ version = "0.1.4" [[deps.ProgressMeter]] deps = ["Distributed", "Printf"] -git-tree-sha1 = "8f6bc219586aef8baf0ff9a5fe16ee9c70cb65e4" +git-tree-sha1 = "13c5103482a8ed1536a54c08d0e742ae3dca2d42" uuid = "92933f4c-e287-5a05-a399-4b506db050ca" -version = "1.10.2" +version = "1.10.4" [[deps.PtrArrays]] git-tree-sha1 = "1d36ef11a9aaf1e8b74dacc6a731dd1de8fd493d" @@ -1974,9 +1941,9 @@ version = "1.3.4" [[deps.RecursiveArrayTools]] deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] -git-tree-sha1 = "e96b644f7bfbf1015f8e42a7c7abfae2a48fafbf" +git-tree-sha1 = "112c876cee36a5784df19098b55db2b238afc36a" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" -version = "3.31.0" +version = "3.31.2" [deps.RecursiveArrayTools.extensions] RecursiveArrayToolsFastBroadcastExt = "FastBroadcast" @@ -2013,9 +1980,9 @@ version = "1.0.1" [[deps.Requires]] deps = ["UUIDs"] -git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +git-tree-sha1 = "62389eeff14780bfe55195b7204c0d8738436d64" uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.3.0" +version = "1.3.1" [[deps.Rmath]] deps = ["Random", "Rmath_jll"] @@ -2057,9 +2024,9 @@ version = "0.1.0" [[deps.SciMLBase]] deps = ["ADTypes", "Accessors", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "Moshi", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "SciMLStructures", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface"] -git-tree-sha1 = "ee305515b0946db5f56af699e8b5804fee04146c" +git-tree-sha1 = "6f3987e7fed3239d06985a4752670ca5ff25c695" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -version = "2.75.1" +version = "2.82.0" [deps.SciMLBase.extensions] SciMLBaseChainRulesCoreExt = "ChainRulesCore" @@ -2069,7 +2036,7 @@ version = "2.75.1" SciMLBasePyCallExt = "PyCall" SciMLBasePythonCallExt = "PythonCall" SciMLBaseRCallExt = "RCall" - SciMLBaseZygoteExt = "Zygote" + SciMLBaseZygoteExt = ["Zygote", "ChainRulesCore"] [deps.SciMLBase.weakdeps] ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" @@ -2084,9 +2051,9 @@ version = "2.75.1" [[deps.SciMLOperators]] deps = ["Accessors", "ArrayInterface", "DocStringExtensions", "LinearAlgebra", "MacroTools"] -git-tree-sha1 = "6149620767866d4b0f0f7028639b6e661b6a1e44" +git-tree-sha1 = "1c4b7f6c3e14e6de0af66e66b86d525cae10ecb4" uuid = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" -version = "0.3.12" +version = "0.3.13" weakdeps = ["SparseArrays", "StaticArraysCore"] [deps.SciMLOperators.extensions] @@ -2095,9 +2062,9 @@ weakdeps = ["SparseArrays", "StaticArraysCore"] [[deps.SciMLStructures]] deps = ["ArrayInterface"] -git-tree-sha1 = "0444a37a25fab98adbd90baa806ee492a3af133a" +git-tree-sha1 = "566c4ed301ccb2a44cbd5a27da5f885e0ed1d5df" uuid = "53ae85a6-f571-4167-b2af-e1d143709226" -version = "1.6.1" +version = "1.7.0" [[deps.ScopedValues]] deps = ["HashArrayMappedTries", "Logging"] @@ -2190,9 +2157,9 @@ version = "1.11.0" [[deps.SparseConnectivityTracer]] deps = ["ADTypes", "DocStringExtensions", "FillArrays", "LinearAlgebra", "Random", "SparseArrays"] -git-tree-sha1 = "6651f4663027f3b30a31429d257185f56a571184" +git-tree-sha1 = "9603842a7a68464a066b5754e89fc7f810db8ae7" uuid = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" -version = "0.6.13" +version = "0.6.15" [deps.SparseConnectivityTracer.extensions] SparseConnectivityTracerDataInterpolationsExt = "DataInterpolations" @@ -2215,15 +2182,19 @@ uuid = "dc90abb0-5640-4711-901d-7e5b23a2fada" version = "0.1.2" [[deps.SparseMatrixColorings]] -deps = ["ADTypes", "DataStructures", "DocStringExtensions", "LinearAlgebra", "Random", "SparseArrays"] -git-tree-sha1 = "97092c0a40d6033b7da27ea15bcf75fd5b446254" +deps = ["ADTypes", "DocStringExtensions", "LinearAlgebra", "Random", "SparseArrays"] +git-tree-sha1 = "d59566cf03c67733edce6d80e0fb17e183ab31ba" uuid = "0a514795-09f3-496d-8182-132a7b665d35" -version = "0.4.13" -weakdeps = ["Colors"] +version = "0.4.16" [deps.SparseMatrixColorings.extensions] + SparseMatrixColoringsCliqueTreesExt = "CliqueTrees" SparseMatrixColoringsColorsExt = "Colors" + [deps.SparseMatrixColorings.weakdeps] + CliqueTrees = "60701a23-6482-424a-84db-faee86b9b1f8" + Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" + [[deps.SpecialFunctions]] deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] git-tree-sha1 = "64cca0c26b4f31ba18f13f6c12af7c85f478cfde" @@ -2254,9 +2225,9 @@ version = "0.1.1" [[deps.Static]] deps = ["CommonWorldInvalidations", "IfElse", "PrecompileTools"] -git-tree-sha1 = "87d51a3ee9a4b0d2fe054bdd3fc2436258db2603" +git-tree-sha1 = "f737d444cb0ad07e61b3c1bef8eb91203c321eff" uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" -version = "1.1.1" +version = "1.2.0" [[deps.StaticArrayInterface]] deps = ["ArrayInterface", "Compat", "IfElse", "LinearAlgebra", "PrecompileTools", "Static"] @@ -2271,9 +2242,9 @@ weakdeps = ["OffsetArrays", "StaticArrays"] [[deps.StaticArrays]] deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "e3be13f448a43610f978d29b7adf78c76022467a" +git-tree-sha1 = "0feb6b9031bd5c51f9072393eb5ab3efd31bf9e4" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.12" +version = "1.9.13" weakdeps = ["ChainRulesCore", "Statistics"] [deps.StaticArrays.extensions] @@ -2285,12 +2256,6 @@ git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" version = "1.4.3" -[[deps.StaticLint]] -deps = ["CSTParser", "Serialization", "SymbolServer"] -git-tree-sha1 = "36732c098f291ee3b867718bb9933e8b67ab4798" -uuid = "b3cc710f-9c33-5bdb-a03d-a94903873e97" -version = "8.2.2" - [[deps.Statistics]] deps = ["LinearAlgebra"] git-tree-sha1 = "ae3bb1eb3bba077cd276bc5cfc337cc65c3075c0" @@ -2315,9 +2280,9 @@ version = "0.34.4" [[deps.StatsFuns]] deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] -git-tree-sha1 = "b423576adc27097764a90e163157bcfc9acf0f46" +git-tree-sha1 = "35b09e80be285516e52c9054792c884b9216ae3c" uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "1.3.2" +version = "1.4.0" weakdeps = ["ChainRulesCore", "InverseFunctions"] [deps.StatsFuns.extensions] @@ -2338,9 +2303,9 @@ version = "0.4.1" [[deps.StructArrays]] deps = ["ConstructionBase", "DataAPI", "Tables"] -git-tree-sha1 = "5a3a31c41e15a1e042d60f2f4942adccba05d3c9" +git-tree-sha1 = "8ad2e38cbb812e29348719cc63580ec1dfeb9de4" uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -version = "0.7.0" +version = "0.7.1" weakdeps = ["Adapt", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "SparseArrays", "StaticArrays"] [deps.StructArrays.extensions] @@ -2363,12 +2328,6 @@ deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" version = "7.7.0+0" -[[deps.SymbolServer]] -deps = ["InteractiveUtils", "LibGit2", "Markdown", "Pkg", "REPL", "SHA", "Serialization", "Sockets", "UUIDs"] -git-tree-sha1 = "adcc6a2335e5448adc05939f67d382fb8d17a367" -uuid = "cf896787-08d5-524d-9de7-132aaa0cb996" -version = "7.4.0" - [[deps.SymbolicIndexingInterface]] deps = ["Accessors", "ArrayInterface", "RuntimeGeneratedFunctions", "StaticArraysCore"] git-tree-sha1 = "d6c04e26aa1c8f7d144e1a8c47f1c73d3013e289" @@ -2414,12 +2373,6 @@ deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" version = "1.11.0" -[[deps.TestItemDetection]] -deps = ["CSTParser"] -git-tree-sha1 = "c63abb8bf01ba3f0e5421760454d578ee9bd12ca" -uuid = "76b0de8b-5c4b-48ef-a724-914b33ca988d" -version = "0.2.0" - [[deps.ThreadingUtilities]] deps = ["ManualMemory"] git-tree-sha1 = "eda08f7e9818eb53661b3deb74e3159460dfbc27" @@ -2434,9 +2387,9 @@ version = "0.11.3" [[deps.TimerOutputs]] deps = ["ExprTools", "Printf"] -git-tree-sha1 = "3832505b94c1868baea47764127e6d36b5c9f29e" +git-tree-sha1 = "f57facfd1be61c42321765d3551b3df50f7e09f6" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.27" +version = "0.5.28" [deps.TimerOutputs.extensions] FlameGraphsExt = "FlameGraphs" @@ -2444,11 +2397,6 @@ version = "0.5.27" [deps.TimerOutputs.weakdeps] FlameGraphs = "08572546-2f56-4bcf-ba4e-bab62c3a3f89" -[[deps.Tokenize]] -git-tree-sha1 = "468b4685af4abe0e9fd4d7bf495a6554a6276e75" -uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624" -version = "0.5.29" - [[deps.TranscodingStreams]] git-tree-sha1 = "0c45878dcfdcfa8480052b6ab162cdd138781742" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" @@ -2481,11 +2429,6 @@ git-tree-sha1 = "4d4ed7f294cda19382ff7de4c137d24d16adc89b" uuid = "981d1d27-644d-49a2-9326-4793e63143c3" version = "0.1.0" -[[deps.URIs]] -git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" -uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" -version = "1.5.1" - [[deps.UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" @@ -2570,15 +2513,15 @@ version = "2.13.6+1" [[deps.XSLT_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "XML2_jll", "Zlib_jll"] -git-tree-sha1 = "7d1671acbe47ac88e981868a078bd6b4e27c5191" +git-tree-sha1 = "82df486bfc568c29de4a207f7566d6716db6377c" uuid = "aed1982a-8fda-507f-9586-7b0439959a61" -version = "1.1.42+0" +version = "1.1.43+0" [[deps.XZ_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "56c6604ec8b2d82cc4cfe01aa03b00426aac7e1f" +git-tree-sha1 = "fee71455b0aaa3440dfdd54a9a36ccef829be7d4" uuid = "ffd25f8a-64ca-5728-b0f7-c24cf3aae800" -version = "5.6.4+1" +version = "5.8.1+0" [[deps.Xorg_libX11_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libxcb_jll", "Xorg_xtrans_jll"] @@ -2624,9 +2567,9 @@ version = "1.17.0+3" [[deps.Xorg_xtrans_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "6dba04dbfb72ae3ebe5418ba33d087ba8aa8cb00" +git-tree-sha1 = "a63799ff68005991f9d9491b6e95bd3478d783cb" uuid = "c5fb5394-a638-5e4d-96e5-b29de1b5cf10" -version = "1.5.1+0" +version = "1.6.0+0" [[deps.ZMQ]] deps = ["FileWatching", "PrecompileTools", "Sockets", "ZeroMQ_jll"] @@ -2636,9 +2579,9 @@ version = "1.4.0" [[deps.ZeroMQ_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "libsodium_jll"] -git-tree-sha1 = "f02ce8f0fda1ed40f4d0d59a2ad05e35e8ac9b0e" +git-tree-sha1 = "766d90db2817565b667c1cc9cc420d668f2e8dba" uuid = "8f1865be-045e-5c20-9c9f-bfbfb0764568" -version = "4.3.5+3" +version = "4.3.6+0" [[deps.Zlib_jll]] deps = ["Libdl"] @@ -2652,17 +2595,19 @@ uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" version = "1.5.7+1" [[deps.Zygote]] -deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "0b3c944f5d2d8b466c5d20a84c229c17c528f49e" +deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "207d714f3514b0d564e3a08f9e9f753bf6566c2d" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.75" +version = "0.7.6" [deps.Zygote.extensions] + ZygoteAtomExt = "Atom" ZygoteColorsExt = "Colors" ZygoteDistancesExt = "Distances" ZygoteTrackerExt = "Tracker" [deps.Zygote.weakdeps] + Atom = "c52e3926-4ff0-5f6e-af25-54175e0327b1" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" @@ -2675,9 +2620,9 @@ version = "0.2.7" [[deps.cuDNN]] deps = ["CEnum", "CUDA", "CUDA_Runtime_Discovery", "CUDNN_jll"] -git-tree-sha1 = "efb5a4aa0aea8151a4eb21700e8b2c9990c45b0f" +git-tree-sha1 = "1faded2a0800a6c2c329d77bf75d91685e212222" uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" -version = "1.4.1" +version = "1.4.2" [[deps.demumble_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -2716,9 +2661,9 @@ version = "2.0.3+0" [[deps.libpng_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Zlib_jll"] -git-tree-sha1 = "055a96774f383318750a1a5e10fd4151f04c29c5" +git-tree-sha1 = "068dfe202b0a05b8332f1e8e6b4080684b9c7700" uuid = "b53b4c65-9356-5827-b1ea-8c7a1a84506f" -version = "1.6.46+0" +version = "1.6.47+0" [[deps.libsixel_jll]] deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "Libdl", "libpng_jll"] @@ -2727,10 +2672,10 @@ uuid = "075b6546-f08a-558a-be8f-8157d0f608a5" version = "1.10.5+0" [[deps.libsodium_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "f76d682d87eefadd3f165d8d9fda436464213142" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "011b0a7331b41c25524b64dc42afc9683ee89026" uuid = "a9144af2-ca23-56d9-984f-0d03f7b5ccf8" -version = "1.0.20+3" +version = "1.0.21+0" [[deps.libvorbis_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Ogg_jll", "Pkg"] diff --git a/Project.toml b/Project.toml index 812ab5b..5e7eaba 100644 --- a/Project.toml +++ b/Project.toml @@ -3,14 +3,15 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -LanguageServer = "2b0e0bc5-e4fd-59b4-8912-456d1b03d8d7" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test_multihead_attention.jl b/test_multihead_attention.jl index 00e27aa..4b6036b 100644 --- a/test_multihead_attention.jl +++ b/test_multihead_attention.jl @@ -12,7 +12,7 @@ const KV_LEN = 50 rng = TaskLocalRNG() -l = MultiheadAttention(EMBED_DIM, NUM_HEADS, kvdim = KV_DIM) +l = MultiheadAttention2(EMBED_DIM, NUM_HEADS, kvdim = KV_DIM) @info "layer" l ps = LuxCore.initialparameters(rng, l) @@ -28,7 +28,7 @@ v = rand(rng, Float32, (KV_DIM,)) # with mask -l = MultiheadAttention(EMBED_DIM, NUM_HEADS) +l = MultiheadAttention2(EMBED_DIM, NUM_HEADS) @info "layer" l ps = LuxCore.initialparameters(rng, l) diff --git a/test_transformer.jl b/test_transformer.jl index 57600fe..4d4a043 100644 --- a/test_transformer.jl +++ b/test_transformer.jl @@ -19,7 +19,7 @@ ps, st = LuxCore.setup(rng, model_encoder) @info "status" st # Unbatched -x = randn(rng, Float32, (INPUT_DIM,)) +x = randn(rng, Float32, (INPUT_DIM, 1, 1)) @info "input" size(x) y, st = model_encoder(x, ps, st) @@ -44,7 +44,7 @@ ps, st = LuxCore.setup(rng, model_decoder) @info "status" st # Unbatched -x = randn(rng, Float32, (INPUT_DIM,)) +x = randn(rng, Float32, (INPUT_DIM, 1, 1)) @info "input" x y, st = model_decoder((; x = x, memory = memory_unbatched), ps, st) diff --git a/transformers.jl b/transformers.jl index a731555..65e7b10 100644 --- a/transformers.jl +++ b/transformers.jl @@ -1,6 +1,45 @@ # using LuxCore using Random: AbstractRNG using Lux +using Static + +struct MultiheadAttention2{WQ, WK, WV, WO} <: + LuxCore.AbstractLuxContainerLayer{(:weight_q, :weight_k, :weight_v, :weight_o)} + weight_q::WQ + weight_k::WK + weight_v::WV + weight_o::WO + nheads::Int +end + +function MultiheadAttention2(embed_dim::Int, num_heads::Int; kw...) + qdim = get(kw, :qdim, embed_dim) + kvdim = get(kw, :kvdim, embed_dim) + v_embed_dim = get(kw, :v_embed_dim, embed_dim) + init_weight = get(kw, :init_weight, glorot_uniform) + weight_q = Dense(qdim => embed_dim * num_heads; init_weight, use_bias = False()) + weight_k = Dense(kvdim => embed_dim * num_heads; init_weight, use_bias = False()) + weight_v = Dense(kvdim => v_embed_dim * num_heads; init_weight, use_bias = False()) + weight_o = + Dense(v_embed_dim * num_heads => v_embed_dim; init_weight, use_bias = False()) + MultiheadAttention2(weight_q, weight_k, weight_v, weight_o, num_heads) +end + +function (mha::MultiheadAttention2)(x::NT, ps, st::NamedTuple) where {NT <: NamedTuple} + q, st_weight_q = mha.weight_q(x.q, ps.weight_q, st.weight_q) + k, st_weight_k = mha.weight_k(x.k, ps.weight_k, st.weight_k) + v, st_weight_v = mha.weight_v(x.v, ps.weight_v, st.weight_v) + mask = hasproperty(x, :mask) ? x.mask : nothing + y, _α = dot_product_attention(q, k, v; nheads = mha.nheads, mask) + o, st_weight_o = mha.weight_o(y, ps.weight_o, st.weight_o) + st = ( + weight_q = st_weight_q, + weight_k = st_weight_k, + weight_v = st_weight_v, + weight_o = st_weight_o, + ) + o, st +end """ MultiheadAttention{F} <: LuxCore.AbstractLuxLayer @@ -31,7 +70,7 @@ So far, ``Q, K, V`` are inputs `x` for the layer, ``W^Q, W^K, W^V, W^O`` are parameters `ps`, and the layer has no states `st`. """ -struct MultiheadAttention{F} <: LuxCore.AbstractLuxLayer +struct MultiheadAttention_{F} <: LuxCore.AbstractLuxLayer embed_dim::Int qdim::Int kvdim::Int @@ -72,13 +111,13 @@ NamedTuple of these three variables. - `k`: ``K`` - `v`: ``V`` """ -function MultiheadAttention( +function MultiheadAttention_( embed_dim::Int, num_heads::Int; init_weight = glorot_uniform, kw..., ) - MultiheadAttention{typeof(init_weight)}( + MultiheadAttention_{typeof(init_weight)}( embed_dim, haskey(kw, :qdim) ? kw[:qdim] : embed_dim, haskey(kw, :kvdim) ? kw[:kvdim] : embed_dim, @@ -88,7 +127,7 @@ function MultiheadAttention( ) end -function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention) +function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention_) # see the original paper for weight dimensions (note that q,k,v weights have `num_heads` of matrices) ( weight_q = l.init_weight(rng, l.embed_dim * l.num_heads, l.qdim), @@ -98,11 +137,11 @@ function LuxCore.initialparameters(rng::AbstractRNG, l::MultiheadAttention) ) end -function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention) +function LuxCore.initialstates(::AbstractRNG, ::MultiheadAttention_) NamedTuple() end -function LuxCore.parameterlength(l::MultiheadAttention) +function LuxCore.parameterlength(l::MultiheadAttention_) dim_weight_q = l.embed_dim * l.num_heads * l.qdim dim_weight_k = l.embed_dim * l.num_heads * l.kvdim dim_weight_v = l.v_embed_dim * l.num_heads * l.kvdim @@ -110,11 +149,11 @@ function LuxCore.parameterlength(l::MultiheadAttention) dim_weight_q + dim_weight_k + dim_weight_v + dim_weight_o end -function LuxCore.statelength(l::MultiheadAttention) +function LuxCore.statelength(l::MultiheadAttention_) 0 end -function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple) +function (l::MultiheadAttention_)(x::NamedTuple, ps, _st::NamedTuple) if size(x.q, 1) != l.embed_dim DimensionMismatch( "Length of queue must match the layer's embed_dim: size(q)[1] = $(size(x.q, 1)), embed_dim = $(l.embed_dim)", @@ -162,17 +201,19 @@ function (l::MultiheadAttention)(x::NamedTuple, ps, _st::NamedTuple) # FIXME: expand this to multi dimensional matrix multiplication # qk_dim, v_dim is divisible by num_heads. qk_dim = embed_dim * num_heads # [q] = (qk_dim, q_len, batch_size...) - q = ps.weight_q * x.q + # need to apply weights for all batches + @info "" size(ps.weight_q) size(x.q) + q = ps.weight_q ⊠ x.q # [k] = (qk_dim, kv_len, batch_size...) - k = ps.weight_k * x.k + k = ps.weight_k ⊠ x.k # [v] = (v_dim, kv_len, batch_size...) - v = ps.weight_v * x.v + v = ps.weight_v ⊠ x.v # [mask] = (kv_len, q_len, nheads, batch_size) mask = hasproperty(x, :mask) ? x.mask : nothing # [y] = (v_dim, q_len, batch_size...) # [α] = (kv_len, q_len, nheads, batch_size...) y, α = dot_product_attention(q, k, v; nheads = l.num_heads, mask) - ps.weight_o * y, _st + ps.weight_o ⊠ y, _st end """ @@ -252,7 +293,7 @@ function TransformerEncoderLayer( ) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat} sublayer_self_attention = let layer_split = Lux.WrappedFunction(x -> (q = x, k = x, v = x)) - layer_self_attention = MultiheadAttention(model_dim, num_heads) + layer_self_attention = MultiheadAttention2(model_dim, num_heads) layer_dropout = Lux.Dropout(dropout) layer_residual_connection = Lux.SkipConnection( Lux.Chain(; layer_split, layer_self_attention, layer_dropout), @@ -282,12 +323,17 @@ function TransformerEncoderLayer( end function (encoder::TransformerEncoderLayer)(x, ps, st) - x, st_sublayer_self_attention = Lux.apply( - encoder.sublayer_self_attention, + x, st_sublayer_self_attention = encoder.sublayer_self_attention( x, ps.sublayer_self_attention, st.sublayer_self_attention, ) + # x, st_sublayer_self_attention = Lux.apply( + # encoder.sublayer_self_attention, + # x, + # ps.sublayer_self_attention, + # st.sublayer_self_attention, + # ) x, st_sublayer_feed_forward_network = Lux.apply( encoder.sublayer_feed_forward_network, x, @@ -327,7 +373,7 @@ function TransformerDecoderLayer( ) where {F, T1 <: AbstractFloat, T2 <: AbstractFloat} sublayer_self_attention = let layer_split = Lux.WrappedFunction(x -> (q = x.x, k = x.x, v = x.x, mask = x.mask)) - layer_self_attention = MultiheadAttention(model_dim, num_heads) + layer_self_attention = MultiheadAttention2(model_dim, num_heads) layer_dropout = Lux.Dropout(dropout) layer_residual_connection = Lux.SkipConnection( Lux.Chain(; layer_split, layer_self_attention, layer_dropout), @@ -342,7 +388,7 @@ function TransformerDecoderLayer( sublayer_multihead_attention = let layer_split = Lux.WrappedFunction(x::NamedTuple -> (q = x.x, k = x.memory, v = x.memory)) - layer_self_attention = MultiheadAttention(model_dim, num_heads) + layer_self_attention = MultiheadAttention2(model_dim, num_heads) layer_dropout = Lux.Dropout(dropout) layer_residual_connection = Lux.Parallel( +,