理解pytorch几个高级选择函数(如gather)

1. 引言

  最近在刷开源的Pytorch版动手学深度学习,里面谈到几个高级选择函数,如index_select,masked_select,gather等。这些函数大多很容易理解,可是对于gather函数,确实有些难理解,官方文档开始也看得一脸懵,感受不太直观。下面谈谈我对这几个函数的一些理解。数组

2. 维度的理解

  对于numpy和pytorch,其数组在作维度运算上刚开始可能会给人一种直观上的误解,以numpy求矩阵某个维度的最大值为例(pytorch的理解也是同样的)函数

import numpy as np
a = np.arange(1, 13).reshape(3, 4)
"""
result:
a = [[1, 2, 3, 4],
      [5, 6, 7, 8,],
      [9, 10, 11, 12]]
"""

# 对a维度0求最大值
a.max(axis = 0)
"""
result:
[9, 10, 11, 12]
"""

# 对a维度1求最大值
a.max(axis = 1)
"""
result:
[4, 8, 12]
"""

  若是对a矩阵在维度0上找最大值,根据咱们直观上的经验应该是[4, 8, 12]。即从[1, 2, 3, 4]找到4,从[5, 6, 7, 8]找到8,从[9, 10, 11, 12]找到12。可是从上面结果来看,numpy运算却给了咱们直观上认为是列最大值的结果[9, 10, 11, 12]。
  实际numpy(pytorch)运算应该理解为往给定的维度进行移动运算。仍是以维度0为例,维度0上有3个向量,分别为[1, 2, 3, 4],[5, 6, 7, 8]和[9, 10, 11, 12]。往维度0移动,即[1, 2, 3, 4]和[5, 6, 7, 8]逐元素计算最大值,获得[5, 6, 7, 8],再和[9, 10, 11, 12]运算获得结果[9, 10, 11, 12]。学习

维度运算图1
  另外,对于维度为3的数组,在numpy和pytorch中,应该把维度0理解为通道数,维度1和维度2才是对应高和宽。若是是3维数组对应着用于多输入通道和单输出通道的卷积核(维度为U x V x D),那么4维数组就对应着用于多输入通道和多输出通道的卷积核(维度为U x V x D x P),此时,维度0则为多通道卷积核数量的方向,维度1为通道数,维度2和3才是分别对应高和宽。
维度运算图2

3. gather函数

pytorch和numpy中许多函数都涉及维度运算,gather也不例外,可是它相对于其余函数更难理解。依然先来看一个例子code

import torch
a = torch.arange(1, 16).reshape(5, 3)
"""
result:
a = [[1, 2, 3],
      [4, 5, 6],
      [7, 8, 9],
      [10, 11, 12],
      [13, 14, 15]]
"""

# 定义两个index
b = torch.tensor([[0, 1, 2], [2, 3, 4], [0, 2, 4]])
c = torch.tensor([[1, 2, 0, 2, 1], [1, 2, 1, 0, 0]])

# axis=0
output1 = a.gather(0, b)
"""
result:
[[1, 5, 9],
[7, 11, 15],
[1, 8, 15]]
"""

# axis=1
output2 = a.gather(1, c)
"""
result:
[[2, 3, 1, 3, 2],
[5, 6, 5, 4, 4]]
"""

上面的例子看起来可能有点复杂,咱们来一步步的分析它,先从gather维度为0开始讲起。blog

  1. a.gather(0, b)分为3个部分,a是须要被提取元素的矩阵,0表明的是提取的维度为0,b是提取元素的索引
    • 其中规定b和a是同维张量,即a是2维张量,b也必须是2维张量
  2. 0除了表明往维度0的方向提取元素外,还有一个特权---提取结果output能够在这个维度上的长度与a不一样。打个比方,a如今的shape为(5, 3),那么提取结果output1的shape能够是(1,3),(2, 3),甚至(n, 3)。具体维度0的长度到底为多少由b来决定。
  3. 根据0的特权,致使了给定的b张量除了维度0外,其余的维度大小必须和a同样。其中张量b实际上包含如下两个信息
    • b能够利用除用于gather的维度(此处为维度0)外的维度来定位出惟一一个向量,也就是a[:, ?](三维度也是同理的,有a[:, ?1, ?2]),?的取值范围为a同维度的index。
    • 对于上述定位出的向量,经过b中的元素来定位提取向量中的哪个元素。
    • 上面说得可能有点抽象,实际上b中的每一个元素都能在a中提取出一个元素。举个具体点的例子,按照上面所说的,b[0, 0]能够提取a中的一个元素。对于b[0,0],除了维度0外,能够经过维度1来定位出惟一一个向量a[:, 0]。由于b[0, 0]的元素为0,即提取的是a[:, 0]的第0个元素---1,并将其做为output1[0, 0]的提取结果。
      下图给出了维度0和维度1,gather运算的图示
gather 2维度
对于3维或者更高维度的张量gather的原理也是同样的
gather 2维度

4. index_select函数

其余的高级选择函数都比较容易理解,这里简单的提一下。torch.index_select主要是根据传入的tensor来往给定的axis方向来选取张量索引

import torch
a = torch.arange(9).reshape(3, 3)
torch.index_select(a, 0, torch.tensor([0, 2]))
"""
result:
[[0, 1, 2],
[6, 7, 8]]
"""

5. masked_select函数

实际上就是经过掩码条件来选择元素,像torch.masked_select(x, x>0.5),其实是和x[x>0.5]等价的,最后返回的是一维张量文档

import torch
a = torch.rand(5, 3)

# 结果和a[a > 0.5]等价
torch.masked_select(a, a>0.5)

6. nonzero函数

找到非零元素的index深度学习

import torch
a = torch.eye(3)
torch.nonzero(a)

"""
result: 对应着非零元素的index
[[0, 0],
[1, 1],
[2, 2]]
"""