pytorch各种乘法,mm, matmul, dot, @, *, mul, multiply

  1. torch.mm

线代的矩阵乘法,要求输入都是矩阵

  1. torch.matmul

注意:torch.mm和torch.matmul不等价

根据输入不同执行不同的操作:

  • 输入都是二维矩阵,矩阵乘法,等同于torch.mm

  • 输入都是一维向量,计算向量内积,等同于torch.dot

  • 第一个参数是向量,第二个是矩阵,则将第一个参数变成(1,n)的矩阵,再执行矩阵乘法

  • 第一个参数是矩阵,第二个是向量,执行矩阵向量乘法,等同于torch.mv

  • 两个都是高维张量,自己看文档去

  1. torch.dot

向量点积(内积),输入必须都是一维的。向量点积计算公式:

\(\bold a=(a_1, a_2, a_3)\)

\(\bold b=(b_1, b_2, b_3)\)

\(\bold a \cdot \bold b=a_1b_1+a_2b_2+a_3b_3\)

因此向量内积是个标量

  1. torch.mul

按元素相乘,element-wise的乘法,也叫哈达玛积

  1. torch.multiply

torch.mul的别称

  1. *

torch.mul的简写

  1. @

torch.matmul的简写(注意不是torch.mat的简写)

  1. torch.outer

向量外积,输入向量维度分别为n和m,则输出(n, m)


pytorch各种乘法,mm, matmul, dot, @, *, mul, multiply
https://jcdu.top/2025/01/04/pytorch各种乘法,mm, matmul, dot, @, _, mul, multiply/
作者
horizon86
发布于
2025年1月4日
许可协议