Astuces de matrice de confusion en Python

WBOY
Libérer: 2023-06-11 10:43:54
original
2534 Les gens l'ont consulté

Avec la popularité de l'apprentissage automatique et de l'exploration de données, de plus en plus de data scientists et de chercheurs commencent à utiliser Python, un langage de programmation de haut niveau, pour traiter et analyser des données, et l'intuitivité et la facilité d'utilisation de Python le rendent populaire. in Il est largement utilisé dans les domaines de l’apprentissage profond et de l’intelligence artificielle. Cependant, de nombreux débutants rencontrent certaines difficultés lors de l'utilisation de Python, dont la difficulté de la matrice de confusion. Dans cet article, nous présenterons l'utilisation des matrices de confusion en Python et quelques techniques utiles pour traiter les matrices de confusion.

1. Qu'est-ce qu'une matrice de confusion

En deep learning et en data mining, la matrice de confusion est un tableau rectangulaire utilisé pour comparer la différence entre les résultats prédits et les résultats réels. Cette matrice montre les performances de l'algorithme de classification, y compris des indicateurs importants tels que l'exactitude, le taux d'erreur, la précision et le rappel de l'algorithme de classification. La matrice de confusion visualise généralement les performances du classificateur et fournit la référence principale pour les résultats de prédiction pour l'amélioration et l'optimisation du classificateur.

Normalement, la matrice de confusion se compose de quatre paramètres :

  • True Positive (TP) : L'algorithme de classification prédit correctement la classe positive comme une classe positive.
  • Faux Négatif (FN) : L'algorithme de classification prédit à tort une classe positive comme une classe négative.
  • Faux Positif (FP) : L'algorithme de classification prédit à tort une classe négative comme une classe positive.
  • True Negative (TN) : L'algorithme de classification prédit correctement une classe négative comme une classe négative.

2. Comment calculer la matrice de confusion

La bibliothèque scikit-learn en Python fournit une fonction pratique pour calculer la matrice de confusion. Cette fonction s'appelle confusion_matrix() et peut être utilisée comme entrée entre le classificateur et les résultats réels de l'ensemble de tests, et renvoie les valeurs des paramètres de la matrice de confusion. La syntaxe de cette fonction est la suivante :

from sklearn.metrics import confusion_matrix
confusion_matrix(y_true, y_pred, labels=None, sample_weight=None)
Copier après la connexion

Parmi eux, y_true représente le résultat correct du classificateur, y_pred représente le résultat de prédiction du classificateur et labels représente le nom de l'étiquette de classe (si non fourni, la valeur par défaut est celle de y_true et la valeur extraite de y_pred), sample_weight représente le poids de chaque échantillon (si cela n'est pas nécessaire, ne définissez pas ce paramètre).

Par exemple, supposons que nous devions calculer la matrice de confusion des données suivantes :

y_true = [1, 0, 1, 2, 0, 1]
y_pred = [1, 0, 2, 1, 0, 2]
Copier après la connexion

Afin de calculer la matrice de confusion, vous pouvez utiliser le code suivant : #🎜 🎜#

from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_true, y_pred)
print(cm)
Copier après la connexion

Le résultat est affiché :

array([[2, 0, 0],
       [0, 1, 2],
       [0, 1, 0]])
Copier après la connexion

C'est-à-dire que la matrice de confusion montre que "1" est correctement classé comme "1" 2 fois et "0" est correctement classé comme "0" 1 fois, "2" a été correctement classé comme "2" 0 fois, "1" a été incorrectement classé comme "2" 2 fois, "2" a été incorrectement classé comme "1" 1 fois, "0" C'était incorrectement classé comme « 2 » une fois.

3. Afficher la matrice de confusion

Il existe de nombreuses situations où nous avons besoin d'une meilleure visualisation de la matrice de confusion. La bibliothèque matplotlib en Python peut visualiser des matrices de confusion. Ce qui suit est du code Python qui utilise la bibliothèque matplotlib et sklearn.metrics pour visualiser la matrice de confusion.

import itertools
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

# Compute confusion matrix
cnf_matrix = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)

# Plot non-normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names,
                      title='Confusion matrix, without normalization')

# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
                      title='Normalized confusion matrix')

plt.show()
Copier après la connexion
Dans le code ci-dessus, nous définissons une fonction personnalisée appelée plot_confusion_matrix(), qui accepte comme paramètres les paramètres de la matrice de confusion, la chaîne de texte du nom de la catégorie et la matrice de confusion comme couleur image Sortie, où la couleur de chaque cellule de la matrice de confusion représente l'ampleur de sa valeur. Ensuite, nous devons calculer la matrice de confusion en utilisant les catégories vraies et prédites respectives et représenter la matrice de confusion à l'aide de la fonction plot_confusion_matrix() définie ci-dessus.

4. Résumé

Le langage Python fournit un grand nombre de bibliothèques de visualisation et d'analyse de données, qui peuvent permettre aux scientifiques et aux chercheurs de mener davantage d'analyses de données d'apprentissage en profondeur et d'intelligence artificielle. rapidement . Dans cet article, nous présentons la matrice de confusion et ses applications, ainsi que comment calculer la matrice de confusion en Python et comment utiliser la bibliothèque matplotlib pour générer des graphiques de la matrice de confusion. La technologie des matrices de confusion a des applications importantes dans les domaines de l’apprentissage profond et de l’intelligence artificielle. Il est donc indispensable d’apprendre la technologie des matrices de confusion.

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Étiquettes associées:
source:php.cn
Déclaration de ce site Web
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn
Tutoriels populaires
Plus>
Derniers téléchargements
Plus>
effets Web
Code source du site Web
Matériel du site Web
Modèle frontal