Variational Autoencoder (VAE) ialah rangkaian neural pembelajaran tanpa pengawasan yang digunakan untuk pemampatan dan penjanaan imej. Berbanding dengan pengekod auto tradisional, VAE boleh membina semula imej input dan menjana imej baharu yang serupa dengannya. Idea teras adalah untuk mengekod imej input ke dalam pengedaran pembolehubah terpendam dan sampel daripadanya untuk menghasilkan imej baharu. VAE adalah unik dalam menggunakan inferens variasi untuk melatih model, mencapai pembelajaran parameter dengan memaksimumkan sempadan bawah antara data yang diperhatikan dan dijana. Kaedah ini membolehkan VAE mempelajari struktur asas data dan keupayaan untuk menjana sampel baharu. VAE telah mencapai kejayaan yang luar biasa dalam banyak bidang, termasuk tugas seperti penjanaan imej, penyuntingan atribut dan pembinaan semula imej.
VAE (pengekod automatik variasi) mempunyai struktur yang serupa dengan pengekod automatik dan terdiri daripada pengekod dan penyahkod. Pengekod memampatkan imej input ke dalam taburan pembolehubah terpendam, termasuk vektor min dan vektor varians. Penyahkod mengambil sampel pembolehubah terpendam untuk menghasilkan imej baharu. Untuk menjadikan taburan pembolehubah pendam lebih munasabah, VAE memperkenalkan istilah regularisasi bagi pencapahan KL untuk menjadikan pengagihan pembolehubah pendam lebih dekat dengan taburan normal piawai. Melakukannya boleh meningkatkan ekspresif dan keupayaan penjanaan model.
Berikut mengambil set data digit tulisan tangan MNIST sebagai contoh untuk memperkenalkan proses pelaksanaan VAE.
Pertama, kita perlu mengimport perpustakaan dan set data yang diperlukan.
import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.autograd import Variable # 加载数据集 transform = transforms.Compose([ transforms.ToTensor(), ]) train_dataset = datasets.MNIST(root='./data/', train=True, transform=transform, download=True) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
Seterusnya, tentukan struktur rangkaian pengekod dan penyahkod.
# 定义编码器 class Encoder(nn.Module): def __init__(self): super(Encoder, self).__init__() self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) self.fc1 = nn.Linear(128 * 7 * 7, 256) self.fc21 = nn.Linear(256, 20) # 均值向量 self.fc22 = nn.Linear(256, 20) # 方差向量 def forward(self, x): x = nn.functional.relu(self.conv1(x)) x = nn.functional.relu(self.conv2(x)) x = nn.functional.relu(self.conv3(x)) x = x.view(-1, 128 * 7 * 7) x = nn.functional.relu(self.fc1(x)) mean = self.fc21(x) log_var = self.fc22(x) return mean, log_var # 定义解码器 class Decoder(nn.Module): def __init__(self): super(Decoder, self).__init__() self.fc1 = nn.Linear(20, 256) self.fc2 = nn.Linear(256, 128 * 7 * 7) self.conv1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1) self.conv2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1) self.conv3 = nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=1) def forward(self, x): x = nn.functional.relu(self.fc1(x)) x = nn.functional.relu(self.fc2(x)) x = x.view(-1, 128, 7, 7) x = nn.functional.relu(self.conv1(x)) x = nn.functional.relu(self.conv2(x)) x = nn.functional.sigmoid(self.conv3(x)) return x # 定义VAE模型 class VAE(nn.Module): def __init__(self): super(VAE, self).__init__() self.encoder = Encoder() self.decoder = Decoder() def reparameterize(self, mean, log_var): std = torch.exp(0.5 * log_var) eps = torch.randn_like(std) return eps * std + mean def forward(self, x): mean, log_var = self.encoder(x)
Langkah seterusnya ialah proses perambatan ke hadapan model VAE, yang merangkumi pensampelan daripada pembolehubah terpendam untuk menjana imej baharu, dan mengira syarat penyelarasan ralat pembinaan semula dan perbezaan KL.
z = self.reparameterize(mean, log_var) x_recon = self.decoder(z) return x_recon, mean, log_var def loss_function(self, x_recon, x, mean, log_var): recon_loss = nn.functional.binary_cross_entropy(x_recon, x, size_average=False) kl_loss = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp()) return recon_loss + kl_loss def sample(self, num_samples): z = torch.randn(num_samples, 20) samples = self.decoder(z) return samples
Akhir sekali, kami menentukan pengoptimum dan mula melatih model.
# 定义优化器 vae = VAE() optimizer = optim.Adam(vae.parameters(), lr=1e-3) # 开始训练模型 num_epochs = 10 for epoch in range(num_epochs): for batch_idx, (data, _) in enumerate(train_loader): data = Variable(data) optimizer.zero_grad() x_recon, mean, log_var = vae(data) loss = vae.loss_function(x_recon, data, mean, log_var) loss.backward() optimizer.step() if batch_idx % 100 == 0: print('Epoch [{}/{}], Batch [{}/{}], Loss: {:.4f}'.format( epoch+1, num_epochs, batch_idx+1, len(train_loader), loss.data.item()))
Selepas latihan selesai, kita boleh menggunakan VAE untuk menjana imej digit tulisan tangan baharu.
# 生成手写数字图像 samples = vae.sample(10) fig, ax = plt.subplots(1, 10, figsize=(10, 1)) for i in range(10): ax[i].imshow(samples[i].detach().numpy().reshape(28, 28), cmap='gray') ax[i].axis('off') plt.show()
VAE ialah model pemampatan imej dan generatif yang berkuasa yang mencapai pemampatan imej dengan mengekod imej input ke dalam pengedaran pembolehubah terpendam sambil mengambil sampel daripadanya untuk menjana imej baharu. Berbeza daripada pengekod auto tradisional, VAE juga memperkenalkan istilah penyelarasan perbezaan KL untuk menjadikan pengedaran pembolehubah pendam lebih munasabah. Apabila melaksanakan VAE, adalah perlu untuk menentukan struktur rangkaian pengekod dan penyahkod, dan mengira syarat regularisasi ralat pembinaan semula dan perbezaan KL. Dengan melatih model VAE, taburan pembolehubah terpendam bagi imej input boleh dipelajari dan imej baharu boleh dihasilkan daripadanya.
Di atas adalah pengenalan asas dan proses pelaksanaan VAE Saya harap ia akan membantu pembaca.
Atas ialah kandungan terperinci Proses pelaksanaan pemampatan imej: pengekod auto variasi. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!