Table des matières
Principe du modèle VAE
Ensemble de données
Implémentation du modèle VAE
Formation du modèle VAE
Générer de nouvelles images de chiffres manuscrits
Résumé
Maison développement back-end Tutoriel Python Exemple d'algorithme VAE en Python

Exemple d'algorithme VAE en Python

Jun 11, 2023 pm 07:58 PM
python 实例 vae算法

VAE est un modèle génératif, son nom complet est Variational Autoencoder et sa traduction chinoise est variational autoencoder. Il s'agit d'un algorithme d'apprentissage non supervisé qui peut être utilisé pour générer de nouvelles données, telles que des images, de l'audio, du texte, etc. Comparés aux auto-encodeurs ordinaires, les VAE sont plus flexibles et plus puissants et peuvent générer des données plus complexes et plus réalistes.

Python est l'un des langages de programmation les plus utilisés et l'un des principaux outils d'apprentissage en profondeur. En Python, il existe de nombreux excellents frameworks d'apprentissage automatique et d'apprentissage profond, tels que TensorFlow, PyTorch, Keras, etc., qui ont tous des implémentations VAE.

Cet article utilisera un exemple de code Python pour présenter comment utiliser TensorFlow pour implémenter l'algorithme VAE et générer de nouvelles images de chiffres manuscrites.

Principe du modèle VAE

VAE est une méthode d'apprentissage non supervisée qui peut extraire des fonctionnalités potentielles des données et utiliser ces fonctionnalités pour générer de nouvelles données. VAE apprend la distribution des données en considérant la distribution de probabilité des variables latentes. Il mappe les données originales dans un espace latent et convertit l'espace latent en données reconstruites via un décodeur.

La structure du modèle de VAE comprend deux parties : l'encodeur et le décodeur. L'encodeur compresse les données d'origine dans l'espace de variables latentes et le décodeur mappe les variables latentes à l'espace de données d'origine. Entre l'encodeur et le décodeur, il existe également une couche de reparamétrage pour garantir que l'échantillonnage des variables latentes est différentiable.

La fonction de perte de VAE se compose de deux parties. L'une est l'erreur de reconstruction, qui est la distance entre les données d'origine et les données générées par le décodeur. L'autre partie est le terme de régularisation, qui est utilisé pour limiter la distribution. des variables latentes.

Ensemble de données

Nous utiliserons l'ensemble de données MNIST pour entraîner le modèle VAE et générer de nouvelles images de chiffres manuscrits. L'ensemble de données MNIST contient un ensemble d'images de chiffres manuscrites, chaque image est une image en niveaux de gris 28 × 28.

Nous pouvons utiliser l'API fournie par TensorFlow pour charger l'ensemble de données MNIST et convertir l'image sous forme vectorielle. Le code est le suivant :

import tensorflow as tf
import numpy as np

# 加载MNIST数据集
mnist = tf.keras.datasets.mnist

# 加载训练集和测试集
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 将图像转换为向量形式
x_train = x_train.astype(np.float32) / 255.
x_test = x_test.astype(np.float32) / 255.
x_train = x_train.reshape((-1, 28 * 28))
x_test = x_test.reshape((-1, 28 * 28))
Copier après la connexion

Implémentation du modèle VAE

Nous pouvons utiliser TensorFlow pour implémenter le modèle VAE. Le codeur et le décodeur sont tous deux des réseaux neuronaux multicouches et la couche de reparamétrage est une couche aléatoire.

Le code d'implémentation du modèle VAE est le suivant :

import tensorflow_probability as tfp

# 定义编码器
encoder_inputs = tf.keras.layers.Input(shape=(784,))
x = tf.keras.layers.Dense(256, activation='relu')(encoder_inputs)
x = tf.keras.layers.Dense(128, activation='relu')(x)
mean = tf.keras.layers.Dense(10)(x)
logvar = tf.keras.layers.Dense(10)(x)

# 定义重参数化层
def sampling(args):
    mean, logvar = args
    epsilon = tfp.distributions.Normal(0., 1.).sample(tf.shape(mean))
    return mean + tf.exp(logvar / 2) * epsilon

z = tf.keras.layers.Lambda(sampling)([mean, logvar])

# 定义解码器
decoder_inputs = tf.keras.layers.Input(shape=(10,))
x = tf.keras.layers.Dense(128, activation='relu')(decoder_inputs)
x = tf.keras.layers.Dense(256, activation='relu')(x)
decoder_outputs = tf.keras.layers.Dense(784, activation='sigmoid')(x)

# 构建模型
vae = tf.keras.models.Model(encoder_inputs, decoder_outputs)

# 定义损失函数
reconstruction = -tf.reduce_sum(encoder_inputs * tf.math.log(1e-10 + decoder_outputs) + 
                                (1 - encoder_inputs) * tf.math.log(1e-10 + 1 - decoder_outputs), axis=1)
kl_divergence = -0.5 * tf.reduce_sum(1 + logvar - tf.square(mean) - tf.exp(logvar), axis=-1)
vae_loss = tf.reduce_mean(reconstruction + kl_divergence)

vae.add_loss(vae_loss)
vae.compile(optimizer='rmsprop')
vae.summary()
Copier après la connexion

Lors de l'écriture du code, vous devez faire attention aux points suivants :

  • Utilisez la couche Lambda pour implémenter des opérations de paramétrage lourdes
  • La fonction de perte inclut la reconstruction termes d'erreur et de régularisation
  • Ajoutez la fonction de perte au modèle, il n'est pas nécessaire de calculer manuellement le gradient, vous pouvez directement utiliser l'optimiseur pour la formation

Formation du modèle VAE

Nous pouvons utiliser l'ensemble de données MNIST pour former le Modèle VAE. Le code pour entraîner le modèle est le suivant :

vae.fit(x_train, x_train,
        epochs=50,
        batch_size=128,
        validation_data=(x_test, x_test))
Copier après la connexion

Pendant l'entraînement, nous pouvons utiliser plusieurs époques et des lots plus grands pour améliorer l'effet d'entraînement.

Générer de nouvelles images de chiffres manuscrits

Une fois la formation terminée, nous pouvons utiliser le modèle VAE pour générer de nouvelles images de chiffres manuscrits. Le code pour générer l'image est le suivant :

import matplotlib.pyplot as plt

# 随机生成潜在变量
z = np.random.normal(size=(1, 10))

# 将潜在变量解码为图像
generated = vae.predict(z)

# 将图像转换为灰度图像
generated = generated.reshape((28, 28))
plt.imshow(generated, cmap='gray')
plt.show()
Copier après la connexion

Nous pouvons générer différentes images de chiffres manuscrits en exécutant le code plusieurs fois. Ces images sont générées sur la base de la distribution des données apprises par VAE, avec diversité et créativité.

Résumé

Cet article présente comment implémenter l'algorithme VAE à l'aide de TensorFlow en Python, et démontre son application via l'ensemble de données MNIST et la génération de nouvelles images de chiffres manuscrites. En apprenant l’algorithme VAE, non seulement de nouvelles données peuvent être générées, mais également des caractéristiques potentielles des données peuvent être extraites, offrant ainsi une nouvelle idée pour l’analyse des données et la reconnaissance de formes.

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration de ce site Web
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn

Outils d'IA chauds

Undresser.AI Undress

Undresser.AI Undress

Application basée sur l'IA pour créer des photos de nu réalistes

AI Clothes Remover

AI Clothes Remover

Outil d'IA en ligne pour supprimer les vêtements des photos.

Undress AI Tool

Undress AI Tool

Images de déshabillage gratuites

Clothoff.io

Clothoff.io

Dissolvant de vêtements AI

AI Hentai Generator

AI Hentai Generator

Générez AI Hentai gratuitement.

Article chaud

R.E.P.O. Crystals d'énergie expliqués et ce qu'ils font (cristal jaune)
4 Il y a quelques semaines By 尊渡假赌尊渡假赌尊渡假赌
R.E.P.O. Meilleurs paramètres graphiques
4 Il y a quelques semaines By 尊渡假赌尊渡假赌尊渡假赌
R.E.P.O. Comment réparer l'audio si vous n'entendez personne
4 Il y a quelques semaines By 尊渡假赌尊渡假赌尊渡假赌
WWE 2K25: Comment déverrouiller tout dans Myrise
1 Il y a quelques mois By 尊渡假赌尊渡假赌尊渡假赌

Outils chauds

Bloc-notes++7.3.1

Bloc-notes++7.3.1

Éditeur de code facile à utiliser et gratuit

SublimeText3 version chinoise

SublimeText3 version chinoise

Version chinoise, très simple à utiliser

Envoyer Studio 13.0.1

Envoyer Studio 13.0.1

Puissant environnement de développement intégré PHP

Dreamweaver CS6

Dreamweaver CS6

Outils de développement Web visuel

SublimeText3 version Mac

SublimeText3 version Mac

Logiciel d'édition de code au niveau de Dieu (SublimeText3)

Python vs C: applications et cas d'utilisation comparés Python vs C: applications et cas d'utilisation comparés Apr 12, 2025 am 12:01 AM

Python convient à la science des données, au développement Web et aux tâches d'automatisation, tandis que C convient à la programmation système, au développement de jeux et aux systèmes intégrés. Python est connu pour sa simplicité et son écosystème puissant, tandis que C est connu pour ses capacités de contrôle élevées et sous-jacentes.

Comment utiliser les journaux Debian Apache pour améliorer les performances du site Web Comment utiliser les journaux Debian Apache pour améliorer les performances du site Web Apr 12, 2025 pm 11:36 PM

Cet article expliquera comment améliorer les performances du site Web en analysant les journaux Apache dans le système Debian. 1. Bases de l'analyse du journal APACH LOG enregistre les informations détaillées de toutes les demandes HTTP, y compris l'adresse IP, l'horodatage, l'URL de la demande, la méthode HTTP et le code de réponse. Dans Debian Systems, ces journaux sont généralement situés dans les répertoires /var/log/apache2/access.log et /var/log/apache2/error.log. Comprendre la structure du journal est la première étape d'une analyse efficace. 2.

Python: jeux, GUIS, et plus Python: jeux, GUIS, et plus Apr 13, 2025 am 12:14 AM

Python excelle dans les jeux et le développement de l'interface graphique. 1) Le développement de jeux utilise Pygame, fournissant des fonctions de dessin, audio et d'autres fonctions, qui conviennent à la création de jeux 2D. 2) Le développement de l'interface graphique peut choisir Tkinter ou Pyqt. Tkinter est simple et facile à utiliser, PYQT a des fonctions riches et convient au développement professionnel.

Laravel (PHP) contre Python: environnements de développement et écosystèmes Laravel (PHP) contre Python: environnements de développement et écosystèmes Apr 12, 2025 am 12:10 AM

La comparaison entre Laravel et Python dans l'environnement de développement et l'écosystème est la suivante: 1. L'environnement de développement de Laravel est simple, seul PHP et compositeur sont nécessaires. Il fournit une riche gamme de packages d'extension tels que Laravelforge, mais la maintenance des forfaits d'extension peut ne pas être opportun. 2. L'environnement de développement de Python est également simple, seuls Python et PIP sont nécessaires. L'écosystème est énorme et couvre plusieurs champs, mais la gestion de la version et de la dépendance peut être complexe.

PHP et Python: comparaison de deux langages de programmation populaires PHP et Python: comparaison de deux langages de programmation populaires Apr 14, 2025 am 12:13 AM

PHP et Python ont chacun leurs propres avantages et choisissent en fonction des exigences du projet. 1.Php convient au développement Web, en particulier pour le développement rapide et la maintenance des sites Web. 2. Python convient à la science des données, à l'apprentissage automatique et à l'intelligence artificielle, avec syntaxe concise et adaptée aux débutants.

Le rôle de Debian Sniffer dans la détection des attaques DDOS Le rôle de Debian Sniffer dans la détection des attaques DDOS Apr 12, 2025 pm 10:42 PM

Cet article traite de la méthode de détection d'attaque DDOS. Bien qu'aucun cas d'application directe de "Debiansniffer" n'ait été trouvé, les méthodes suivantes ne peuvent être utilisées pour la détection des attaques DDOS: technologie de détection d'attaque DDOS efficace: détection basée sur l'analyse du trafic: identification des attaques DDOS en surveillant des modèles anormaux de trafic réseau, tels que la croissance soudaine du trafic, une surtension dans des connexions sur des ports spécifiques, etc. Par exemple, les scripts Python combinés avec les bibliothèques Pyshark et Colorama peuvent surveiller le trafic réseau en temps réel et émettre des alertes. Détection basée sur l'analyse statistique: en analysant les caractéristiques statistiques du trafic réseau, telles que les données

Certificat NGINX SSL Mise à jour du tutoriel Debian Certificat NGINX SSL Mise à jour du tutoriel Debian Apr 13, 2025 am 07:21 AM

Cet article vous guidera sur la façon de mettre à jour votre certificat NGINXSSL sur votre système Debian. Étape 1: Installez d'abord CERTBOT, assurez-vous que votre système a des packages CERTBOT et Python3-CERTBOT-NGINX installés. Si ce n'est pas installé, veuillez exécuter la commande suivante: Sudoapt-getUpDaSuDoapt-GetInstallCertBotpyThon3-Certerbot-Nginx Étape 2: Obtenez et configurez le certificat Utilisez la commande Certbot pour obtenir le certificat LETSCRYPT et configure

Comment Debian Readdir s'intègre à d'autres outils Comment Debian Readdir s'intègre à d'autres outils Apr 13, 2025 am 09:42 AM

La fonction ReadDir dans le système Debian est un appel système utilisé pour lire le contenu des répertoires et est souvent utilisé dans la programmation C. Cet article expliquera comment intégrer ReadDir avec d'autres outils pour améliorer sa fonctionnalité. Méthode 1: combinant d'abord le programme de langue C et le pipeline, écrivez un programme C pour appeler la fonction readdir et sortir le résultat: # include # include # include # includeIntmain (intargc, char * argv []) {dir * dir; structDirent * entrée; if (argc! = 2) {

See all articles