首頁 後端開發 Python教學 python中K-近鄰演算法的原理與實作(附源碼)

python中K-近鄰演算法的原理與實作(附源碼)

Oct 27, 2018 pm 02:21 PM
python scikit-learn 機器學習 演算法

這篇文章帶給大家的內容是關於python中K-近鄰演算法的原理與實作(附原始碼),有一定的參考價值,有需要的朋友可以參考一下,希望對你有幫助。

k-近鄰演算法透過測量不同特徵值之間的距離方法進行分類。

k-近鄰演算法原理

對於一個存在標籤的訓練樣本集,輸入沒有標籤的新資料後,將新資料的每個特徵與樣本集中將資料對應的特徵進行比較,根據演算法選擇樣本資料集中前k個最相似的數據,選擇k個最相似資料中出現次數最多的分類,作為新資料的分類。

k-近鄰演算法實作

這裡只是單一新資料的預測,對同時多個新資料的預測放在後文中。

假定有訓練樣本集X_train(X_train.shape=(10, 2)),對應的標記y_train(y_train.shape=(10,),包含0、1),使用matplotlib.pyplot 作圖表示如下(綠色的點表示標記0,紅色的點表示標記1):

python中K-近鄰演算法的原理與實作(附源碼)

現有一個新的資料:x(x = np.array([3.18557125, 6.03119673])),作圖表示如下(藍色的點):

python中K-近鄰演算法的原理與實作(附源碼)

#首先,使用歐拉距離公式計算x 到X_train 中每個樣本的距離:

import math

distances = [math.sqrt(np.sum((x_train - x) ** 2)) for x_train in X_train]
登入後複製

第二,對distances 進行升序操作,使用np.argsort() 方法傳回排序後的索引,而不會對原始資料的順序有任何影響:

import numpy as np

nearest = np.argsort(distances)
登入後複製

第三,取k 個距離最近的樣本對應的標記:

topK_y = [y_train[i] for i in nearest[:k]]
登入後複製

最後,對這k 個距離最近的樣本對應的標記進行統計,找出佔比最多標記即為x 的預測分類,此例的預測分類為0:

from collections import Counter

votes = Counter(topK_y)
votes.most_common(1)[0][0]
登入後複製

將上面的程式碼封裝到一個方法中:

import numpy as np
import math

from collections import Counter


def kNN(k, X_train, y_train, x):
    distances = [math.sqrt(np.sum((x_train - x) ** 2)) for x_train in X_train]
    nearest = np.argsort(distances)

    topK_y = [y_train[i] for i in nearest[:k]]
    votes = Counter(topK_y)
    return votes.most_common(1)[0][0]
登入後複製

Scikit Learn 中的k-近鄰演算法

一個典型的機器學習演算法流程是將訓練資料集透過機器學習演算法訓練(fit)出模型,透過這個模型來預測輸入樣例的結果。

python中K-近鄰演算法的原理與實作(附源碼)

對於k-近鄰演算法來說,它是一個特殊的沒有模型的演算法,但是我們將其訓練資料集看作是模型。 Scikit Learn 中就是怎麼處理的。

Scikit Learn 中k-近鄰演算法使用

Scikit Learn 中k-鄰近演算法在neighbors 模組中,初始化時傳入參數n_neighbors 為6,即為上方的k:

from sklearn.neighbors import KNeighborsClassifier

kNN_classifier = KNeighborsClassifier(n_neighbors=6)
登入後複製

fit() 方法根據訓練資料集「訓練」分類器,該方法會傳回分類器本身:

kNN_classifier.fit(X_train, y_train)
登入後複製

predict( ) 方法預測輸入的結果,此方法要求傳入的參數類型為矩陣。因此,這裡先對 x 進行 reshape 操作:

X_predict = x.reshape(1, -1)
y_predict = kNN_classifier.predict(X_predict)
登入後複製

y_predict 值為0,與前面實作的 kNN 方法結果一致。

實作Scikit Learn 中的KNeighborsClassifier 分類器

定義一個KNNClassifier 類,其構造器方法傳入參數k,表示預測時選取的最相似資料的個數字:

class KNNClassifier:
    def __init__(self, k):
        self.k = k
        self._X_train = None
        self._y_train = None
登入後複製

fit() 方法訓練分類器,並且傳回分類器本身:

def fit(self, X_train, y_train):
    self._X_train = X_train
    self._y_train = y_train
    return self
登入後複製

predict() 方法處理資料集進行預測,參數X_predict 類型為矩陣。此方法使用列表解析式對 X_predict 進行了遍歷,對每個待測資料呼叫了一次 _predict() 方法。

def predict(self, X_predict):
    y_predict = [self._predict(x) for x in X_predict]
    return np.array(y_predict)

def _predict(self, x):
    distances = [math.sqrt(np.sum((x_train - x) ** 2))
                 for x_train in self._X_train]
    nearest = np.argsort(distances)

    topK_y = [self._y_train[i] for i in nearest[:self.k]]
    votes = Counter(topK_y)

    return votes.most_common(1)[0][0]
登入後複製

演算法準確性

模型存在的問題

上面透過訓練樣本集訓練出了模型,但是並不知道這個模型的好壞,還有兩個問題。

  1. 如果模型很壞,預測的結果就不是我們想要的。同時實際情況中,很難拿到真實的標記(label),無法檢驗模型。

  2. 訓練模型時訓練樣本並沒有包含所有的標記。

對於第一個問題,通常將樣本集中一定比例(如20%)的數據作為測試數據,其餘數據作為訓練數據。

以 Scikit Learn 中提供的鳶尾花資料為例,其包含了150個樣本。

import numpy as np
from sklearn import datasets

iris = datasets.load_iris()
X = iris.data
y = iris.target
登入後複製

現在將樣本分為20%範例測試資料和80%比例訓練資料:

test_ratio = 0.2
test_size = int(len(X) * test_ratio)

X_train = X[test_size:]
y_train = y[test_size:]

X_test = X[:test_size]
y_test = y[:test_size]
登入後複製

將X_train 和y_train 作為訓練資料用於訓練模型,X_test 和y_test 作為測試資料驗證模型準確性。

對於第二個問題,還是以 Scikit Learn 中提供的鳶尾花資料為例,其標記 y 的內容為:

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
登入後複製

发现0、1、2是以顺序存储的,在将样本划分为训练数据和测试数据过程中,如果训练数据中才对标记只包含0、1,这样的训练数据对于模型的训练将是致命的。以此,应将样本数据先进行随机处理。

np.random.permutation() 方法传入一个整数 n,会返回一个区间在 [0, n) 且随机排序的一维数组。将 X 的长度作为参数传入,返回 X 索引的随机数组:

shuffle_indexes = np.random.permutation(len(X))
登入後複製

将随机化的索引数组分为训练数据的索引与测试数据的索引两部分:

test_ratio = 0.2
test_size = int(len(X) * test_ratio)

test_indexes = shuffle_indexes[:test_size]
train_indexes = shuffle_indexes[test_size:]
登入後複製

再通过两部分的索引将样本数据分为训练数据和测试数据:

X_train = X[train_indexes]
y_train = y[train_indexes]

X_test = X[test_indexes]
y_test = y[test_indexes]
登入後複製

可以将两个问题的解决方案封装到一个方法中,seed 表示随机数种子,作用在 np.random 中:

import numpy as np

def train_test_split(X, y, test_ratio=0.2, seed=None):
    if seed:
        np.random.seed(seed)
    shuffle_indexes = np.random.permutation(len(X))

    test_size = int(len(X) * test_ratio)
    test_indexes = shuffle_indexes[:test_size]
    train_indexes = shuffle_indexes[test_size:]

    X_train = X[train_indexes]
    y_train = y[train_indexes]

    X_test = X[test_indexes]
    y_test = y[test_indexes]

    return X_train, X_test, y_train, y_test
登入後複製

Scikit Learn 中封装了 train_test_split() 方法,放在了 model_selection 模块中:

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
登入後複製

算法正确率

通过 train_test_split() 方法对样本数据进行了预处理后,开始训练模型,并且对测试数据进行验证:

from sklearn.neighbors import KNeighborsClassifier

kNN_classifier = KNeighborsClassifier(n_neighbors=6)
kNN_classifier.fit(X_train, y_train)
y_predict = kNN_classifier.predict(X_test)
登入後複製

y_predict 是对测试数据 X_test 的预测结果,其中与 y_test 相等的个数除以 y_test 的个数就是该模型的正确率,将其和 y_test 进行比较可以算出模型的正确率:

def accuracy_score(y_true, y_predict):
    return sum(y_predict == y_true) / len(y_true)
登入後複製

调用该方法,返回一个小于等于1的浮点数:

accuracy_score(y_test, y_predict)
登入後複製

同样在 Scikit Learn 的 metrics 模块中封装了 accuracy_score() 方法:

from sklearn.metrics import accuracy_score

accuracy_score(y_test, y_predict)
登入後複製

Scikit Learn 中的 KNeighborsClassifier 类的父类 ClassifierMixin 中有一个 score() 方法,里面就调用了 accuracy_score() 方法,将测试数据 X_test 和 y_test 作为参数传入该方法中,可以直接计算出算法正确率。

class ClassifierMixin(object):
    def score(self, X, y, sample_weight=None):
        from .metrics import accuracy_score
        return accuracy_score(y, self.predict(X), sample_weight=sample_weight)
登入後複製

超参数

前文中提到的 k 是一种超参数,超参数是在算法运行前需要决定的参数。 Scikit Learn 中 k-近邻算法包含了许多超参数,在初始化构造函数中都有指定:

def __init__(self, n_neighbors=5,
             weights='uniform', algorithm='auto', leaf_size=30,
             p=2, metric='minkowski', metric_params=None, n_jobs=None,
             **kwargs):
    # code here
登入後複製

这些超参数的含义在源代码和官方文档[scikit-learn.org]中都有说明。

算法优缺点

k-近邻算法是一个比较简单的算法,有其优点但也有缺点。

优点是思想简单,但效果强大, 天然的适合多分类问题。

缺点是效率低下,比如一个训练集有 m 个样本,n 个特征,则预测一个新的数据的算法复杂度为 O(m*n);同时该算法可能产生维数灾难,当维数很大时,两个点之间的距离可能也很大,如 (0,0,0,...,0) 和 (1,1,1,...,1)(10000维)之间的距离为100。

源码地址

Github | ML-Algorithms-Action


以上是python中K-近鄰演算法的原理與實作(附源碼)的詳細內容。更多資訊請關注PHP中文網其他相關文章!

本網站聲明
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn

熱AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover

AI Clothes Remover

用於從照片中去除衣服的線上人工智慧工具。

Undress AI Tool

Undress AI Tool

免費脫衣圖片

Clothoff.io

Clothoff.io

AI脫衣器

Video Face Swap

Video Face Swap

使用我們完全免費的人工智慧換臉工具,輕鬆在任何影片中換臉!

熱工具

記事本++7.3.1

記事本++7.3.1

好用且免費的程式碼編輯器

SublimeText3漢化版

SublimeText3漢化版

中文版,非常好用

禪工作室 13.0.1

禪工作室 13.0.1

強大的PHP整合開發環境

Dreamweaver CS6

Dreamweaver CS6

視覺化網頁開發工具

SublimeText3 Mac版

SublimeText3 Mac版

神級程式碼編輯軟體(SublimeText3)

PHP和Python:解釋了不同的範例 PHP和Python:解釋了不同的範例 Apr 18, 2025 am 12:26 AM

PHP主要是過程式編程,但也支持面向對象編程(OOP);Python支持多種範式,包括OOP、函數式和過程式編程。 PHP適合web開發,Python適用於多種應用,如數據分析和機器學習。

在PHP和Python之間進行選擇:指南 在PHP和Python之間進行選擇:指南 Apr 18, 2025 am 12:24 AM

PHP適合網頁開發和快速原型開發,Python適用於數據科學和機器學習。 1.PHP用於動態網頁開發,語法簡單,適合快速開發。 2.Python語法簡潔,適用於多領域,庫生態系統強大。

Python vs. JavaScript:學習曲線和易用性 Python vs. JavaScript:學習曲線和易用性 Apr 16, 2025 am 12:12 AM

Python更適合初學者,學習曲線平緩,語法簡潔;JavaScript適合前端開發,學習曲線較陡,語法靈活。 1.Python語法直觀,適用於數據科學和後端開發。 2.JavaScript靈活,廣泛用於前端和服務器端編程。

PHP和Python:深入了解他們的歷史 PHP和Python:深入了解他們的歷史 Apr 18, 2025 am 12:25 AM

PHP起源於1994年,由RasmusLerdorf開發,最初用於跟踪網站訪問者,逐漸演變為服務器端腳本語言,廣泛應用於網頁開發。 Python由GuidovanRossum於1980年代末開發,1991年首次發布,強調代碼可讀性和簡潔性,適用於科學計算、數據分析等領域。

vs code 可以在 Windows 8 中運行嗎 vs code 可以在 Windows 8 中運行嗎 Apr 15, 2025 pm 07:24 PM

VS Code可以在Windows 8上運行,但體驗可能不佳。首先確保系統已更新到最新補丁,然後下載與系統架構匹配的VS Code安裝包,按照提示安裝。安裝後,注意某些擴展程序可能與Windows 8不兼容,需要尋找替代擴展或在虛擬機中使用更新的Windows系統。安裝必要的擴展,檢查是否正常工作。儘管VS Code在Windows 8上可行,但建議升級到更新的Windows系統以獲得更好的開發體驗和安全保障。

visual studio code 可以用於 python 嗎 visual studio code 可以用於 python 嗎 Apr 15, 2025 pm 08:18 PM

VS Code 可用於編寫 Python,並提供許多功能,使其成為開發 Python 應用程序的理想工具。它允許用戶:安裝 Python 擴展,以獲得代碼補全、語法高亮和調試等功能。使用調試器逐步跟踪代碼,查找和修復錯誤。集成 Git,進行版本控制。使用代碼格式化工具,保持代碼一致性。使用 Linting 工具,提前發現潛在問題。

notepad 怎麼運行python notepad 怎麼運行python Apr 16, 2025 pm 07:33 PM

在 Notepad 中運行 Python 代碼需要安裝 Python 可執行文件和 NppExec 插件。安裝 Python 並為其添加 PATH 後,在 NppExec 插件中配置命令為“python”、參數為“{CURRENT_DIRECTORY}{FILE_NAME}”,即可在 Notepad 中通過快捷鍵“F6”運行 Python 代碼。

vscode 擴展是否是惡意的 vscode 擴展是否是惡意的 Apr 15, 2025 pm 07:57 PM

VS Code 擴展存在惡意風險,例如隱藏惡意代碼、利用漏洞、偽裝成合法擴展。識別惡意擴展的方法包括:檢查發布者、閱讀評論、檢查代碼、謹慎安裝。安全措施還包括:安全意識、良好習慣、定期更新和殺毒軟件。

See all articles