pytorch 01 关于分割任务中 onehot 编码转换的问题
程序员文章站
2022-04-03 10:13:18
在分割任务中,我们拿到的label通常是由数字类别组成的,但是在应用某些损失函数时,我们需要把label转换成 one—hot编码的形式。例如:原始label维度 224*224*1(由数字0-2组成) ,为一个三类别的分割任务,在onehot编码后维度为 224*224*3,(可以看成3张224*224*1的切片)。代码:一:当维度为 N 1 *one-hot后 N C *def make_one_hot(input, num_classes): """Convert ....
在分割任务中,我们拿到的label通常是由数字类别组成的,但是在应用某些损失函数时,我们需要把label转换成 one—hot编码的形式。
例如:原始label维度 224*224*1(由数字0-2组成) ,为一个三类别的分割任务,在onehot编码后维度为 224*224*3,(可以看成3张224*224*1的切片)。
代码:
一:当维度为 N 1 *
one-hot后 N C *
def make_one_hot(input, num_classes):
"""Convert class index tensor to one hot encoding tensor.
Args:
input: A tensor of shape [N, 1, *]
num_classes: An int of number of class
Returns:
A tensor of shape [N, num_classes, *]
"""
shape = np.array(input.shape)
shape[1] = num_classes
shape = tuple(shape)
result = torch.zeros(shape)
result = result.scatter_(1, torch.LongTensor(input), 1)
return result
二:当维度为 1 *
one_hot后 N *
def make_one_hot(input, num_classes):
"""Convert class index tensor to one hot encoding tensor.
Args:
input: A tensor of shape [N, 1, *]
num_classes: An int of number of class
Returns:
A tensor of shape [N, num_classes, *]
"""
shape = np.array(input.shape)
shape[0] = num_classes
shape = tuple(shape)
result = torch.zeros(shape)
result = result.scatter_(0, torch.LongTensor(input), 1)
return result
* 代表图像大小 例如 224 x 224
本文地址:https://blog.csdn.net/wwwww_bw/article/details/107643179