torch.nn.CrossEntropyLoss

交叉熵损失,会自动给输出的logits做softmax,然后和真是标签计算交叉熵,然后再取相反数

https://zhuanlan.zhihu.com/p/383044774

CrossEntropyLoss(y_hat, y_truth) = -sum(y_truth_one_hot * log(softmax(y_hat)))

输入的y_hat是(n, C),n是样本数,C是类别数,y_truth是(n,1),表示n个样本真实类别的编号,这个编号会在函数内部被转换成one-hot编码


torch.nn.CrossEntropyLoss
https://jcdu.top/2022/05/17/torch.nn.CrossEntropyLoss/
作者
horizon86
发布于
2022年5月17日
许可协议