K 近邻算法(长文讲解)
💡一则或许对你有用的小广告
欢迎加入小哈的星球 ,你将获得:专属的项目实战 / 1v1 提问 / Java 学习路线 / 学习打卡 / 每月赠书 / 社群讨论
- 新项目:《从零手撸:仿小红书(微服务架构)》 正在持续爆肝中,基于
Spring Cloud Alibaba + Spring Boot 3.x + JDK 17...
,点击查看项目介绍 ;演示链接: http://116.62.199.48:7070 ;- 《从零手撸:前后端分离博客项目(全栈开发)》 2 期已完结,演示链接: http://116.62.199.48/ ;
截止目前, 星球 内专栏累计输出 90w+ 字,讲解图 3441+ 张,还在持续爆肝中.. 后续还会上新更多项目,目标是将 Java 领域典型的项目都整一波,如秒杀系统, 在线商城, IM 即时通讯,权限管理,Spring Cloud Alibaba 微服务等等,已有 3100+ 小伙伴加入学习 ,欢迎点击围观
前言
在机器学习领域,K 近邻算法(K-Nearest Neighbors, KNN)因其直观的逻辑和易于理解的原理,成为许多开发者入门算法的首选。它既可用于分类任务,也能处理回归问题,凭借“以邻为鉴”的思想,在推荐系统、图像识别和数据挖掘等场景中发挥重要作用。本文将从基础概念到实践应用,系统解析 K 近邻算法的原理与实现,帮助读者逐步掌握这一经典算法的核心思想。
知识点一:K 近邻算法的基本概念
什么是 K 近邻算法?
K 近邻算法是一种基于实例的学习方法,其核心思想是:通过计算样本之间的距离,找到与目标样本最相似的 K 个邻居,然后根据这些邻居的类别或数值进行预测。这种“以邻为鉴”的策略,使得算法无需显式训练模型,而是直接基于数据本身的相似性进行推理。
分类与回归的双重应用
K 近邻算法既可用于分类问题,也可用于回归问题:
- 分类任务:根据 K 个最近邻的类别标签,通过投票(多数表决)确定目标样本的类别。
- 回归任务:根据 K 个最近邻的数值,计算平均值或加权平均值作为预测结果。
核心思想的比喻
可以将 K 近邻算法想象成一个“社交网络决策模型”:假设你遇到一个陌生问题,会询问周围 K 个最熟悉的邻居的意见,最终根据他们的选择做出决定。如果邻居们多数选择 A,那么你也倾向选择 A;若邻居们给出的数值平均为 5,那么你也会预测 5。这种“集体智慧”的逻辑,正是 K 近邻算法的直观体现。
知识点二:K 近邻算法的核心原理
距离计算:确定“邻居”的关键
K 近邻算法的准确性依赖于如何定义样本之间的“距离”。常见的距离计算方法包括:
距离类型 | 公式 | 特点 | ||
---|---|---|---|---|
欧氏距离 | ( d = \sqrt{\sum_{i=1}^n (x_i - y_i)^2} ) | 计算两点间直线距离,对数值差异敏感 | ||
曼哈顿距离 | ( d = \sum_{i=1}^n | x_i - y_i | ) | 计算各维度差值的绝对值之和,适合城市街区路径 |
余弦相似度 | ( d = 1 - \frac{\sum_{i=1}^n x_i y_i}{\sqrt{\sum x_i^2} \sqrt{\sum y_i^2}} ) | 关注向量方向而非距离,适用于文本等高维数据 |
举例说明:假设两个二维点 A(1,2) 和 B(4,6),计算欧氏距离为:
[
d_{\text{欧式}} = \sqrt{(4-1)^2 + (6-2)^2} = \sqrt{9 + 16} = 5
]
K 值选择:平衡偏差与方差
K 值是算法的超参数,其选择直接影响模型性能:
- K 过小:模型容易过拟合,对噪声敏感(例如 K=1 时完全依赖最近邻);
- K 过大:模型可能欠拟合,丢失局部特征(例如 K=100 时邻居范围过大)。
通常,K 值通过交叉验证确定。例如,在鸢尾花分类任务中,通过尝试 K=3、5、7,观察准确率的变化,选择最优值。
知识点三:K 近邻算法的实现步骤
步骤 1:数据预处理
K 近邻算法对数据的尺度敏感,需进行标准化(归一化):
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
步骤 2:距离计算与排序
对测试样本与所有训练样本计算距离,并按距离排序:
def compute_distances(test_sample, train_data):
distances = []
for sample in train_data:
dist = np.sqrt(np.sum((test_sample - sample)**2))
distances.append(dist)
return np.argsort(distances) # 返回排序后的索引
步骤 3:选择最近邻并预测
根据排序后的距离,选择前 K 个样本,并进行投票或平均:
def predict_knn(k, sorted_indices, train_labels):
k_neighbors = train_labels[sorted_indices[:k]]
# 分类任务:投票
prediction = np.bincount(k_neighbors).argmax()
# 回归任务:平均
# prediction = np.mean(k_neighbors)
return prediction
完整代码示例(分类任务)
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(
iris.data, iris.target, test_size=0.2, random_state=42
)
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)
print(f"Accuracy: {accuracy_score(y_test, y_pred):.2f}")
知识点四:K 近邻算法的优缺点分析
优点
- 简单直观:无需复杂训练,适合快速验证假设;
- 适应性强:可通过调整 K 值和距离度量,适应不同数据分布;
- 无模型假设:不依赖数据分布假设,适合非线性问题。
缺点
- 计算效率低:每次预测需遍历所有训练样本,不适合大规模数据;
- 对噪声敏感:K 值选择不当或数据存在噪声时,可能降低准确性;
- 维度灾难:高维数据中距离计算失去意义,需降维或特征工程优化。
知识点五:进阶技巧与优化
技巧 1:数据标准化
通过标准化消除量纲差异,例如:
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
X_normalized = scaler.fit_transform(X)
技巧 2:选择合适的距离度量
根据数据类型选择距离:
- 连续数值数据:欧氏距离或曼哈顿距离;
- 文本或高维稀疏数据:余弦相似度或 Jaccard 系数。
技巧 3:优化 K 值选择
使用网格搜索自动调参:
from sklearn.model_selection import GridSearchCV
param_grid = {'n_neighbors': np.arange(1, 20)}
knn = KNeighborsClassifier()
grid_search = GridSearchCV(knn, param_grid, cv=5)
grid_search.fit(X_train, y_train)
best_k = grid_search.best_params_['n_neighbors']
结论
K 近邻算法凭借其直观的逻辑和灵活的应用场景,成为机器学习入门的经典算法。通过理解距离计算、K 值选择等核心原理,并结合数据标准化和参数调优技巧,开发者可以高效解决分类与回归问题。尽管它在计算效率和高维数据上存在局限,但通过与特征工程、降维技术结合,仍能在实际项目中发挥重要作用。建议读者通过动手实现案例(如手写数字识别),进一步巩固对 K 近邻算法的理解。