最近在刷开源的Pytorch版动手学深度学习,里面谈到几个高级选择函数,如index_select,masked_select,gather等。这些函数大多很容易理解,可是对于gather函数,确实有些难理解,官方文档开始也看得一脸懵,感受不太直观。下面谈谈我对这几个函数的一些理解。数组
对于numpy和pytorch,其数组在作维度运算上刚开始可能会给人一种直观上的误解,以numpy求矩阵某个维度的最大值为例(pytorch的理解也是同样的)函数
import numpy as np a = np.arange(1, 13).reshape(3, 4) """ result: a = [[1, 2, 3, 4], [5, 6, 7, 8,], [9, 10, 11, 12]] """ # 对a维度0求最大值 a.max(axis = 0) """ result: [9, 10, 11, 12] """ # 对a维度1求最大值 a.max(axis = 1) """ result: [4, 8, 12] """
若是对a矩阵在维度0上找最大值,根据咱们直观上的经验应该是[4, 8, 12]。即从[1, 2, 3, 4]找到4,从[5, 6, 7, 8]找到8,从[9, 10, 11, 12]找到12。可是从上面结果来看,numpy运算却给了咱们直观上认为是列最大值的结果[9, 10, 11, 12]。
实际numpy(pytorch)运算应该理解为往给定的维度进行移动运算。仍是以维度0为例,维度0上有3个向量,分别为[1, 2, 3, 4],[5, 6, 7, 8]和[9, 10, 11, 12]。往维度0移动,即[1, 2, 3, 4]和[5, 6, 7, 8]逐元素计算最大值,获得[5, 6, 7, 8],再和[9, 10, 11, 12]运算获得结果[9, 10, 11, 12]。学习
pytorch和numpy中许多函数都涉及维度运算,gather
也不例外,可是它相对于其余函数更难理解。依然先来看一个例子code
import torch a = torch.arange(1, 16).reshape(5, 3) """ result: a = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]] """ # 定义两个index b = torch.tensor([[0, 1, 2], [2, 3, 4], [0, 2, 4]]) c = torch.tensor([[1, 2, 0, 2, 1], [1, 2, 1, 0, 0]]) # axis=0 output1 = a.gather(0, b) """ result: [[1, 5, 9], [7, 11, 15], [1, 8, 15]] """ # axis=1 output2 = a.gather(1, c) """ result: [[2, 3, 1, 3, 2], [5, 6, 5, 4, 4]] """
上面的例子看起来可能有点复杂,咱们来一步步的分析它,先从gather维度为0开始讲起。blog
a.gather(0, b)
分为3个部分,a
是须要被提取元素的矩阵,0
表明的是提取的维度为0,b
是提取元素的索引
0
除了表明往维度0的方向提取元素外,还有一个特权---提取结果output能够在这个维度上的长度与a不一样。打个比方,a如今的shape为(5, 3),那么提取结果output1的shape能够是(1,3),(2, 3),甚至(n, 3)。具体维度0的长度到底为多少由b来决定。0
的特权,致使了给定的b张量除了维度0外,其余的维度大小必须和a同样。其中张量b
实际上包含如下两个信息
其余的高级选择函数都比较容易理解,这里简单的提一下。torch.index_select主要是根据传入的tensor来往给定的axis方向来选取张量索引
import torch a = torch.arange(9).reshape(3, 3) torch.index_select(a, 0, torch.tensor([0, 2])) """ result: [[0, 1, 2], [6, 7, 8]] """
实际上就是经过掩码条件来选择元素,像torch.masked_select(x, x>0.5),其实是和x[x>0.5]等价的,最后返回的是一维张量文档
import torch a = torch.rand(5, 3) # 结果和a[a > 0.5]等价 torch.masked_select(a, a>0.5)
找到非零元素的index深度学习
import torch a = torch.eye(3) torch.nonzero(a) """ result: 对应着非零元素的index [[0, 0], [1, 1], [2, 2]] """