0

0

K-Fold交叉验证中准确率、精确率、召回率和F1分数的正确计算方法

心靈之曲

心靈之曲

发布时间:2025-12-04 11:42:06

|

493人浏览过

|

来源于php中文网

原创

K-Fold交叉验证中准确率、精确率、召回率和F1分数的正确计算方法

本文旨在指导读者如何在k-fold交叉验证中准确计算分类模型的准确率、精确率、召回率和f1分数。我们将探讨手动实现可能存在的问题,并重点介绍如何利用scikit-learn库中的`cross_val_score`函数,以简洁、高效且标准化的方式完成这些评估任务,确保模型评估结果的可靠性和公正性。

K-Fold交叉验证与模型评估的重要性

在机器学习模型的开发过程中,评估模型的泛化能力至关重要。K-Fold交叉验证是一种广泛使用的技术,它通过将数据集划分为K个子集(折叠),轮流使用其中K-1个子集作为训练数据,剩余一个子集作为测试数据,重复K次,最终将K次评估结果取平均,从而更全面地衡量模型的性能,减少因特定训练/测试集划分而导致的评估偏差。

对于分类任务,常用的评估指标包括:

  • 准确率 (Accuracy):正确预测的样本数占总样本数的比例。
  • 精确率 (Precision):在所有被预测为正类的样本中,实际为正类的比例。
  • 召回率 (Recall):在所有实际为正类的样本中,被正确预测为正类的比例。
  • F1分数 (F1 Score):精确率和召回率的调和平均值,综合考虑了两者的表现。

手动实现K-Fold评估的潜在问题

尽管可以手动编写循环来实现K-Fold交叉验证,但在实践中,这种做法常常会引入错误或不规范的行为。例如,在一个简单的循环中重复使用train_test_split函数来生成K个折叠,可能会导致以下问题:

  1. 非标准化的折叠划分:train_test_split默认是随机划分,如果不在循环外部显式控制,每次迭代的训练集和测试集可能不是严格意义上的K-Fold划分(即测试集之间不重叠,且每个样本恰好出现在测试集中一次)。
  2. 缺乏分层抽样:对于分类问题,特别是当类别不平衡时,仅仅随机划分可能导致某些折叠中的类别分布与原始数据集差异较大,从而影响评估结果的可靠性。
  3. 代码冗余与复杂性:手动管理数据划分、模型训练、预测和指标计算会使代码变得冗长且容易出错。

以下是一个手动实现K-Fold评估的示例,展示了其基本思路但存在上述潜在问题:

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.neural_network import MLPClassifier # 假设使用MLP模型

# 示例数据
X = np.random.rand(100, 10)
y = np.random.randint(0, 2, 100)
clf = MLPClassifier(random_state=42, max_iter=100) # 示例分类器
n_folds = 5

# 手动实现K-Fold(存在潜在问题)
total_accuracy = 0
total_precision = 0
total_recall = 0
total_f1 = 0

print("--- 手动K-Fold评估(不推荐) ---")
for fold in range(n_folds):
    # 每次循环都随机划分,不保证是标准K-Fold
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1/n_folds, random_state=fold)

    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)

    total_accuracy += accuracy_score(y_test, y_pred)
    total_precision += precision_score(y_test, y_pred, zero_division=0)
    total_recall += recall_score(y_test, y_pred, zero_division=0)
    total_f1 += f1_score(y_test, y_pred, zero_division=0)

print(f"平均准确率: {total_accuracy / n_folds:.2f}")
print(f"平均精确率: {total_precision / n_folds:.2f}")
print(f"平均召回率: {total_recall / n_folds:.2f}")
print(f"平均F1分数: {total_f1 / n_folds:.2f}")

注意事项:上述手动实现方式的主要问题在于每次迭代都调用train_test_split,它默认是随机划分,并且没有确保每次划分的测试集是K-Fold交叉验证中不重叠的“折叠”。要正确实现K-Fold,需要使用KFold或StratifiedKFold对象来生成索引。然而,更推荐的方法是直接使用Scikit-learn提供的cross_val_score函数。

抠抠图
抠抠图

免费在线AI智能批量抠图,AI图片编辑,智能印花提取。

下载

使用 cross_val_score 进行标准化评估

Scikit-learn库提供了cross_val_score函数,它封装了K-Fold交叉验证的整个过程,包括数据划分、模型训练、预测和指标计算,极大地简化了代码并确保了评估的正确性和标准化。

cross_val_score 函数的关键参数包括:

  • estimator:要评估的机器学习模型实例。
  • X:特征数据。
  • y:目标变量。
  • cv:交叉验证的折叠数(K值)。对于分类任务,当cv是一个整数时,cross_val_score默认使用StratifiedKFold,确保每个折叠中的类别比例与原始数据集相似,这对于处理不平衡数据集尤为重要。
  • scoring:指定要计算的评估指标。可以是一个字符串(如'accuracy'、'precision'、'recall'、'f1'),也可以是一个可调用对象或一个指标名称列表。

下面是使用cross_val_score计算准确率、精确率、召回率和F1分数的示例代码:

import numpy as np
from sklearn.model_selection import cross_val_score
from sklearn.neural_network import MLPClassifier # 假设使用MLP模型

# 示例数据
X = np.random.rand(100, 10)
y = np.random.randint(0, 2, 100) # 假设二分类问题
clf = MLPClassifier(random_state=42, max_iter=100) # 示例分类器
n_folds = 5 # K-Fold的K值

print("\n--- 使用 cross_val_score 进行标准化评估 ---")

# 计算平均准确率
accuracy_scores = cross_val_score(clf, X, y, cv=n_folds, scoring='accuracy')
print(f'平均准确率: {accuracy_scores.mean():.2f} (标准差: {accuracy_scores.std():.2f})')

# 计算平均精确率
# 注意:对于二分类,默认是针对正类(标签为1)计算。
# 如果是多分类或需要指定正类,可能需要使用 make_scorer 或指定 average 参数
precision_scores = cross_val_score(clf, X, y, cv=n_folds, scoring='precision', error_score='raise')
print(f'平均精确率: {precision_scores.mean():.2f} (标准差: {precision_scores.std():.2f})')

# 计算平均召回率
recall_scores = cross_val_score(clf, X, y, cv=n_folds, scoring='recall', error_score='raise')
print(f'平均召回率: {recall_scores.mean():.2f} (标准差: {recall_scores.std():.2f})')

# 计算平均F1分数
f1_scores = cross_val_score(clf, X, y, cv=n_folds, scoring='f1', error_score='raise')
print(f'平均F1分数: {f1_scores.mean():.2f} (标准差: {f1_scores.std():.2f})')

# 提示:如果需要计算多分类的加权/宏平均/微平均指标,
# 可以使用 'precision_weighted', 'recall_macro', 'f1_micro' 等 scoring 字符串。
# 例如:
# f1_macro_scores = cross_val_score(clf, X, y, cv=n_folds, scoring='f1_macro')
# print(f'平均F1宏平均: {f1_macro_scores.mean():.2f}')

评估结果解读与注意事项

  • 平均值与标准差:cross_val_score返回的是一个数组,包含了K次交叉验证中每次的评估分数。通常,我们会计算这些分数的平均值作为模型的最终评估结果。同时,计算标准差可以帮助我们了解模型性能在不同折叠上的波动性,标准差越小,说明模型越稳定。
  • scoring参数的灵活性:除了上述常用的字符串,scoring参数还可以接受一个评分函数(通过make_scorer创建)或一个包含多个字符串的列表(需要结合cross_validate函数)。这为更复杂的评估需求提供了极大的灵活性。
  • error_score参数:当某些指标(如精确率、召回率)在某些折叠中因分母为零(例如,测试集中没有预测为正类的样本)而无法计算时,error_score参数可以控制行为。默认是'raise',会抛出错误。可以设置为一个数值(如0),表示在这种情况下该指标得分为0。
  • 计算效率:cross_val_score在内部会为每个折叠重新训练模型,因此计算成本与手动循环相同。但它提供了更清晰、更少出错的接口。

总结

通过本文的介绍,我们理解了在K-Fold交叉验证中正确计算模型评估指标的重要性,以及手动实现可能带来的挑战。Scikit-learn的cross_val_score函数提供了一种简洁、可靠且标准化的方法来执行这一任务,它能够自动处理数据划分(包括分层抽样),并计算各种分类指标的平均值和标准差。在实际项目中,强烈推荐使用cross_val_score来评估模型的泛化能力,从而做出更明智的模型选择和优化决策。

相关专题

更多
js 字符串转数组
js 字符串转数组

js字符串转数组的方法:1、使用“split()”方法;2、使用“Array.from()”方法;3、使用for循环遍历;4、使用“Array.split()”方法。本专题为大家提供js字符串转数组的相关的文章、下载、课程内容,供大家免费下载体验。

257

2023.08.03

js截取字符串的方法
js截取字符串的方法

js截取字符串的方法有substring()方法、substr()方法、slice()方法、split()方法和slice()方法。本专题为大家提供字符串相关的文章、下载、课程内容,供大家免费下载体验。

208

2023.09.04

java基础知识汇总
java基础知识汇总

java基础知识有Java的历史和特点、Java的开发环境、Java的基本数据类型、变量和常量、运算符和表达式、控制语句、数组和字符串等等知识点。想要知道更多关于java基础知识的朋友,请阅读本专题下面的的有关文章,欢迎大家来php中文网学习。

1465

2023.10.24

字符串介绍
字符串介绍

字符串是一种数据类型,它可以是任何文本,包括字母、数字、符号等。字符串可以由不同的字符组成,例如空格、标点符号、数字等。在编程中,字符串通常用引号括起来,如单引号、双引号或反引号。想了解更多字符串的相关内容,可以阅读本专题下面的文章。

619

2023.11.24

java读取文件转成字符串的方法
java读取文件转成字符串的方法

Java8引入了新的文件I/O API,使用java.nio.file.Files类读取文件内容更加方便。对于较旧版本的Java,可以使用java.io.FileReader和java.io.BufferedReader来读取文件。在这些方法中,你需要将文件路径替换为你的实际文件路径,并且可能需要处理可能的IOException异常。想了解更多java的相关内容,可以阅读本专题下面的文章。

550

2024.03.22

php中定义字符串的方式
php中定义字符串的方式

php中定义字符串的方式:单引号;双引号;heredoc语法等等。想了解更多字符串的相关内容,可以阅读本专题下面的文章。

545

2024.04.29

go语言字符串相关教程
go语言字符串相关教程

本专题整合了go语言字符串相关教程,阅读专题下面的文章了解更多详细内容。

161

2025.07.29

c++字符串相关教程
c++字符串相关教程

本专题整合了c++字符串相关教程,阅读专题下面的文章了解更多详细内容。

81

2025.08.07

高德地图升级方法汇总
高德地图升级方法汇总

本专题整合了高德地图升级相关教程,阅读专题下面的文章了解更多详细内容。

9

2026.01.16

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
PostgreSQL 教程
PostgreSQL 教程

共48课时 | 7.3万人学习

好课诞生记
好课诞生记

共20课时 | 6万人学习

swift开发文档
swift开发文档

共33课时 | 19.6万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号