0

0

Keras模型训练与评估精度不一致问题解析与解决方案

花韻仙語

花韻仙語

发布时间:2025-08-11 17:20:02

|

1030人浏览过

|

来源于php中文网

原创

Keras模型训练与评估精度不一致问题解析与解决方案

本文深入探讨了Keras模型在训练过程中(model.fit)报告的精度与模型评估(model.evaluate)精度不一致的常见问题。通过分析两者计算机制的差异,特别是批量更新和指标聚合方式,揭示了产生差异的根本原因。文章提供了通过引入validation_data并在自定义回调中监控val_accuracy的解决方案,确保训练过程中的监控指标与最终评估结果保持一致,从而提高模型训练的可靠性和可解释性。

1. 问题现象与初步分析

在使用keras进行模型训练时,我们可能会观察到model.fit在每个epoch结束时打印的accuracy(训练精度)与训练结束后使用model.evaluate在相同训练集上计算得到的精度存在差异。例如,fit报告的精度可能达到1.0,而evaluate的结果却略低于1.0。这种差异尤其在自定义回调函数中依赖logs['accuracy']进行逻辑判断(如提前停止)时,可能导致意外的行为。

造成这种差异的根本原因在于model.fit和model.evaluate计算指标的方式不同:

  • model.fit中的训练精度(accuracy):在每个epoch内,模型会分批次(batch)处理数据并更新权重。model.fit报告的accuracy是该epoch内所有批次精度的平均值。重要的是,每个批次的精度是在该批次数据被处理之前(或在权重更新之后但尚未处理下一个批次之前)计算的。这意味着,对于一个epoch内的不同批次,模型的权重可能在不断变化,因此计算出的精度是基于动态变化的模型状态。当一个epoch结束时,报告的accuracy是整个epoch中,模型在处理各个批次时所达到的平均性能。

  • model.evaluate中的精度:model.evaluate函数在调用时,会使用模型当前的最终权重来对整个数据集进行一次性(或分批次)评估。它不会在评估过程中更新权重。因此,model.evaluate的结果代表了模型在固定权重下的整体性能。

当batch_size较小,或者模型在训练初期权重变化较大时,model.fit报告的平均精度与model.evaluate在最终权重下计算的精度之间就可能出现显著差异。

2. 解决方案:引入验证集与监控val_accuracy

解决这一问题的关键在于,让model.fit在每个epoch结束时,使用当前epoch的最终权重,在一个固定数据集上计算指标。这可以通过fit方法的validation_data参数来实现。即使我们希望在训练集上进行评估以比较,也可以将训练集本身作为validation_data。

当validation_data被提供时,Keras会在每个epoch结束时,使用该epoch的最终模型权重对验证数据进行一次评估,并报告val_loss和val_accuracy等指标。这些val_accuracy值将与model.evaluate在相同数据集上得到的结果更加一致,因为它反映的是模型在固定权重下的表现。

聚好用AI
聚好用AI

可免费AI绘图、AI音乐、AI视频创作,聚集全球顶级AI,一站式创意平台

下载

对于自定义的提前停止回调,也应该监控val_accuracy而不是accuracy。

2.1 示例代码(修正后)

以下是修正后的代码示例,展示了如何通过引入validation_data并调整自定义回调来解决精度不一致问题:

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import keras
import random
from keras import layers
from keras.callbacks import EarlyStopping
from keras.optimizers import Adam


def random_seed(seed_num=1):
    """
    设置随机种子以确保结果可复现。
    """
    np.random.seed(seed_num)
    tf.random.set_seed(seed_num)
    random.seed(seed_num)


class CustomEarlyStopping(keras.callbacks.Callback):
    """
    自定义提前停止回调,根据验证精度停止训练。
    """
    def __init__(self, threshold):
        super().__init__()
        self.threshold = threshold

    def on_epoch_end(self, epoch, logs=None):
        # 监控 'val_accuracy' 而不是 'accuracy'
        accuracy = logs.get("val_accuracy")
        if accuracy is not None and accuracy >= self.threshold:
            print(f"\n达到验证精度阈值 {self.threshold},停止训练。")
            self.model.stop_training = True

# 1. 数据准备
x = np.arange(-20, 30, 0.1)
y = np.zeros_like(x)
df = pd.DataFrame({'x': x, 'y': y})
# 创建一个简单的二分类问题:x < 10 为 0,否则为 1
df.y = df.x.map(lambda x_val: 0 if x_val < 10 else 1)
X_train = df.drop(columns='y')
y_train = df.y

# 2. 模型构建
random_seed() # 设置随机种子
model = keras.Sequential([
    layers.Input(shape=X_train.shape[-1]),
    layers.Normalization(), # 数据归一化层
    layers.Dense(1, activation='relu'), # 第一个全连接层
    layers.Dense(1, activation='sigmoid'), # 输出层,用于二分类
])

# 3. 模型编译
model.compile(
    optimizer=Adam(learning_rate=0.1), # 使用Adam优化器
    loss='binary_crossentropy', # 二元交叉熵损失函数
    metrics=['accuracy'], # 监控精度
)

# 4. 模型训练
history = model.fit(
    X_train, y_train,
    validation_data=(X_train, y_train), # 关键:将训练集也作为验证集
    batch_size=128,
    epochs=300, # 增加epochs以确保模型充分训练
    callbacks=[
        CustomEarlyStopping(1.0) # 使用自定义提前停止回调
    ]
)
history_df = pd.DataFrame(history.history)

# 5. 结果验证
# 获取history中记录的最后一个训练精度(注意这里仍然是训练精度)
last_accuracy_fit = history_df.accuracy.tolist()[-1]
# 获取history中记录的最后一个验证精度
last_accuracy_val = history_df.val_accuracy.tolist()[-1]
# 使用model.evaluate在训练集上进行评估
predict_accuracy = model.evaluate(X_train, y_train, verbose=0)[-1] # verbose=0 不打印进度条

print(f'Fit报告的最后一个训练精度 (accuracy): {last_accuracy_fit:.6f}')
print(f'Fit报告的最后一个验证精度 (val_accuracy): {last_accuracy_val:.6f}')
print(f'model.evaluate评估的精度: {predict_accuracy:.6f}')

# 预期输出:
# Fit报告的最后一个训练精度 (accuracy): 1.000000
# Fit报告的最后一个验证精度 (val_accuracy): 1.000000
# model.evaluate评估的精度: 1.000000

2.2 修正点解析

  1. validation_data=(X_train, y_train): 在model.fit中加入了validation_data参数,并将训练数据本身作为验证数据传入。这使得Keras在每个epoch结束时,都会使用该epoch的最终模型权重对X_train和y_train进行一次完整的评估,并生成val_accuracy指标。
  2. accuracy = logs.get("val_accuracy"): 在自定义的CustomEarlyStopping回调中,将监控的指标从logs["accuracy"]改为了logs.get("val_accuracy")。这样,提前停止的判断依据就是模型在epoch结束时,使用该epoch的最终权重在固定验证集(这里是训练集)上的表现。
  3. 增加epochs: 为了确保模型有足够的机会达到理想精度,将epochs从100增加到了300。

通过这些修改,model.fit报告的val_accuracy与model.evaluate的结果将保持高度一致,因为它们都是在相同的固定数据集上,使用相同的模型最终权重进行计算的。

3. 注意事项与总结

  • 理解指标含义:始终要区分model.fit在训练过程中报告的批次平均精度(accuracy)和epoch结束时在固定数据集上评估的精度(val_accuracy)。前者是动态的,后者是静态的,更具代表性。
  • 验证集的重要性:在实际项目中,validation_data通常应该是与训练集独立的数据集,用于监控模型的泛化能力,防止过拟合。本教程中为了演示精度一致性问题,使用了训练集作为验证集,但在生产环境中应避免。
  • model.evaluate的权威性:model.evaluate始终是评估模型在特定数据集上最终性能的“黄金标准”,因为它是在模型训练完成后,使用其最终权重进行的一次性评估。
  • 提前停止策略:在使用提前停止回调时,务必基于val_loss或val_accuracy等验证指标进行判断,而不是训练指标,以确保模型在验证集上表现良好时停止训练,避免过拟合。

通过理解Keras内部指标计算的机制并正确配置model.fit的参数,我们可以更准确地监控模型训练过程,并确保训练结果的可靠性。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

WorkBuddy
WorkBuddy

腾讯云推出的AI原生桌面智能体工作台

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
TypeScript类型系统进阶与大型前端项目实践
TypeScript类型系统进阶与大型前端项目实践

本专题围绕 TypeScript 在大型前端项目中的应用展开,深入讲解类型系统设计与工程化开发方法。内容包括泛型与高级类型、类型推断机制、声明文件编写、模块化结构设计以及代码规范管理。通过真实项目案例分析,帮助开发者构建类型安全、结构清晰、易维护的前端工程体系,提高团队协作效率与代码质量。

49

2026.03.13

Python异步编程与Asyncio高并发应用实践
Python异步编程与Asyncio高并发应用实践

本专题围绕 Python 异步编程模型展开,深入讲解 Asyncio 框架的核心原理与应用实践。内容包括事件循环机制、协程任务调度、异步 IO 处理以及并发任务管理策略。通过构建高并发网络请求与异步数据处理案例,帮助开发者掌握 Python 在高并发场景中的高效开发方法,并提升系统资源利用率与整体运行性能。

89

2026.03.12

C# ASP.NET Core微服务架构与API网关实践
C# ASP.NET Core微服务架构与API网关实践

本专题围绕 C# 在现代后端架构中的微服务实践展开,系统讲解基于 ASP.NET Core 构建可扩展服务体系的核心方法。内容涵盖服务拆分策略、RESTful API 设计、服务间通信、API 网关统一入口管理以及服务治理机制。通过真实项目案例,帮助开发者掌握构建高可用微服务系统的关键技术,提高系统的可扩展性与维护效率。

276

2026.03.11

Go高并发任务调度与Goroutine池化实践
Go高并发任务调度与Goroutine池化实践

本专题围绕 Go 语言在高并发任务处理场景中的实践展开,系统讲解 Goroutine 调度模型、Channel 通信机制以及并发控制策略。内容包括任务队列设计、Goroutine 池化管理、资源限制控制以及并发任务的性能优化方法。通过实际案例演示,帮助开发者构建稳定高效的 Go 并发任务处理系统,提高系统在高负载环境下的处理能力与稳定性。

59

2026.03.10

Kotlin Android模块化架构与组件化开发实践
Kotlin Android模块化架构与组件化开发实践

本专题围绕 Kotlin 在 Android 应用开发中的架构实践展开,重点讲解模块化设计与组件化开发的实现思路。内容包括项目模块拆分策略、公共组件封装、依赖管理优化、路由通信机制以及大型项目的工程化管理方法。通过真实项目案例分析,帮助开发者构建结构清晰、易扩展且维护成本低的 Android 应用架构体系,提升团队协作效率与项目迭代速度。

99

2026.03.09

JavaScript浏览器渲染机制与前端性能优化实践
JavaScript浏览器渲染机制与前端性能优化实践

本专题围绕 JavaScript 在浏览器中的执行与渲染机制展开,系统讲解 DOM 构建、CSSOM 解析、重排与重绘原理,以及关键渲染路径优化方法。内容涵盖事件循环机制、异步任务调度、资源加载优化、代码拆分与懒加载等性能优化策略。通过真实前端项目案例,帮助开发者理解浏览器底层工作原理,并掌握提升网页加载速度与交互体验的实用技巧。

105

2026.03.06

Rust内存安全机制与所有权模型深度实践
Rust内存安全机制与所有权模型深度实践

本专题围绕 Rust 语言核心特性展开,深入讲解所有权机制、借用规则、生命周期管理以及智能指针等关键概念。通过系统级开发案例,分析内存安全保障原理与零成本抽象优势,并结合并发场景讲解 Send 与 Sync 特性实现机制。帮助开发者真正理解 Rust 的设计哲学,掌握在高性能与安全性并重场景中的工程实践能力。

230

2026.03.05

PHP高性能API设计与Laravel服务架构实践
PHP高性能API设计与Laravel服务架构实践

本专题围绕 PHP 在现代 Web 后端开发中的高性能实践展开,重点讲解基于 Laravel 框架构建可扩展 API 服务的核心方法。内容涵盖路由与中间件机制、服务容器与依赖注入、接口版本管理、缓存策略设计以及队列异步处理方案。同时结合高并发场景,深入分析性能瓶颈定位与优化思路,帮助开发者构建稳定、高效、易维护的 PHP 后端服务体系。

619

2026.03.04

AI安装教程大全
AI安装教程大全

2026最全AI工具安装教程专题:包含各版本AI绘图、AI视频、智能办公软件的本地化部署手册。全篇零基础友好,附带最新模型下载地址、一键安装脚本及常见报错修复方案。每日更新,收藏这一篇就够了,让AI安装不再报错!

173

2026.03.04

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
10分钟--Midjourney创作自己的漫画
10分钟--Midjourney创作自己的漫画

共1课时 | 0.1万人学习

Midjourney 关键词系列整合
Midjourney 关键词系列整合

共13课时 | 0.9万人学习

AI绘画教程
AI绘画教程

共2课时 | 0.2万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号