欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

pytorch学习笔记2-张量的操作

程序员文章站 2022-03-22 15:30:27
张量的操作1.张量的拼接与切分(1)张量拼接1. torch.cat():将张按维度dim尽心拼接参数:tensors:张量序列, dim拼接维度t = torch.ones(2, 3)t_0 = torch.cat([t, t], dim=0)t_1 = torch.cat([t, t], dim=1)print("t_0:{} shape:{}\nt_1:{} shape:{}".format(t_0, t_0.shape, t_1, t_1.shape))#运行效果...

张量的操作

1.张量的拼接与切分

(1)张量拼接

1. torch.cat():将张按维度dim尽心拼接

参数:tensors:张量序列, dim拼接维度


t = torch.ones(2, 3)
t_0 = torch.cat([t, t], dim=0)
t_1 = torch.cat([t, t], dim=1)

print("t_0:{} shape:{}\nt_1:{} shape:{}".format(t_0, t_0.shape, t_1, t_1.shape))

#运行效果
t_0:tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]]) shape:torch.Size([4, 3])
t_1:tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]]) shape:torch.Size([2, 6])

2. torch.stack() 创建新的维度,在新的维度上进行拼接,会扩展张量维度

参数:tensors:张量序列, dim拼接维度

t = torch.ones(2, 3)
t_stack = torch.stack([t, t], dim=2)

print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))

#运行效果
t_stack:
tensor([[[1., 1.],
         [1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.],
         [1., 1.]]]) 
shape:
torch.Size([2, 3, 2])

Process finished with exit code 0

(2)张量切分

1. torch.chunk() 按照制定维度dim进行平均切分,若不能整除,最后一份张量小于其他张量

参数: input:要切分张量 chunks:要切分份数 dim:要切分的维度

a = torch.ones(2, 5)
list_of_tensors = torch.chunk(a, dim=1, chunks=2)

for idx, t in enumerate(list_of_tensors):
    print("第{}个张量:{},shape is{}".format(idx+1, t, t.shape))


#运行效果

 第1个张量:tensor([[1., 1., 1.],
        [1., 1., 1.]]),shape istorch.Size([2, 3])
第2个张量:tensor([[1., 1.],
        [1., 1.]]),shape istorch.Size([2, 2])

2. torch.split() 将张量按维度进行切分,可以指定切分长度,可以按列表的形式切分

参数:  tensor:要切分张量  split_size_or_sections:为int时,表示每一份的长度为list时按list元素切分  dim:切分维度

t = torch.ones(2, 5)
# list_of_tensors = torch.split(t, 2, dim=1)
list_of_tensors = torch.split(t, [2, 1, 1, 1], dim=1)
for idx, t in enumerate(list_of_tensors):
    print("第{}个张量:{},shape is{}".format(idx+1, t, t.shape))

#运行效果

第1个张量:tensor([[1., 1.],
        [1., 1.]]),shape istorch.Size([2, 2])
第2个张量:tensor([[1.],
        [1.]]),shape istorch.Size([2, 1])
第3个张量:tensor([[1.],
        [1.]]),shape istorch.Size([2, 1])
第4个张量:tensor([[1.],
        [1.]]),shape istorch.Size([2, 1])

2.张量索引

1. torch.index_select() 在维度dim上,按index索引数据,返回值index索引数据拼接的张量,idx必须为long类型

参数:input:要索引的张量  dim:要索引的维度   index:要索引的数据序号

t = torch.randint(0, 9, size=(3, 3))
idx = torch.tensor([0, 2], dtype=torch.long)
t_select = torch.index_select(t, dim=0, index=idx)
print(idx)
print("t:\n{}\nt_select:\n{}".format(t, t_select))
print(t_select.shape)

#运行效果
tensor([0, 2])
t:
tensor([[0, 0, 1],
        [6, 3, 4],
        [7, 5, 1]])
t_select:
tensor([[0, 0, 1],
        [7, 5, 1]])
torch.Size([2, 3])

2. torch.masked_select() 按照mask中的Ture进行索引,返回一维张量

参数: input:要索引的张量   mask:与input同形状的bool形张量

t = torch.randint(0, 9, (3, 3))
mask = t.ge(5) # 大于等于5的为ture  gt le lt
t_select = torch.masked_select(t, mask)
print("t:\n{}\nmask:\n{}\nsleect:\n{}".format(t, mask, t_select))


#运行结果
t:
tensor([[7, 1, 4],
        [1, 3, 4],
        [2, 5, 5]])
mask:
tensor([[ True, False, False],
        [False, False, False],
        [False,  True,  True]])
sleect:
tensor([7, 5, 5])

3.张量变换

1. torch.reshape() 变换张量形状,改变之后的张量和之前的张量共享地址

参数: input:变换张量  shape:新张量的形状

t = torch.randperm(8)
t_reshape = torch.reshape(t, (2, 4)) # 新的张量的大小要和之前的匹配 例如:1*8=2*4
# t_reshape = torch.reshape(t, (-1, 4)) #-1就不需要关心长度
print("t:\n{}\nt_reshape:\n{}".format(t, t_reshape))

# 此处改变之后的张量和之前的张量共享地址,修改t[0] reshape的第一个元素值也发生变换
t[0] = 1024
print("t:\n{}\nt_reshape:{}\n".format(t, t_reshape))
print("t.data 内存地址:{}".format(id(t.data)))
print("T_reshape.data 内存地址:{}".format(id(t_reshape.data)))

#运行效果

t:
tensor([2, 4, 5, 3, 7, 6, 0, 1])
t_reshape:
tensor([[2, 4, 5, 3],
        [7, 6, 0, 1]])
t:
tensor([1024,    4,    5,    3,    7,    6,    0,    1])
t_reshape:tensor([[1024,    4,    5,    3],
        [   7,    6,    0,    1]])

t.data 内存地址:140201125808720
T_reshape.data 内存地址:140201125808720

Process finished with exit code 0

2.  torch.transpose() 交换张量的两个维度dim0和dim1交换

参数:input:要变换的张量   dim0:要交换的维度     dim1 :要交换的维度

t = torch.rand((2, 3, 4))
t_transpose = torch.transpose(t, dim0=1, dim1=2) #c×h×w  -> h*w*c
#print("t:\n{}\nt_transpose:\n{}".format(t, t_transpose))
print("t.shape:{}\nt_transpose.shape:{}".format(t.shape, t_transpose.shape))


#运行结果
t.shape:torch.Size([2, 3, 4])
t_transpose.shape:torch.Size([2, 4, 3])

3. torch.t() 2维张量转置,对矩阵而言 等价于 torch.transpose(input , 0 ,1)

t = torch.rand((2, 4))
t_t = torch.t(t)
print("t:\n{}\nt_t:\n{}".format(t, t_t))
print("t.shape:{}\nt_t.shape:{}".format(t.shape, t_t.shape))

#运行效果
t:
tensor([[0.4730, 0.9863, 0.2210, 0.2282],
        [0.7951, 0.5498, 0.4361, 0.4551]])
t_t:
tensor([[0.4730, 0.7951],
        [0.9863, 0.5498],
        [0.2210, 0.4361],
        [0.2282, 0.4551]])
t.shape:torch.Size([2, 4])
t_t.shape:torch.Size([4, 2])

 

 

4.  torch.spqueezs() 压缩长度为1的维度 长度为1的维度都被移除了【1,2,3,1】->【2,3】当dim具体为那个维度时,如果该维度长度为1 则被移除,否则不变

t = torch.rand((1, 2, 3, 1))
t_sq = torch.squeeze(t)
t_0 = torch.squeeze(t, dim=0)
t_1 = torch.squeeze(t, dim=1)

print(t.shape)
print(t_sq.shape)
print(t_0.shape)
print(t_1.shape)

#运行结果
torch.Size([1, 2, 3, 1])
torch.Size([2, 3])
torch.Size([2, 3, 1])
torch.Size([1, 2, 3, 1])

5. torch.unsqueeze() 依据dim扩展维度在dim指定出扩展一个长度为1的维度

t = torch.rand((3, 2))
t_unsq = torch.unsqueeze(t, dim=0)
t_2 = torch.unsqueeze(t, dim=1)
print("t: {}\nt_unsq: {}\nt_1: {}".format(t.shape, t_unsq.shape, t_2.shape))

#运行效果
t: torch.Size([3, 2])
t_unsq: torch.Size([1, 3, 2])
t_1: torch.Size([3, 1, 2])

4. 张量数学运算

1. 加减乘除 2.对数指数幂函数 3.三角函数。常规的可以具体使用时查阅,以下介绍三种特殊的

# torch.add 计算 input+alpha*other
# torch.addcdiv 加法结合乘法 out=input+value*(tensor1/tensor2)
# torch.addcmul 加法结合除法 out=input+value*tensor1*tensor2
t_0 = torch.randn((3, 3))
t_1 = torch.ones_like(t_0)
t_add = torch.add(t_0, 10, t_1)
t_ad = torch.addcdiv(t_0, 10, torch.tensor(t_0), torch.tensor(t_1))
t_am = torch.addcmul(t_0, 10, torch.tensor(t_0), torch.tensor(t_1))
print("t_0:\n{}\nt_1:\n{}\nt_add:\n{}".format(t_0, t_1, t_add))
print("t_0:\n{}\nt_1:\n{}\nt_ad:\n{}".format(t_0, t_1, t_ad))
print("t_0:\n{}\nt_1:\n{}\nt_am:\n{}".format(t_0, t_1, t_am))


 

 

本文地址:https://blog.csdn.net/zbr794866300/article/details/110929285

相关标签: pytorch学习