ホームページ バックエンド開発 Python チュートリアル PythonによるK最近傍アルゴリズムの原理と実装(ソースコード添付)

PythonによるK最近傍アルゴリズムの原理と実装(ソースコード添付)

Oct 27, 2018 pm 02:21 PM
python scikit-learn 機械学習 アルゴリズム

この記事で紹介する内容は、Python での K 最近傍アルゴリズムの原理と実装に関するものです (ソース コードが添付されています)。一定の参考価値があります。必要な友人が参考にしていただければ幸いです。あなたに役立ちます。

k 最近傍アルゴリズムは、異なる特徴値間の距離を測定することによって分類を実行します。

k-最近傍アルゴリズムの原理

ラベル付きのトレーニング サンプル セットの場合、ラベルなしの新しいデータを入力した後、新しいデータの各特徴がサンプルに集中します。データに対応する特徴が比較され、アルゴリズムに従ってサンプル データセット内の最も類似した上位 k 個のデータが選択され、k 個の最も類似したデータの中で最も多く出現した分類が新しいデータの分類として選択されます。 。

k-最近傍アルゴリズムの実装

ここでは 1 つの新しいデータの予測のみを示し、同時に複数の新しいデータを予測する場合は後ほど説明します。

トレーニング サンプル セット X_train (X_train.shape=(10, 2))、対応するマーク y_train (y_train.shape=(10,)、0、1 を含む) があると仮定し、matplotlib を使用します。描画する pyplot 次のように表されます (緑の点はマーク 0 を表し、赤の点はマーク 1 を表します):

PythonによるK最近傍アルゴリズムの原理と実装(ソースコード添付)

新しいデータがあります: x (x = np.array([3.18557125, 6.03119673]))、描画表現は次のとおりです (青い点):

PythonによるK最近傍アルゴリズムの原理と実装(ソースコード添付)

まず、次を使用します。オイラー距離 この式は、x から各サンプルまでの距離を計算します。元のデータの順序に影響します:

import math

distances = [math.sqrt(np.sum((x_train - x) ** 2)) for x_train in X_train]
ログイン後にコピー

3 番目に、k 個の最も近いサンプルに対応するマークを取得します: <div class="code" style="position:relative; padding:0px; margin:0px;"><pre class="brush:php;toolbar:false">import numpy as np nearest = np.argsort(distances)</pre><div class="contentsignin">ログイン後にコピー</div></div>最後に、k 個の最も近いサンプルに対応するマークを取得します。統計を作成し、マークの割合が最も大きい予測分類 (x です) を見つけます。この例の予測分類は 0 です。

topK_y = [y_train[i] for i in nearest[:k]]
ログイン後にコピー

上記のコードを次のようにカプセル化します。メソッド:

from collections import Counter

votes = Counter(topK_y)
votes.most_common(1)[0][0]
ログイン後にコピー

Scikit Learn の K 近傍アルゴリズム

典型的な機械学習アルゴリズムのプロセスでは、トレーニング データ セットを使用してモデルをトレーニング (適合) します。機械学習アルゴリズムを通じて、このモデルを使用して入力サンプルの結果を予測します。

k 最近傍アルゴリズムは、モデルを持たない特殊なアルゴリズムですが、その学習データセットをモデルとみなします。これは、Scikit Learn でどのように処理されるかです。 PythonによるK最近傍アルゴリズムの原理と実装(ソースコード添付)

Scikit Learn の k 近傍アルゴリズムは、

Scikit Learn の k 近傍アルゴリズムは、neighbors モジュール内にあります。初期化時のパラメータ n_neighbors は、 6、これは上記です。 k:

import numpy as np
import math

from collections import Counter


def kNN(k, X_train, y_train, x):
    distances = [math.sqrt(np.sum((x_train - x) ** 2)) for x_train in X_train]
    nearest = np.argsort(distances)

    topK_y = [y_train[i] for i in nearest[:k]]
    votes = Counter(topK_y)
    return votes.most_common(1)[0][0]
ログイン後にコピー

fit()

メソッドは、トレーニング データ セットに基づいて分類子を「トレーニング」し、分類子自体を返します。

from sklearn.neighbors import KNeighborsClassifier

kNN_classifier = KNeighborsClassifier(n_neighbors=6)
ログイン後にコピー

predict( ) メソッドは入力の結果を予測します。このメソッドでは渡されるパラメーターの型が行列である必要があります。したがって、

reshape

操作は最初に x に対して実行されます。<div class="code" style="position:relative; padding:0px; margin:0px;"><pre class="brush:php;toolbar:false">kNN_classifier.fit(X_train, y_train)</pre><div class="contentsignin">ログイン後にコピー</div></div>y_predict 値は 0 であり、以前に実装された kNN メソッドの結果と一致します。

Scikit Learn で KNeighborsClassifier 分類器を実装する

KNNClassifier クラスを定義し、そのコンストラクター メソッドは、予測中に選択された最も類似したデータを表すパラメーター k を渡します。番号:

X_predict = x.reshape(1, -1)
y_predict = kNN_classifier.predict(X_predict)
ログイン後にコピー

fit()

このメソッドは分類子をトレーニングし、分類子自体を返します:

class KNNClassifier:
    def __init__(self, k):
        self.k = k
        self._X_train = None
        self._y_train = None
ログイン後にコピー

predict() メソッドはデータですテスト対象に設定 予測を行うために、パラメーター X_predict のタイプは行列です。このメソッドは、リスト分析を使用して X_predict を走査し、テスト対象のデータごとに

_predict()

メソッドを 1 回呼び出します。 <div class="code" style="position:relative; padding:0px; margin:0px;"><pre class="brush:php;toolbar:false">def fit(self, X_train, y_train):     self._X_train = X_train     self._y_train = y_train     return self</pre><div class="contentsignin">ログイン後にコピー</div></div>アルゴリズムの精度

モデルの問題

モデルはトレーニング サンプル セットを通じてトレーニングされましたが、このモデルがどれほど優れているかはわかりませんが、問題が 2 つあります。

モデルが悪い場合、予測結果は期待どおりではありません。同時に、実際の状況では、実際のラベルを取得してモデルをテストすることは困難です。

  1. モデルをトレーニングする場合、トレーニング サンプルにはすべてのマーカーが含まれているわけではありません。

  2. 最初の質問では、通常、サンプル セット内のデータの一定の割合 (20% など) がテスト データとして使用され、残りのデータがトレーニング データとして使用されます。

  3. Scikit Learn で提供される虹彩データを例に挙げます。これには 150 個のサンプルが含まれています。
def predict(self, X_predict):
    y_predict = [self._predict(x) for x in X_predict]
    return np.array(y_predict)

def _predict(self, x):
    distances = [math.sqrt(np.sum((x_train - x) ** 2))
                 for x_train in self._X_train]
    nearest = np.argsort(distances)

    topK_y = [self._y_train[i] for i in nearest[:self.k]]
    votes = Counter(topK_y)

    return votes.most_common(1)[0][0]
ログイン後にコピー

次に、サンプルを 20% のサンプル テスト データと 80% の比率トレーニング データに分割します。

import numpy as np
from sklearn import datasets

iris = datasets.load_iris()
X = iris.data
y = iris.target
ログイン後にコピー

X_train と y_train をモデルをトレーニングするためのトレーニング データとして使用し、X_test と y_test をモデルのテスト データとして使用します。モデルの精度の検証。

2 番目の質問では、Scikit Learn で提供されている虹彩データを例として、y マークの内容は次のとおりです。

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
ログイン後にコピー

发现0、1、2是以顺序存储的,在将样本划分为训练数据和测试数据过程中,如果训练数据中才对标记只包含0、1,这样的训练数据对于模型的训练将是致命的。以此,应将样本数据先进行随机处理。

np.random.permutation() 方法传入一个整数 n,会返回一个区间在 [0, n) 且随机排序的一维数组。将 X 的长度作为参数传入,返回 X 索引的随机数组:

shuffle_indexes = np.random.permutation(len(X))
ログイン後にコピー

将随机化的索引数组分为训练数据的索引与测试数据的索引两部分:

test_ratio = 0.2
test_size = int(len(X) * test_ratio)

test_indexes = shuffle_indexes[:test_size]
train_indexes = shuffle_indexes[test_size:]
ログイン後にコピー

再通过两部分的索引将样本数据分为训练数据和测试数据:

X_train = X[train_indexes]
y_train = y[train_indexes]

X_test = X[test_indexes]
y_test = y[test_indexes]
ログイン後にコピー

可以将两个问题的解决方案封装到一个方法中,seed 表示随机数种子,作用在 np.random 中:

import numpy as np

def train_test_split(X, y, test_ratio=0.2, seed=None):
    if seed:
        np.random.seed(seed)
    shuffle_indexes = np.random.permutation(len(X))

    test_size = int(len(X) * test_ratio)
    test_indexes = shuffle_indexes[:test_size]
    train_indexes = shuffle_indexes[test_size:]

    X_train = X[train_indexes]
    y_train = y[train_indexes]

    X_test = X[test_indexes]
    y_test = y[test_indexes]

    return X_train, X_test, y_train, y_test
ログイン後にコピー

Scikit Learn 中封装了 train_test_split() 方法,放在了 model_selection 模块中:

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
ログイン後にコピー

算法正确率

通过 train_test_split() 方法对样本数据进行了预处理后,开始训练模型,并且对测试数据进行验证:

from sklearn.neighbors import KNeighborsClassifier

kNN_classifier = KNeighborsClassifier(n_neighbors=6)
kNN_classifier.fit(X_train, y_train)
y_predict = kNN_classifier.predict(X_test)
ログイン後にコピー

y_predict 是对测试数据 X_test 的预测结果,其中与 y_test 相等的个数除以 y_test 的个数就是该模型的正确率,将其和 y_test 进行比较可以算出模型的正确率:

def accuracy_score(y_true, y_predict):
    return sum(y_predict == y_true) / len(y_true)
ログイン後にコピー

调用该方法,返回一个小于等于1的浮点数:

accuracy_score(y_test, y_predict)
ログイン後にコピー

同样在 Scikit Learn 的 metrics 模块中封装了 accuracy_score() 方法:

from sklearn.metrics import accuracy_score

accuracy_score(y_test, y_predict)
ログイン後にコピー

Scikit Learn 中的 KNeighborsClassifier 类的父类 ClassifierMixin 中有一个 score() 方法,里面就调用了 accuracy_score() 方法,将测试数据 X_test 和 y_test 作为参数传入该方法中,可以直接计算出算法正确率。

class ClassifierMixin(object):
    def score(self, X, y, sample_weight=None):
        from .metrics import accuracy_score
        return accuracy_score(y, self.predict(X), sample_weight=sample_weight)
ログイン後にコピー

超参数

前文中提到的 k 是一种超参数,超参数是在算法运行前需要决定的参数。 Scikit Learn 中 k-近邻算法包含了许多超参数,在初始化构造函数中都有指定:

def __init__(self, n_neighbors=5,
             weights='uniform', algorithm='auto', leaf_size=30,
             p=2, metric='minkowski', metric_params=None, n_jobs=None,
             **kwargs):
    # code here
ログイン後にコピー

这些超参数的含义在源代码和官方文档[scikit-learn.org]中都有说明。

算法优缺点

k-近邻算法是一个比较简单的算法,有其优点但也有缺点。

优点是思想简单,但效果强大, 天然的适合多分类问题。

缺点是效率低下,比如一个训练集有 m 个样本,n 个特征,则预测一个新的数据的算法复杂度为 O(m*n);同时该算法可能产生维数灾难,当维数很大时,两个点之间的距离可能也很大,如 (0,0,0,...,0) 和 (1,1,1,...,1)(10000维)之间的距离为100。

源码地址

Github | ML-Algorithms-Action


以上がPythonによるK最近傍アルゴリズムの原理と実装(ソースコード添付)の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

このウェブサイトの声明
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。

ホットAIツール

Undresser.AI Undress

Undresser.AI Undress

リアルなヌード写真を作成する AI 搭載アプリ

AI Clothes Remover

AI Clothes Remover

写真から衣服を削除するオンライン AI ツール。

Undress AI Tool

Undress AI Tool

脱衣画像を無料で

Clothoff.io

Clothoff.io

AI衣類リムーバー

AI Hentai Generator

AI Hentai Generator

AIヘンタイを無料で生成します。

ホットツール

メモ帳++7.3.1

メモ帳++7.3.1

使いやすく無料のコードエディター

SublimeText3 中国語版

SublimeText3 中国語版

中国語版、とても使いやすい

ゼンドスタジオ 13.0.1

ゼンドスタジオ 13.0.1

強力な PHP 統合開発環境

ドリームウィーバー CS6

ドリームウィーバー CS6

ビジュアル Web 開発ツール

SublimeText3 Mac版

SublimeText3 Mac版

神レベルのコード編集ソフト(SublimeText3)

PSフェザーリングは、遷移の柔らかさをどのように制御しますか? PSフェザーリングは、遷移の柔らかさをどのように制御しますか? Apr 06, 2025 pm 07:33 PM

羽毛の鍵は、その漸進的な性質を理解することです。 PS自体は、勾配曲線を直接制御するオプションを提供しませんが、複数の羽毛、マッチングマスク、および細かい選択により、半径と勾配の柔らかさを柔軟に調整して、自然な遷移効果を実現できます。

インストール後にMySQLの使用方法 インストール後にMySQLの使用方法 Apr 08, 2025 am 11:48 AM

この記事では、MySQLデータベースの操作を紹介します。まず、MySQLWorkBenchやコマンドラインクライアントなど、MySQLクライアントをインストールする必要があります。 1. mysql-uroot-pコマンドを使用してサーバーに接続し、ルートアカウントパスワードでログインします。 2。CreatedAtaBaseを使用してデータベースを作成し、データベースを選択します。 3. createTableを使用してテーブルを作成し、フィールドとデータ型を定義します。 4. INSERTINTOを使用してデータを挿入し、データをクエリし、更新することでデータを更新し、削除してデータを削除します。これらの手順を習得することによってのみ、一般的な問題に対処することを学び、データベースのパフォーマンスを最適化することでMySQLを効率的に使用できます。

mysqlは支払う必要がありますか mysqlは支払う必要がありますか Apr 08, 2025 pm 05:36 PM

MySQLには、無料のコミュニティバージョンと有料エンタープライズバージョンがあります。コミュニティバージョンは無料で使用および変更できますが、サポートは制限されており、安定性要件が低く、技術的な能力が強いアプリケーションに適しています。 Enterprise Editionは、安定した信頼性の高い高性能データベースを必要とするアプリケーションに対する包括的な商業サポートを提供し、サポートの支払いを喜んでいます。バージョンを選択する際に考慮される要因には、アプリケーションの重要性、予算編成、技術スキルが含まれます。完璧なオプションはなく、最も適切なオプションのみであり、特定の状況に応じて慎重に選択する必要があります。

PSフェザーリングをセットアップする方法は? PSフェザーリングをセットアップする方法は? Apr 06, 2025 pm 07:36 PM

PSフェザーリングは、イメージエッジブラー効果であり、エッジエリアのピクセルの加重平均によって達成されます。羽の半径を設定すると、ぼやけの程度を制御でき、値が大きいほどぼやけます。半径の柔軟な調整は、画像とニーズに応じて効果を最適化できます。たとえば、キャラクターの写真を処理する際に詳細を維持するためにより小さな半径を使用し、より大きな半径を使用してアートを処理するときにかすんだ感覚を作成します。ただし、半径が大きすぎるとエッジの詳細を簡単に失う可能性があり、効果が小さすぎると明らかになりません。羽毛効果は画像解像度の影響を受け、画像の理解と効果の把握に従って調整する必要があります。

MySQLインストール後にデータベースのパフォーマンスを最適化する方法 MySQLインストール後にデータベースのパフォーマンスを最適化する方法 Apr 08, 2025 am 11:36 AM

MySQLパフォーマンスの最適化は、インストール構成、インデックス作成、クエリの最適化、監視、チューニングの3つの側面から開始する必要があります。 1。インストール後、INNODB_BUFFER_POOL_SIZEパラメーターやclose query_cache_sizeなど、サーバーの構成に従ってmy.cnfファイルを調整する必要があります。 2。過度のインデックスを回避するための適切なインデックスを作成し、説明コマンドを使用して実行計画を分析するなど、クエリステートメントを最適化します。 3. MySQL独自の監視ツール(ShowProcessList、ShowStatus)を使用して、データベースの健康を監視し、定期的にデータベースをバックアップして整理します。これらの手順を継続的に最適化することによってのみ、MySQLデータベースのパフォーマンスを改善できます。

MySQLはダウンロード後にインストールできません MySQLはダウンロード後にインストールできません Apr 08, 2025 am 11:24 AM

MySQLのインストール障害の主な理由は次のとおりです。1。許可の問題、管理者として実行するか、SUDOコマンドを使用する必要があります。 2。依存関係が欠落しており、関連する開発パッケージをインストールする必要があります。 3.ポート競合では、ポート3306を占めるプログラムを閉じるか、構成ファイルを変更する必要があります。 4.インストールパッケージが破損しているため、整合性をダウンロードして検証する必要があります。 5.環境変数は誤って構成されており、環境変数はオペレーティングシステムに従って正しく構成する必要があります。これらの問題を解決し、各ステップを慎重に確認して、MySQLを正常にインストールします。

高負荷アプリケーションのMySQLパフォーマンスを最適化する方法は? 高負荷アプリケーションのMySQLパフォーマンスを最適化する方法は? Apr 08, 2025 pm 06:03 PM

MySQLデータベースパフォーマンス最適化ガイドリソース集約型アプリケーションでは、MySQLデータベースが重要な役割を果たし、大規模なトランザクションの管理を担当しています。ただし、アプリケーションのスケールが拡大すると、データベースパフォーマンスのボトルネックが制約になることがよくあります。この記事では、一連の効果的なMySQLパフォーマンス最適化戦略を検討して、アプリケーションが高負荷の下で効率的で応答性の高いままであることを保証します。実際のケースを組み合わせて、インデックス作成、クエリ最適化、データベース設計、キャッシュなどの詳細な主要なテクノロジーを説明します。 1.データベースアーキテクチャの設計と最適化されたデータベースアーキテクチャは、MySQLパフォーマンスの最適化の基礎です。いくつかのコア原則は次のとおりです。適切なデータ型を選択し、ニーズを満たす最小のデータ型を選択すると、ストレージスペースを節約するだけでなく、データ処理速度を向上させることもできます。

MySQLインストール後に開始できないサービスのソリューション MySQLインストール後に開始できないサービスのソリューション Apr 08, 2025 am 11:18 AM

MySQLは開始を拒否しましたか?パニックにならないでください、チェックしてみましょう!多くの友人は、MySQLのインストール後にサービスを開始できないことを発見し、彼らはとても不安でした!心配しないでください、この記事はあなたがそれを落ち着いて対処し、その背後にある首謀者を見つけるためにあなたを連れて行きます!それを読んだ後、あなたはこの問題を解決するだけでなく、MySQLサービスの理解と問題のトラブルシューティングのためのあなたのアイデアを改善し、より強力なデータベース管理者になることができます! MySQLサービスは開始に失敗し、単純な構成エラーから複雑なシステムの問題に至るまで、多くの理由があります。最も一般的な側面から始めましょう。基本知識:サービススタートアッププロセスMYSQLサービススタートアップの簡単な説明。簡単に言えば、オペレーティングシステムはMySQL関連のファイルをロードし、MySQLデーモンを起動します。これには構成が含まれます

See all articles