0

0

TensorFlow MNIST手写数字分类:训练集准确率极低,问题出在哪儿?

霞舞

霞舞

发布时间:2025-03-10 11:26:32

|

1012人浏览过

|

来源于php中文网

原创

tensorflow mnist手写数字分类:训练集准确率极低,问题出在哪儿?

TensorFlow MNIST手写数字分类:低训练集准确率的根本原因及修复方案

在使用TensorFlow进行MNIST手写数字分类时,许多开发者会遇到一个难题:即使对训练集和测试集进行了像素归一化,训练集的准确率仍然异常低。本文将深入分析此问题,并结合代码示例提供有效的解决方案。

问题根源在于原始代码中y_pred的计算方式。代码中y_pred = tf.nn.softmax(tf.matmul(X, W) + B)这一行,错误地将softmax函数应用于未经softmax处理的预测结果。tf.nn.softmax_cross_entropy_with_logits函数期望输入的是未经softmax处理的预测值(logits)。原始代码却将softmax后的结果传入该函数,导致交叉熵损失函数计算错误,最终影响模型训练效果,导致训练集准确率极低。

为了解决这个问题,我们需要调整y_pred的计算方式以及准确率的计算方式。正确的做法是在损失函数计算后应用softmax函数获取最终的预测概率,而损失函数计算则使用未经softmax处理的预测值。

AdsGo AI
AdsGo AI

全自动 AI 广告专家,助您在数分钟内完成广告搭建、优化及扩量

下载

修正后的代码如下:

# 导入必要的库
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
import os
import pickle

# 超参数设置
numClasses = 10
inputSize = 784
batch_size = 64
learning_rate = 0.05

# 下载数据集
mnist = input_data.read_data_sets('original_data/', one_hot=True)

train_img = mnist.train.images
train_label = mnist.train.labels
test_img = mnist.test.images
test_label = mnist.test.labels
train_img /= 255.0
test_img /= 255.0


X = tf.compat.v1.placeholder(tf.float32, shape=[None, inputSize])
y = tf.compat.v1.placeholder(tf.float32, shape=[None, numClasses])
W = tf.Variable(tf.random_normal([inputSize, numClasses], stddev=0.1))
B = tf.Variable(tf.constant(0.1), [numClasses])
y_pred = tf.matmul(X, W) + B  # 修正:移除softmax

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_pred)) + 0.01 * tf.nn.l2_loss(W)
opt = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(tf.nn.softmax(y_pred), 1))  # 修正:在计算准确率时应用softmax
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

saver = tf.train.Saver()
multiclass_parameters = {}

# 运行
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    # 开始训练
    for epoch in range(20):
        total_batch = int(len(train_img) / batch_size)

        for batch in range(total_batch):
            batch_input = train_img[batch * batch_size: (batch + 1) * batch_size]
            batch_label = train_label[batch * batch_size: (batch + 1) * batch_size]

            _, trainingLoss = sess.run([opt, loss], feed_dict={X: batch_input, y: batch_label})

        train_acc = sess.run(accuracy, feed_dict={X: train_img, y: train_label})
        print("Epoch %d Training Accuracy %g" % (epoch + 1, train_acc))

通过以上修正,tf.nn.softmax_cross_entropy_with_logits函数能够正确计算损失,模型得以有效训练,最终显著提升训练集准确率。 请注意,在计算最终预测概率时,仍然需要使用tf.nn.softmax函数。

相关专题

更多
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

24

2025.12.22

Python 深度学习框架与TensorFlow入门
Python 深度学习框架与TensorFlow入门

本专题深入讲解 Python 在深度学习与人工智能领域的应用,包括使用 TensorFlow 搭建神经网络模型、卷积神经网络(CNN)、循环神经网络(RNN)、数据预处理、模型优化与训练技巧。通过实战项目(如图像识别与文本生成),帮助学习者掌握 如何使用 TensorFlow 开发高效的深度学习模型,并将其应用于实际的 AI 问题中。

35

2026.01.07

c++ 根号
c++ 根号

本专题整合了c++根号相关教程,阅读专题下面的文章了解更多详细内容。

17

2026.01.23

c++空格相关教程合集
c++空格相关教程合集

本专题整合了c++空格相关教程,阅读专题下面的文章了解更多详细内容。

22

2026.01.23

yy漫画官方登录入口地址合集
yy漫画官方登录入口地址合集

本专题整合了yy漫画入口相关合集,阅读专题下面的文章了解更多详细内容。

91

2026.01.23

漫蛙最新入口地址汇总2026
漫蛙最新入口地址汇总2026

本专题整合了漫蛙最新入口地址大全,阅读专题下面的文章了解更多详细内容。

124

2026.01.23

C++ 高级模板编程与元编程
C++ 高级模板编程与元编程

本专题深入讲解 C++ 中的高级模板编程与元编程技术,涵盖模板特化、SFINAE、模板递归、类型萃取、编译时常量与计算、C++17 的折叠表达式与变长模板参数等。通过多个实际示例,帮助开发者掌握 如何利用 C++ 模板机制编写高效、可扩展的通用代码,并提升代码的灵活性与性能。

14

2026.01.23

php远程文件教程合集
php远程文件教程合集

本专题整合了php远程文件相关教程,阅读专题下面的文章了解更多详细内容。

65

2026.01.22

PHP后端开发相关内容汇总
PHP后端开发相关内容汇总

本专题整合了PHP后端开发相关内容,阅读专题下面的文章了解更多详细内容。

59

2026.01.22

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
Git 教程
Git 教程

共21课时 | 2.9万人学习

Git版本控制工具
Git版本控制工具

共8课时 | 1.5万人学习

Git中文开发手册
Git中文开发手册

共0课时 | 0人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号