矩阵乘法求导

矩阵乘法求导

pyotrch中只能是标量对矩阵求导,所以矩阵乘法结束后加个sum \[ L = sum(\bm{WX}) \]

其中,\(\bm{W}\)\(\bm{X}\)都是矩阵,那么 \[ \frac{\partial L}{\partial\bm{W}}_{\cdot i}=\sum\bm{X}_{i\cdot} \]

梯度和W的形状相同,梯度中每列都是相同的,只要是第i列,梯度值就是\(\bm{X}\)的第i行的和。

用公式不太好表示,我们用pytorch代码来描述一下:

1
2
3
4
5
6
7
>>> w = torch.arange(0,50,dtype=torch.float32, requires_grad=True)
>>> nw = w.view(10, 5)
>>> x = torch.arange(50,100, dtype=torch.float32, requires_grad=True)
>>> nx = x.view(5,10)
>>> loss = torch.matmul(nw, nx)
>>> sum_loss = loss.sum()
>>> sum_loss.backward()
1
2
3
4
5
6
7
8
9
10
11
12
>>> print(w.grad.view_as(nw),'\n', nx.detach().sum(dim=1))
tensor([[545., 645., 745., 845., 945.],
[545., 645., 745., 845., 945.],
[545., 645., 745., 845., 945.],
[545., 645., 745., 845., 945.],
[545., 645., 745., 845., 945.],
[545., 645., 745., 845., 945.],
[545., 645., 745., 845., 945.],
[545., 645., 745., 845., 945.],
[545., 645., 745., 845., 945.],
[545., 645., 745., 845., 945.]])
tensor([545., 645., 745., 845., 945.])

这个推导过程也不是很复杂,可以自己举个例子试试。


矩阵乘法求导
https://jcdu.top/2023/06/14/矩阵乘法求导/
作者
horizon86
发布于
2023年6月14日
许可协议