欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

pytorch 01 关于分割任务中 onehot 编码转换的问题

程序员文章站 2022-04-03 10:13:18
在分割任务中,我们拿到的label通常是由数字类别组成的,但是在应用某些损失函数时,我们需要把label转换成 one—hot编码的形式。例如:原始label维度 224*224*1(由数字0-2组成) ,为一个三类别的分割任务,在onehot编码后维度为 224*224*3,(可以看成3张224*224*1的切片)。代码:一:当维度为 N 1 *one-hot后 N C *def make_one_hot(input, num_classes): """Convert ....

在分割任务中,我们拿到的label通常是由数字类别组成的,但是在应用某些损失函数时,我们需要把label转换成 one—hot编码的形式。

例如:原始label维度 224*224*1(由数字0-2组成) ,为一个三类别的分割任务,在onehot编码后维度为 224*224*3,(可以看成3张224*224*1的切片)。

 

代码:

一:当维度为 N  1 *
one-hot后 N C *

def make_one_hot(input, num_classes):
    """Convert class index tensor to one hot encoding tensor.
    Args:
         input: A tensor of shape [N, 1, *]
         num_classes: An int of number of class
    Returns:
        A tensor of shape [N, num_classes, *]
    """
    shape = np.array(input.shape)
    shape[1] = num_classes
    shape = tuple(shape)
    result = torch.zeros(shape)
    result = result.scatter_(1, torch.LongTensor(input), 1)

    return result
二:当维度为 1 * 
one_hot后 N *
def make_one_hot(input, num_classes):
    """Convert class index tensor to one hot encoding tensor.
    Args:
         input: A tensor of shape [N, 1, *]
         num_classes: An int of number of class
    Returns:
        A tensor of shape [N, num_classes, *]
    """
    shape = np.array(input.shape)
    shape[0] = num_classes
    shape = tuple(shape)
    result = torch.zeros(shape)
    result = result.scatter_(0, torch.LongTensor(input), 1)

    return result

* 代表图像大小 例如 224 x 224

 

本文地址:https://blog.csdn.net/wwwww_bw/article/details/107643179

相关标签: Pytroch