0

0

PyTorch模型训练不收敛?精度计算错误排查与修正

花韻仙語

花韻仙語

发布时间:2025-12-03 08:42:13

|

649人浏览过

|

来源于php中文网

原创

PyTorch模型训练不收敛?精度计算错误排查与修正

本文针对pytorch模型训练中出现的精度不提升、甚至低于随机猜测的问题进行深入探讨。核心原因在于模型评估阶段对正确预测数量的累积方式不当,以及输入数据维度处理和标签维度的潜在错误。教程将详细分析这些常见错误,并提供正确的代码实现,帮助开发者有效诊断并解决模型训练效果不佳的困境,确保模型评估的准确性。

引言:模型训练精度不佳的困扰

深度学习模型开发过程中,开发者常常会遇到模型训练效果不理想,甚至在经过数百个Epoch后,其在测试集上的表现仍然不如随机猜测的情况。面对这种问题,许多人会首先尝试调整批量大小、网络层数、学习率或Epoch数量等超参数,但往往徒劳无功。这不仅令人沮丧,也阻碍了模型的进一步优化。实际上,有时问题的根源并非出在模型架构或训练算法本身,而在于数据处理流程、维度匹配或最常见的——模型性能评估逻辑中的细微错误。

原始代码分析与潜在问题诊断

为了更好地理解并解决这类问题,我们将对提供的PyTorch代码进行逐一分析,并指出其中可能导致模型精度异常的关键点。

1. 数据加载与预处理

代码中定义了SDSS和testSDSS两个几乎完全相同的数据集类,都从SDSS.csv文件中加载数据。

  • 重复定义: 这种重复定义是不必要的。通常,一个Dataset类足以,然后通过torch.utils.data.random_split或手动加载不同文件来创建训练集和测试集。
  • 数据维度:
    • self.x_data = torch.from_numpy(xy[:, 1:]):特征数据,形状为 [n_samples, n_features] (例如 [N, 5])。
    • self.y_data = torch.from_numpy(xy[:, [0]]):标签数据,形状为 [n_samples, 1] (例如 [N, 1])。这种形状对于分类任务的CrossEntropyLoss而言,通常需要进一步处理为 [n_samples]。

2. 模型架构

NeuralNet是一个简单的两层全连接网络,中间使用LeakyReLU激活函数,输出层直接连接到num_classes,没有显式的softmax,这与nn.CrossEntropyLoss的内部实现是兼容的。

3. 训练循环中的关键问题

训练循环遵循了PyTorch的标准模式:前向传播、损失计算、反向传播、优化器更新。然而,其中存在几个关键的维度处理问题:

  • 问题A:输入数据维度处理错误 inputs = torch.flatten(inputs) 在训练循环的每次迭代中,都执行了 inputs = torch.flatten(inputs)。

    • 解析: DataLoader通常会输出形状为 [batch_size, input_size] 的批次数据。如果 input_size 为5,那么 inputs 的形状是 [batch_size, 5]。对其执行 torch.flatten(inputs) 会将其转换为 [batch_size * 5]。然而,模型中的第一个线性层 self.l1 = nn.Linear(input_size, hidden_size) 期望的输入是 [batch_size, input_size]。这种维度不匹配会导致模型在接收数据时出现问题,或者更隐蔽地,模型会尝试将 [batch_size * 5] 解释为 [batch_size, 5],从而导致训练失败或模型无法学习。正确做法是移除这行代码。
  • 问题B:标签数据维度处理 labels = torch.flatten(labels) 在训练循环中,标签也被 flatten 处理。

    • 解析: nn.CrossEntropyLoss 期望分类任务的标签是形状为 [batch_size] 的1D张量,其中包含类别索引(0到num_classes-1)。由于原始标签 self.y_data 是 [n_samples, 1],经过 DataLoader 批处理后为 [batch_size, 1]。对其执行 flatten 确实可以将其转换为 [batch_size]。但是,更语义化且更推荐的做法是使用 labels.squeeze(1) 来移除单维度。
  • 问题C:训练步数统计 n_total_steps = len(dataset) 在训练循环外,n_total_steps 被设置为 len(dataset),即总样本数。

    • 解析: 在 print 语句中,step {i+1}/{n_total_steps} 中的 n_total_steps 应该表示批次的总数量,而非样本总数。正确的应该是 len(data_loader)。

4. 评估循环:精度计算的核心错误

这是导致模型精度“低于随机猜测”的最主要原因。

Whimsical
Whimsical

Whimsical推出的AI思维导图工具

下载
  • 错误代码:

    # ...
    n_correct = 0
    n_samples = 0
    for inputs, labels in test_loader:
        # ...
        _, predictions = torch.max(outputs, 1)
        n_samples += labels.shape[0]
        n_correct = (predictions == labels).sum().item() # 核心错误所在!
    
    acc = 100 * n_correct / n_samples
    print(f'accuracy = {acc}')
  • 问题解析: 在 for 循环中,每次迭代处理一个批次时,n_correct = (predictions == labels).sum().item() 这行代码都会重新赋值 n_correct。这意味着 n_correct 变量在每次批次处理后都会被当前批次的正确预测数覆盖,而不是将所有批次的正确预测数累加起来。最终,当循环结束时,n_correct 将只包含最后一个批次的正确预测数量。因此,计算出的 acc 值将是极低的,完全无法反映模型在整个测试集上的真实性能。

解决方案与代码修正

针对上述诊断出的问题,我们将进行以下修正:

  1. 修正精度累积: 将评估循环中的 n_correct = ... 改为 n_correct += ...。
  2. 修正输入维度: 移除训练和评估循环中对 inputs 的 flatten 操作。
  3. 修正标签维度: 将训练和评估循环中对 labels 的 flatten 操作改为 labels.squeeze(1)。
  4. 修正训练步数统计: 将 n_total_steps 的计算方式改为 len(data_loader)。
  5. 优化数据集类: 建议合并SDSS和testSDSS,并使用random_split进行数据集划分。

以下是修正后的关键代码片段:

import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.optim

# 设备配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 优先使用GPU

# 超参数
input_size = 5
hidden_size = 10
num_classes = 3
num_epochs = 100
batch_size = 10
learning_rate = 0.0001

# 修正后的Dataset类,支持训练/测试划分
class SDSSDataset(Dataset):
    def __init__(self, filepath='SDSS.csv'):
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32, skiprows=0)
        self.n_samples = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:, 1:]) # 特征
        self.y_data = torch.from_numpy(xy[:, [0]]) # 标签,形状 [n_samples, 1]

    def __getitem__(self, index):
        # 返回特征和标签,标签在这里不做squeeze,留给DataLoader之后处理
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.n_samples

# 加载数据集并进行训练/测试集划分
full_dataset = SDSSDataset()
train_size = int(0.8 * len(full_dataset)) # 80% 用于训练
test_size = len(full_dataset) - train_size # 20% 用于测试
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

data_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

# 模型定义 (不变)
class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(NeuralNet,self).__init__()
        self.l1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.LeakyReLU()
        self.l2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.l1(x)
        out = self.relu(out)
        out = self.l2(out)
        return out

model = NeuralNet(input_size, hidden_size, num_classes).to(device) # 将模型移到设备

# 损失函数和优化器 (不变)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr

相关专题

更多
python中print函数的用法
python中print函数的用法

python中print函数的语法是“print(value1, value2, ..., sep=' ', end=' ', file=sys.stdout, flush=False)”。本专题为大家提供print相关的文章、下载、课程内容,供大家免费下载体验。

184

2023.09.27

页面置换算法
页面置换算法

页面置换算法是操作系统中用来决定在内存中哪些页面应该被换出以便为新的页面提供空间的算法。本专题为大家提供页面置换算法的相关文章,大家可以免费体验。

400

2023.08.14

pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

431

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

19

2025.12.22

Java 桌面应用开发(JavaFX 实战)
Java 桌面应用开发(JavaFX 实战)

本专题系统讲解 Java 在桌面应用开发领域的实战应用,重点围绕 JavaFX 框架,涵盖界面布局、控件使用、事件处理、FXML、样式美化(CSS)、多线程与UI响应优化,以及桌面应用的打包与发布。通过完整示例项目,帮助学习者掌握 使用 Java 构建现代化、跨平台桌面应用程序的核心能力。

34

2026.01.14

php与html混编教程大全
php与html混编教程大全

本专题整合了php和html混编相关教程,阅读专题下面的文章了解更多详细内容。

14

2026.01.13

PHP 高性能
PHP 高性能

本专题整合了PHP高性能相关教程大全,阅读专题下面的文章了解更多详细内容。

33

2026.01.13

MySQL数据库报错常见问题及解决方法大全
MySQL数据库报错常见问题及解决方法大全

本专题整合了MySQL数据库报错常见问题及解决方法,阅读专题下面的文章了解更多详细内容。

18

2026.01.13

PHP 文件上传
PHP 文件上传

本专题整合了PHP实现文件上传相关教程,阅读专题下面的文章了解更多详细内容。

12

2026.01.13

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
10分钟--Midjourney创作自己的漫画
10分钟--Midjourney创作自己的漫画

共1课时 | 0.1万人学习

Midjourney 关键词系列整合
Midjourney 关键词系列整合

共13课时 | 0.9万人学习

AI绘画教程
AI绘画教程

共2课时 | 0.2万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号