惯性聚合 高效追踪和阅读你感兴趣的博客、新闻、科技资讯
阅读原文 在惯性聚合中打开

推荐订阅源

GbyAI
GbyAI
博客园_首页
OSCHINA 社区最新新闻
OSCHINA 社区最新新闻
阮一峰的网络日志
阮一峰的网络日志
酷 壳 – CoolShell
酷 壳 – CoolShell
博客园 - 司徒正美
V
V2EX
Cloudbric
Cloudbric
Hugging Face - Blog
Hugging Face - Blog
腾讯CDC
量子位
博客园 - 三生石上(FineUI控件)
博客园 - 叶小钗
K
Kaspersky official blog
博客园 - 【当耐特】
T
Tenable Blog
L
Lohrmann on Cybersecurity
The Cloudflare Blog
S
Schneier on Security
A
Arctic Wolf
Latest news
Latest news
C
Cyber Attacks, Cyber Crime and Cyber Security
罗磊的独立博客
T
The Exploit Database - CXSecurity.com
Cisco Talos Blog
Cisco Talos Blog
小众软件
小众软件
P
Privacy & Cybersecurity Law Blog
WordPress大学
WordPress大学
Simon Willison's Weblog
Simon Willison's Weblog
雷峰网
雷峰网
NISL@THU
NISL@THU
人人都是产品经理
人人都是产品经理
月光博客
月光博客
J
Java Code Geeks
V
Visual Studio Blog
S
Security Affairs
博客园 - Franky
T
Tailwind CSS Blog
Apple Machine Learning Research
Apple Machine Learning Research
H
Heimdal Security Blog
有赞技术团队
有赞技术团队
V2EX - 技术
V2EX - 技术
AWS News Blog
AWS News Blog
G
GRAHAM CLULEY
T
Troy Hunt's Blog
SecWiki News
SecWiki News
Spread Privacy
Spread Privacy
宝玉的分享
宝玉的分享
www.infosecurity-magazine.com
www.infosecurity-magazine.com
博客园 - 聂微东

土法炼钢兴趣小组的算法知识备份

国密算法与国密 TLS 系列索引 【系统架构设计】架构质量属性:不只是"高可用高性能" 【系统架构设计百科】告警策略:如何避免"狼来了" 【系统架构设计】CQRS:读写分离的架构哲学 【系统架构设计】空间架构:极端扩展场景的解法 【系统架构设计】微服务架构深度审视:优势、代价与适用边界 【系统架构设计】扩展性原理:水平、垂直与对角扩展 【系统架构设计】无状态设计:扩展的第一步也是最难的一步 【系统架构设计】缓存架构:从本地到分布式的多级缓存体系 【系统架构设计】管道与过滤器:Unix 哲学的架构表达 【系统架构设计】复杂性管理:架构的核心战场 【系统架构设计】消息队列架构:异步解耦的设计与陷阱 【系统架构设计】CDN 架构:全球加速的设计原理 【系统架构设计】连接池设计:被忽视的性能杀手 【系统架构设计】弹性设计模式:熔断器、舱壁与超时 【系统架构设计】高可用设计模式:冗余、故障转移与仲裁 【系统架构设计】容量规划:从拍脑袋到数据驱动 【系统架构设计】数据库扩展:分库分表的工程实践与替代方案 【系统架构设计】SLO 工程:可靠性的量化管理 【系统架构设计】性能建模:用数学思维分析系统瓶颈 【系统架构设计】混沌工程:主动验证系统的韧性 【系统架构设计】零拷贝与内存映射:数据搬运的极致优化 【系统架构设计】线程模型:从 thread-per-request 到协程 【系统架构设计】容灾架构:多活与灾备设计 【系统架构设计】数据库性能模式:索引、查询与连接管理 【系统架构设计】数据建模:从关系范式到文档模型的真实权衡 【系统架构设计】吞吐量优化:批处理、流水线与并发模型 【系统架构设计】流处理架构:从批处理到实时的范式迁移 【系统架构设计】搜索引擎架构:倒排索引之上的系统设计 【系统架构设计】时序数据架构:监控与 IoT 的存储设计 【系统架构设计】数据迁移与版本化:在线不停机的数据演进 【系统架构设计】数据湖与数据仓库:分析架构的演进路线 【系统架构设计】API 网关设计:入口层的职责边界 【系统架构设计】应用层数据一致性模式:在正确性与性能之间走钢丝 【系统架构设计】多模数据库选型:Polyglot Persistence 的工程实践 【系统架构设计】服务发现与注册:动态拓扑的基础设施 【系统架构设计】配置管理架构:从配置文件到配置中心 【系统架构设计】全链路压测:大规模系统的性能验证 【系统架构设计】幂等性设计:分布式环境下的安全重试 【系统架构设计】契约测试与 Schema 演进:服务间的信任协议 【系统架构设计】长连接与推送架构:WebSocket、SSE 与 MQTT 【系统架构设计】延迟分析:从 P50 到 P999 的全链路追踪 【系统架构设计百科】DDD 战术模式:聚合、实体与值对象 【系统架构设计百科】防腐层与开放主机服务:系统集成的 DDD 方案 【系统架构设计百科】领域事件与事件风暴:从业务到架构的桥梁 【系统架构设计百科】CQRS + Event Sourcing 完整实战:从领域建模到部署 【系统架构设计百科】DDD 与微服务:用领域模型划分服务边界 【系统架构设计】DDD 战略设计:限界上下文与上下文映射 【系统架构设计百科】认证架构:从 Session 到 JWT 到 OIDC 【系统架构设计】API 设计哲学:REST vs GraphQL vs gRPC 的真实权衡 排序算法专题:从 TimSort 到并行排序 【密码学百科】国密算法体系:SM2/SM3/SM4/SM9 全景解读 【密码学百科】承诺方案:Pedersen 承诺、向量承诺与多项式承诺 【密码学百科】不经意传输与隐私信息检索:OT、OT 扩展与 PIR 【密码学百科】门限密码学:门限签名、门限解密与分布式密钥生成 完美哈希:从理论到 gperf 实践 【密码学百科】安全多方计算:从 Yao 的混淆电路到实用 MPC 【密码学百科】同态加密:从 Paillier 到全同态加密(FHE) 【密码学百科】零知识证明系统:zk-SNARKs、zk-STARKs 与 Bulletproofs 【密码学百科】概率论与密码分析:生日攻击、差分分析与线性分析 【密码学百科】计算复杂性与归约:密码安全性证明的基石 【密码学百科】秘密共享:Shamir 方案、VSS 与安全多方计算入口 【密码学百科】椭圆曲线代数:Weierstrass 方程、点群运算与曲线选择 【密码学百科】离散对数与配对密码学:从 DLP 到 BLS 签名 【密码学百科】格密码数学基础:SVP、LWE 与格基约化 【密码学百科】抽象代数:群、环、域的密码学视角 【密码学百科】有限域算术:GF(2^n) 运算与在 AES/ECC 中的应用 【密码学百科】数论进阶:二次剩余、椭圆曲线上的 Weil 配对 【密码学百科】密码学简史:从凯撒密码到量子时代 【密码学百科】威胁模型与安全目标:CIA 三要素之外 【密码学百科】Kerckhoffs 原则与现代密码设计哲学 【密码学百科】随机性:密码学的基石 【密码学百科】信息论入门:熵、完美保密与 Shannon 定理 【密码学百科】分组密码原理:Feistel 网络与 SPN 结构 【密码学百科】AES 逐步拆解:SubBytes 到 MixColumns 的数学 【密码学百科】分组密码工作模式全览:ECB/CBC/CTR/OFB/CFB 【密码学百科】流密码:RC4 的兴衰与 ChaCha20 的崛起 【密码学百科】密码学哈希函数:MD5→SHA-2→SHA-3 的进化之路 【密码学百科】MAC 与 HMAC:消息认证的正确姿势 【密码学百科】认证加密(AEAD):GCM、ChaCha20-Poly1305 与 OCB 【密码学百科】密钥派生函数:HKDF、PBKDF2、Argon2 与密码存储 【密码学百科】公钥密码的数论基础:模运算、群、原根 【密码学百科】RSA 从原理到攻击:教科书 RSA 为什么不安全 【密码学百科】Diffie-Hellman 密钥交换与离散对数问题 【密码学百科】椭圆曲线密码学(ECC):从几何直觉到点群运算 【密码学百科】数字签名:ECDSA、EdDSA 与 Schnorr 签名 【密码学百科】现代密钥交换:X25519、ECDHE 与前向保密 【密码学百科】混合加密与 KEM/DEM 范式:ECIES 与 HPKE 【密码学百科】填充方案:PKCS#1 v1.5、OAEP 与 PSS 【密码学百科】TLS 协议全解析:从握手到 0-RTT 【密码学百科】PKI 与数字证书:信任链的构建与崩塌 【密码学百科】密码认证协议:从 SRP 到 OPAQUE 【密码学百科】零知识证明入门:如何证明你知道而不泄露 【密码学百科】安全信道构造:Noise 协议框架与 Signal 协议 【密码学百科】密钥管理工程:HSM、KMS 与密钥生命周期 【密码学百科】侧信道攻击:从时序攻击到功耗分析 【密码学百科】密码学实现陷阱:三层漏洞分类、审计工具链与系统性预防 密码敏捷性:如何设计可升级的密码系统 【密码学百科】OpenSSL/BoringSSL 架构剖析:ENGINE、Provider 与 FIPS 模块 排序基准测试:用数据说话
【Transformer 与注意力机制】03 矩阵乘法的两种视角
Liao Tonglang · 2026-04-15 · via 土法炼钢兴趣小组的算法知识备份

上一篇我们花了很长时间讲点积。这一篇要做一件看似平淡但极其重要的事:把单次点积扩展为批量点积,也就是矩阵乘法。在这个扩展里隐藏着 Transformer 高效的根本秘密:所有的 attention 计算都可以归约成几次大型矩阵乘法,而矩阵乘法是 GPU 上几十年磨出来的最快算子


一、为什么要从「单点积」走到「矩阵乘法」

回想上一篇的最终画面:一个 Q @ K.T 一次性算出 \(n \times n\) 个点积。@ 这一个符号看起来普通,但它是注意力机制能够在 GPU 上高效跑起来的核心。

如果我们用 for 循环写 attention:

scores = torch.zeros(T, T)
for i in range(T):
    for j in range(T):
        scores[i, j] = (Q[i] * K[j]).sum()

这里的「写 attention」,并不是说完整的注意力机制只有这 4 行,而是说:attention 最核心、最贵的那一步,就是先把所有 query 和所有 key 两两做一次点积

更具体一点,假设序列里有 \(T\) 个 token。\(Q[i]\) 是第 \(i\) 个 token 的查询向量(query),\(K[j]\) 是第 \(j\) 个 token 的键向量(key)。代码里的 (Q[i] * K[j]).sum() 就是 \(\mathbf{q}_i \cdot \mathbf{k}_j\),也就是「第 \(i\) 个 token 对第 \(j\) 个 token 打多少分」。这个分数越大,表示第 \(i\) 个 token 越应该去关注第 \(j\) 个 token。

把所有 \((i, j)\) 都算一遍,就得到一个 \(T \times T\) 的分数矩阵 scores。这正是 attention 公式里 \(QK^\top\) 的含义:第 \((i, j)\) 个元素就是查询 \(i\) 和键 \(j\) 的点积。后面再对每一行做 softmax,得到注意力权重;再拿这些权重去加权求和 \(V\),才是完整的 attention 输出。所以这段 for 循环虽然短,但它已经把 attention 的「打分阶段」完整地写出来了。

正确,但慢得不能用。在 T=2048 的序列上这段代码可能要跑几秒钟,而 Q @ K.T 只要几毫秒。差距来自两个原因:

第一,Python 循环本身极慢(解释开销)。

第二,单个点积无法利用 GPU 的并行硬件。GPU 有几千个核心,等着一次同时算几千个点积。逐个算就只用了一个核心,浪费了 99% 的算力。

矩阵乘法的本质就是「把同样的运算重复很多次,让硬件同时做」。所以我们必须从「单次点积」升级到「批量点积」,也就是矩阵乘法。

二、矩阵的最朴素定义

矩阵(matrix) 是一个二维数组。一个 \(n \times m\) 的矩阵有 \(n\)\(m\) 列:

\[ A = \begin{pmatrix} a_{11} & a_{12} & \cdots & a_{1m} \\ a_{21} & a_{22} & \cdots & a_{2m} \\ \vdots & \vdots & \ddots & \vdots \\ a_{n1} & a_{n2} & \cdots & a_{nm} \end{pmatrix} \]

每个元素用 \(A_{ij}\) 表示,下标 \(i\) 是行号,\(j\) 是列号。在代码里通常 0-indexed(A[0][0] 是左上角),数学公式里通常 1-indexed。两种约定可以用上下文区分,但写代码时一定要问自己「我现在用的是 0 还是 1」——这是工程 bug 的常见来源。

矩阵的「形状(shape)」是它的 \((n, m)\)。一个 \(3 \times 4\) 的矩阵不能直接和一个 \(5 \times 4\) 的矩阵相加,因为形状不同。形状是矩阵运算里最常见的报错原因。

矩阵可以看成「一组列向量横向并排」或者「一组行向量纵向堆叠」。这两种视角后面会反复出现。

矩阵的行、列与形状

三、矩阵的转置

转置(transpose) 是矩阵最简单的运算。把行变列、列变行:

\[ A^\top_{ij} = A_{ji} \]

如果 \(A\)\(n \times m\),那么 \(A^\top\)\(m \times n\)。形状交换。

转置把行和列交换

转置在 attention 里随处可见:K.transpose(-2, -1) 就是把 (B, T, D) 转成 (B, D, T),让后面的 Q @ K^T 能匹配 D 维度。

转置满足几个性质:

  • \((A^\top)^\top = A\)
  • \((A + B)^\top = A^\top + B^\top\)
  • \((\alpha A)^\top = \alpha A^\top\)
  • \((AB)^\top = B^\top A^\top\)(注意顺序反了!这条最容易记错)。

最后一条特别重要,请暂停几秒看一下。

为什么 \((AB)^\top = B^\top A^\top\) 而不是 \(A^\top B^\top\)?因为 \(AB\) 的形状是 \(A_{n \times k} \times B_{k \times m} = C_{n \times m}\),转置后是 \(m \times n\)。而 \(B^\top\)\(m \times k\)\(A^\top\)\(k \times n\),相乘正好得到 \(m \times n\)。如果是 \(A^\top B^\top\),形状 \(k \times n\)\(m \times k\) 根本不合法(除非 \(n = m\) 且我们运气好)。

所以转置改变乘法顺序——这件事在推导 attention 反向传播时会用到,也在 RoPE 的「等价旋转」推导里出现。

四、矩阵乘法的形式定义

矩阵乘法(matrix multiplication):给定 \(A\)\(n \times k\)\(B\)\(k \times m\),乘积 \(C = AB\)\(n \times m\),其中:

\[ C_{ij} = \sum_{l=1}^{k} A_{il} B_{lj} \]

读作:「\(C\) 的第 \((i, j)\) 个元素,等于 \(A\) 的第 \(i\) 行和 \(B\) 的第 \(j\) 列做点积」。

这就是矩阵乘法的本质——\(C\) 的每个元素都是一次点积。如果 \(A\)\(n\) 行、\(B\)\(m\) 列,那总共有 \(n \times m\) 次点积,每次点积是 \(k\) 次乘加。所以矩阵乘法的计算量是 \(O(n \cdot k \cdot m)\)

这也解释了「为什么 \(A\) 的列数必须等于 \(B\) 的行数」——只有这样,每次点积才能配对地把 \(A\) 的一行(\(k\) 个数)和 \(B\) 的一列(\(k\) 个数)相乘求和。维度不匹配是工程里最常见的报错。

形状记忆口诀:「(n×k) × (k×m) → (n×m)」。中间的 \(k\) 消掉了,外面的 \(n\)\(m\) 留下。

五、第一种视角:行 × 列的点积视角

最朴素的矩阵乘法理解是「点积视角」。把 \(A\) 的每一行单独拿出来,把 \(B\) 的每一列单独拿出来,每对组合做一次点积,填入对应位置。

具体例子:

\[ A = \begin{pmatrix} 1 & 2 \\ 3 & 4 \end{pmatrix}, \quad B = \begin{pmatrix} 5 & 6 \\ 7 & 8 \end{pmatrix} \]

计算 \(C = AB\)

  • \(C_{11}\)\(A\) 第 1 行 \((1, 2)\)\(B\) 第 1 列 \((5, 7)\) 的点积 \(= 1 \cdot 5 + 2 \cdot 7 = 19\)
  • \(C_{12}\)\(A\) 第 1 行 \((1, 2)\)\(B\) 第 2 列 \((6, 8)\) 的点积 \(= 1 \cdot 6 + 2 \cdot 8 = 22\)
  • \(C_{21}\)\(A\) 第 2 行 \((3, 4)\)\(B\) 第 1 列 \((5, 7)\) 的点积 \(= 3 \cdot 5 + 4 \cdot 7 = 43\)
  • \(C_{22}\)\(A\) 第 2 行 \((3, 4)\)\(B\) 第 2 列 \((6, 8)\) 的点积 \(= 3 \cdot 6 + 4 \cdot 8 = 50\)

\[ C = \begin{pmatrix} 19 & 22 \\ 43 & 50 \end{pmatrix} \]

这种「逐元素填表」的视角对计算最方便,也最容易写代码。在 attention 里,每个 attention score 就是一次行与列的点积\(Q\) 的第 \(i\) 行(第 \(i\) 个 token 的查询向量)与 \(K^\top\) 的第 \(j\) 列(第 \(j\) 个 token 的键向量)做点积,得到「token \(i\) 关注 token \(j\) 的分数」。

矩阵乘法的两种视角

六、第二种视角:列的线性组合视角

但是「点积视角」不是唯一的视角。同一个矩阵乘法 \(C = AB\),还可以看成:

\(C\) 的每一列,是 \(A\) 的列向量的线性组合,组合系数来自 \(B\) 的对应列。

形式上:

\[ C_{:,j} = \sum_{l=1}^k B_{lj} \cdot A_{:,l} \]

读作:「\(C\) 的第 \(j\) 列,等于 \(A\) 的列向量按 \(B\)\(j\) 列的系数线性组合」。

用上面的例子验证。\(A\) 的第 1 列是 \((1, 3)\),第 2 列是 \((2, 4)\)\(B\) 的第 1 列是 \((5, 7)\)

按线性组合视角,\(C\) 的第 1 列 \(= 5 \cdot (1, 3) + 7 \cdot (2, 4) = (5, 15) + (14, 28) = (19, 43)\)

和前面算出来的 \(C\) 第 1 列 \((19, 43)\) 一致。

列的线性组合视角

为什么这种视角重要?因为它揭示了矩阵乘法 = 一组「向量被某种规则重新组合」\(A\) 的列是「基本单元」,\(B\) 是「组合规则」,\(C\) 是「组合结果」。

这种视角在理解线性变换时特别有力。比如 PCA:把数据 \(A\) 投影到主成分 \(B\),结果 \(C\) 的每一列就是数据沿某个主方向的坐标。再比如神经网络的全连接层 \(\mathbf{y} = W \mathbf{x}\),可以理解为「\(\mathbf{x}\) 是组合系数,\(\mathbf{y}\)\(W\) 的列向量按 \(\mathbf{x}\) 的系数组合得到的新向量」。

七、第三种(隐含)视角:行的线性组合

类似地,还有:

\(C\) 的每一行,是 \(B\) 的行向量的线性组合,组合系数来自 \(A\) 的对应行。

形式上:

\[ C_{i,:} = \sum_{l=1}^k A_{il} \cdot B_{l,:} \]

这是「列线性组合视角」的镜像。它在 attention 输出 \(\mathrm{softmax}(QK^\top)V\) 里特别有用:把 attention 输出看成 \(V\) 的行向量按 attention 权重的线性组合。每个 token 的输出 = 所有 token 的 V 向量按其注意权重加权求和。这是注意力「混合信息」的核心机制。

行的线性组合视角

在第 13 篇缩放点积注意力里,我会再次回来强调这个视角。

八、第四种视角:外积之和

数学上还有一个更高级的视角:

\(AB\) 等于「\(A\) 的列与 \(B\) 的行做外积」之和。

形式上:

\[ AB = \sum_{l=1}^{k} A_{:,l} \cdot B_{l,:}^\top \]

这里 \(A_{:,l}\)\(n\) 维列向量,\(B_{l,:}^\top\)\(m\) 维行向量(写成列),它们的外积是 \(n \times m\) 矩阵。把 \(k\) 个这样的外积加起来,就得到 \(C\)

外积之和视角

这个视角在「低秩分解」「张量分解」「高效注意力」等话题中特别有用。LoRA(Low-Rank Adaptation)的核心思想就是「把一个大矩阵 \(W\) 写成两个小矩阵的外积之和 \(BA\)」,从而只训练 \(B\)\(A\) 而不动 \(W\)

我把这个视角放在第四,是因为对初学者它最不直观。但在第 56 篇 PEFT 与 LoRA 里它会成为主角。先记住「外积之和」这个说法存在。

九、四个视角等价吗

是的,四个视角描述的是完全相同的运算,只是从不同方向切。可以用代数严格证明它们等价(任何一本好的线代教材都会做这件事)。

那为什么需要四个视角?因为不同问题适合不同视角

  • 算具体数值时用「点积视角」(行 × 列)。
  • 推导线性变换性质时用「列线性组合视角」。
  • 推导注意力输出时用「行线性组合视角」。
  • 做低秩分析时用「外积之和视角」。

数学家的本事之一就是「对同一个对象有多个视角,需要哪个调哪个」。深度学习也一样。本篇把这四个视角都列出来,是希望你以后看到任何矩阵乘法时,都能从最方便的角度切入。

十、矩阵乘法的代数性质

矩阵乘法的几条基本性质:

结合律\((AB)C = A(BC)\)。乘法可以重新组合,但不能重新排序

分配律\(A(B + C) = AB + AC\)\((A + B)C = AC + BC\)。和加法相容。

不满足交换律:一般 \(AB \ne BA\)。这是矩阵和数最大的区别。\(2 \times 3 = 3 \times 2\),但矩阵乘法不行。即使 \(AB\)\(BA\) 形状都合法(都是方阵且同样大小),结果通常也不同。

单位元\(I A = A I = A\)\(I\) 是单位矩阵(对角线为 1,其余为 0)。

零矩阵\(0 A = A 0 = 0\)

与转置的关系\((AB)^\top = B^\top A^\top\)

与逆的关系(如果可逆):\((AB)^{-1} = B^{-1} A^{-1}\)

请特别记住「矩阵乘法不可交换」这条。在推导 attention 公式或反向传播时,乘法顺序错了答案就错了。这一条在公式推导里出错的概率最高。

十一、矩阵乘法的几何意义

如果把 \(\mathbf{x}\) 看成一个向量,\(W\) 看成一个矩阵,那 \(\mathbf{y} = W \mathbf{x}\) 是什么?

线性变换\(W\) 把向量 \(\mathbf{x}\) 映射成向量 \(\mathbf{y}\)。这种映射有两条性质:

  • \(W (\alpha \mathbf{x}) = \alpha (W \mathbf{x})\)(齐次性)。
  • \(W (\mathbf{x}_1 + \mathbf{x}_2) = W \mathbf{x}_1 + W \mathbf{x}_2\)(可加性)。

只满足这两条的变换叫线性变换,矩阵正好是描述线性变换的工具。

线性变换的几何效果包括:旋转、缩放、剪切、投影、镜像。任何这些组合(不包含平移)都可以用矩阵乘法表达。

举一个具体例子。设

\[ W = \begin{pmatrix} 2 & 1 \\ 0 & 1 \end{pmatrix}, \quad \mathbf{x} = \begin{pmatrix} 1 \\ 1 \end{pmatrix} \]

那么

\[ W\mathbf{x} = \begin{pmatrix} 2 & 1 \\ 0 & 1 \end{pmatrix} \begin{pmatrix} 1 \\ 1 \end{pmatrix} = \begin{pmatrix} 3 \\ 1 \end{pmatrix} \]

这条式子不只是“把数字乘一乘”。几何上看,它做了两件事:

  • 基向量 \(\mathbf{e}_1 = (1, 0)\) 被送到 \((2, 0)\),说明 x 方向被拉长了。
  • 基向量 \(\mathbf{e}_2 = (0, 1)\) 被送到 \((1, 1)\),说明原来竖直的方向被“推斜”了。

于是原来的单位正方形不再是正方形,而会变成一个平行四边形;向量 \(\mathbf{x} = (1, 1)\) 也被送到新的位置 \((3, 1)\)。这就是“矩阵乘法在几何上改造整个空间”的含义。

矩阵作为线性变换

平移呢?平移不是线性的(\(\mathbf{x} \mapsto \mathbf{x} + \mathbf{b}\) 不满足齐次性)。所以神经网络里要加平移就要用 \(\mathbf{y} = W \mathbf{x} + \mathbf{b}\),叫仿射变换(affine transformation)

比如如果再加上偏置 \(\mathbf{b} = (1, -1)\),那么整个图形会在变形之后再整体平移。这一步“整体挪动”不是线性的,所以必须从 \(W\mathbf{x}\) 升级成 \(W\mathbf{x} + \mathbf{b}\)

理解「矩阵 = 线性变换」是从代数升级到几何的关键。3Blue1Brown 的「Essence of Linear Algebra」第 3 集和第 4 集把这件事讲得最透。如果你看完本篇还觉得抽象,去看那两集视频。

十二、几种特殊矩阵

单位矩阵 \(I\):对角线全 1,其余全 0。\(IA = A\)。例如

\[ I = \begin{pmatrix} 1 & 0 \\ 0 & 1 \end{pmatrix}, \quad I \begin{pmatrix} 3 \\ 2 \end{pmatrix} = \begin{pmatrix} 3 \\ 2 \end{pmatrix} \]

它什么都不改变,所以是矩阵乘法里的“恒等操作”。在代码里 torch.eye(n) 创建。

对角矩阵:只有对角线非零。它乘以一个向量,等于对各坐标分别缩放。例如

\[ D = \begin{pmatrix} 2 & 0 \\ 0 & 0.5 \end{pmatrix}, \quad D \begin{pmatrix} 3 \\ 4 \end{pmatrix} = \begin{pmatrix} 6 \\ 2 \end{pmatrix} \]

几何上就是 x 方向拉长 2 倍、y 方向压缩到一半。

对称矩阵\(A^\top = A\)。例如

\[ S = \begin{pmatrix} 2 & 1 \\ 1 & 3 \end{pmatrix} \]

因为左上到右下对角线两侧完全对称,所以它是对称矩阵。协方差矩阵就是最常见的例子。attention 分数 \(QK^\top\) 一般不是对称的(因为 \(Q \ne K\)),但有些论文会强行对称化。

正交矩阵\(Q^\top Q = I\)。例如 90° 旋转矩阵

\[ Q = \begin{pmatrix} 0 & -1 \\ 1 & 0 \end{pmatrix}, \quad Q \begin{pmatrix} 1 \\ 0 \end{pmatrix} = \begin{pmatrix} 0 \\ 1 \end{pmatrix} \]

它改变方向,但不改变长度和夹角。RoPE 的位置编码用的就是这种“保长保角”的性质。

置换矩阵:每行每列恰好一个 1,其余为 0。比如

\[ P = \begin{pmatrix} 0 & 1 \\ 1 & 0 \end{pmatrix}, \quad P \begin{pmatrix} a \\ b \end{pmatrix} = \begin{pmatrix} b \\ a \end{pmatrix} \]

它的作用不是变形,而是“交换顺序”。更高维时可以表示任意重排。

稀疏矩阵:大部分元素为 0。比如

\[ M = \begin{pmatrix} 1 & 0 & 0 \\ 0 & 0 & 0 \\ 0 & 0 & 2 \end{pmatrix} \]

只有 2 个非零元素,其余都是 0。这类矩阵可以用专门的稀疏存储格式节省内存。在大型语言模型的 MoE(专家混合)路由里经常出现。

低秩矩阵\(\mathrm{rank}(A) \ll \min(n, m)\)。例如

\[ L = \begin{pmatrix} 1 & 2 \\ 2 & 4 \end{pmatrix} = \begin{pmatrix} 1 \\ 2 \end{pmatrix} \begin{pmatrix} 1 & 2 \end{pmatrix} \]

第二行只是第一行的 2 倍,所以它只有 1 个独立方向,是一个 rank-1 矩阵。LoRA 用的正是这种“用两个小矩阵乘起来近似一个大矩阵”的思想。

每种特殊矩阵都对应一种「优化机会」,工程实现里可以专门处理。

十三、矩阵乘法的复杂度

朴素的矩阵乘法是 \(O(n^3)\)(如果 \(A, B\) 都是 \(n \times n\))。具体:每个输出元素需要 \(n\) 次乘加,总共 \(n^2\) 个元素,所以是 \(n^3\)

这听起来不快——\(n=1000\) 就是 \(10^9\) 次乘法。但好消息是:

第一,乘法之间互相独立,可以完全并行。GPU 可以一次同时做几千次乘法。

第二,乘法的数据局部性很好。BLAS 库(cuBLAS、MKL)做了几十年优化,能把缓存利用率推到极限。

第三,理论上有 \(O(n^{2.37})\) 的算法(Strassen、Coppersmith-Winograd 系列),但实际中常数太大,不实用。NVIDIA 的 cuBLAS 用的是优化过的朴素 \(O(n^3)\) 算法。

所以矩阵乘法虽然理论复杂度是 \(n^3\),但工程实现极快。在 H100 上一次 4096×4096×4096 的 FP16 矩阵乘法只需要约 0.3 毫秒。

十四、批量矩阵乘法(Batched MatMul)

在深度学习里,几乎从来不是「一次矩阵乘法」,而是「一批矩阵乘法」。比如:

  • batch size = 32,每个样本要做一次注意力。
  • multi-head 注意力,head 数 = 16。

这样实际上每个 forward 要做 \(32 \times 16 = 512\) 次矩阵乘法。

PyTorch 的 torch.matmul 自动处理 batch 维度。对于形状 (B, T, D) 的 Q 和 (B, D, T) 的 K^T,Q @ K^T 输出 (B, T, T)。前面的维度(B)作为 batch,后两维做矩阵乘法。

cuBLAS 提供 cublasGemmStridedBatchedEx 等接口,专门优化批量矩阵乘法。它能做到「把 512 次小型 GEMM 融合成一次大型 GEMM」,效率大大优于循环单独调用。

理解「Transformer 里几乎每一层都是批量矩阵乘法」之后,你就能感性理解为什么 GPU 是 Transformer 的天作之合。

十五、attention 里的矩阵乘法

在 attention 里有几次关键的矩阵乘法:

Q = X @ W_Q          # (B, T, D) @ (D, D) → (B, T, D)
K = X @ W_K          # 同上
V = X @ W_V          # 同上

scores = Q @ K.transpose(-2, -1)   # (B, T, D) @ (B, D, T) → (B, T, T)
attn = softmax(scores)             # 形状不变 (B, T, T)
out = attn @ V                     # (B, T, T) @ (B, T, D) → (B, T, D)

一共 5 次矩阵乘法(实际多头时还要乘以头数)。其中最贵的是 Q @ K^Tattn @ V,因为它们的中间维度是 T(序列长度),可能很大。

这就是 attention 计算的全貌。所有的「智能」都隐藏在 W_Q, W_K, W_V 三个矩阵里——它们决定了 Q、K、V 是怎么从 X 变出来的。模型学的就是这三个矩阵。

第 13 篇会把这一节展开到完整的注意力公式分析。本篇先到 「形状对,方向对」这一层。

十六、\(QK^\top\) 的形状故事

很多读者第一次看到 \(QK^\top\) 时会卡壳:「为什么是 \(K\) 转置?为什么不是 \(K\) 本身?」

答案在形状里。

\(Q\) 是 (T, D):T 个查询向量,每个 D 维。

\(K\) 是 (T, D):T 个键向量,每个 D 维。

如果直接 Q @ K,形状是 (T, D) @ (T, D),不合法——D ≠ T。

我们想要的输出是 (T, T),即「每个查询对每个键的分数」。要让形状对上,必须把 \(K\) 转置成 (D, T):

\[ Q \cdot K^\top: (T, D) \times (D, T) \to (T, T) \]

形状对上了。每个 \((i, j)\) 元素 \(= \sum_l Q_{il} K_{jl} = \mathbf{q}_i \cdot \mathbf{k}_j\),正是查询 \(i\) 与键 \(j\) 的点积。

所以 \(K^\top\) 不是数学上多么深奥的操作,就是「把 K 翻一下让形状能匹配」。一旦形状对上,每个输出元素自然变成查询和键的点积。

QK^T 形状

十七、softmax 沿着哪个维度

scores 的形状是 (B, T, T)。第二维(中间的 T)代表「查询位置」,第三维(最右的 T)代表「键位置」。

softmax 应该沿着第三维做:

attn = F.softmax(scores, dim=-1)

这表示「对每个查询,独立地把它对所有键的分数归一化」。每一行加起来是 1。

如果错误地沿着第二维做,就变成「对每个键,把它收到的所有查询的分数归一化」。这没有意义——查询是独立的,不应该相互归一化。

这是新手最容易出错的地方之一。pytorch 的 dim=-1 通常是对的,但只要你不确定就停下来确认形状。

十八、多头里的 reshape 戏法

multi-head attention 里有个经典操作:「把 D 维拆成 H 个 d_h 维的头」。在代码里:

B, T, D = X.shape
H = 8       # 头数
d_h = D // H

Q = X @ W_Q                     # (B, T, D)
Q = Q.reshape(B, T, H, d_h)     # (B, T, H, d_h)
Q = Q.transpose(1, 2)           # (B, H, T, d_h)
# 现在每个 head 独立地做 attention,head 维像 batch 维一样并行

对每个 head 独立做 attention:

scores = Q @ K.transpose(-2, -1)   # (B, H, T, T)
attn = F.softmax(scores / d_h**0.5, dim=-1)
out = attn @ V                     # (B, H, T, d_h)

最后把 H 维 merge 回去:

out = out.transpose(1, 2).reshape(B, T, D)
out = out @ W_O

整个过程的关键在于「reshape + transpose 让 head 维成为 batch 维」。这样 GPU 一次 matmul 就并行算完所有 head,不需要循环。这是工程上能跑得起多头的原因。

十九、einsum:另一种写法

PyTorch / NumPy 的 einsum 提供了一种更显式的矩阵运算写法。

# 标准写法
scores = Q @ K.transpose(-2, -1)

# einsum 写法
scores = torch.einsum('btd,bsd->bts', Q, K)

'btd,bsd->bts' 含义是:

  • 第一个张量的下标是 (b, t, d)。
  • 第二个张量的下标是 (b, s, d)。
  • 输出张量的下标是 (b, t, s)。
  • 共同出现但不在输出中的下标(这里是 d)会被求和。

这种写法的好处:显式表达了「哪些维度对应、哪些维度被求和」。当形状变得复杂时(比如多头 + 多 batch + 多 query),einsum 比 matmul 加 reshape 更易读。

缺点:在某些情况下 einsum 不如 matmul 快(虽然现代 PyTorch 的 einsum 已经能在简单情况下自动优化为 matmul)。

我个人在写新模型时偏好 einsum,因为它让我能在代码里清晰看到「每个维度的角色」。在性能关键路径上再换成 matmul。

二十、关键概念回顾

矩阵乘法是一组点积\(C_{ij}\)\(A\) 的第 \(i\) 行与 \(B\) 的第 \(j\) 列的点积。这是最朴素也最实用的视角,写代码、debug 形状都靠它。

矩阵乘法的四种视角。点积视角(行 × 列)、列线性组合视角(C 的列是 A 的列的组合)、行线性组合视角(C 的行是 B 的行的组合)、外积之和视角(A 的列与 B 的行的外积求和)。四个视角等价,但适用场景不同。

矩阵乘法不可交换\(AB \ne BA\) 一般情况下。这是矩阵和标量的本质区别,公式推导时随处会用到。

形状是工程的灵魂\((n, k) \times (k, m) = (n, m)\)。中间维度必须匹配,外面维度决定输出。99% 的代码错误来自形状不匹配。

\(QK^\top\) 是 attention 的核心。两次最重要的矩阵乘法:\(QK^\top\) 算注意力分数,\(\mathrm{attn} \cdot V\) 算最终输出。整个 Transformer 的智能都隐藏在 \(W_Q, W_K, W_V\) 这三个学得的矩阵里。

批量矩阵乘法是 Transformer 高效的根本。前面的维度(batch、head)作为 batch 维,最后两维真正做矩阵乘法。GPU 的 cuBLAS 把这种「一批小矩阵乘」优化到了硬件极限。

二十一、常见误解

误解一:「矩阵乘法就是逐元素相乘。」 不是。逐元素相乘叫 Hadamard 乘积或 element-wise multiply,符号通常是 \(\odot\) 或代码里的 *。矩阵乘法是「行 × 列做点积」,符号是 \(\cdot\)@,结果形状一般和输入不同。混淆这两个是新手最常见的错误。

误解二:「(A B)^T = A^T B^T」 错。正确的是 \((AB)^\top = B^\top A^\top\),顺序反转。这条公式我自己也错过几次,每次错都付出过 debug 的时间代价。建议背下来。

误解三:「矩阵乘法是 \(O(n^3)\),所以注意力是 \(O(T^2 D)\),慢得不能用。」 理论复杂度对,但实际中 GPU 把矩阵乘法跑到了硬件极限。\(T = 4096, D = 512\) 的 attention 在 H100 上只要几百微秒。「理论复杂度高」和「实际速度慢」之间隔着工程优化。

误解四:「@matmul 是不同的。」 在 PyTorch / NumPy 里 @matmul 的语法糖,完全等价。区别只在可读性。

误解五:「attention 的 \(QK^\top\) 一定是方阵。」 不一定。只有「self-attention」里 Q 和 K 来自同一序列时是 (T, T) 方阵。「cross-attention」(如 encoder-decoder)里 Q 来自一个序列,K 来自另一个,形状是 (T_q, T_k) 不一定相等。

误解六:「multi-head 是把 D 维拆成多块独立运算,所以参数比单头多。」 错。multi-head 的总参数量和单头相同(\(W_Q\) 仍然是 \(D \times D\),只不过被「视作」拆成 H 块小的 \(D \times d_h\))。多头的好处是「让不同子空间学不同的关系模式」,不是参数更多。第 15 篇会展开。

误解七:「矩阵乘法慢,所以应该尽量避免。」 反过来。在 GPU 上矩阵乘法是最快的算子之一。要避免的是「逐元素操作」(element-wise),因为它们 memory-bound。能用矩阵乘法表达的运算,反而应该追求。

二十二、矩阵乘法的历史小掌故

矩阵乘法的标准定义并不是一开始就「显然」的。

18 世纪。Lagrange 和 Laplace 在解线性方程组时已经在用「矩阵」的概念,但还没有正式的矩阵乘法。

19 世纪中期。英国数学家 Arthur Cayley 在 1858 年的论文《A Memoir on the Theory of Matrices》里首次系统定义了矩阵乘法。他选择「行 × 列」的定义,是为了让线性变换的复合能用矩阵乘法表达——也就是 \(f(g(\mathbf{x})) = (FG)\mathbf{x}\)

这个动机非常重要:矩阵乘法的定义不是任意选的,是为了让「复合变换 = 矩阵乘积」成立。理解这点之后,你就能感性理解为什么矩阵乘法是「行 × 列」而不是别的——因为它对应着函数复合。

20 世纪初。Hermann Weyl、John von Neumann 等人把矩阵和线性算子统一在「内积空间 + 算子」的框架下,这是泛函分析的基础。

1969 年。Volker Strassen 给出了第一个亚立方算法 \(O(n^{\log_2 7}) \approx O(n^{2.807})\)。这是一个数学突破,但工程上常数太大,n 很小时不实用。

21 世纪。Coppersmith-Winograd 系列把指数推到 2.37 左右,Le Gall 2014 推到 2.3728639,2024 年又有微小推进。但这些都是理论结果,工程几乎不用。

实际工程。所有主流深度学习框架的矩阵乘法都基于 cuBLAS / cuDNN / MKL 等库,使用优化过的 \(O(n^3)\) 算法(带 cache blocking、SIMD、Tensor Core 等加速)。

所以矩阵乘法的故事是:「1858 年定义,几十年里走遍线代、量子力学、统计、计算机;1969 年理论突破但实践无用21 世纪工程把朴素算法推到硬件极限」。理论和工程在这个问题上分道扬镳,各做各的。

二十三、GPU 上的矩阵乘法实现

矩阵乘法在 GPU 上跑得快,是几代工程师 + NVIDIA 硬件设计共同作用的结果。一些关键技术:

Cache blocking(分块缓存)。把大矩阵切成小块,每次只把当前需要的块加载到快速缓存(L1/L2 cache 或 shared memory)。这样减少了对慢速 global memory 的访问。

Tensor Core。NVIDIA Volta(V100,2017)开始引入的专用矩阵乘法硬件单元,一次能做 4×4×4 的 FP16 矩阵乘法。Ampere(A100)扩展到 BF16、TF32;Hopper(H100)扩展到 FP8。每代 Tensor Core 都是「让矩阵乘法更快」的硬件升级。

SIMD(Single Instruction Multiple Data)。一条指令同时操作多个数据。CPU 的 AVX、GPU 的 SIMT 都是这种思路。

Streaming Multiprocessor(SM)。GPU 的基本计算单元。H100 有 132 个 SM,每个 SM 有几百个 CUDA 核心和 4 个 Tensor Core。矩阵乘法可以充分利用所有 SM。

Persistent kernel 和 Split-K。当矩阵形状不规则时(如 K 维度很大但 M、N 很小),用「Split-K」把 K 维度切片并行,再用「persistent kernel」让 SM 不闲。这些是 cuBLAS 内部的优化技巧。

FlashAttention。是 attention 上的一个特殊优化:把 \(QK^\top\)、softmax、\(\cdot V\) 三步融合成一个 kernel,避免中间结果写回 HBM。第 49 篇会专门讲。

理解「GPU 上矩阵乘法不是简单的循环,而是几代优化的结晶」之后,你就能感性理解「为什么神经网络只用矩阵乘法 + 简单非线性 + softmax」——因为这些都是 GPU 已经优化到极致的算子。任何想替代它们的新算子都要面对「优化几十年」的对手,胜算很低。

二十四、矩阵乘法的反向传播

如果你写过深度学习框架,会遇到「矩阵乘法的反向传播」公式。这里给个完整推导。

前向\(Y = X W\),其中 \(X\)\(n \times d_{in}\)\(W\)\(d_{in} \times d_{out}\)\(Y\)\(n \times d_{out}\)

设损失 \(L\),已知 \(\partial L / \partial Y\)(形状和 \(Y\) 相同)。要求 \(\partial L / \partial X\)\(\partial L / \partial W\)

对 W 求导

\[ \frac{\partial L}{\partial W} = X^\top \frac{\partial L}{\partial Y} \]

形状:\((d_{in}, n) \times (n, d_{out}) = (d_{in}, d_{out})\),和 \(W\) 一致。

对 X 求导

\[ \frac{\partial L}{\partial X} = \frac{\partial L}{\partial Y} W^\top \]

形状:\((n, d_{out}) \times (d_{out}, d_{in}) = (n, d_{in})\),和 \(X\) 一致。

记忆诀窍:「前向是 \(XW\),反向对 X 是 \(\delta W^\top\),反向对 W 是 \(X^\top \delta\)」。前向乘什么,反向就用它的转置乘对应的梯度。

这个公式在自定义层时反复用到。PyTorch 的 autograd 会自动算,但理解原理能帮你 debug 数值问题。

二十五、一个完整的 attention 形状追踪

为了让你彻底掌握 attention 里矩阵的形状,下面给一个完整追踪:

# 输入
B = 32        # batch size
T = 128       # 序列长度
D = 512       # hidden size (embedding dim)
H = 8         # 头数
d_h = D // H  # 每头的维度 = 64

X = torch.randn(B, T, D)   # 输入 (B, T, D)

# 投影到 Q, K, V
W_Q = nn.Linear(D, D, bias=False)
W_K = nn.Linear(D, D, bias=False)
W_V = nn.Linear(D, D, bias=False)

Q = W_Q(X)   # (B, T, D)
K = W_K(X)   # (B, T, D)
V = W_V(X)   # (B, T, D)

# 拆头
Q = Q.reshape(B, T, H, d_h).transpose(1, 2)   # (B, H, T, d_h)
K = K.reshape(B, T, H, d_h).transpose(1, 2)   # (B, H, T, d_h)
V = V.reshape(B, T, H, d_h).transpose(1, 2)   # (B, H, T, d_h)

# 注意力分数
scores = Q @ K.transpose(-2, -1)              # (B, H, T, T)
scores = scores / d_h**0.5                    # 形状不变

# 因果掩码(仅 decoder 用)
mask = torch.triu(torch.ones(T, T), diagonal=1).bool()
scores = scores.masked_fill(mask, float('-inf'))

# Softmax
attn = F.softmax(scores, dim=-1)              # (B, H, T, T)

# 加权求和
out = attn @ V                                 # (B, H, T, d_h)

# 合头
out = out.transpose(1, 2).reshape(B, T, D)    # (B, T, D)

# 输出投影
W_O = nn.Linear(D, D, bias=False)
out = W_O(out)                                # (B, T, D)

每一行的形状变化都标注出来了。如果你能跟着每一行点头说「形状对,逻辑也对」,那你已经掌握了 multi-head attention 的全部基本骨架。

二十六、矩阵乘法的内存与计算 trade-off

attention 计算的关键瓶颈其实不在 FLOPs(浮点运算量),而在内存访问

矩阵乘法 (M, K) × (K, N) 的:

  • 计算量:\(2MKN\) FLOPs。
  • 内存量:\(MK + KN + MN\) 个浮点数。

「计算量 / 内存量」叫算术强度(arithmetic intensity),单位是 FLOPs/byte。强度高,说明每 byte 内存访问能做更多计算,GPU 利用率高。

矩阵乘法的算术强度大约是 \(\min(M, N, K)\)。当矩阵足够大(如 1024×1024),强度高,GPU 利用率高(>50% 理论峰值)。当矩阵小(如 64×64),强度低,GPU 大部分时间在等内存,利用率低(<10%)。

这就是为什么大模型偏好「大矩阵」:把 hidden size 设到 4096、12288,让每次矩阵乘法都「大」,提高 GPU 利用率。如果用很多小矩阵,反而慢。

也是为什么 multi-head attention 的「每个头的 d_h」不能太小:太小会让每次矩阵乘法的 K 维很小,算术强度低,GPU 跑不满。这是 GQA、MLA 等变种试图解决的问题之一。

理解「矩阵越大,GPU 越爽」之后,你能看穿很多模型设计选择背后的硬件考量。

二十七、Tensor Core 的 4×4×4 单元

NVIDIA 的 Tensor Core 一次能做 4×4×4 的矩阵乘加(D = A·B + C,其中 A, B, C, D 都是 4×4 矩阵)。

这意味着如果你的矩阵不能被 4(或 8、16,取决于具体硬件)整除,Tensor Core 就无法用满,会有 padding 浪费。

实际工程里,模型设计常常对齐 8 或 16 的倍数:hidden size 是 128, 256, 512 等 2 的幂,head 数是 8, 16, 32 等。这些都是为了硬件友好。

如果你看 Llama 2 的配置:hidden_size = 4096, num_heads = 32, head_dim = 128。所有数都对齐 16 的倍数。这不是巧合。

第 50 篇量化推理时会再次回到这件事——量化用 INT8 时硬件单元更大(NVIDIA Hopper 的 INT8 Tensor Core 是 16×16×32),对齐要求更严格。

二十八、矩阵乘法的并行模式

在分布式训练里,矩阵乘法可以按多种方式拆分到多卡:

数据并行(Data Parallel)。每张卡有完整的模型,处理不同的 batch。最简单。

张量并行(Tensor Parallel)。把单个矩阵乘法切到多张卡上。比如 \(Y = X W\) 中,把 \(W\) 按列切成 \([W_1, W_2]\),每卡各算 \(X W_i\),最后拼起来。这是 Megatron-LM 的核心技巧。

流水线并行(Pipeline Parallel)。把不同层放到不同卡,前向后向像流水线一样。

序列并行(Sequence Parallel)。把序列长度维度切到不同卡。和张量并行常常配合用。

每种并行都对矩阵乘法做了不同的拆分。第 53 篇分布式训练会详谈。本篇先记住「矩阵乘法天然适合并行,怎么拆是工程问题」。

二十九、稀疏与低秩

不是所有矩阵都需要「完整稠密」存储。

稀疏矩阵。大部分元素为 0。可以用 CSR、COO 等格式存储,节省内存。在 MoE 模型里,每个专家只服务一部分 token,attention 只对部分 key 算分数(sparse attention),都用到稀疏。

低秩矩阵\(\mathrm{rank}(W) = r\)\(r \ll \min(d_{in}, d_{out})\)。可以分解成 \(W = AB\),其中 \(A\)\(d_{in} \times r\)\(B\)\(r \times d_{out}\)。参数量从 \(d_{in} d_{out}\) 降到 \(r(d_{in} + d_{out})\)

低秩分解是 LoRA 的根基。LoRA 微调一个大模型时,把 \(\Delta W\)(权重变化量)约束为低秩,只训练 \(A\)\(B\),原模型 \(W\) 不动。这能让一个 70B 模型用几十 MB 的额外参数适配新任务。

第 56 篇会专门讲 LoRA。本篇先种下「矩阵可以分解、稀疏存」的种子。

三十、einsum 的更多用法

einsum 是矩阵运算的「瑞士军刀」。除了基本矩阵乘法,还能做很多事:

# 标准矩阵乘法
torch.einsum('ik,kj->ij', A, B)

# 转置
torch.einsum('ij->ji', A)

# 求和(reduce)
torch.einsum('ij->', A)        # 全部求和
torch.einsum('ij->i', A)       # 每行求和

# Hadamard 积(逐元素相乘)
torch.einsum('ij,ij->ij', A, B)

# 外积
torch.einsum('i,j->ij', a, b)

# 点积
torch.einsum('i,i->', a, b)

# Batch 矩阵乘法
torch.einsum('bik,bkj->bij', A, B)

# Multi-head attention 分数
torch.einsum('bhid,bhjd->bhij', Q, K)

einsum 的可读性优势在复杂场景下特别明显。我建议至少把上面这几个用法记下来,遇到 PyTorch 文档里复杂的形状操作可以参考。

三十一、术语对照表

中文 English 备注
矩阵 matrix 二维数组
矩阵乘法 matrix multiplication / matmul 行 × 列
矩阵转置 transpose 行列互换
Hadamard 积 Hadamard product / element-wise multiply 逐元素乘
外积 outer product 向量 × 向量 → 矩阵
内积 inner product 向量 × 向量 → 标量
单位矩阵 identity matrix 对角 1 其余 0
对角矩阵 diagonal matrix 仅对角非零
对称矩阵 symmetric matrix \(A = A^T\)
正交矩阵 orthogonal matrix \(Q^T Q = I\)
置换矩阵 permutation matrix 重排行列
稀疏矩阵 sparse matrix 大部分为零
低秩矩阵 low-rank matrix rank ≪ min(n, m)
仿射变换 affine transformation 线性 + 平移
GEMM general matrix-matrix multiply BLAS 中矩阵乘的标准术语
BLAS Basic Linear Algebra Subprograms 线代基本运算库
cuBLAS CUDA BLAS NVIDIA 的 BLAS 实现
Tensor Core Tensor Core NVIDIA 的矩阵乘硬件
算术强度 arithmetic intensity FLOPs/byte
张量并行 tensor parallel 切矩阵到多卡
流水线并行 pipeline parallel 切层到多卡
LoRA Low-Rank Adaptation 低秩适配
MoE Mixture of Experts 专家混合

三十二、FAQ

Q1:矩阵乘法在 PyTorch 里有几种写法?哪种最快?

A:@torch.matmultorch.bmmtorch.einsumA.mm(B) 等都能做矩阵乘法。@matmul 在大多数情况下走同一个底层 cuBLAS 调用,速度相同。einsum 在简单情况下也会被优化为 matmul。bmm 是「batched mm」,需要 3D 输入。

Q2:怎么判断 attention 计算是否走了 Tensor Core?

A:使用 torch.profiler 或 nsight 工具看 kernel name。Tensor Core kernels 通常带 tensoropwmma 字样。FP32 不能走 Tensor Core,必须 FP16/BF16/TF32/FP8。所以混合精度训练(autocast)几乎是大模型的标配。

Q3:为什么 attention 里 Q 和 K 不能共享权重?

A:可以共享(有些早期实验做过),但效果通常不如独立。原因是 Q 和 K 在语义上承担不同角色:Q 表示「我要找什么」,K 表示「我代表什么」。独立的权重让两者各自学习最合适的表达。

Q4:multi-head 是不是把同一个矩阵乘法做了 H 次?

A:等价于一次大矩阵乘法。\(W_Q\) 整体是 \(D \times D\),被「视作」H 个 \(D \times d_h\) 的小矩阵。前向只调一次大 GEMM,然后 reshape 成多头。

Q5:FlashAttention 怎么避免 \(QK^\top\) 这个 (T, T) 的中间结果?

A:分块计算。把 Q 切成块,K、V 也切成块,每次只算一块的局部注意力,把结果累加到输出。中间结果不写回 HBM,留在 SRAM。第 49 篇详谈。

Q6:sparse attention 怎么实现?

A:先决定哪些 (i, j) 位置需要算,把它们的索引存好,只算这些位置的点积。框架支持有限(torch.sparse),实际工程里多用自定义 CUDA kernel。Longformer、BigBird 等是代表作。

Q7:矩阵乘法在 CPU 上为什么也快?

A:MKL(Intel)、OpenBLAS 等 CPU 版 BLAS 库做了几十年优化,能把矩阵乘法跑到 CPU 理论峰值的 80%+。但 CPU 的总算力远低于 GPU,所以大模型基本不在 CPU 上训。

Q8:为什么 attention 的复杂度是 \(O(T^2 D)\) 不是 \(O(T^2 D^2)\)

A:\(QK^\top\)\((T, D) \times (D, T) = (T, T)\),计算量 \(2T^2 D\)。再乘 \(V\)\((T, T) \times (T, D) = (T, D)\),计算量 \(2T^2 D\)。两次都是 \(O(T^2 D)\),加起来还是 \(O(T^2 D)\)\(D\) 只出现一次。

Q9:序列长度 T 翻倍,attention 计算时间翻几倍?

A:理想情况翻 4 倍(\(T^2\))。实际上由于内存带宽瓶颈,可能更糟(5-8 倍)。FlashAttention 让它接近理论的 4 倍。

Q10:把 attention 换成 RNN 能不能省时间?

A:理论上 RNN 是 \(O(T D^2)\),T 大时比 attention 的 \(O(T^2 D)\) 快。但 RNN 的递归结构无法并行,GPU 利用率低,实际不如 attention 快。Mamba 等 state-space 模型在试图找平衡点。

三十三、坑点合集

坑 1:维度不匹配。99% 的 PyTorch 报错。养成「写完一行就 print 形状」的习惯。

坑 2:transpose 后忘了 contiguous。某些操作(特别是 view)需要 contiguous 内存。x = x.transpose(1, 2).contiguous() 是常见写法。

坑 3:reshape 和 view 的差别。view 要求 contiguous,reshape 不要求(但可能多一次 copy)。在性能敏感路径用 view + contiguous,可读性敏感路径用 reshape。

坑 4:bias 项的形状Linear(D, D) 的 bias 是 (D,),会被广播到所有 batch。如果你手写线性层,记得 bias 形状对齐。

坑 5:FP16 下大矩阵乘法溢出。中间累加用 FP32,最后转回 FP16。torch.cuda.amp.autocast 自动处理。

坑 6:mask 的 broadcast。attention mask 通常是 (T, T),需要广播到 (B, H, T, T)。PyTorch 自动广播,但要确认形状能正确广播(前面用 1 填充)。

坑 7:除以 sqrt(d_k) 写错位置。应该在 mask 之前。如果在 mask 后除,mask 的 -inf 不受影响(因为 -inf / 任何正数 = -inf),但其他细节会错。

坑 8:multi-head 里头的拆分顺序。先 reshape 到 (B, T, H, d_h) 再 transpose 到 (B, H, T, d_h)。顺序错了会拆错维度。

坑 9:合头时忘了 contiguous。out.transpose(1, 2) 后 reshape 可能报错,因为不是 contiguous。要加 .contiguous() 或用 .reshape() 而非 .view()。

坑 10:忘记输出投影 W_O。multi-head attention 的最后一步是 W_O。这是合头之后的「整合」步骤,不能省。

矩阵乘法不只是 attention 的核心,还贯穿整个深度学习栈。

全连接层(Linear / Dense layer)\(\mathbf{y} = W\mathbf{x} + \mathbf{b}\) 就是一次矩阵乘法加一次平移。最简单也最常用的层。

卷积层(Conv layer)。卷积本质上也是矩阵乘法——把输入展开(im2col)成大矩阵,把卷积核展开成另一个大矩阵,相乘。这是 cuDNN 的标准做法。所以即使是 CNN,底层也是大量 GEMM。

RNN / LSTM。每一步的状态更新 \(\mathbf{h}_t = \tanh(W \mathbf{h}_{t-1} + U \mathbf{x}_t)\) 是两次矩阵乘法。总计算量 \(O(T D^2)\),T 是步数,D 是隐藏维度。

Embedding 查表nn.Embedding(V, D) 表面上是「按索引查表」,底层等价于「one-hot 向量乘以 V × D 的矩阵」。所以连 embedding 也是矩阵乘法的特例。

Output head(输出层)。语言模型的输出 \(\mathbf{logits} = W_{out} \mathbf{h}\),把 hidden 投影到 vocab 大小。\(W_{out}\)\(D \times V\),对大词表(V=100K)这是一次大型矩阵乘法。

LayerNorm / RMSNorm。归一化本身不是矩阵乘法(是逐元素操作),但其前置的均值/方差计算可以表达成「向量与全 1 向量的内积」。

所以整个 Transformer 的算力 95% 以上花在矩阵乘法上。其他算子(softmax、激活、norm、dropout)虽然必要,但占比小。理解矩阵乘法 = 理解 Transformer 计算的主要部分。

三十五、矩阵乘法的数值稳定性

矩阵乘法虽然简单,但在长链反复乘时会出现数值问题。

梯度爆炸 / 消失。深网络里,反向传播是连续的矩阵乘法。如果矩阵的奇异值(最大值)大于 1,多次相乘后梯度会指数爆炸;如果小于 1,会消失。这是为什么神经网络初始化要小心(Xavier、Kaiming 初始化都是为这件事设计的)。

累加误差。FP16/BF16 的精度有限,长向量点积会累加误差。Tensor Core 用 FP32 累加缓解。

条件数。矩阵的条件数(最大奇异值/最小奇异值)大时,乘以它对输入的小扰动会变成大输出。这在求解线性方程组时是大问题,在前向传播里是小问题。

反归一化的梯度。LayerNorm 的反向传播包含 1/std 这种除法,std 接近 0 时梯度爆炸。所以 LayerNorm 内部要 clamp 或加 epsilon。

这些数值问题大部分被现代框架自动处理,但你应该知道它们存在。当你看到训练 loss 突然爆炸或 NaN 时,第一反应应该是「数值稳定性出问题了」。

三十六、关于「矩阵」这个词的概念延伸

随着深度学习的发展,「矩阵」的概念扩展到了更高维度。

张量(tensor)。任意维度的数组。0 维是标量,1 维是向量,2 维是矩阵,3 维及以上叫张量。PyTorch 里所有变量都是 tensor。

张量乘法torch.matmul 在前面有 batch 维时自动按 batch 处理。torch.einsum 提供更灵活的张量运算。torch.tensordot 沿任意维度做收缩。

张量分解。CP 分解、Tucker 分解、Tensor Train 分解等。在压缩大模型时偶尔用到。

张量网络。物理学(量子多体)和机器学习共同的工具。MPS、PEPS、TT 都是张量网络的特例。

本系列基本停留在 2D 矩阵 + batch 的层面,但你要意识到「矩阵 = 张量的 2D 特例」,更高维的世界存在。

三十七、回到 attention:再次看 \(QK^\top\)

经过本篇的讨论,我们再回看 attention 公式:

\[ \mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V \]

这条公式里的每个矩阵乘法你现在应该都能解释:

  • \(QK^\top\):(T, D) × (D, T) = (T, T),每个元素是查询和键的点积。
  • 除以 \(\sqrt{d_k}\):标量除法,让分数方差稳定。
  • \(\mathrm{softmax}\):沿最后一维归一化。
  • \(\cdot V\):(T, T) × (T, D) = (T, D),每行是 V 各行的加权和。

整条公式有两次矩阵乘法。如果加 multi-head 和 batch,就是 (B, H, T, D/H) × (B, H, D/H, T) → (B, H, T, T) 和 (B, H, T, T) × (B, H, T, D/H) → (B, H, T, D/H)。前面的 batch 和 head 维让 cuBLAS 一次性处理几十个矩阵乘法。

你现在应该能说出每个维度的意义、每次乘法的形状、为什么要 transpose、为什么 softmax 沿 dim=-1。如果能做到,本篇的目标达成了。

三十八、本篇的小结

我们从「单次点积」走到了「批量矩阵乘法」。这个升级看起来只是「把循环写成矩阵」,实际上是 Transformer 高效的根本:

  • 算法层:大量重复的小操作合并成大型矩阵乘法。
  • 硬件层:GPU 的 Tensor Core、cache hierarchy、SIMD 都是为矩阵乘法设计的。
  • 工程层:cuBLAS 把矩阵乘法跑到硬件极限。
  • 架构层:Transformer 的每一层都是矩阵乘法 + 简单非线性,全栈优化空间大。

这种「算法形式 + 硬件特性匹配」是工程的最高境界。当算法形式选得好时,硬件、库、工具链都会向它聚集,形成正反馈。Transformer 在 GPU 时代之所以击败 RNN,本质就是因为它的算法形式更适合 GPU。

下一篇我们离开「线性代数」回到「函数」的世界,讨论「神经网络是什么、为什么需要堆叠多层、为什么需要非线性」。看似离 attention 远了一些,实际上是为了打地基——理解了「函数复合 + 非线性」之后,attention 的 W_Q, W_K, W_V 才能真正归位到「学出来的线性投影」这一抽象。

三十九、再来几个具体的形状练习

为了让你彻底固化「形状直觉」,下面是几个练习场景。请在脑子里(或纸上)算出每一步的形状。

场景 A:批量大小 4,序列长度 512,hidden 768,头数 12。

  • W_Q 形状?答:(768, 768)。
  • 单头维度 d_h?答:64。
  • Q 的形状(拆头前后)?答:拆前 (4, 512, 768),拆后 (4, 12, 512, 64)。
  • \(QK^\top\) 形状?答:(4, 12, 512, 512)。
  • 这一矩阵存储多少 byte(FP16)?答:\(4 \cdot 12 \cdot 512 \cdot 512 \cdot 2 = 25\,165\,824\) bytes ≈ 24 MB。
  • 一层 attention 的 FLOPs(不含 V)?答:\(2 \cdot 4 \cdot 12 \cdot 512 \cdot 512 \cdot 64 \approx 1.6 \times 10^9\) FLOPs。

场景 B:序列长度 32K,hidden 4096,头数 32。

  • \(QK^\top\) 形状(不含 batch)?答:(32, 32K, 32K)。
  • FP16 存储?答:\(32 \cdot 32768^2 \cdot 2 \approx 67\) GB。

是的,你没看错。仅一个 attention layer 的 attention map 就要 67 GB。这就是为什么长上下文需要 FlashAttention 之类的优化——不能把整个 attention map 实例化。

场景 C:cross-attention,T_q = 16(解码时一个 token),T_k = 4096(KV cache)。

  • \(QK^\top\) 形状?答:(H, 16, 4096)。
  • 比 self-attention 在同等长度下小很多。这就是为什么解码(一个 token 一个 token 生成)相对廉价——Q 的长度只有 1。

通过这些具体场景,你能感性理解「为什么长上下文这么难」「为什么解码比训练快」「为什么 attention map 不能实例化」等工程话题。

四十、一份「形状自查清单」

每次写新模型代码时,过一遍这个清单:

  1. 输入张量的形状是什么?每一维代表什么?
  2. 每次 reshape 之后形状变成什么?变化合理吗?
  3. 每次 transpose 之后哪些维度被交换?需不需要 contiguous?
  4. 每次 matmul 的左右两个张量形状能否相乘?输出形状对吗?
  5. 每次 broadcast 是否符合预期?是不是把不该 broadcast 的维度 broadcast 了?
  6. 每次 softmax 的 dim 参数对吗?归一化的轴是不是「类别轴」?
  7. mask 的形状能广播到 scores 吗?mask 的位置是 1 还是 0 表示「保留」?
  8. 输出形状和期望的下一层输入形状一致吗?

这八个问题在每次写新代码时问一遍,能预防大多数 bug。

四十一、本篇与后续的连接

本篇站在「线性代数 + 工程」的交界处。后面:

  • 04 篇 函数到神经网络:把矩阵乘法纳入「线性变换 + 非线性激活」的框架,理解 W_Q, W_K, W_V 是「学得的线性投影」。
  • 05 篇 激活函数:解释为什么纯矩阵乘法不够,需要非线性。
  • 08 篇 Softmax:把 attention 分数变成概率分布的桥梁。
  • 13 篇 缩放点积注意力:把所有矩阵乘法 + softmax 拼成完整公式。
  • 15 篇 多头注意力:解释为什么多头比单头好。
  • 49 篇 FlashAttention:把矩阵乘法 + softmax + 矩阵乘法融合成一个 kernel。
  • 53 篇 分布式训练:把矩阵乘法切分到多卡。
  • 56 篇 LoRA / PEFT:用低秩分解微调模型。

矩阵乘法这个工具贯穿全系列。读到这些后续章节时,请回想本篇的「四种视角」「形状追踪」「算术强度」「Tensor Core」等概念——它们会帮你把零散的工程知识串起来。

四十二、给读者的一个练习

读完本篇请尝试以下事情。如果都能做到,本篇的目标达成了。

练习 1:在草稿纸上手算一个 3×3 的矩阵乘以一个 3×2 的矩阵。用「点积视角」做一次,再用「列线性组合视角」做一次,验证两种结果一致。

练习 2:用 NumPy 写一个 matmul_naive(A, B) 函数,只用 for 循环。然后比较它和 np.dot(A, B) 在 1000×1000 矩阵上的速度差距。预期差距 100 倍以上。

练习 3:用 PyTorch 实现一个 multi_head_attention(X, W_Q, W_K, W_V, W_O, H) 函数,要求带详细的形状注释。在 (B=4, T=128, D=512, H=8) 输入下跑通。

练习 4:把上一题的 attention 改成「causal」(只能看到自己和之前的位置)。用 torch.tril 生成下三角 mask。

练习 5:用 torch.einsum 重写练习 3 的 attention,对比可读性。

练习 6:估算一个 GPT-2(12 层,hidden 768,12 头,T=1024)单次前向 attention 部分的 FLOPs 和峰值内存(attention map 部分)。

练习 7:阅读 nanoGPT 的 model.pyCausalSelfAttention 类,把每一行和本篇对应起来。哪些行是矩阵乘法?哪些行是 reshape / transpose?哪些行是数值技巧?

这些练习的目的不是让你「掌握每个细节」,而是让「形状」「矩阵乘法」「einsum」这些概念在你脑子里变成可操作的工具。

四十三、一个轻松的画面收尾

想象一个工厂车间。流水线上有几千个工位,每个工位的工人都在做同一件简单的事——拿起一行数和一列数,按位置相乘,再加起来。

整个工厂的运转就是矩阵乘法。

你说「让这个工厂跑得更快」?方法不是让每个工人更聪明,而是:让流水线更宽(并行度)、让原材料离工位更近(cache)、让工人们手里的工具更好(Tensor Core)、让整个工厂的物流更顺畅(memory bandwidth)。

这就是 GPU 在做的事,过去十年 NVIDIA 的工程进步,几乎全部用来「让矩阵乘法工厂跑得更快」。

而 Transformer 是「把所有工序都设计成这种工厂友好的形式」的架构。它和 GPU 互相成就,共同定义了 2020 年代的 AI 时代。

下一篇我们暂时离开车间,去想另一个问题:「这个工厂在生产什么?」——也就是「神经网络作为函数」的视角。


下一步

下一篇 04 函数到神经网络 会把矩阵乘法纳入「神经网络 = 函数复合」的框架,理解 W_Q, W_K, W_V 这些线性变换在做什么、为什么要堆叠、为什么需要非线性。

如果你已经熟悉神经网络,可以直接跳到 05 激活函数,那一篇会讲为什么没有非线性的话再多层都是一层。

如果你想直接看 attention 怎么把所有这些矩阵乘法组合起来,去 13 缩放点积注意力 鸟瞰,再回来补 04-12 的细节。


参考文献

经典教材

  • Gilbert Strang. Introduction to Linear Algebra (5th ed.). Wellesley-Cambridge Press, 2016. 第 2 章矩阵乘法的几种视角讲得最清楚。
  • Sheldon Axler. Linear Algebra Done Right (3rd ed.). Springer, 2015. 抽象向量空间的视角,对线性变换的本质理解有帮助。
  • Stephen Boyd, Lieven Vandenberghe. Introduction to Applied Linear Algebra. Cambridge University Press, 2018. 工程导向,有大量代码示例。

论文

  • Ashish Vaswani et al. “Attention Is All You Need.” NeurIPS, 2017. attention 的多次矩阵乘法是工程效率的根本。
  • Tri Dao et al. “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” NeurIPS, 2022. 把 attention 里的几次矩阵乘法融合成一次大型分块计算的工作。
  • Edward Hu et al. “LoRA: Low-Rank Adaptation of Large Language Models.” ICLR, 2022. 用「外积之和视角」做低秩适配的代表作。

博客与可视化


上一篇02 向量与点积的几何直觉  下一篇04 函数到神经网络

同主题继续阅读

把当前热点继续串成多页阅读,而不是停在单篇消费。

2026-06-09 · transformer

【Transformer 与注意力机制】59|推理退化:为什么大模型会输出乱码、死循环和无意义文本

大模型推理时偶尔会突然陷入死循环、输出乱码或连续无意义数字,这不是随机 bug,而是注意力机制、Causal Mask、解码策略和数值精度在自回归生成中共同作用的结果。本文从 QKV 计算坍塌出发,解释 Attention Sink、Softmax 马太效应、Causal Mask 的退路切断、FP16 溢出路径和 KV Cache 污染,并给出从架构到运行时的多层防线。

2026-04-15 · transformer

【Transformer 与注意力机制】系列总览

从《Attention Is All You Need》出发把 Transformer 注意力机制、Q/K/V、多头注意力、位置编码、Causal Mask、Softmax、FFN、训练范式、模型变体、推理工程、可解释性、未来架构以及推理退化防御串成 59 篇深度博客。

2026-04-15 · transformer

【Transformer 与注意力机制】01|为什么要从这里开始

这是【Transformer 与注意力机制】系列的第一篇,承担两件事:一是把这套五十多篇文章为谁写、解决什么问题、彼此之间是什么关系交代清楚;二是为完全没基础的读者画出一条从向量、点积、矩阵乘法走到自注意力、再走到大语言模型的爬升路径,让你在投入时间之前先知道终点在哪、路上要经过哪些坎、读完之后你会、还不会做什么事。

2026-04-15 · transformer

【Transformer 与注意力机制】42|FlashAttention:注意力计算的硬件级重写

FlashAttention 的关键不是近似注意力,也不是把公式改掉,而是重新安排标准 attention 在 GPU 内存层级里的计算路径。本文解释为什么标准 attention 的瓶颈常常是 HBM 读写,FlashAttention 如何用 tiling 和 online softmax 避免物化完整注意力矩阵,以及它为什么省显存、提吞吐,却没有消除 O(n²) 的根本复杂度。