KNN演算法分類的基本原理與實例

王林
發布: 2024-01-23 11:24:20
轉載
735 人瀏覽過

KNN演算法分類的基本原理與實例

KNN演算法是一種簡單易用的分類演算法,適用於小規模資料集和低維特徵空間。它在圖像分類、文字分類等領域中表現出色,因其實現簡單、易於理解而備受青睞。

KNN演算法的基本思想是透過比較待分類樣本的特徵與訓練樣本的特徵,找到最接近的K個鄰居,並根據這K個鄰居的類別確定待分類樣本的類別。 KNN演算法中使用已標記好類別的訓練集和待分類的測試集。 KNN演算法的分類過程包括以下幾個步驟:首先,計算待分類樣本與所有訓練樣本之間的距離;其次,選擇距離最近的K個鄰居;然後,根據K個鄰居的類別進行投票,得出待分類樣本的類別;最後,將待分類樣本的類別確定為投票結果中得票最多的類別。透過這些步驟,KNN演算法可以對待分類樣本進行準確的分類。

1.計算距離

對於未分類的測試樣本,需計算其與訓練集所有樣本的距離,常用歐式、曼哈頓等方法。

2.選擇K個鄰居

根據計算出來的距離,選擇與待分類樣本距離最近的K個訓練集樣本。這些樣本就是待分類樣本的K個鄰居。

3.確定類別

根據K個鄰居的類別來決定待分類樣本的類別。通常採用「多數決法」來決定待分類樣本的類別,即選擇K個鄰居中出現最多的類別作為待分類樣本的類別。

KNN演算法相對簡單,但也有一些需要注意的問題。首先,K值的選擇對演算法的效能有很大的影響,通常需要透過交叉驗證等方法來確定最優的K值。其次,KNN演算法對資料集的規模和維度敏感,對於大規模和高維資料集的處理會出現效率問題。此外,KNN演算法還存在「類別不平衡」的問題,即某些類別的樣本數量較少,可能導致演算法對這些類別的分類效果較差。

以下是使用Python實作KNN演算法的分類實例,程式碼如下:

import numpy as np
from collections import Counter

class KNN:
    def __init__(self, k):
        self.k = k

    def fit(self, X, y):
        self.X_train = X
        self.y_train = y

    def predict(self, X_test):
        predictions = []

        for x_test in X_test:
            distances = []
            for x_train in self.X_train:
                distance = np.sqrt(np.sum((x_test - x_train)**2))
                distances.append(distance)
            idx = np.argsort(distances)[:self.k]
            k_nearest_labels = [self.y_train[i] for i in idx]
            most_common = Counter(k_nearest_labels).most_common(1)
            predictions.append(most_common[0][0])

        return np.array(predictions)
登入後複製

這個KNN類別的建構子中傳入參數k表示選擇多少個鄰居來進行分類。 fit方法用於訓練模型,接受一個訓練集X和它們對應的標籤y。 predict方法用於對測試集進行分類,接受一個測試集X_test,傳回預測的標籤。

在predict方法中,對於每個測試樣本,首先計算它與訓練集中所有樣本的距離,並選擇距離最近的k個樣本。然後,統計這k個樣本中出現最頻繁的標籤,並作為測試樣本的分類標籤。

下面是使用這個KNN類別進行分類的例子,資料集為一個二維平面上的點集,其中紅色點表示類別1,藍色點表示類別2:

import matplotlib.pyplot as plt

# 生成数据集
X = np.random.rand(200, 2) * 5 - 2.5
y = np.zeros(200)
y[np.sum(X**2, axis=1) > 2] = 1

# 分割训练集和测试集
train_idx = np.random.choice(200, 150, replace=False)
test_idx = np.array(list(set(range(200)) - set(train_idx)))
X_train, y_train = X[train_idx], y[train_idx]
X_test, y_test = X[test_idx], y[test_idx]

# 训练模型并进行预测
knn = KNN(k=5)
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)

# 计算准确率并绘制分类结果
accuracy = np.mean(y_pred == y_test)
print("Accuracy:", accuracy)

plt.scatter(X_test[:, 0], X_test[:, 1], c=y_pred)
plt.show()
登入後複製

運行這段程式碼後,可以看到分類結果影像。其中,顏色表示預測的類別,紅色表示類別1,藍色表示類別2。根據分類結果,可以計算出模型的準確率。

這個實例展示了KNN演算法在二維平面上的應用,透過計算距離來決定鄰居,並根據鄰居的類別來進行分類。在實際應用中,KNN演算法可以用於影像分類、文字分類等領域,是一種簡單而有效的分類演算法。

以上是KNN演算法分類的基本原理與實例的詳細內容。更多資訊請關注PHP中文網其他相關文章!

來源:163.com
本網站聲明
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn
最新問題
熱門教學
更多>
最新下載
更多>
網站特效
網站源碼
網站素材
前端模板