pytorch中的squeeze和unsqueeze
程序员文章站
2022-06-15 13:48:54
...
pytorch中的squeeze和unsqueeze
unsqueeze即在参数指定的维度位置,增加一个维度(就是在第几个“[”的位置增加一个“[”)
import torch
a = torch.arange(0,8)
print(a)
b = a.view(2,4)
print(b)
b = b.unsqueeze(1)
print(b)
tensor([0, 1, 2, 3, 4, 5, 6, 7])
tensor([[0, 1, 2, 3],
[4, 5, 6, 7]])
tensor([[[0, 1, 2, 3]],
[[4, 5, 6, 7]]])
```python
import torch
a = torch.arange(0,8)
print(a)
b = a.view(2,4)
print(b)
b = b.unsqueeze(0)
print(b)
tensor([0, 1, 2, 3, 4, 5, 6, 7])
tensor([[0, 1, 2, 3],
[4, 5, 6, 7]])
tensor([[[0, 1, 2, 3],
[4, 5, 6, 7]]])
squeeze即去除一个维度(这个维度只能为1)
import torch
a = torch.arange(0,8)
print(a)
b = a.view(1,2,4)
print(f"b's shape is {b.shape} \n {b}")
b = b.squeeze(-3)
print(f"b's shape is {b.shape} \n {b}")
tensor([0, 1, 2, 3, 4, 5, 6, 7])
b's shape is torch.Size([1, 2, 4])
tensor([[[0, 1, 2, 3],
[4, 5, 6, 7]]])
b's shape is torch.Size([2, 4])
tensor([[0, 1, 2, 3],
[4, 5, 6, 7]])
上一篇: 报错:Caused by: javax.net.ssl.SSLHandshakeException: java.security.cert.CertificateException: java.sec
下一篇: python报错问题
推荐阅读
-
解决Java中的强制类型转换和二进制表示问题
-
C#实现输入10个数存入到数组中并求max和min及平均数的方法示例
-
SQLserver中字符串查找功能patindex和charindex的区别
-
MySQL中的LOCATE和POSITION函数使用方法
-
理解SQL SERVER中的逻辑读,预读和物理读
-
从MySQL全库备份中恢复某个库和某张表的方法
-
android中Invalidate和postInvalidate的更新view区别
-
详解HTML5中div和section以及article的区别
-
老生常谈jquery中detach()和remove()的区别
-
浅谈JavaScript中的apply/call/bind和this的使用