0

0

PyTorch CrossEntropyLoss中的数据类型错误解析与最佳实践

花韻仙語

花韻仙語

发布时间:2025-10-26 12:03:14

|

707人浏览过

|

来源于php中文网

原创

PyTorch CrossEntropyLoss中的数据类型错误解析与最佳实践

本文深入探讨了pytorch中`crossentropyloss`常见的`runtimeerror: expected scalar type long but found float`错误。该错误通常源于目标标签(target)的数据类型不符合`crossentropyloss`的预期。我们将详细解析错误原因,并提供如何在训练循环中正确使用`crossentropyloss`,包括标签类型转换、输入顺序以及避免重复应用softmax等关键最佳实践,以确保模型训练的稳定性和准确性。

深度学习的分类任务中,torch.nn.CrossEntropyLoss是一个非常常用的损失函数。它结合了LogSoftmax和负对数似然损失(NLLLoss),能够高效地处理多分类问题。然而,初学者在使用时常会遇到一个特定的运行时错误:RuntimeError: expected scalar type Long but found Float。这个错误明确指出,CrossEntropyLoss在处理其目标标签(target)时,期望的数据类型是torch.Long(即64位整数),但实际接收到的是torch.Float。

理解CrossEntropyLoss的工作原理

CrossEntropyLoss函数在PyTorch中通常接收两个主要参数:

  1. input (或 logits):这是模型的原始输出,通常是未经Softmax激活函数处理的“对数几率”(logits)。它的形状通常是 (N, C),其中 N 是批量大小,C 是类别数量。对于图像任务,如果模型输出是像素级别的分类(如U-Net),则形状可能是 (N, C, H, W)。
  2. target (或 labels):这是真实的类别标签。它应该包含每个样本的类别索引,其数据类型必须是torch.long(或torch.int64)。它的形状通常是 (N),对于像素级别的分类,形状可能是 (N, H, W)。target中的值应介于 0 到 C-1 之间,代表对应的类别索引。

关键点: CrossEntropyLoss内部会自行执行Softmax操作,因此,向其传入经过Softmax处理的概率值是不正确的,这可能导致数值不稳定或不准确的损失计算。

RuntimeError: expected scalar type Long but found Float 错误解析与修正

这个错误的核心在于target张量的数据类型不匹配。在提供的代码片段中,错误发生在以下这行:

loss = criterion(output, labels.float())

尽管labels张量在创建时已经被明确指定为long类型:

labels = Variable(torch.FloatTensor(10).uniform_(0, 120).long())

但在计算损失时,又通过.float()方法将其强制转换回了float类型。这就是导致CrossEntropyLoss抛出错误的原因。

修正方法: 只需移除对labels的.float()调用,确保target张量保持其long类型即可。

# 错误代码
# loss = criterion(output, labels.float())

# 正确代码
loss = criterion(output, labels)

训练循环中的常见误用及修正

除了上述直接的类型转换错误,在提供的train_one_epoch函数中,也存在一些与CrossEntropyLoss使用相关的常见误区。

1. 标签数据类型转换错误

在train_one_epoch函数内部,标签被错误地转换成了float类型:

labels = labels.to(device).float() # 错误:将标签转换为float类型

这会直接导致CrossEntropyLoss接收到float类型的标签,再次触发同样的RuntimeError。

修正方法: 确保标签在送入损失函数前是long类型。

labels = labels.to(device).long() # 正确:将标签转换为long类型

2. CrossEntropyLoss输入参数顺序和类型错误

在train_one_epoch函数中,计算损失的行是:

Videoleap
Videoleap

Videoleap是一个一体化的视频编辑平台

下载
loss = criterion(labels, torch.argmax(outputs, dim=1)) # 错误:参数顺序和类型不符

这里存在两个问题:

  • 参数顺序错误: criterion(即CrossEntropyLoss)期望的第一个参数是模型的输出(logits),第二个参数是真实标签(target)。这里却反了过来。
  • target参数类型错误: torch.argmax(outputs, dim=1) 已经是一个预测结果的类别索引,它不应该作为CrossEntropyLoss的target参数传入。target参数应是真实的、未经模型处理的类别标签。

修正方法: 将模型的原始输出(logits)作为第一个参数,真实的long类型标签作为第二个参数。

3. 预先应用Softmax的错误

在计算outputs时,代码中显式地应用了F.softmax:

outputs = F.softmax(model(inputs.float()), dim=1) # 错误:CrossEntropyLoss内部已包含Softmax

由于CrossEntropyLoss内部已经包含了Softmax操作,再次应用F.softmax会导致:

  • 冗余计算: 增加了不必要的计算开销。
  • 数值稳定性问题: 两次Softmax操作可能导致数值精度下降,尤其是在处理非常大或非常小的对数几率时。

修正方法: 直接将模型的原始输出(logits)传递给CrossEntropyLoss。

优化后的训练函数示例

综合以上修正,以下是train_one_epoch函数的一个优化版本,遵循了CrossEntropyLoss的最佳实践:

import torch
import torch.nn as nn
import torch.nn.functional as F
import time

# 假设 model, optimizer, dataloaders, device 已经定义

def train_one_epoch(model, optimizer, data_loader, device):
    model.train()
    running_loss = 0.0
    start_time = time.time()
    total = 0
    correct = 0

    # 确保 data_loader 是实际的 DataLoader 对象
    # 这里假设 dataloaders['train'] 是一个可迭代的 DataLoader
    current_data_loader = data_loader # 如果传入的是字符串'train',需要根据实际情况获取
    if isinstance(data_loader, str):
        current_data_loader = dataloaders[data_loader] # 假设 dataloaders 是一个全局字典

    for i, (inputs, labels) in enumerate(current_data_loader):
        inputs = inputs.to(device)
        # 核心修正:确保标签是long类型
        labels = labels.to(device).long() 

        optimizer.zero_grad()

        # 修正:直接使用模型的原始输出(logits),不应用Softmax
        # 假设 model(inputs.float()) 返回的是 logits
        logits = model(inputs.float()) 

        # 打印形状以调试
        # print("Inputs shape:", inputs.shape)
        # print("Logits shape:", logits.shape)
        # print("Labels shape:", labels.shape)

        # 修正:CrossEntropyLoss的正确使用方式是 (logits, target_indices)
        loss = criterion(logits, labels) 

        loss.backward()
        optimizer.step()

        # 计算准确率时,需要对logits应用argmax
        _, predicted = torch.max(logits.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total

        running_loss += loss.item()
        if i % 10 == 0:    # print every 10 batches
            batch_time = time.time()
            speed = (i+1)/(batch_time-start_time)
            print('[%5d] loss: %.3f, speed: %.2f, accuracy: %.2f %%' %
                  (i, running_loss, speed, accuracy))
            running_loss = 0.0
            total = 0
            correct = 0

验证模型函数 (val_model) 的注意事项

val_model函数在处理标签时使用了labels = labels.to(device).long(),这是正确的。同时,outputs = model(inputs.float()) 假设模型输出的是logits,然后用 torch.max(outputs.data, 1) 来获取预测类别,这也是标准做法。

唯一需要注意的是,model.val() 应该更正为 model.eval(),这会将模型设置为评估模式,禁用Dropout和BatchNorm等层,以确保评估结果的稳定性。

def val_model(model, data_loader, device): # 添加 device 参数
    model.eval() # 修正:使用 model.eval()
    start_time = time.time()
    total = 0
    correct = 0

    current_data_loader = data_loader
    if isinstance(data_loader, str):
        current_data_loader = dataloaders[data_loader]

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(current_data_loader):
            inputs = inputs.to(device)
            labels = labels.to(device).long() # 正确

            outputs = model(inputs.float()) # 假设 model 输出 logits

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            # 修正:(predicted == labels).sum() 返回一个标量,直接 .item() 即可
            correct += (predicted == labels).sum().item() 
        accuracy = 100 * correct / total

        print('Finished Testing')
        print('Testing accuracy: %.1f %%' %(accuracy))

总结与最佳实践

处理PyTorch中的CrossEntropyLoss时,请牢记以下关键点:

  1. 目标标签的数据类型: CrossEntropyLoss的target参数必须是torch.long类型(即64位整数),且包含类别索引(从0到C-1)。
  2. 模型输出: CrossEntropyLoss的input参数应是模型的原始输出(logits),即未经Softmax激活函数处理的对数几率。
  3. 避免重复Softmax: 不要在将模型输出传递给CrossEntropyLoss之前手动应用F.softmax,因为CrossEntropyLoss内部已经包含了此操作。
  4. 参数顺序: CrossEntropyLoss的调用格式是 loss = criterion(logits, target_labels)。
  5. 评估模式: 在验证或测试模型时,务必使用model.eval()来设置模型为评估模式,并在torch.no_grad()上下文管理器中执行前向传播,以节省内存和计算。

遵循这些原则,可以有效避免RuntimeError: expected scalar type Long but found Float以及其他与CrossEntropyLoss使用相关的常见问题,确保模型训练的顺利进行。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
数据类型有哪几种
数据类型有哪几种

数据类型有整型、浮点型、字符型、字符串型、布尔型、数组、结构体和枚举等。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

310

2023.10.31

php数据类型
php数据类型

本专题整合了php数据类型相关内容,阅读专题下面的文章了解更多详细内容。

222

2025.10.31

css中float用法
css中float用法

css中float属性允许元素脱离文档流并沿其父元素边缘排列,用于创建并排列、对齐文本图像、浮动菜单边栏和重叠元素。想了解更多float的相关内容,可以阅读本专题下面的文章。

580

2024.04.28

C++中int、float和double的区别
C++中int、float和double的区别

本专题整合了c++中int和double的区别,阅读专题下面的文章了解更多详细内容。

102

2025.10.23

C++类型转换方式
C++类型转换方式

本专题整合了C++类型转换相关内容,想了解更多相关内容,请阅读专题下面的文章。

301

2025.07.15

点击input框没有光标怎么办
点击input框没有光标怎么办

点击input框没有光标的解决办法:1、确认输入框焦点;2、清除浏览器缓存;3、更新浏览器;4、使用JavaScript;5、检查硬件设备;6、检查输入框属性;7、调试JavaScript代码;8、检查页面其他元素;9、考虑浏览器兼容性。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

186

2023.11.24

pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

433

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

24

2025.12.22

C++ 设计模式与软件架构
C++ 设计模式与软件架构

本专题深入讲解 C++ 中的常见设计模式与架构优化,包括单例模式、工厂模式、观察者模式、策略模式、命令模式等,结合实际案例展示如何在 C++ 项目中应用这些模式提升代码可维护性与扩展性。通过案例分析,帮助开发者掌握 如何运用设计模式构建高质量的软件架构,提升系统的灵活性与可扩展性。

14

2026.01.30

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
Git 教程
Git 教程

共21课时 | 3.2万人学习

Git版本控制工具
Git版本控制工具

共8课时 | 1.5万人学习

Git中文开发手册
Git中文开发手册

共0课时 | 0人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号