🚏pytorch损失函数之nn.CrossEntropyLoss()
2023-2-7
| 2023-2-7
0  |  阅读时长 0 分钟
type
status
password
date
slug
summary
category
URL
tags
icon

nn.CrossEntropyLoss()函数

nn.CrossEntropyLoss()这个损失函数用于多分类问题虽然说的是交叉熵,但是和我理解的交叉熵不一样。nn.CrossEntropyLoss()nn.logSoftmax()nn.NLLLoss()的整合,可以直接使用它来替换网络中的这两个操作。下面我们来看一下计算过程。
  • 首先输入是size是(minibatch,C)。这里的C是类别数。损失函数的计算如下:
    • 损失函数中也有权重weight参数设置,若设置权重,则公式为:
      其他参数不具体说,和nn.BCELoss()设置差不多,默认情况下,对minibatch的loss求均值。
      💡
      注意这里的标签值class,并不参与直接计算,而是作为一个索引,索引对象为实际类别

      举例说明

      我们一共有三种类别,批量大小为1(为了好计算),那么输入size为(1,3),具体值为torch.Tensor([[-0.7715, -0.6205,-0.2562]])。标签值为target = torch.tensor([0]),这里标签值为0,表示属于第0类。loss计算如下:
       
      我们在看看是否等价nn.logSoftmax()和nn.NLLLoss()的整合:
      可以看出nn.LogSoftmax()的对输入的操作就是:x是输入向量。对应上式中的tensor([[-1.3447, -1.1937, -0.8294]])
      而nn.NLLLoss()的操作是:这里没有设置权重,也就是权重默认为1,表示目标类所对应输入x中值,则loss就为
       
    • pytorch
    • 数据目录概述pytorch_lightning模型训练
      Loading...
      目录