惯性聚合 高效追踪和阅读你感兴趣的博客、新闻、科技资讯
阅读原文 在惯性聚合中打开

推荐订阅源

Simon Willison's Weblog
Simon Willison's Weblog
P
Privacy International News Feed
www.infosecurity-magazine.com
www.infosecurity-magazine.com
T
Troy Hunt's Blog
Hacker News - Newest:
Hacker News - Newest: "LLM"
Attack and Defense Labs
Attack and Defense Labs
S
Secure Thoughts
V2EX - 技术
V2EX - 技术
cs.AI updates on arXiv.org
cs.AI updates on arXiv.org
O
OpenAI News
Cloudbric
Cloudbric
Google Online Security Blog
Google Online Security Blog
Schneier on Security
Schneier on Security
cs.CV updates on arXiv.org
cs.CV updates on arXiv.org
Help Net Security
Help Net Security
Cyberwarzone
Cyberwarzone
G
GRAHAM CLULEY
L
Lohrmann on Cybersecurity
Threat Intelligence Blog | Flashpoint
Threat Intelligence Blog | Flashpoint
Spread Privacy
Spread Privacy
NISL@THU
NISL@THU
N
News and Events Feed by Topic
T
Tenable Blog
S
Security @ Cisco Blogs
N
News and Events Feed by Topic
The Hacker News
The Hacker News
C
CXSECURITY Database RSS Feed - CXSecurity.com
宝玉的分享
宝玉的分享
月光博客
月光博客
酷 壳 – CoolShell
酷 壳 – CoolShell
美团技术团队
奇客Solidot–传递最新科技情报
奇客Solidot–传递最新科技情报
Google DeepMind News
Google DeepMind News
钛媒体:引领未来商业与生活新知
钛媒体:引领未来商业与生活新知
T
Tailwind CSS Blog
V
Visual Studio Blog
P
Proofpoint News Feed
Webroot Blog
Webroot Blog
让小产品的独立变现更简单 - ezindie.com
让小产品的独立变现更简单 - ezindie.com
博客园 - 三生石上(FineUI控件)
cs.CL updates on arXiv.org
cs.CL updates on arXiv.org
Jina AI
Jina AI
雷峰网
雷峰网
T
The Blog of Author Tim Ferriss
Hugging Face - Blog
Hugging Face - Blog
腾讯CDC
L
LangChain Blog
The Register - Security
The Register - Security
OSCHINA 社区最新新闻
OSCHINA 社区最新新闻
博客园 - 聂微东

博客园 - longbigfish

部署(https证书) https证书问题(本地) Ubuntu 24安装Neo4j详细教程 protect 紧急 手机 刷脏页的两种模式 python中的多线程陷阱与pytorch分布式执行机制 git之复合指令和submodule rpc编程示例 mpi编程 cifs远程挂载 使用脚本进入一个命令行控制台,并预设执行的命令列表 cifs挂载远程文件出现 No such device or address错误 longtable 跨越多个页面时,如何在跨页时自动断行并加上横线及去掉页眉 matplotlib中文显示-微软雅黑 latex编译过程-关于嵌入所有字体 python做图笔记 linux启动全过程 连接并同步windows下的git仓库 反向ssh
参数更新
longbigfish · 2025-06-23 · via 博客园 - longbigfish

1. loss

是一个单值

假设输入的词元id是[0, 1]

目标词元id是[1, 2]

 也就是根据输入得到两个预测输出,

注意上面的是id,每个id实际上是一个嵌入向量,比如768维向量,

假设词汇表是3,实际词汇表可能是5w

通过模型矩阵计算后,对于输入的每一个位置,都会输出一个3维度的向量,对齐进行softmax选择最大的概率作为预测输出,

这里输入序列有两个词元,因此会预测出两个结果,实际上是两个3维度的概率向量,比如[[0.320, 0.333, 0.347], [0.301, 0.332, 0.367]],

这两个概率向量表明,预测输出都是2的概率最大

但实际上目标值第一个是1,第二个是2,

计算损失实际上是根据目标词元id,对预测结果中对应位置的概率求负对数

位置0: -log(0.333) ≈ 1.100
位置1: -log(0.367) ≈ 1.003
平均损失 = (1.100 + 1.003)/2 = 1.0515

其意义是,如果目标词元位置的概率很大,说明预测的准,那么这个熵损失值就很小,概率趋于1损失就趋于0,如果预测的不准就是概率小,那么损失值就很大,概率趋于0,损失值就趋于无穷大,

2. 梯度

2.1 损失对logits的梯度

 交叉熵损失梯度公式:∂L/∂logits = softmax(logits) - one_hot(target)

位置0(目标=1)
one_hot(1) = [0, 1, 0]
∂L/∂logits0 = [0.320, 0.333, 0.347] - [0, 1, 0] = [0.320, -0.667, 0.347]

位置1(目标=2):
one_hot(2) = [0, 0, 1]
∂L/∂logits1 = [0.301, 0.332, 0.367] - [0, 0, 1] = [0.301, 0.332, -0.633]

可以看到,目标位置的梯度为负,且预测的越准的话,这个梯度绝对值就越小,那么在进行梯度下降是动作就要“轻微”点

非目标位置的梯度是正值,且非目标位置如果概率越大,表明越不准确,那么进行梯度下降时这个地方要“剧烈”点

 2.2 损失对参数的梯度

∂L/∂W = 嵌入^T × (∂L/∂logits)

假设输入嵌入矩阵是n*d,那么其转置是d*n,那么转置的每一行表示了n个位置每个位置的一部分,

(∂L/∂logits)是n*w矩阵,表示对于输入的n个位置,每个位置对于词汇表每个词汇的预测概率相应的损失,

这两矩阵相乘,结果是d*w矩阵,

可以用第一个值举例,这个值是由n个输入向量取每个第一个值,同时对n个输出概率向量每个取对词汇表第一个词汇的梯度值,进行相乘得到一个标量值,

那么结果的d*w矩阵,第i行第j列,包含了每个位置预测结果中第j个词元的概率梯度综合,以及输入序列嵌入矩阵每个输入的第i个值,

实际上,这个d*w矩阵就是参数矩阵

2.3 

我们再回忆下参数流

嵌入矩阵:w*d 

      输入:n*d  --h3

#忽略QKV

#QKV矩阵:d*d

#      QK得到n*n自注意力矩阵

#      再与V得到n*d矩阵

FFN网络矩阵1:d*f    --W3

      得到n*f矩阵  -h2

FFN网络矩阵2: f*d  --W2

      得到n*d矩阵    --h1

输出层矩阵:d*w    --W1

      得到n*w矩阵

2.3.1

根据n*w矩阵,我们得到对logits的梯度,也就是n*w矩阵,

下面我们反向一步步得到每个参数矩阵的梯度,

对输出层矩阵参数d*w,因为我们根据n*d矩阵和d*w矩阵相乘得到n*w矩阵,那么需要n*d的转置与n*w相乘就可得到d*w参数梯度。

其中n*d在有些教程中表示为h(隐藏矩阵),是一种中间态数据,需要报存在显存

2.3.2

第一步,计算∂L/∂logits, 这个通过预测结果softmax矩阵与目标序列one-hot矩阵相减得到,是一个n*w矩阵,

第二步,计算∂L/∂W1 = h1(T) * (∂L/∂logits) ,是一个d*w矩阵,这个矩阵形状和W1一样,    -- h1*W1 = n*w

第三步,计算∂L/∂h1 =  (∂L/∂logits) * W1(T), 是一个n*d矩阵,和h1形状一样,  

第四步,计算∂L/∂W2 = h2(T) * ∂L/∂h1, 是一个f*d矩阵,形状和W2一样,        -- h2*W2 = h1

第五步,计算∂L/∂h2 =  ∂L/∂h1 * W2(T), 是一个n*f矩阵,

第六步,计算∂L/∂W3 = h3(T)*  ∂L/∂h2, 是一个d*f矩阵 ,                     -- h3*W3 = h2, h3就是inputs

第七步,计算∂L/∂h3 =  ∂L/∂h2 * W3(T), 是一个n*d矩阵,

3.  参数更新

根据上面计算出的梯度,对每一个实体矩阵(嵌入矩阵以及参数矩阵)进行更新,中间态矩阵不用更新。

假设lr=0.01

3.1

更新inputs,也即是h3,也即是词嵌入

对于输入序列n个词的第i个嵌入,

E[i] = E[i] - 0.01*∂L/∂h3 [i]

或者整体上

E = E - 0.01

3.2 更新 W3

使用与其形状相同的梯度矩阵乘以学习率,然后将W3与结果相减,得到新的W3

4. 优化器

上面的还没有提到优化器,现在我们加入优化器

优化器包含两个状态m和v,或者叫动量和方差,对于每个参数值,都有一一对应的动量和方差,也就是说,每一个参数值同时对应两个优化器值,

所有的优化器状态初始化为0,

另外,还有几个值以及初始化举例如下,

学习率 lr=0.01

beta1=0.9, beta2=0.999, 即 β1,β2

epsilon=1e-8, 即  ε

时间步 t=1 (初始)

也就是说,这些值是优化器的一部分,可以看到,学习率成为了优化器的一部分,

使用优化器对每一个参数值进行更新,以W3举例,

首先已经得到W3的梯度矩阵,和W3的形状一样,设为W3Grad,现在对W3[0,0]进行更新,公式

m = β1·m + (1-β1)·grad
v = β2·v + (1-β2)·grad²
m̂ = m / (1 - β1ᵗ)
v̂ = v / (1 - β2ᵗ)

param = param - lr·m̂ / (√v̂ + ε)

其中grad就是W3Grad[0,0], m,v也是对应的优化器状态值,目前初始值0,

这里我们更新的param 就是W3[0,0]

更新完成后,对应的m,v都变为新的值了,全部更新完后,时间t加1

也就是说,每次更新参数时,先将对应优化器值进行更新(根据给定的β1、β2以及计算出来的梯度值和时间步),然后,使用更新后的m、v,lr以及t对参数进行更新,

使用β1对m进行更新的意义,大部分还是保持m当前的值(使用β1乘以当前m,β1值比较接近1),少部分根据梯度值增加,也就是说,如果梯度越大,那么这个动量m变化也越大,梯度可以为负,所以动量可以往小变,一般来说,目标词元处的梯度为负,其他为正,

同理,梯度越大,方差v变化也越大,这个方差始终往大变

完了以后,这俩公式

m̂ = m / (1 - β1ᵗ)
v̂ = v / (1 - β2ᵗ)

将m和v的值按比例放大,时间步越大,这个放大比例越小,

最后就是更新参数了,方差大的话,更新的幅度小,动量大的话,更新的幅度大。

5 总结

从输入序列a矩阵到最后输出z矩阵中间会有x个参数矩阵,总共有x-1个中间态矩阵,或者叫临时矩阵

我们将临时矩阵标注为h1,,,h(x-1)

将参数矩阵标注为p1,,,px

a * p1 = h1,

h1* p2  = h2,

h(x-1)* px = z

最后的 z 是一个经过softmax后是一个概率矩阵,如果输入序列是n * d维度,那么z 就是 n * w维度,其中n是词元数,d是嵌入维度,w是词汇量

根据目标序列,将每个目标词元转为w维度的one-hot向量,组成一个n*w矩阵,与z进行减法运算,计算出梯度,实际的意义就是在目标词元处的概率如果大,梯度就小,如果概率小,梯度就大。

然后进行反向传播,

从z开始,依次计算px、h(x-1)、p(x-1),h(x-2),,,,,一直到p1,a的梯度,

最后根据梯度进行参数更新,注意,a实际上对应的是嵌入式表中输入词元对应的行,

我们用h0表示a,hx表示z,

要计算pi的梯度,前提是hi梯度已经得到,

根据公式h(i-1) * pi = hi,

得到pi = T(h(i-1)) * hi, 将hi的梯度带入,就得到pi的梯度,

同理h(i-1) = hi * T(pi) ,将hi的梯度带入得到h(i-1)的梯度。