0

0

最近邻算法(kNN)详解

P粉084495128

P粉084495128

发布时间:2025-07-23 11:04:43

|

474人浏览过

|

来源于php中文网

原创

本文介绍kNN算法,其通过计算不同特征值距离分类样本。以电影分类为例说明原理,还讲解用Numpy实现该算法的步骤,包括数据预处理、模型训练等,也提及超参数搜索函数,最后展示了用sklearn封装好的方法实现,以及相关笔记内容。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

最近邻算法(knn)详解 - php中文网

原理介绍

简言之,kNN算法计算不同特征值之间的距离对样本进行分类。

OK,说完结论,懂的可以直接看代码部分了,如果不能理解的请听我娓娓道来~现在有这么一组数据

电影名称 打斗镜头 拥抱镜头 电影类型
谍影重重 57 2 动作片
叶问3 65 2 动作片
我的特工爷爷 21 4 动作片
奔爱 4 46 爱情片
夜孔雀 8 39 爱情片
代理情人 2 38 爱情片
这个杀手不太冷 49 6

上面6个样本(电影)分别给出其特征(打斗镜头、拥抱镜头)和标签(电影类型)信息,现在给定一个新的样本,我们想知道这部电影的类型。由于是2维数据,我们可以用平面直角坐标系表示。

最近邻算法(kNN)详解 - php中文网        

绿色的点是未知的,红色的黄色的点是已知的。kNN要做的就是计算未知的点到所有已知点的距离,根据距离进行排序。

D谍影重重=(4957)2+(62)28.94D谍影重重=(49−57)2+(6−2)2≈8.94

D叶问3=(4965)2+(62)216.49D叶问3=(49−65)2+(6−2)2≈16.49

D我的特工爷爷=(4921)2+(64)228.07D我的特工爷爷=(49−21)2+(6−4)2≈28.07

D奔爱=(494)2+(646)260.21D奔爱=(49−4)2+(6−46)2≈60.21

D夜孔雀=(498)2+(639)252.63D夜孔雀=(49−8)2+(6−39)2≈52.63

D代理情人=(492)2+(638)256.86D代理情人=(49−2)2+(6−38)2≈56.86

排序后的数据如下,

电影名称 与未知电影距离
谍影重重 8.94
叶问3 16.49
我的特工爷爷 28.07
夜孔雀 52.63
代理情人 56.86
奔爱 60.21

我们在kNN算法中经常会听到说当k=3时、当k=5时...这里的k指的就是样本数。在这个例子中,当k=3时,前三个样本出现最多的电影类型是动作片,因此《这个杀手不太冷》样本也应该归为动作片。同样的,当k=5时,前5个样本出现最多的电影类型也是动作片(35>2553>52),因此样本也属于动作片。

上面提到的是2维数据,但是我们现实中处理的样本可能有3个甚至更多特征,我们无法用视觉来抽象这些特征,但是计算方法还是一样的,只不过根号里做差的数变多了而已。

最近邻算法(kNN)详解 - php中文网        

代码实现——Numpy

机器学习算法的一般流程可以归为三步。

  • 数据预处理
    1. 加载数据
    2. 交叉验证
    3. 归一化
  • 模型训练
  • 模型验证

机器学习的任务就是从海量数据中找到有价值的信息, 所以在使用算法之前,我们要对数据进行预处理。

Python精要参考 pdf版
Python精要参考 pdf版

这本书给出了一份关于python这门优美语言的精要的参考。作者通过一个完整而清晰的入门指引将你带入python的乐园,随后在语法、类型和对象、运算符与表达式、控制流函数与函数编程、类及面向对象编程、模块和包、输入输出、执行环境等多方面给出了详尽的讲解。如果你想加入 python的世界,David M beazley的这本书可不要错过哦。 (封面是最新英文版的,中文版貌似只译到第二版)

下载
In [17]
# 1. 加载莺尾花数据集from sklearn import datasets


iris = datasets.load_iris()
X = iris.data
y = iris.target
   

如果我们查看y标签信息会发现,它的前50个值为0,51—100的值为1,后50个值为2,如果直接交叉验证,取到的测试集数据可能都是label值为2的样本,这并不是我们想要的。所以在这之前,我们需要先对样本打乱顺序。zip()能将可迭代的对象打包成元组,利用 * 操作符可以将元组解压为列表。

In [18]
# 2. 实现交叉验证import numpy as npdef train_test_split(X, y, ratio=0.3):
    # 乱序
    data = list(zip(X, y))
    np.random.shuffle(data)
    X, y = zip(*data)    # 切割
    boundary_X = int((1-ratio) * len(X))
    boundary_y = int((1-ratio) * len(y))    # 将boundary_X和boundary_y之前的作为训练集
    x_train = np.array(X[: boundary_X])
    x_test = np.array(X[boundary_X:])
    y_train = np.array(y[: boundary_y])
    y_test = np.array(y[boundary_y:])    return x_train, x_test, y_train, y_test

x_train, x_test, y_train, y_test = train_test_split(X, y)
   

归一化主要有两种形式:0-1均匀分布和标准正态分布。

In [19]
# 3. 归一化def normalization(data):
    return (data - data.min()) / (data.max() - data.min()) 


def standardization(data):
    return (data - data.mean()) / data.std()


x_train = standardization(x_train)
x_test = standardization(x_test)
   

kNN的“模型训练”有点不同于一般的模型训练过程,它们可能需要求一些参数,而kNN是计算未知点到已知点的距离。从严格意义上来说,这并不算是训练。

In [20]
# 4. 距离计算from collections import Counterclass KNNClassifier:
    def __init__(self, k):
        self._k = k
        self._X_train = None
        self._y_train = None

    def fit(self, X_train, y_train):
        self._X_train = X_train
        self._y_train = y_train    # 预测X_predict样本的分类结果,这里的X_predict用的是交叉验证中的测试集
    def predict(self, X_predict):
        return np.array([self._predict(x) for x in X_predict])    def _predict(self, x):
        # 计算输入样本_X_train到所有已知数据的距离
        distances = np.sqrt(np.sum((self._X_train - x)**2, axis=1))        # 记录distances中前k个小的数对应的类别的出现次数
        votes = Counter(self._y_train[np.argpartition(distances, self._k)[: self._k]])        # most_common(n)可以打印n个出现最多次元素的值和次数
        predict_y = votes.most_common(1)[0][0]        return predict_y    # 计算准确率
    def score(self, X_test, label):
        y_predict = self.predict(X_test)
        n_sample = len(label)
        right_sample = 0
        for i, e in enumerate(label):            if y_predict[i] == e:
                right_sample += 1
        return right_sample / n_sample


knn = KNNClassifier(k=3)
knn.fit(x_train, y_train)
knn.score(x_test, y_test)
       
0.9555555555555556
               

超参数搜索函数

kNN的参数不止是k,距离模式distype也是它的参数。对于k和distype这两种参数的组合,可能会有很多不同的结果,不妨设计一个超参数搜索函数来优化k和distype。

In [21]
class KNNClassifierSuper(KNNClassifier):
    def __init__(self, k, distype):
        super().__init__(k)
        self.distype = distype    def _predict(self, x):
        assert self.distype in ["1", "2", "3"], "Error distance type!"
        if self.distype == "1":
            distances = np.sum(abs(self._X_train - x), axis=1)        elif self.distype == "2":
            distances = np.sqrt(np.sum((self._X_train - x)**2, axis=1))        else:
            distances = np.max(abs(self._X_train - x), axis=1)
        votes = Counter(self._y_train[np.argpartition(distances, self._k)[: self._k]])
        predict_y = votes.most_common(1)[0][0]        return predict_y# ManhattanDistance —— "1"# EuclideanDistance —— "2"# ChebyshevDistance —— "3"for k in range(3, 15, 2):    for distype in range(1, 4):
        knn = KNNClassifierSuper(k, str(distype))
        knn.fit(x_train, y_train)        print("k = {}\tdistype = {}\tscore = {}".format(
            k, distype, knn.score(x_test, y_test)))
       
k = 3	distype = 1	score = 0.9555555555555556
k = 3	distype = 2	score = 0.9555555555555556
k = 3	distype = 3	score = 0.9777777777777777
k = 5	distype = 1	score = 0.9777777777777777
k = 5	distype = 2	score = 0.9777777777777777
k = 5	distype = 3	score = 0.9777777777777777
k = 7	distype = 1	score = 0.9777777777777777
k = 7	distype = 2	score = 0.9777777777777777
k = 7	distype = 3	score = 1.0
k = 9	distype = 1	score = 0.9555555555555556
k = 9	distype = 2	score = 0.9777777777777777
k = 9	distype = 3	score = 1.0
k = 11	distype = 1	score = 0.9777777777777777
k = 11	distype = 2	score = 0.9555555555555556
k = 11	distype = 3	score = 0.9777777777777777
k = 13	distype = 1	score = 0.9777777777777777
k = 13	distype = 2	score = 1.0
k = 13	distype = 3	score = 0.9333333333333333
       

代码实现——sklearn

上面我们用Numpy实现了交叉验证、归一化、距离计算等方法,这些在sklearn中都已经为我们封装好了。

In [31]
from sklearn.model_selection import train_test_split, GridSearchCVfrom sklearn import preprocessingfrom sklearn.neighbors import KNeighborsClassifier# 加载莺尾花数据iris = datasets.load_iris()
X, y = iris.data, iris.target# 交叉验证x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.4)# 归一化x_train = preprocessing.scale(x_train)
x_test = preprocessing.scale(x_test)# 距离计算 + 超参数搜索函数# p = 1 manhattan_distance # p = 2 euclidean_distance# arbitrary p minkowski_distance for k in range(3, 14, 2):    for p in range(1, 5):
        knn = KNeighborsClassifier(n_neighbors=k, p=p)
        knn.fit(x_train, y_train)        print("k = {}\tp = {}\tscore = {}".format(
            k, p, knn.score(x_test, y_test)))
       
k = 3	p = 1	score = 0.9666666666666667
k = 3	p = 2	score = 0.9833333333333333
k = 3	p = 3	score = 0.9666666666666667
k = 3	p = 4	score = 0.9666666666666667
k = 5	p = 1	score = 0.9666666666666667
k = 5	p = 2	score = 0.9833333333333333
k = 5	p = 3	score = 0.9833333333333333
k = 5	p = 4	score = 0.9833333333333333
k = 7	p = 1	score = 0.9833333333333333
k = 7	p = 2	score = 0.9833333333333333
k = 7	p = 3	score = 0.9833333333333333
k = 7	p = 4	score = 0.9833333333333333
k = 9	p = 1	score = 0.9666666666666667
k = 9	p = 2	score = 0.95
k = 9	p = 3	score = 0.9666666666666667
k = 9	p = 4	score = 0.9666666666666667
k = 11	p = 1	score = 0.9666666666666667
k = 11	p = 2	score = 0.95
k = 11	p = 3	score = 0.9333333333333333
k = 11	p = 4	score = 0.9333333333333333
k = 13	p = 1	score = 0.9666666666666667
k = 13	p = 2	score = 0.95
k = 13	p = 3	score = 0.9166666666666666
k = 13	p = 4	score = 0.9166666666666666
       

笔记

  1. ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

ndarray计算的时候尽量用np的属性(np.sum()而不是sum)


  1. ValueError: kth(=3) out of bounds (1)

计算距离的时候np.sum()需要指定axis=1,不然会直接对多维数组进行sum得到一个数值,在np.argpartition会出错。


  1. np.sum()如果不指定axis是无法广播的,会直接返回数值

  1. axis一种较好的理解方式是把他看成消除器。对于shape为(2L, 3L, 4L)的数组arr,np.sum(arr, axis=0)会返回shape为(3L, 4L)的数组,np.sum(arr, axis=1) 会返回shape为(2L, 4L)的数组,np.sum(arr, axis=0)会返回shape为(2L, 3L)的数组。

相关专题

更多
页面置换算法
页面置换算法

页面置换算法是操作系统中用来决定在内存中哪些页面应该被换出以便为新的页面提供空间的算法。本专题为大家提供页面置换算法的相关文章,大家可以免费体验。

403

2023.08.14

Java JVM 原理与性能调优实战
Java JVM 原理与性能调优实战

本专题系统讲解 Java 虚拟机(JVM)的核心工作原理与性能调优方法,包括 JVM 内存结构、对象创建与回收流程、垃圾回收器(Serial、CMS、G1、ZGC)对比分析、常见内存泄漏与性能瓶颈排查,以及 JVM 参数调优与监控工具(jstat、jmap、jvisualvm)的实战使用。通过真实案例,帮助学习者掌握 Java 应用在生产环境中的性能分析与优化能力。

19

2026.01.20

PS使用蒙版相关教程
PS使用蒙版相关教程

本专题整合了ps使用蒙版相关教程,阅读专题下面的文章了解更多详细内容。

61

2026.01.19

java用途介绍
java用途介绍

本专题整合了java用途功能相关介绍,阅读专题下面的文章了解更多详细内容。

87

2026.01.19

java输出数组相关教程
java输出数组相关教程

本专题整合了java输出数组相关教程,阅读专题下面的文章了解更多详细内容。

39

2026.01.19

java接口相关教程
java接口相关教程

本专题整合了java接口相关内容,阅读专题下面的文章了解更多详细内容。

10

2026.01.19

xml格式相关教程
xml格式相关教程

本专题整合了xml格式相关教程汇总,阅读专题下面的文章了解更多详细内容。

13

2026.01.19

PHP WebSocket 实时通信开发
PHP WebSocket 实时通信开发

本专题系统讲解 PHP 在实时通信与长连接场景中的应用实践,涵盖 WebSocket 协议原理、服务端连接管理、消息推送机制、心跳检测、断线重连以及与前端的实时交互实现。通过聊天系统、实时通知等案例,帮助开发者掌握 使用 PHP 构建实时通信与推送服务的完整开发流程,适用于即时消息与高互动性应用场景。

19

2026.01.19

微信聊天记录删除恢复导出教程汇总
微信聊天记录删除恢复导出教程汇总

本专题整合了微信聊天记录相关教程大全,阅读专题下面的文章了解更多详细内容。

160

2026.01.18

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
10分钟--Midjourney创作自己的漫画
10分钟--Midjourney创作自己的漫画

共1课时 | 0.1万人学习

Midjourney 关键词系列整合
Midjourney 关键词系列整合

共13课时 | 0.9万人学习

AI绘画教程
AI绘画教程

共2课时 | 0.2万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号