欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch__my_doc

程序员文章站 2024-01-30 10:51:22
...

1. torch.cross为叉乘,输出垂直与两个向量的向量。a*b是点乘,生成一个数。

pytorch__my_doc
https://blog.csdn.net/dcrmg/article/details/52416832

2. torch.sign

torch.sign(input, out=None)
说明:符号函数,返回一个新张量,包含输入input张量每个元素的正负(大于0的元素对应1,小于0的元素对应-1,0还是0)
参数:
input(Tensor) – 输入张量
out(Tensor,可选) – 输出张量

>>> a = torch.randn(4)
>>> a
tensor([ 0.7734,  0.5677, -0.3896,  1.9878])
>>> torch.sign(a)
tensor([ 1.,  1., -1.,  1.])

3. torch.masked_fill(self, mask: Tensor, value: Number)

import torch
a=torch.tensor([1,0,2,3])
b=torch.tensor([1,0,3,5])
c=a.masked_fill(mask = torch.ByteTensor([1,1,0,0]), value=torch.tensor(88888))
print('c:',c)
print('a:',a)

d=a.masked_fill_(b<2, 99999)
print('d:',d)
print('a:',a)

###########################输出#############################################
c: tensor([88888, 88888,     2,     3])
a: tensor([1, 0, 2, 3])
d: tensor([99999, 99999,     2,     3])
a: tensor([99999, 99999,     2,     3])

mask为True(1)时,在a张量中的对应索引处曲value值。
但是第一种情况不改变a
第二种情况改变a,不知道为什么

4. torch张量数据类型转换

https://blog.csdn.net/weixin_36670529/article/details/110293966

import torch
 
tensor = torch.randn(2, 2)
print(tensor.type())
 
# torch.long() 将tensor转换为long类型
long_tensor = tensor.long()
print(long_tensor.type())
 
# torch.half()将tensor转换为半精度浮点类型
half_tensor = tensor.half()
print(half_tensor.type())
 
# torch.int()将该tensor转换为int类型
int_tensor = tensor.int()
print(int_tensor.type())
 
# torch.double()将该tensor转换为double类型
double_tensor = tensor.double()
print(double_tensor.type())
 
# torch.float()将该tensor转换为float类型
float_tensor = tensor.float()
print(float_tensor.type())
 
# torch.char()将该tensor转换为char类型
char_tensor = tensor.char()
print(char_tensor.type())
 
# torch.byte()将该tensor转换为byte类型
byte_tensor = tensor.byte()
print(byte_tensor.type())
 
# torch.short()将该tensor转换为short类型
short_tensor = tensor.short()
print(short_tensor.type())
 
 
torch.FloatTensor
torch.LongTensor
torch.HalfTensor
torch.IntTensor
torch.DoubleTensor
torch.FloatTensor
torch.CharTensor
torch.ByteTensor
torch.ShortTensor

5. torch.where

torch.where(condition, x, y):
condition:判断条件
x:若满足条件,则取x中元素
y:若不满足条件,则取y中元素

推荐阅读