


























我们来系统性地讲解 k 近邻算法(k-Nearest Neighbors, KNN)。我会从直观理解、核心原理、关键细节、代码示例到优缺点,层层递进。
KNN 可能是机器学习中最简单、最直观的算法。它的核心思想只有一句话:
要判断一个新样本的类别或数值,就去看它在训练数据中最相似的 k 个邻居是谁,然后让这些邻居投票决定。
它不需要显式的训练过程,属于惰性学习算法——所有计算都推迟到预测时进行。
以分类问题为例:
\[d = \sqrt{\sum_{i=1}^{n}(x_i - y_i)^2} \]
\[d = \sum_{i=1}^{n}|x_i - y_i| \]
如果特征量纲不一致(如“收入”用元、“年龄”用岁),量级大的特征会主导距离计算。
必须做的事:标准化(Z-score)或归一化(Min-Max),让所有特征在相同尺度上。
import numpy as np
from collections import Counter
class KNNClassifier:
def __init__(self, k=3, distance='euclidean'):
self.k = k
self.distance = distance
def fit(self, X, y):
self.X_train = X
self.y_train = y
def predict(self, X):
return np.array([self._predict_one(x) for x in X])
def _predict_one(self, x):
# 计算所有距离
if self.distance == 'euclidean':
distances = np.sqrt(np.sum((self.X_train - x) ** 2, axis=1))
elif self.distance == 'manhattan':
distances = np.sum(np.abs(self.X_train - x), axis=1)
# 获取 k 个最近邻居的索引
k_indices = np.argsort(distances)[:self.k]
k_labels = self.y_train[k_indices]
# 多数投票
most_common = Counter(k_labels).most_common(1)
return most_common[0][0]
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.datasets import load_iris
# 加载数据
iris = load_iris()
X, y = iris.data, iris.target
# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 关键步骤:标准化!
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# 寻找最优 k 值
best_k, best_score = 0, 0
for k in range(1, 31):
knn = KNeighborsClassifier(n_neighbors=k)
scores = cross_val_score(knn, X_train_scaled, y_train, cv=5)
if scores.mean() > best_score:
best_k, best_score = k, scores.mean()
print(f"最优 k 值: {best_k}, 交叉验证准确率: {best_score:.3f}")
# 训练最终模型
final_knn = KNeighborsClassifier(n_neighbors=best_k, weights='distance')
final_knn.fit(X_train_scaled, y_train)
print(f"测试集准确率: {final_knn.score(X_test_scaled, y_test):.3f}")
| 优点 | 缺点 |
|---|---|
| 简单直观,无需训练 | 预测慢,计算量大 |
| 没有假设数据分布 | 需存储全部训练数据 |
| 多分类问题天然支持 | 对高维数据效果差(维度灾难) |
| 新数据加入容易 | 对不平衡数据敏感 |
| 只有一个超参数 k | 缺失值处理麻烦 |
algorithm='ball_tree' 即可KNN 是“理解成本极低,但用好需要细致处理”的典型算法。关键三件事:标准化特征、精心选择 k 值、根据问题选择合适的距离度量。它常作为复杂问题的第一个基线模型,帮你快速建立对问题的认知。
此内容由惯性聚合(RSS阅读器)自动聚合整理,仅供阅读参考。 原文来自 — 版权归原作者所有。