

























torch.where(condition, x, y) 是一个三元运算符:如果条件为真,取 x 的值;如果条件为假,保持 y(即原本的 weights)不变。
if self.cls_wise_reg_weights is not None:
for cls, weight in self.cls_wise_reg_weights.items():
weights = torch.where(
(cls_target[i] == cls)[:, None],
weights.new_tensor(weight),
weights
)
instance_reg_weights.append(weights)
weights的shape是[5,10], 全1;
self.cls_wise_reg_weights是{9:[1,1,1,1,1,1,0,0,1,1]}
cls_target[8,7,9,1,4]
已知:
weights.shape == [5, 10]
weights 全是 1
self.cls_wise_reg_weights = {
9: [1,1,1,1,1,1,0,0,1,1]
}
cls_target = [8,7,9,1,4]
代码:
for cls, weight in self.cls_wise_reg_weights.items():
weights = torch.where(
(cls_target == cls)[:, None],
weights.new_tensor(weight),
weights
)
逐步展开。
因为是全 1:
weights =
tensor([
[1,1,1,1,1,1,1,1,1,1], # 对应 cls_target=8
[1,1,1,1,1,1,1,1,1,1], # 对应 cls_target=7
[1,1,1,1,1,1,1,1,1,1], # 对应 cls_target=9
[1,1,1,1,1,1,1,1,1,1], # 对应 cls_target=1
[1,1,1,1,1,1,1,1,1,1], # 对应 cls_target=4
])
这里只有一个:
cls = 9
weight =
[1,1,1,1,1,1,0,0,1,1]
(cls_target == cls)cls_target == 9
得到:
tensor([False, False, True, False, False])
再:
[:, None]
变成 shape [5,1]:
tensor([
[False],
[False],
[ True],
[False],
[False]
])
torch.where(condition, A, B)
这里:
condition.shape = [5,1]
A.shape = [10]
B.shape = [5,10]
广播后:
condition ->
[
[F,F,F,F,F,F,F,F,F,F],
[F,F,F,F,F,F,F,F,F,F],
[T,T,T,T,T,T,T,T,T,T],
[F,F,F,F,F,F,F,F,F,F],
[F,F,F,F,F,F,F,F,F,F],
]
A 也广播成:
[
[1,1,1,1,1,1,0,0,1,1],
[1,1,1,1,1,1,0,0,1,1],
...
]
只有 cls_target==9 的那一行被替换:
weights =
tensor([
[1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,1,1,0,0,1,1], # 被替换
[1,1,1,1,1,1,1,1,1,1],
[1,1,1,1,1,1,1,1,1,1],
])
本质上:
对类别 9 的样本,把它对应的 regression weight 改成
[1,1,1,1,1,1,0,0,1,1]
也就是:
我们用你给的具体数值来肉眼“跑”一遍代码。你的 cls_target 包含了 5 个样本的类别标签,其中索引为 2 的样本类别恰好是 9。
下面是极其精细的单步推导过程:
weights 的形状是 [5, 10],初始全为 1:$$weights = \begin{pmatrix} 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 \ 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 \ 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 \ 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 \ 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 \end{pmatrix}$$
cls_target 为 [8, 7, 9, 1, 4],形状为 [5]。cls = 9,weight = [1, 1, 1, 1, 1, 1, 0, 0, 1, 1](注意第 6、7 位是 0,其余是 1)。(cls_target == cls)将标签数组里的每一个值和 9 做对比:
8 == 9 $\rightarrow$ False7 == 9 $\rightarrow$ False9 == 9 $\rightarrow$ True (索引为 2 的样本命中了)1 == 9 $\rightarrow$ False4 == 9 $\rightarrow$ False此时得到的布尔张量形状为 [5]:[False, False, True, False, False]。
[:, None]将形状从 [5] 升维变成 [5, 1]:
# condition 的形状:[5, 1]
condition = [[False],
[False],
[True ],
[False],
[False]]
torch.where 隐式广播与替换在执行 torch.where 时,PyTorch 会对 condition 和 weight 进行广播(矩阵复制对齐):
condition 从 [5, 1] 广播成 [5, 10](每一列都复制成一样)。weight 从 [10] 广播成 [5, 10](每一行都复制成一样)。最终替换逻辑按行展开如下:
False $\rightarrow$ 保持原样(全 1)。True $\rightarrow$ 替换为目标类别权重 [1, 1, 1, 1, 1, 1, 0, 0, 1, 1]。经过这一轮循环后,weights 的值更新为:
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # 类别8,保持原样
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # 类别7,保持原样
[1, 1, 1, 1, 1, 1, 0, 0, 1, 1], # 类别9,成功被修改!(最后两个0前面是6个1)
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # 类别1,保持原样
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) # 类别4,保持原样
其形状依然保持 [5, 10] 不变。也就是说,只有命中类别 9 的那个样本(第3个),其回归损失权重被精细地改写了。
此内容由惯性聚合(RSS阅读器)自动聚合整理,仅供阅读参考。 原文来自 — 版权归原作者所有。