0

0

PyTorch模型训练准确率异常:常见评估逻辑错误与修正方法

霞舞

霞舞

发布时间:2025-12-02 12:49:11

|

136人浏览过

|

来源于php中文网

原创

PyTorch模型训练准确率异常:常见评估逻辑错误与修正方法

本文针对pytorch模型训练中准确率异常低的问题进行深入探讨。核心原因在于模型评估阶段对正确预测数目的累加逻辑存在错误,以及对模型输入张量进行了不当展平。文章将详细解析这些常见陷阱,提供正确的代码修正方案,确保模型性能评估的准确性,帮助开发者有效诊断并解决训练过程中的此类问题。

1. 问题现象与初步分析

在PyTorch模型训练过程中,开发者有时会遇到模型准确率始终处于极低水平,甚至低于随机猜测的情况,即使调整了批量大小、网络层数、迭代次数和学习率等超参数也无济于事。这种现象往往令人困惑,因为模型结构和数据加载看似正常。实际上,这通常不是模型本身无法学习,而是模型评估逻辑或数据预处理阶段存在缺陷,导致模型性能被错误地衡量。

2. 核心问题诊断:准确率累加错误

通过对提供的代码进行分析,导致模型准确率异常的核心问题之一在于测试阶段对正确预测样本数 n_correct 的累加方式不正确。

在模型测试循环中,计算每个批次的正确预测数目的原始代码如下:

# ... (在测试循环内部)
n_correct = (predictions == labels).sum().item()

这行代码的问题在于,它在每次迭代时都会重新赋值 n_correct,而不是将其与之前批次的正确预测数累加。这意味着 n_correct 最终只会保存最后一个批次的正确预测数,而不是整个测试集上的总和。因此,最终计算出的准确率将是基于单个批次而非整个数据集的,从而导致结果极低且不准确。

九歌
九歌

九歌--人工智能诗歌写作系统

下载

修正方法: 要解决此问题,只需将 n_correct 的赋值操作改为累加操作,确保在循环外部初始化 n_correct,并在循环内部使用 += 进行累加:

# ... (在测试循环内部)
n_correct += (predictions == labels).sum().item()

3. 潜在的数据形状处理错误:模型输入展平

除了 n_correct 的累加错误外,代码中还存在一个更基础但同样关键的潜在问题,即对模型输入 inputs 的不当 torch.flatten 操作。

在训练循环和测试循环中,都出现了以下代码:

# ... (在训练或测试循环内部)
inputs = torch.flatten(inputs)

假设 input_size 被定义为5,且 DataLoader 提供的 inputs 形状为 (batch_size, 5)。对于 nn.Linear(input_size, hidden_size) 这样的全连接层,它期望的输入是 (batch_size, input_size)。如果对 inputs 进行 torch.flatten 操作,其形状将变为 (batch_size * 5)。这将导致 nn.Linear 层接收到的特征维度与 input_size 不匹配,或者在PyTorch内部进行不正确的自动调整,使模型接收到一个被展平的、不再具有原始特征结构的张量,从而无法进行有效的学习。

修正方法: 应移除训练和测试循环中对 inputs 的 torch.flatten(inputs) 操作。模型期望的输入形状通常由 nn.Linear 层的 in_features 参数决定,即 (batch_size, input_size)。

# 移除此行:
# inputs = torch.flatten(inputs)

关于 labels 的形状处理: 代码中对 labels 也进行了 torch.flatten(labels) 操作。 由于 SDSS 数据集类中 self.y_data 的定义是 xy[:, [0]],这会使得 labels 的原始形状为 (batch_size, 1)。 nn.CrossEntropyLoss 期望的 target 形状是 (batch_size)(对于多分类问题,直接是类别索引)。因此,将 (batch_size, 1) 展平为 (batch_size) 是正确的处理方式。更语义化的做法是使用 labels = labels.squeeze(1),它明确表示移除维度为1的单维度。

4. 修正后的代码示例

以下是模型训练和测试循环中经过修正的关键部分代码:

import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader

# device config
device = torch.device('cpu') # 示例使用CPU,若有GPU可改为'cuda'

input_size = 5
hidden_size = 10
num_classes = 3
num_epochs = 100
batch_size = 10
learning_rate = 0.0001

class SDSS(Dataset):
    def __init__(self):
        xy = np.loadtxt('SDSS.csv', delimiter=',', dtype=np.float32, skiprows=0)
        self.n_samples = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:, 1:]) # size [n_samples, n_features]
        self.y_data = torch.from_numpy(xy[:, [

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

465

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

27

2025.12.22

Rust内存安全机制与所有权模型深度实践
Rust内存安全机制与所有权模型深度实践

本专题围绕 Rust 语言核心特性展开,深入讲解所有权机制、借用规则、生命周期管理以及智能指针等关键概念。通过系统级开发案例,分析内存安全保障原理与零成本抽象优势,并结合并发场景讲解 Send 与 Sync 特性实现机制。帮助开发者真正理解 Rust 的设计哲学,掌握在高性能与安全性并重场景中的工程实践能力。

2

2026.03.05

PHP高性能API设计与Laravel服务架构实践
PHP高性能API设计与Laravel服务架构实践

本专题围绕 PHP 在现代 Web 后端开发中的高性能实践展开,重点讲解基于 Laravel 框架构建可扩展 API 服务的核心方法。内容涵盖路由与中间件机制、服务容器与依赖注入、接口版本管理、缓存策略设计以及队列异步处理方案。同时结合高并发场景,深入分析性能瓶颈定位与优化思路,帮助开发者构建稳定、高效、易维护的 PHP 后端服务体系。

58

2026.03.04

AI安装教程大全
AI安装教程大全

2026最全AI工具安装教程专题:包含各版本AI绘图、AI视频、智能办公软件的本地化部署手册。全篇零基础友好,附带最新模型下载地址、一键安装脚本及常见报错修复方案。每日更新,收藏这一篇就够了,让AI安装不再报错!

30

2026.03.04

Swift iOS架构设计与MVVM模式实战
Swift iOS架构设计与MVVM模式实战

本专题聚焦 Swift 在 iOS 应用架构设计中的实践,系统讲解 MVVM 模式的核心思想、数据绑定机制、模块拆分策略以及组件化开发方法。内容涵盖网络层封装、状态管理、依赖注入与性能优化技巧。通过完整项目案例,帮助开发者构建结构清晰、可维护性强的 iOS 应用架构体系。

59

2026.03.03

C++高性能网络编程与Reactor模型实践
C++高性能网络编程与Reactor模型实践

本专题围绕 C++ 在高性能网络服务开发中的应用展开,深入讲解 Socket 编程、多路复用机制、Reactor 模型设计原理以及线程池协作策略。内容涵盖 epoll 实现机制、内存管理优化、连接管理策略与高并发场景下的性能调优方法。通过构建高并发网络服务器实战案例,帮助开发者掌握 C++ 在底层系统与网络通信领域的核心技术。

25

2026.03.03

Golang 测试体系与代码质量保障:工程级可靠性建设
Golang 测试体系与代码质量保障:工程级可靠性建设

Go语言测试体系与代码质量保障聚焦于构建工程级可靠性系统。本专题深入解析Go的测试工具链(如go test)、单元测试、集成测试及端到端测试实践,结合代码覆盖率分析、静态代码扫描(如go vet)和动态分析工具,建立全链路质量监控机制。通过自动化测试框架、持续集成(CI)流水线配置及代码审查规范,实现测试用例管理、缺陷追踪与质量门禁控制,确保代码健壮性与可维护性,为高可靠性工程系统提供质量保障。

79

2026.02.28

Golang 工程化架构设计:可维护与可演进系统构建
Golang 工程化架构设计:可维护与可演进系统构建

Go语言工程化架构设计专注于构建高可维护性、可演进的企业级系统。本专题深入探讨Go项目的目录结构设计、模块划分、依赖管理等核心架构原则,涵盖微服务架构、领域驱动设计(DDD)在Go中的实践应用。通过实战案例解析接口抽象、错误处理、配置管理、日志监控等关键工程化技术,帮助开发者掌握构建稳定、可扩展Go应用的最佳实践方法。

61

2026.02.28

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
誉天教育RHCE视频教程
誉天教育RHCE视频教程

共9课时 | 1.5万人学习

尚观Linux RHCE视频教程(二)
尚观Linux RHCE视频教程(二)

共34课时 | 6万人学习

尚观RHCE视频教程(一)
尚观RHCE视频教程(一)

共28课时 | 4.9万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号