0

0

PyTorch模型训练准确率异常:常见评估逻辑错误与修正方法

霞舞

霞舞

发布时间:2025-12-02 12:49:11

|

136人浏览过

|

来源于php中文网

原创

PyTorch模型训练准确率异常:常见评估逻辑错误与修正方法

本文针对pytorch模型训练中准确率异常低的问题进行深入探讨。核心原因在于模型评估阶段对正确预测数目的累加逻辑存在错误,以及对模型输入张量进行了不当展平。文章将详细解析这些常见陷阱,提供正确的代码修正方案,确保模型性能评估的准确性,帮助开发者有效诊断并解决训练过程中的此类问题。

1. 问题现象与初步分析

在PyTorch模型训练过程中,开发者有时会遇到模型准确率始终处于极低水平,甚至低于随机猜测的情况,即使调整了批量大小、网络层数、迭代次数和学习率等超参数也无济于事。这种现象往往令人困惑,因为模型结构和数据加载看似正常。实际上,这通常不是模型本身无法学习,而是模型评估逻辑或数据预处理阶段存在缺陷,导致模型性能被错误地衡量。

2. 核心问题诊断:准确率累加错误

通过对提供的代码进行分析,导致模型准确率异常的核心问题之一在于测试阶段对正确预测样本数 n_correct 的累加方式不正确。

在模型测试循环中,计算每个批次的正确预测数目的原始代码如下:

# ... (在测试循环内部)
n_correct = (predictions == labels).sum().item()

这行代码的问题在于,它在每次迭代时都会重新赋值 n_correct,而不是将其与之前批次的正确预测数累加。这意味着 n_correct 最终只会保存最后一个批次的正确预测数,而不是整个测试集上的总和。因此,最终计算出的准确率将是基于单个批次而非整个数据集的,从而导致结果极低且不准确。

修正方法: 要解决此问题,只需将 n_correct 的赋值操作改为累加操作,确保在循环外部初始化 n_correct,并在循环内部使用 += 进行累加:

# ... (在测试循环内部)
n_correct += (predictions == labels).sum().item()

3. 潜在的数据形状处理错误:模型输入展平

除了 n_correct 的累加错误外,代码中还存在一个更基础但同样关键的潜在问题,即对模型输入 inputs 的不当 torch.flatten 操作。

PaperAiBye
PaperAiBye

支持近30多种语言降ai降重,并且支持多种语言免费测句子的ai率,支持英文aigc报告等

下载

在训练循环和测试循环中,都出现了以下代码:

# ... (在训练或测试循环内部)
inputs = torch.flatten(inputs)

假设 input_size 被定义为5,且 DataLoader 提供的 inputs 形状为 (batch_size, 5)。对于 nn.Linear(input_size, hidden_size) 这样的全连接层,它期望的输入是 (batch_size, input_size)。如果对 inputs 进行 torch.flatten 操作,其形状将变为 (batch_size * 5)。这将导致 nn.Linear 层接收到的特征维度与 input_size 不匹配,或者在PyTorch内部进行不正确的自动调整,使模型接收到一个被展平的、不再具有原始特征结构的张量,从而无法进行有效的学习。

修正方法: 应移除训练和测试循环中对 inputs 的 torch.flatten(inputs) 操作。模型期望的输入形状通常由 nn.Linear 层的 in_features 参数决定,即 (batch_size, input_size)。

# 移除此行:
# inputs = torch.flatten(inputs)

关于 labels 的形状处理: 代码中对 labels 也进行了 torch.flatten(labels) 操作。 由于 SDSS 数据集类中 self.y_data 的定义是 xy[:, [0]],这会使得 labels 的原始形状为 (batch_size, 1)。 nn.CrossEntropyLoss 期望的 target 形状是 (batch_size)(对于多分类问题,直接是类别索引)。因此,将 (batch_size, 1) 展平为 (batch_size) 是正确的处理方式。更语义化的做法是使用 labels = labels.squeeze(1),它明确表示移除维度为1的单维度。

4. 修正后的代码示例

以下是模型训练和测试循环中经过修正的关键部分代码:

import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader

# device config
device = torch.device('cpu') # 示例使用CPU,若有GPU可改为'cuda'

input_size = 5
hidden_size = 10
num_classes = 3
num_epochs = 100
batch_size = 10
learning_rate = 0.0001

class SDSS(Dataset):
    def __init__(self):
        xy = np.loadtxt('SDSS.csv', delimiter=',', dtype=np.float32, skiprows=0)
        self.n_samples = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:, 1:]) # size [n_samples, n_features]
        self.y_data = torch.from_numpy(xy[:, [

相关专题

更多
pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

431

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

21

2025.12.22

高德地图升级方法汇总
高德地图升级方法汇总

本专题整合了高德地图升级相关教程,阅读专题下面的文章了解更多详细内容。

4

2026.01.16

全民K歌得高分教程大全
全民K歌得高分教程大全

本专题整合了全民K歌得高分技巧汇总,阅读专题下面的文章了解更多详细内容。

1

2026.01.16

C++ 单元测试与代码质量保障
C++ 单元测试与代码质量保障

本专题系统讲解 C++ 在单元测试与代码质量保障方面的实战方法,包括测试驱动开发理念、Google Test/Google Mock 的使用、测试用例设计、边界条件验证、持续集成中的自动化测试流程,以及常见代码质量问题的发现与修复。通过工程化示例,帮助开发者建立 可测试、可维护、高质量的 C++ 项目体系。

10

2026.01.16

java数据库连接教程大全
java数据库连接教程大全

本专题整合了java数据库连接相关教程,阅读专题下面的文章了解更多详细内容。

33

2026.01.15

Java音频处理教程汇总
Java音频处理教程汇总

本专题整合了java音频处理教程大全,阅读专题下面的文章了解更多详细内容。

15

2026.01.15

windows查看wifi密码教程大全
windows查看wifi密码教程大全

本专题整合了windows查看wifi密码教程大全,阅读专题下面的文章了解更多详细内容。

42

2026.01.15

浏览器缓存清理方法汇总
浏览器缓存清理方法汇总

本专题整合了浏览器缓存清理教程汇总,阅读专题下面的文章了解更多详细内容。

7

2026.01.15

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
誉天教育RHCE视频教程
誉天教育RHCE视频教程

共9课时 | 1.4万人学习

尚观Linux RHCE视频教程(二)
尚观Linux RHCE视频教程(二)

共34课时 | 5.7万人学习

尚观RHCE视频教程(一)
尚观RHCE视频教程(一)

共28课时 | 4.8万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号