Heim > Backend-Entwicklung > Python-Tutorial > Beispiel für die Migration des Bildstils in Python

Beispiel für die Migration des Bildstils in Python

WBOY
Freigeben: 2023-06-11 20:44:25
Original
1404 Leute haben es durchsucht

Bildstilübertragung ist eine auf Deep Learning basierende Technologie, die den Stil eines Bildes auf ein anderes Bild übertragen kann. In den letzten Jahren wurde die Bildstilübertragungstechnologie in den Bereichen Kunst sowie Film- und Fernsehspezialeffekte weit verbreitet eingesetzt. In diesem Artikel stellen wir vor, wie Sie die Bildstilmigration mithilfe der Python-Sprache implementieren.

1. Was ist Bildstilübertragung? Mit der Bildstilübertragung kann der Stil eines Bildes auf ein anderes Bild übertragen werden. Der Stil kann der Malstil des Künstlers, der Aufnahmestil des Fotografen oder andere Stile sein. Ziel der Bildstilübertragung ist es, den Inhalt des Originalbildes beizubehalten und ihm gleichzeitig einen neuen Stil zu verleihen.

Die Bildstilübertragungstechnologie ist eine Deep-Learning-Technologie, die auf einem Faltungs-Neuronalen Netzwerk (CNN) basiert. Ihre Kernidee besteht darin, die Inhalts- und Stilinformationen des Bildes durch ein vorab trainiertes CNN-Modell zu extrahieren und Optimierungsmethoden zu verwenden, um diese beiden zu synthetisieren in ein neues auf dem Bild. Typischerweise werden die Inhaltsinformationen eines Bildes durch die tiefen Faltungsschichten von CNN extrahiert, während die Stilinformationen des Bildes durch die Korrelation zwischen den Faltungskernen von CNN extrahiert werden.

2. Bildstilmigration implementieren

Zu den Hauptschritten zur Implementierung der Bildstilmigration in Python gehören das Laden von Bildern, die Vorverarbeitung von Bildern, das Erstellen von Modellen, das Berechnen von Verlustfunktionen sowie die Verwendung von Optimierungsmethoden zum Iterieren und Ausgeben von Ergebnissen. Als Nächstes werden wir diese Schritt für Schritt behandeln.

Bilder laden
  1. Zuerst müssen wir ein Originalbild und ein Referenzbild laden. Das Originalbild ist das Bild, dessen Stil übertragen werden muss, und das Referenzbild ist das Stilbild, das übertragen werden soll. Das Laden von Bildern kann mit dem PIL-Modul (Python Imaging Library) von Python erfolgen.
from PIL import Image
import numpy as np

# 载入原始图像和参考图像
content_image = Image.open('content.jpg')
style_image = Image.open('style.jpg')

# 将图像转化为numpy数组,方便后续处理
content_array = np.array(content_image)
style_array = np.array(style_image)
Nach dem Login kopieren

Vorverarbeitung von Bildern
  1. Die Vorverarbeitung umfasst die Konvertierung von Originalbildern und Referenzbildern in ein Format, das das neuronale Netzwerk verarbeiten kann, d. h. die Konvertierung des Bildes in einen Tensor und die gleichzeitige Durchführung einer Standardisierung. Hier verwenden wir zur Vervollständigung das von PyTorch bereitgestellte Vorverarbeitungsmodul.
import torch
import torch.nn as nn
import torchvision.transforms as transforms

# 定义预处理函数
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 将图像进行预处理
content_tensor = preprocess(content_image).unsqueeze(0).to(device)
style_tensor = preprocess(style_image).unsqueeze(0).to(device)
Nach dem Login kopieren

Modelle erstellen
  1. Das Bildstilübertragungsmodell kann Modelle verwenden, die in großen Bilddatenbanken trainiert wurden. Zu den häufig verwendeten Modellen gehören VGG19 und ResNet. Hier verwenden wir zur Vervollständigung das VGG19-Modell. Zuerst müssen wir das vorab trainierte VGG19-Modell laden und die letzte vollständig verbundene Schicht entfernen, sodass nur die Faltungsschicht übrig bleibt. Anschließend müssen wir die Inhalts- und Stilinformationen des Bildes anpassen, indem wir die Gewichte der Faltungsschicht ändern.
import torchvision.models as models

class VGG(nn.Module):
    def __init__(self, requires_grad=False):
        super(VGG, self).__init__()
        vgg19 = models.vgg19(pretrained=True).features
        self.slice1 = nn.Sequential()
        self.slice2 = nn.Sequential()
        self.slice3 = nn.Sequential()
        self.slice4 = nn.Sequential()
        self.slice5 = nn.Sequential()
        for x in range(2):
            self.slice1.add_module(str(x), vgg19[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg19[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg19[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg19[x])
        for x in range(21, 30):
            self.slice5.add_module(str(x), vgg19[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, x):
        h_relu1 = self.slice1(x)
        h_relu2 = self.slice2(h_relu1)
        h_relu3 = self.slice3(h_relu2)
        h_relu4 = self.slice4(h_relu3)
        h_relu5 = self.slice5(h_relu4)
        return h_relu1, h_relu2, h_relu3, h_relu4, h_relu5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VGG().to(device).eval()
Nach dem Login kopieren

Berechnen Sie die Verlustfunktion
  1. Da das Ziel der Bildstilübertragung darin besteht, den Inhalt des Originalbilds beizubehalten und ihm gleichzeitig einen neuen Stil zu verleihen, müssen wir eine Verlustfunktion definieren, um dieses Ziel zu erreichen. Die Verlustfunktion besteht aus zwei Teilen: dem Inhaltsverlust und dem Stilverlust.

Inhaltsverlust kann definiert werden, indem der mittlere quadratische Fehler zwischen dem Originalbild und dem generierten Bild in der Feature-Map der Faltungsschicht berechnet wird. Der Stilverlust wird durch Berechnen des mittleren quadratischen Fehlers zwischen der Gram-Matrix zwischen der Feature-Map des generierten Bildes und dem Stilbild in der Faltungsschicht definiert. Die Gram-Matrix ist hier die Korrelationsmatrix zwischen den Faltungskernen der Feature-Map.

def content_loss(content_features, generated_features):
    return torch.mean((content_features - generated_features)**2)

def gram_matrix(input):
    batch_size , h, w, f_map_num = input.size()
    features = input.view(batch_size * h, w * f_map_num)
    G = torch.mm(features, features.t())
    return G.div(batch_size * h * w * f_map_num)

def style_loss(style_features, generated_features):
    style_gram = gram_matrix(style_features)
    generated_gram = gram_matrix(generated_features)
    return torch.mean((style_gram - generated_gram)**2)

content_weight = 1
style_weight = 1000

def compute_loss(content_features, style_features, generated_features):
    content_loss_fn = content_loss(content_features, generated_features[0])
    style_loss_fn = style_loss(style_features, generated_features[1])
    loss = content_weight * content_loss_fn + style_weight * style_loss_fn
    return loss, content_loss_fn, style_loss_fn
Nach dem Login kopieren

Iterieren Sie mit Optimierungsmethoden
  1. Nach der Berechnung der Verlustfunktion können wir Optimierungsmethoden verwenden, um die Pixelwerte des generierten Bildes anzupassen, um die Verlustfunktion zu minimieren. Zu den häufig verwendeten Optimierungsmethoden gehören die Gradientenabstiegsmethode und der L-BFGS-Algorithmus. Hier verwenden wir den von PyTorch bereitgestellten LBFGS-Optimierer, um die Bildmigration abzuschließen. Die Anzahl der Iterationen kann nach Bedarf angepasst werden. Normalerweise können 2000 Iterationen bessere Ergebnisse erzielen.
from torch.optim import LBFGS

generated = content_tensor.detach().clone().requires_grad_(True).to(device)

optimizer = LBFGS([generated])

for i in range(2000):

    def closure():
        optimizer.zero_grad()
        generated_features = model(generated)
        loss, content_loss_fn, style_loss_fn = compute_loss(content_features, style_features, generated_features)
        loss.backward()
        return content_loss_fn + style_loss_fn

    optimizer.step(closure)

    if i % 100 == 0:
        print('Iteration:', i)
        print('Total loss:', closure().tolist())
Nach dem Login kopieren

Ergebnisse ausgeben
  1. Schließlich können wir das generierte Bild lokal speichern und den Effekt der Bildstilmigration beobachten.
import matplotlib.pyplot as plt

generated_array = generated.cpu().detach().numpy()
generated_array = np.squeeze(generated_array, 0)
generated_array = generated_array.transpose(1, 2, 0)
generated_array = np.clip(generated_array, 0, 1)

plt.imshow(generated_array)
plt.axis('off')
plt.show()

Image.fromarray(np.uint8(generated_array * 255)).save('generated.jpg')
Nach dem Login kopieren

3. Zusammenfassung

In diesem Artikel wird die Verwendung der Python-Sprache zur Implementierung der Bildstilübertragungstechnologie vorgestellt. Indem wir das Bild laden, das Bild vorverarbeiten, das Modell erstellen, die Verlustfunktion berechnen, mit der Optimierungsmethode iterieren und das Ergebnis ausgeben, können wir den Stil eines Bildes auf ein anderes übertragen. In praktischen Anwendungen können wir Parameter wie Referenzbilder und die Anzahl der Iterationen an unterschiedliche Anforderungen anpassen, um bessere Ergebnisse zu erzielen.

Das obige ist der detaillierte Inhalt vonBeispiel für die Migration des Bildstils in Python. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Verwandte Etiketten:
Quelle:php.cn
Erklärung dieser Website
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn
Beliebte Tutorials
Mehr>
Neueste Downloads
Mehr>
Web-Effekte
Quellcode der Website
Website-Materialien
Frontend-Vorlage