新知一下
海量新知
5 9 8 1 5 4 1

5分钟玩转PyTorch | 张量广播计算的本质是什么?

Python绿色通道 | 学Python最好的地方 2021/11/26 15:35

PyTorch

中的张量具有和

NumPy

相同的广播特性,允许不同形状的张量之间进行计算。

广播的实质特性,其实是低维向量映射到高维之后,相同位置再进行相加。 我们重点要学会的就是低维向量如何向高维向量进行映射

相同形状的张量计算

虽然我们觉得不同形状之间的张量计算才是广播,但其实相同形状的张量计算本质上也是广播。

t1 = torch.arange(3)

t1

# tensor([0, 1, 2])

# 对应位置元素相加

t1 + t1

# tensor([0, 2, 4])

与Python对比

如果两个

list

相加,结果是什么?

a = [012]

a + a

# [0, 1, 2, 0, 1, 2]

不同形状的张量计算

广播的特性是不同形状的张量进行计算时,一个或多个张量通过隐式转化成相同形状的两个张量,从而完成计算。

但并非任意两个不同形状的张量都能进行广播,因此我们要掌握广播隐式转化的核心依据。

2.1 标量和任意形状的张量

标量(零维张量)可以和任意形状的张量进行计算,计算过程就是标量和张量的每一个元素进行计算。

# 标量与一维向量

t1 = torch.arange(3)

# tensor([0, 1, 2])

t1 + 1 # 等效于t1 + torch.tensor(1)

# tensor([1, 2, 3])

# 标量与二维向量

t2 = torch.zeros((34))

t2 + 1 # 等效于t2 + torch.tensor(1)

# tensor([[1., 1., 1., 1.],

#         [1., 1., 1., 1.],

#         [1., 1., 1., 1.]])

2.2 相同维度,不同形状张量之间的计算

我们以

t2

为例来探讨相同维度、不同形状的张量之间的广播规则。

t2 = torch.zeros(34)

t2

# tensor([[0., 0., 0., 0.],

#         [0., 0., 0., 0.],

#         [0., 0., 0., 0.]])

t21 = torch.ones(14)

t21

# tensor([[1., 1., 1., 1.]])

它们都是二维矩阵,

t21

的形状是

1×4

t2

的形状是

3×4

,它们在第一个分量上取值不同,但该分量上

t21

取值为1,因此可以进行广播计算:

t2 + t21

# tensor([[1., 1., 1., 1.],

#        [1., 1., 1., 1.],

#        [1., 1., 1., 1.]])

而t2和t21的实际计算过程如下: 新知达人, 5分钟玩转PyTorch | 张量广播计算的本质是什么? 可理解为

t21

的一行与

t2

的三行分别进行了相加。而底层原理为

t21

的形状由

1×4

拓展成了

t2

3×4

,然后二者对应位置进行了相加。

t22 = torch.ones(31)

t22

# tensor([[1.],

#         [1.],

#         [1.]])

t2 + t22

# tensor([[1., 1., 1., 1.],

#         [1., 1., 1., 1.],

#         [1., 1., 1., 1.]])

同理,

t22+t2

t21+t2

结果相同。如果矩阵的两个维度都不相同呢?

t23 = torch.arange(3).reshape(31)

t23

# tensor([[0],

#         [1],

#         [2]])

t24 = torch.arange(3).reshape(13)

# tensor([[0, 1, 2]])

t23 + t24

# tensor([[0, 1, 2],

#         [1, 2, 3],

#         [2, 3, 4]])

此时,

t23

的形状是3×1,而

t24

的形状是

1×3

,二者的形状在两个份量上均不同,但都有1存在,因此可以广播: 新知达人, 5分钟玩转PyTorch | 张量广播计算的本质是什么?

如果两个张量的维度对应数不同且都不为1,那么就无法广播。

t25 = torch.ones(24)

# t2的shape为3×4

t2 + t25

# RuntimeError

高维张量的广播

高维张量的广播原理与低维张量的广播原理一致:

t3 = torch.zeros(234)

t3

# tensor([[[0., 0., 0., 0.],

#          [0., 0., 0., 0.],

#          [0., 0., 0., 0.]],

#         [[0., 0., 0., 0.],

#         [0., 0., 0., 0.],

#         [0., 0., 0., 0.]]])

t31 = torch.ones(231)

t31

# tensor([[[1.],

#          [1.],

#          [1.]],

#         [[1.],

#          [1.],

#          [1.]]])

t3+t31

# tensor([[[1., 1., 1., 1.],

#          [1., 1., 1., 1.],

#          [1., 1., 1., 1.]],

#         [[1., 1., 1., 1.],

#          [1., 1., 1., 1.],

#          [1., 1., 1., 1.]]])

总结

维度相同时,如果对应分量不同,但有一个为1,就可以广播。

不同维度计算中的广播

对于不同维度的张量,我们首先可以将低维的张量升维,然后依据相同维度不同形状的张量广播规则进行广播。

低维向量的升维也非常简单,只需将更高维度方向的形状填充为1即可:

# 创建一个二维向量

t2 = torch.arange(4).reshape(22)

t2

# tensor([[0, 1],

#         [2, 3]])

# 创建一个三维向量

t3 = torch.zeros(322)

t3

t2 + t3

# tensor([[[0., 1.],

#          [2., 3.]],

#         [[0., 1.],

#          [2., 3.]],

#         [[0., 1.],

#          [2., 3.]]])

t3

t2

的相加,就相当于

1×2×2

3×2×2

的两个张量进行计算,广播规则与低维张量一致。

相信看完本节,你已经充分掌握了广播机制的运算规则:

  • 维度相同时,如果对应分量不同,但有一个为1,就可以广播

  • 维度不同时,只需将低维向量的更高维度方向的形状填充为1即可


更多“算法”相关内容

更多“算法”相关内容

新知精选

更多新知精选