Maison > Périphériques technologiques > IA > le corps du texte

Créez un classificateur d'apprentissage profond pour les photos de chats et de chiens à l'aide de TensorFlow et Keras.

PHPz
Libérer: 2023-05-16 09:34:16
avant
1272 Les gens l'ont consulté

Créez un classificateur dapprentissage profond pour les photos de chats et de chiens à laide de TensorFlow et Keras.

Dans cet article, nous utiliserons TensorFlow et Keras pour créer un classificateur d'images capable de différencier les images de chats et de chiens. Pour ce faire, nous utiliserons l'ensemble de données cats_vs_dogs de l'ensemble de données TensorFlow. L'ensemble de données se compose de 25 000 images étiquetées de chats et de chiens, dont 80 % sont utilisées pour l'entraînement, 10 % pour la validation et 10 % pour les tests.

Chargement des données

Nous commençons par charger l'ensemble de données à l'aide des ensembles de données TensorFlow. Divisez l'ensemble de données en ensemble d'entraînement, ensemble de validation et ensemble de test, représentant respectivement 80 %, 10 % et 10 % des données, et définissez une fonction pour afficher quelques exemples d'images dans l'ensemble de données.

<code>import tensorflow as tfimport matplotlib.pyplot as pltimport tensorflow_datasets as tfds# 加载数据(train_data, validation_data, test_data), info = tfds.load('cats_vs_dogs', split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'], with_info=True, as_supervised=True)# 获取图像的标签label_names = info.features['label'].names# 定义一个函数来显示一些样本图像plt.figure(figsize=(10, 10))for i, (image, label) in enumerate(train_data.take(9)):ax = plt.subplot(3, 3, i + 1)plt.imshow(image)plt.title(label_names[label])plt.axis('off')</code>
Copier après la connexion

Créez un classificateur dapprentissage profond pour les photos de chats et de chiens à laide de TensorFlow et Keras.

Prétraitement des données

Avant d'entraîner le modèle, les données doivent être prétraitées. L'image sera redimensionnée à une taille uniforme de 150 x 150 pixels, les valeurs des pixels seront normalisées entre 0 et 1 et les données seront traitées par lots afin de pouvoir être importées dans le modèle par lots.

<code>IMG_SIZE = 150</code>
Copier après la connexion
<code>def format_image(image, label):image = tf.cast(image, tf.float32) / 255.0# Normalize the pixel valuesimage = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))# Resize to the desired sizereturn image, labelbatch_size = 32train_data = train_data.map(format_image).shuffle(1000).batch(batch_size)validation_data = validation_data.map(format_image).batch(batch_size)test_data = test_data.map(format_image).batch(batch_size)</code>
Copier après la connexion

Créez un classificateur dapprentissage profond pour les photos de chats et de chiens à laide de TensorFlow et Keras.

Création du modèle

Cet article utilisera le modèle MobileNet V2 pré-entraîné comme modèle de base et y ajoutera une couche de pooling moyenne globale et une couche compacte pour la classification. Cet article va geler les poids du modèle de base afin que seuls les poids de la couche supérieure soient mis à jour pendant l'entraînement.

<code>base_model = tf.keras.applications.MobileNetV2(input_shape=(IMG_SIZE, IMG_SIZE, 3), include_top=False, weights='imagenet')base_model.trainable = False</code>
Copier après la connexion
<code>global_average_layer = tf.keras.layers.GlobalAveragePooling2D()prediction_layer = tf.keras.layers.Dense(1)model = tf.keras.Sequential([base_model,global_average_layer,prediction_layer])model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.0001),loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),metrics=['accuracy'])</code>
Copier après la connexion
Copier après la connexion

Entraînement du modèle

Cet article entraînera le modèle pendant 3 cycles et le validera sur l'ensemble de validation après chaque cycle. Nous sauvegarderons le modèle après la formation afin de pouvoir l'utiliser lors de futurs tests.

<code>global_average_layer = tf.keras.layers.GlobalAveragePooling2D()prediction_layer = tf.keras.layers.Dense(1)model = tf.keras.Sequential([base_model,global_average_layer,prediction_layer])model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.0001),loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),metrics=['accuracy'])</code>
Copier après la connexion
Copier après la connexion
<code>history = model.fit(train_data,epochs=3,validation_data=validation_data)</code>
Copier après la connexion

Créez un classificateur dapprentissage profond pour les photos de chats et de chiens à laide de TensorFlow et Keras.

Historique du modèle

Si vous voulez savoir comment fonctionne la couche Mobilenet V2, l'image ci-dessous est le résultat de cette couche.

Créez un classificateur dapprentissage profond pour les photos de chats et de chiens à laide de TensorFlow et Keras.

Évaluer le modèle

Une fois la formation terminée, le modèle sera évalué sur l'ensemble de test pour voir ses performances sur les nouvelles données.

<code>loaded_model = tf.keras.models.load_model('cats_vs_dogs.h5')test_loss, test_accuracy = loaded_model.evaluate(test_data)</code>
Copier après la connexion
<code>print('Test accuracy:', test_accuracy)</code>
Copier après la connexion

Prédiction

Enfin, cet article utilisera le modèle pour prédire quelques exemples d'images dans l'ensemble de test et afficher les résultats.

<code>for image , _ in test_.take(90) : passpre = loaded_model.predict(image)plt.figure(figsize = (10 , 10))j = Nonefor value in enumerate(pre) : plt.subplot(7,7,value[0]+1)plt.imshow(image[value[0]])plt.xticks([])plt.yticks([])if value[1] > pre.mean() :j = 1color = 'blue' if j == _[value[0]] else 'red'plt.title('dog' , color = color)else : j = 0color = 'blue' if j == _[value[0]] else 'red'plt.title('cat' , color = color)plt.show()</code>
Copier après la connexion

Créez un classificateur dapprentissage profond pour les photos de chats et de chiens à laide de TensorFlow et Keras.

Fait ! Nous avons créé un classificateur d'images capable de différencier les images de chats et de chiens à l'aide de TensorFlow et Keras. Avec quelques ajustements et ajustements, cette approche peut également être appliquée à d’autres problèmes de classification d’images.

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Étiquettes associées:
source:51cto.com
Déclaration de ce site Web
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn
Tutoriels populaires
Plus>
Derniers téléchargements
Plus>
effets Web
Code source du site Web
Matériel du site Web
Modèle frontal