Rumah > pembangunan bahagian belakang > Tutorial Python > Contoh algoritma GAN dalam Python

Contoh algoritma GAN dalam Python

王林
Lepaskan: 2023-06-10 09:53:50
asal
1226 orang telah melayarinya

Generative Adversarial Networks (GAN) ialah algoritma pembelajaran mendalam yang menjana data baharu melalui dua rangkaian neural yang bersaing antara satu sama lain. GAN digunakan secara meluas untuk tugas penjanaan dalam imej, audio, teks dan medan lain. Dalam artikel ini, kami akan menggunakan Python untuk menulis contoh algoritma GAN untuk menjana imej digit tulisan tangan.

  1. Penyediaan Set Data

Kami akan menggunakan set data MNIST sebagai set data latihan kami. Set data MNIST mengandungi 60,000 imej latihan dan 10,000 imej ujian, setiap imej ialah imej skala kelabu 28x28. Kami akan menggunakan perpustakaan TensorFlow untuk memuatkan dan memproses set data. Sebelum memuatkan set data, kami perlu memasang pustaka TensorFlow dan pustaka NumPy.

import tensorflow sebagai tf
import numpy sebagai np

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

prapemprosesan set data

train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images.shape) #127.5 # Normalize nilai piksel kepada julat [-1, 1]

  1. reka bentuk dan latihan seni bina GAN

GAN kami akan merangkumi dua rangkaian neural: Rangkaian penjana dan rangkaian diskriminasi. Rangkaian penjana akan menerima vektor hingar sebagai input dan output imej 28x28. Rangkaian diskriminator akan menerima imej 28x28 sebagai input dan output kebarangkalian bahawa imej itu adalah imej sebenar.

Seni bina kedua-dua rangkaian penjana dan rangkaian diskriminator akan menggunakan rangkaian neural konvolusi (CNN). Dalam rangkaian penjana, kami akan menggunakan lapisan dekonvolusi untuk menyahkod vektor hingar menjadi imej 28x28. Dalam rangkaian diskriminator, kami akan menggunakan lapisan konvolusi untuk mengklasifikasikan imej input.

Input kepada rangkaian penjana ialah vektor hingar dengan panjang 100. Kami akan menyusun lapisan rangkaian dengan menggunakan fungsi tf.keras.Sequential.

def make_generator_model():

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU())

model.add(tf.keras.layers.Reshape((7, 7, 256)))
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size没有限制

model.add(tf.keras.layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
assert model.output_shape == (None, 7, 7, 128)
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU())

model.add(tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
assert model.output_shape == (None, 14, 14, 64)
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU())

model.add(tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
assert model.output_shape == (None, 28, 28, 1)

return model
Salin selepas log masuk

Input kepada rangkaian diskriminator ialah imej 28x28. Kami akan menyusun lapisan rangkaian dengan menggunakan fungsi tf.keras.Sequential.

def make_discriminator_model():

model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                 input_shape=[28, 28, 1]))
model.add(tf.keras.layers.LeakyReLU())
model.add(tf.keras.layers.Dropout(0.3))

model.add(tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
model.add(tf.keras.layers.LeakyReLU())
model.add(tf.keras.layers.Dropout(0.3))

model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(1))

return model
Salin selepas log masuk

Seterusnya, kami akan menulis kod latihan. Kami akan melatih rangkaian penjana dan rangkaian diskriminator secara bergilir-gilir dalam setiap kelompok. Semasa proses latihan, kami akan merekodkan kecerunan dengan menggunakan fungsi tf.GradientTape() dan kemudian mengoptimumkan rangkaian menggunakan fungsi tf.keras.optimizers.Adam().

penjana = make_generator_model()
discriminator = make_discriminator_model()

Loss function

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logis)<🎜

Fungsi kehilangan diskriminator

definisi_kerugian_diskriminasi(keluaran_sebenar, keluaran_palsu):

real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
Salin selepas log masuk

Fungsi kehilangan penjana

def generator_loss(output_palsu): >

return cross_entropy(tf.ones_like(fake_output), fake_output)
Salin selepas log masuk

rreee

Fungsi kehilangan penjana

def generator_loss(output_palsu): >

noise = tf.random.normal([BATCH_SIZE, 100])

with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    generated_images = generator(noise, training=True)

    real_output = discriminator(images, training=True)
    fake_output = discriminator(generated_images, training=True)

    gen_loss = generator_loss(fake_output)
    disc_loss = discriminator_loss(real_output, fake_output)

gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
Salin selepas log masuk

for i in range(train_images.shape[0] // BATCH_SIZE):
    batch_images = train_images[i*BATCH_SIZE:(i+1)*BATCH_SIZE]
    train_step(batch_images)
Salin selepas log masuk
<🎜

generator_optimizer = tf.keras.optimizers.Adam(1e-4)

discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

Tentukan fungsi latihan


@tf.function

def train_step(imej):

# 注意 training` 设定为 False
# 因此,所有层都在推理模式下运行(batchnorm)。
predictions = model(test_input, training=False)

fig = plt.figure(figsize=(4, 4))

for i in range(predictions.shape[0]):
    plt.subplot(4, 4, i+1)
    plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
    plt.axis('off')

plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
plt.show()

Salin selepas log masuk
BATCH_SIZE = 256
    EPOCHS = 100
  1. untuk epoch dalam julat(EPOCHS):
rree 🎜>

Jana imej baharu

Selepas latihan selesai, kami akan menggunakan rangkaian penjana untuk menjana imej baharu. Kami akan menjana 100 vektor hingar secara rawak dan memasukkannya ke dalam rangkaian penjana untuk menjana imej baharu digit tulisan tangan.

import matplotlib.pyplot sebagai plt

def generate_and_save_images(model, epoch, test_input):

rrreee
Menjana vektor hingar secara rawak

noise = tf.random .normal([16, 100])

generate_and_save_images(generator, 0, noise)

Hasilnya menunjukkan bahawa penjana telah berjaya menghasilkan imej digit tulisan tangan baharu. Kita boleh meningkatkan prestasi model dengan menambah bilangan tempoh latihan secara beransur-ansur. Di samping itu, kami boleh meningkatkan lagi prestasi GAN dengan mencuba gabungan hiperparameter lain dan seni bina rangkaian. Ringkasnya, algoritma GAN ialah algoritma pembelajaran mendalam yang sangat berguna yang boleh digunakan untuk menjana pelbagai jenis data. Dalam artikel ini, kami menulis contoh algoritma GAN untuk menjana imej digit tulisan tangan menggunakan Python dan menunjukkan cara untuk melatih dan menggunakan rangkaian penjana untuk menjana imej baharu.

Atas ialah kandungan terperinci Contoh algoritma GAN dalam Python. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

Label berkaitan:
sumber:php.cn
Kenyataan Laman Web ini
Kandungan artikel ini disumbangkan secara sukarela oleh netizen, dan hak cipta adalah milik pengarang asal. Laman web ini tidak memikul tanggungjawab undang-undang yang sepadan. Jika anda menemui sebarang kandungan yang disyaki plagiarisme atau pelanggaran, sila hubungi admin@php.cn
Tutorial Popular
Lagi>
Muat turun terkini
Lagi>
kesan web
Kod sumber laman web
Bahan laman web
Templat hujung hadapan