pytorch中tensor操作维度为负的情况
最近在研究nlp,使用pytorch时对torch.nn这个东西的内容有点兴趣,网上搜了下找到这么一篇博文torch.nn到底是什么?这篇虽然标的是他自己的原创,实际上是翻译的pytorch的官方指导WHAT IS TORCH.NN REALLY?。这篇博文介绍了构建神经网络的主要环节,对理解神经网络的实现还是很有帮助的。
在这篇博文中提到了自己构建**函数log_softmax的方法,代码如下
def log_softmax(x):
return x - x.exp().sum(-1).log().unsqueeze(-1)
刚看的时候我也是一脸懵逼,尤其是对sum和unsqueeze这两个函数,一开始完全不明白其含义,后来查过pytorch官方api后大致有些了解,其中x.sum(a)的含义是对张量x在a这个维度上求和,x.unsqueeze(a)很多博文中表述是给指定位置加上维数为一的维度,这个理解比较抽象,我举个例子
比如,有个一维张量,也就是个向量:
tensor([0.0998, 0.2481, 0.1337, 0.6232])
这个张量是一维的,第1维上的元素是0.0998,0.2481,0.1337,0.6232,但其实在pytorch中,他还有个第0维,这个第0维上的元素就是张量自身,也就是[0.0998, 0.2481, 0.1337, 0.6232]这个东西。
所以我实际上可以对一个一维张量在0和1两个维度上进行扩展,代码如下:
a=torch.rand(4)
print(a)
print(a.unsqueeze(0))
print(a.unsqueeze(1))
执行结果是
tensor([0.7183, 0.6250, 0.5066, 0.5037])
tensor([[0.7183, 0.6250, 0.5066, 0.5037]])
tensor([[0.7183],
[0.6250],
[0.5066],
[0.5037]])
首先我们看a.unsqueeze(0)
这个结果tensor([[0.7183, 0.6250, 0.5066, 0.5037]])
,上边提到第0维的元素就是张量自身,而unsqeeze这个函数是在第0维上添加一个维度,直观的理解是将这个维度上的每个元素外边加一层[]
,由原一阶张量[0.7183, 0.6250, 0.5066, 0.5037]
升阶成二阶张量[[0.7183, 0.6250, 0.5066, 0.5037]]
。
而a.unsqueeze(1)
这个操作则是将第一维中的每个元素外加[]
,从而提升维度,也就是[[0.7183],[0.6250],[0.5066],[0.5037]]
。
同时我也明白了为什么unsqueeze
这个函数中,传入的维度不能超过张量的最大维度,因为n维张量不存在超过n的维度。
需要注意的是,在pytorch中tensor的维度是从外向内依次增大的。
sum
函数是对某一维上的元素进行求和,维度处理逻辑和unsqueeze
相似,这里不再赘述。
下面进入主题,sum
和squeeze
这两个函数中,维度参数为负是什么意思。实际上我并不想弄明白负维度的物理含义,我只是想知道,这些函数是怎么处理负维度的。
为了搞明白这个问题,我去GitHub上下载了pytorch的源码,可惜很遗憾,这里并不能看到torch.sum()的内部实现
可能是用的torch的方法?但torch本身使用lua语言,再去看源码太费劲了。
偶然间我找到这篇博文torch.unsqueeze() 和 torch.squeeze(),里边提到一句:
这就很有意思了,我验证了下,以unsqueeze
为例:
a=torch.rand(2,4)
print(a)
print(a.unsqueeze(0))
print(a.unsqueeze(1))
print(a.unsqueeze(2))
print(a.unsqueeze(-1))
执行结果是
tensor([[0.0257, 0.2360, 0.4513, 0.8321],
[0.2121, 0.2098, 0.7794, 0.4347]])
tensor([[[0.0257, 0.2360, 0.4513, 0.8321],
[0.2121, 0.2098, 0.7794, 0.4347]]])
tensor([[[0.0257, 0.2360, 0.4513, 0.8321]],
[[0.2121, 0.2098, 0.7794, 0.4347]]])
tensor([[[0.0257],
[0.2360],
[0.4513],
[0.8321]],
[[0.2121],
[0.2098],
[0.7794],
[0.4347]]])
tensor([[[0.0257],
[0.2360],
[0.4513],
[0.8321]],
[[0.2121],
[0.2098],
[0.7794],
[0.4347]]])
按照上边公式,dim=-1,input.dim()=2,所以-1+2+1=2,可以看到squeeze(-1)
和squeeze(2)
的结果是一样的。
但是如果维数为负的情况下超过张量维数+1,就会报错
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got -10)
sum
函数也类似,有兴趣可以自行验证。
不过到头来我还是没明白为什么开头提到的那篇博文里要用负维度,想破头也没想明白。
上一篇: mybatis学习及原理解析(三)
下一篇: pytorch的tensor除法