Table of Contents
Dataset preprocessing
Loss function
Discriminator loss function
Generator loss function
Optimizer
Define training function
Randomly generate noise vector
Home Backend Development Python Tutorial GAN algorithm example in Python

GAN algorithm example in Python

Jun 10, 2023 am 09:53 AM
python Example gan algorithm

Generative Adversarial Networks (GAN, Generative Adversarial Networks) is a deep learning algorithm that generates new data through two neural networks competing with each other. GAN is widely used for generation tasks in image, audio, text and other fields. In this article, we will use Python to write an example of a GAN algorithm for generating images of handwritten digits.

  1. Dataset preparation

We will use the MNIST data set as our training data set. The MNIST dataset contains 60,000 training images and 10,000 test images, each image is a 28x28 grayscale image. We will use the TensorFlow library to load and process the dataset. Before loading the dataset, we need to install the TensorFlow library and the NumPy library.

import tensorflow as tf
import numpy as np

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

Dataset preprocessing

train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # Normalize pixel values ​​to the range of [-1, 1]

  1. GAN architecture design and training

Our GAN will include two neural networks: A generator network and a discriminator network. The generator network will receive the noise vector as input and output a 28x28 image. The discriminator network will receive a 28x28 image as input and output the probability that the image is a real image.

The architecture of both the generator network and the discriminator network will use convolutional neural networks (CNN). In the generator network, we will use a deconvolutional layer to decode the noise vector into a 28x28 image. In the discriminator network, we will use convolutional layers to classify the input images.

The input to the generator network is a noise vector of length 100. We will stack the network layers by using the tf.keras.Sequential function.

def make_generator_model():

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU())

model.add(tf.keras.layers.Reshape((7, 7, 256)))
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size没有限制

model.add(tf.keras.layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
assert model.output_shape == (None, 7, 7, 128)
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU())

model.add(tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
assert model.output_shape == (None, 14, 14, 64)
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU())

model.add(tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
assert model.output_shape == (None, 28, 28, 1)

return model
Copy after login

The input of the discriminator network is a 28x28 image. We will stack the network layers by using the tf.keras.Sequential function.

def make_discriminator_model():

model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                 input_shape=[28, 28, 1]))
model.add(tf.keras.layers.LeakyReLU())
model.add(tf.keras.layers.Dropout(0.3))

model.add(tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
model.add(tf.keras.layers.LeakyReLU())
model.add(tf.keras.layers.Dropout(0.3))

model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(1))

return model
Copy after login

Next, we will write the training code. We will alternately train the generator network and the discriminator network in each batch. During the training process, we will record the gradients by using the tf.GradientTape() function, and then optimize the network using the tf.keras.optimizers.Adam() function.

generator = make_generator_model()
discriminator = make_discriminator_model()

Loss function

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

Discriminator loss function

def discriminator_loss(real_output, fake_output):

real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
Copy after login

Generator loss function

def generator_loss(fake_output):

return cross_entropy(tf.ones_like(fake_output), fake_output)
Copy after login

Optimizer

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

Define training function

@tf.function
def train_step(images):

noise = tf.random.normal([BATCH_SIZE, 100])

with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    generated_images = generator(noise, training=True)

    real_output = discriminator(images, training=True)
    fake_output = discriminator(generated_images, training=True)

    gen_loss = generator_loss(fake_output)
    disc_loss = discriminator_loss(real_output, fake_output)

gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
Copy after login

BATCH_SIZE = 256
EPOCHS = 100

for epoch in range(EPOCHS):

for i in range(train_images.shape[0] // BATCH_SIZE):
    batch_images = train_images[i*BATCH_SIZE:(i+1)*BATCH_SIZE]
    train_step(batch_images)
Copy after login
  1. Generate new images

After training is complete, we will use the generator network to generate new images. We will randomly generate 100 noise vectors and feed them into the generator network to generate new images of handwritten digits.

import matplotlib.pyplot as plt

def generate_and_save_images(model, epoch, test_input):

# 注意 training` 设定为 False
# 因此,所有层都在推理模式下运行(batchnorm)。
predictions = model(test_input, training=False)

fig = plt.figure(figsize=(4, 4))

for i in range(predictions.shape[0]):
    plt.subplot(4, 4, i+1)
    plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
    plt.axis('off')

plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
plt.show()

Copy after login

Randomly generate noise vector

noise = tf.random .normal([16, 100])
generate_and_save_images(generator, 0, noise)

The result shows that the generator has successfully generated a new handwritten digit image. We can improve the performance of the model by gradually increasing the number of training epochs. In addition, we can further improve the performance of GAN by trying other hyperparameter combinations and network architectures.

In short, the GAN algorithm is a very useful deep learning algorithm that can be used to generate various types of data. In this article, we wrote an example of a GAN algorithm for generating images of handwritten digits using Python and showed how to train and use a generator network to generate new images.

The above is the detailed content of GAN algorithm example in Python. For more information, please follow other related articles on the PHP Chinese website!

Statement of this Website
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn

Hot AI Tools

Undresser.AI Undress

Undresser.AI Undress

AI-powered app for creating realistic nude photos

AI Clothes Remover

AI Clothes Remover

Online AI tool for removing clothes from photos.

Undress AI Tool

Undress AI Tool

Undress images for free

Clothoff.io

Clothoff.io

AI clothes remover

AI Hentai Generator

AI Hentai Generator

Generate AI Hentai for free.

Hot Article

R.E.P.O. Energy Crystals Explained and What They Do (Yellow Crystal)
4 weeks ago By 尊渡假赌尊渡假赌尊渡假赌
R.E.P.O. Best Graphic Settings
3 weeks ago By 尊渡假赌尊渡假赌尊渡假赌
R.E.P.O. How to Fix Audio if You Can't Hear Anyone
4 weeks ago By 尊渡假赌尊渡假赌尊渡假赌
WWE 2K25: How To Unlock Everything In MyRise
1 months ago By 尊渡假赌尊渡假赌尊渡假赌

Hot Tools

Notepad++7.3.1

Notepad++7.3.1

Easy-to-use and free code editor

SublimeText3 Chinese version

SublimeText3 Chinese version

Chinese version, very easy to use

Zend Studio 13.0.1

Zend Studio 13.0.1

Powerful PHP integrated development environment

Dreamweaver CS6

Dreamweaver CS6

Visual web development tools

SublimeText3 Mac version

SublimeText3 Mac version

God-level code editing software (SublimeText3)

HadiDB: A lightweight, horizontally scalable database in Python HadiDB: A lightweight, horizontally scalable database in Python Apr 08, 2025 pm 06:12 PM

HadiDB: A lightweight, high-level scalable Python database HadiDB (hadidb) is a lightweight database written in Python, with a high level of scalability. Install HadiDB using pip installation: pipinstallhadidb User Management Create user: createuser() method to create a new user. The authentication() method authenticates the user's identity. fromhadidb.operationimportuseruser_obj=user("admin","admin")user_obj.

Navicat's method to view MongoDB database password Navicat's method to view MongoDB database password Apr 08, 2025 pm 09:39 PM

It is impossible to view MongoDB password directly through Navicat because it is stored as hash values. How to retrieve lost passwords: 1. Reset passwords; 2. Check configuration files (may contain hash values); 3. Check codes (may hardcode passwords).

Python: Exploring Its Primary Applications Python: Exploring Its Primary Applications Apr 10, 2025 am 09:41 AM

Python is widely used in the fields of web development, data science, machine learning, automation and scripting. 1) In web development, Django and Flask frameworks simplify the development process. 2) In the fields of data science and machine learning, NumPy, Pandas, Scikit-learn and TensorFlow libraries provide strong support. 3) In terms of automation and scripting, Python is suitable for tasks such as automated testing and system management.

The 2-Hour Python Plan: A Realistic Approach The 2-Hour Python Plan: A Realistic Approach Apr 11, 2025 am 12:04 AM

You can learn basic programming concepts and skills of Python within 2 hours. 1. Learn variables and data types, 2. Master control flow (conditional statements and loops), 3. Understand the definition and use of functions, 4. Quickly get started with Python programming through simple examples and code snippets.

How to optimize MySQL performance for high-load applications? How to optimize MySQL performance for high-load applications? Apr 08, 2025 pm 06:03 PM

MySQL database performance optimization guide In resource-intensive applications, MySQL database plays a crucial role and is responsible for managing massive transactions. However, as the scale of application expands, database performance bottlenecks often become a constraint. This article will explore a series of effective MySQL performance optimization strategies to ensure that your application remains efficient and responsive under high loads. We will combine actual cases to explain in-depth key technologies such as indexing, query optimization, database design and caching. 1. Database architecture design and optimized database architecture is the cornerstone of MySQL performance optimization. Here are some core principles: Selecting the right data type and selecting the smallest data type that meets the needs can not only save storage space, but also improve data processing speed.

How to use AWS Glue crawler with Amazon Athena How to use AWS Glue crawler with Amazon Athena Apr 09, 2025 pm 03:09 PM

As a data professional, you need to process large amounts of data from various sources. This can pose challenges to data management and analysis. Fortunately, two AWS services can help: AWS Glue and Amazon Athena.

Can mysql connect to the sql server Can mysql connect to the sql server Apr 08, 2025 pm 05:54 PM

No, MySQL cannot connect directly to SQL Server. But you can use the following methods to implement data interaction: Use middleware: Export data from MySQL to intermediate format, and then import it to SQL Server through middleware. Using Database Linker: Business tools provide a more friendly interface and advanced features, essentially still implemented through middleware.

How to start the server with redis How to start the server with redis Apr 10, 2025 pm 08:12 PM

The steps to start a Redis server include: Install Redis according to the operating system. Start the Redis service via redis-server (Linux/macOS) or redis-server.exe (Windows). Use the redis-cli ping (Linux/macOS) or redis-cli.exe ping (Windows) command to check the service status. Use a Redis client, such as redis-cli, Python, or Node.js, to access the server.

See all articles