0

0

使用VGG16进行MNIST手写数字识别的迁移学习

霞舞

霞舞

发布时间:2025-08-19 19:24:25

|

914人浏览过

|

来源于php中文网

原创

使用vgg16进行mnist手写数字识别的迁移学习

本文档旨在指导读者使用VGG16模型进行MNIST手写数字识别的迁移学习。我们将重点介绍如何构建模型、加载预训练权重、以及解决可能遇到的GPU配置问题。通过本文,读者可以掌握利用VGG16进行图像分类任务迁移学习的基本方法,并了解如何调试TensorFlow在GPU上的运行环境。

VGG16迁移学习实现MNIST手写数字识别

迁移学习是一种强大的技术,它允许我们利用在大规模数据集上预训练的模型,并将其应用于新的、通常较小的数据集。 这可以显著减少训练时间和所需的计算资源,同时还能提高模型的性能。 在本文中,我们将使用在ImageNet数据集上预训练的VGG16模型进行MNIST手写数字识别。

数据准备

首先,我们需要加载MNIST数据集并进行预处理。MNIST数据集包含 0 到 9 的手写数字的灰度图像。 为了与 VGG16 兼容,我们需要将灰度图像转换为三通道图像,并将图像大小调整为 VGG16 期望的输入大小。

import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras import layers, models
import numpy as np

# 加载MNIST数据集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# 将灰度图像转换为三通道图像
x_train = np.stack([x_train]*3, axis=-1)
x_test = np.stack([x_test]*3, axis=-1)

# 调整图像大小为 (75, 75, 3)
img_height, img_width = 75, 75
x_train = tf.image.resize(x_train, (img_height, img_width)).numpy()
x_test = tf.image.resize(x_test, (img_height, img_width)).numpy()

# 归一化像素值
x_train = x_train / 255.0
x_test = x_test / 255.0

构建VGG16迁移学习模型

接下来,我们将创建一个基于VGG16的迁移学习模型。我们将加载预训练的VGG16模型,移除其顶层(分类层),并添加我们自己的自定义分类层。

class VGG16TransferLearning(tf.keras.Model):
  def __init__(self, base_model, models):
    super(VGG16TransferLearning, self).__init__()
    #base model
    self.base_model = base_model

   # other layers
    self.flatten = tf.keras.layers.Flatten()
    self.dense1 = tf.keras.layers.Dense(512, activation='relu')
    self.dense2 = tf.keras.layers.Dense(512, activation='relu')
    self.dense3 = tf.keras.layers.Dense(10)
    self.layers_list = [self.flatten, self.dense1, self.dense2, self.dense3]

    #instantiate the base model with other layers
    self.model = models.Sequential(
      [self.base_model, *self.layers_list]
    )

  def call(self, *args, **kwargs):
    activation_list = []
    out = args[0]

    for layer in self.model.layers:
      out = layer(out)
      activation_list.append(out)
    if kwargs.get('training', False):
      return out
    else:
      prob = tf.nn.softmax(out)
      return out, prob
# 加载预训练的VGG16模型,不包含顶层
base_model = VGG16(weights="imagenet", include_top=False, input_shape=x_train[0].shape)

# 冻结VGG16模型的权重,防止在训练过程中被修改
base_model.trainable = False

# 创建迁移学习模型
model = VGG16TransferLearning(base_model, models)

编译和训练模型

现在,我们可以编译和训练我们的迁移学习模型。

# 编译模型
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(),
          optimizer=tf.keras.optimizers.legacy.Adam(),
          metrics=['accuracy'])

# 训练模型
model.fit(x_train, y_train, epochs=10, validation_data=(x_test, y_test))

GPU配置问题及解决

在训练过程中,可能会遇到Kernel Restarting的问题,这通常是由于GPU配置不正确导致的。即使你的电脑配备了GPU,TensorFlow也可能无法正确识别和使用它。

白果AI论文
白果AI论文

论文AI生成学术工具,真实文献,免费不限次生成论文大纲 10 秒生成逻辑框架,10 分钟产出初稿,智能适配 80+学科。支持嵌入图表公式与合规文献引用

下载

以下是一些可能的解决方法

  1. 检查TensorFlow是否检测到GPU:

    print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

    如果输出为0,则表示TensorFlow没有检测到GPU。

  2. 安装正确的CUDA和cuDNN版本: TensorFlow需要特定版本的CUDA和cuDNN才能正常使用GPU。请参考TensorFlow官方文档,安装与你的TensorFlow版本兼容的CUDA和cuDNN版本。

  3. 设置环境变量: 确保CUDA和cuDNN的路径已添加到环境变量中。

  4. 使用tf.config.experimental.set_memory_growth: 在某些情况下,TensorFlow可能无法正确分配GPU内存。可以使用以下代码限制TensorFlow使用的GPU内存量:

    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
      try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
          tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
      except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

通过以上步骤,可以解决TensorFlow无法识别GPU的问题,从而避免Kernel Restarting错误。

注意事项和总结

  • 数据预处理: 数据预处理对于迁移学习至关重要。确保你的数据格式与预训练模型的要求相匹配。
  • 冻结层: 在迁移学习中,通常会冻结预训练模型的部分或全部层,以防止过度拟合。
  • GPU配置: 确保TensorFlow正确配置并使用GPU,以加快训练速度。
  • 超参数调整: 根据你的数据集和任务,可能需要调整学习率、批大小等超参数。

通过本文,我们学习了如何使用VGG16模型进行MNIST手写数字识别的迁移学习。我们还讨论了可能遇到的GPU配置问题以及如何解决它们。 迁移学习是一种强大的技术,可以显著提高模型的性能并减少训练时间。 通过掌握本文介绍的技术,你可以将迁移学习应用于各种图像分类任务。

相关专题

更多
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

23

2025.12.22

Python 深度学习框架与TensorFlow入门
Python 深度学习框架与TensorFlow入门

本专题深入讲解 Python 在深度学习与人工智能领域的应用,包括使用 TensorFlow 搭建神经网络模型、卷积神经网络(CNN)、循环神经网络(RNN)、数据预处理、模型优化与训练技巧。通过实战项目(如图像识别与文本生成),帮助学习者掌握 如何使用 TensorFlow 开发高效的深度学习模型,并将其应用于实际的 AI 问题中。

23

2026.01.07

Java JVM 原理与性能调优实战
Java JVM 原理与性能调优实战

本专题系统讲解 Java 虚拟机(JVM)的核心工作原理与性能调优方法,包括 JVM 内存结构、对象创建与回收流程、垃圾回收器(Serial、CMS、G1、ZGC)对比分析、常见内存泄漏与性能瓶颈排查,以及 JVM 参数调优与监控工具(jstat、jmap、jvisualvm)的实战使用。通过真实案例,帮助学习者掌握 Java 应用在生产环境中的性能分析与优化能力。

19

2026.01.20

PS使用蒙版相关教程
PS使用蒙版相关教程

本专题整合了ps使用蒙版相关教程,阅读专题下面的文章了解更多详细内容。

61

2026.01.19

java用途介绍
java用途介绍

本专题整合了java用途功能相关介绍,阅读专题下面的文章了解更多详细内容。

87

2026.01.19

java输出数组相关教程
java输出数组相关教程

本专题整合了java输出数组相关教程,阅读专题下面的文章了解更多详细内容。

39

2026.01.19

java接口相关教程
java接口相关教程

本专题整合了java接口相关内容,阅读专题下面的文章了解更多详细内容。

10

2026.01.19

xml格式相关教程
xml格式相关教程

本专题整合了xml格式相关教程汇总,阅读专题下面的文章了解更多详细内容。

13

2026.01.19

PHP WebSocket 实时通信开发
PHP WebSocket 实时通信开发

本专题系统讲解 PHP 在实时通信与长连接场景中的应用实践,涵盖 WebSocket 协议原理、服务端连接管理、消息推送机制、心跳检测、断线重连以及与前端的实时交互实现。通过聊天系统、实时通知等案例,帮助开发者掌握 使用 PHP 构建实时通信与推送服务的完整开发流程,适用于即时消息与高互动性应用场景。

19

2026.01.19

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
Excel 教程
Excel 教程

共162课时 | 12.5万人学习

Go语言web开发--经典项目电子商城
Go语言web开发--经典项目电子商城

共23课时 | 1.3万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号