Heim > Technologie-Peripheriegeräte > KI > So verwenden Sie ein siamesisches Netzwerk, um Datensätze mit unausgeglichenen Stichproben zu verarbeiten (mit Beispielcode)

So verwenden Sie ein siamesisches Netzwerk, um Datensätze mit unausgeglichenen Stichproben zu verarbeiten (mit Beispielcode)

王林
Freigeben: 2024-01-22 16:15:05
nach vorne
870 Leute haben es durchsucht

So verwenden Sie ein siamesisches Netzwerk, um Datensätze mit unausgeglichenen Stichproben zu verarbeiten (mit Beispielcode)

Das siamesische Netzwerk ist ein neuronales Netzwerkmodell, das für das metrische Lernen verwendet wird und lernen kann, wie ein Ähnlichkeits- oder Differenzmaß zwischen zwei Eingaben berechnet wird. Aufgrund seiner Flexibilität ist es in zahlreichen Anwendungen wie Gesichtserkennung, semantischer Ähnlichkeitsberechnung und Textvergleich beliebt. Allerdings kann das siamesische Netzwerk beim Umgang mit unausgeglichenen Datensätzen auf Probleme stoßen, da es sich möglicherweise zu sehr auf Stichproben aus einigen wenigen Klassen konzentriert und die Mehrheit der Stichproben ignoriert. Um dieses Problem zu lösen, können verschiedene Techniken verwendet werden. Ein Ansatz besteht darin, den Datensatz durch Unterabtastung oder Überabtastung auszugleichen. Bei der Unterabtastung werden zufällig einige Stichproben aus der Mehrheitsklasse entfernt, sodass sie der Anzahl der Stichproben aus der Minderheitsklasse entsprechen. Durch Überabtastung wird die Anzahl der Stichproben in der Minderheitsklasse erhöht, indem neue Stichproben kopiert oder generiert werden, sodass sie der Anzahl der Stichproben in der Mehrheitsklasse entspricht. Dies gleicht den Datensatz effektiv aus, kann jedoch zu Informationsverlust oder Überanpassungsproblemen führen. Eine andere Methode ist die Gewichtsanpassung. Durch die Zuweisung höherer Gewichtungen zu Minderheitenklassenstichproben kann die Aufmerksamkeit des siamesischen Netzwerks für die Minderheitenklasse erhöht werden. Dies verbessert die Modellleistung, indem es sich auf einige wenige Klassen konzentriert, ohne den Datensatz zu ändern. Darüber hinaus können einige fortschrittliche metrische Lernalgorithmen auch verwendet werden, um die Leistung siamesischer Netzwerke zu verbessern, z. B. generative kontradiktorische Netzwerke (GAN), die auf kontradiktorischen generativen Netzwerken basieren.

1. Resampling-Technologie

Die Anzahl der Kategorieproben variiert stark. Um den Datensatz auszugleichen, können Resampling-Techniken verwendet werden. Zu den häufigsten gehören Unterabtastung und Überabtastung, um eine übermäßige Konzentration auf einige wenige Kategorien zu verhindern.

Unterabtastung besteht darin, die Stichprobengröße der Mehrheitskategorie und der Minderheitskategorie auszugleichen, indem einige Stichproben der Mehrheitskategorie gelöscht werden, sodass sie die gleiche Anzahl an Stichproben wie die Minderheitskategorie aufweist. Dieser Ansatz kann den Fokus des Modells auf die Mehrheitskategorie verringern, aber möglicherweise gehen auch einige nützliche Informationen verloren.

Überabtastung besteht darin, das Problem des Stichprobenungleichgewichts auszugleichen, indem Stichproben der Minderheitsklasse kopiert werden, sodass die Minderheitsklasse und die Mehrheitsklasse die gleiche Anzahl von Stichproben haben. Obwohl eine Überabtastung die Anzahl der Stichproben aus Minderheitenklassen erhöhen kann, kann sie auch zu Überanpassungsproblemen führen.

2. Stichprobengewichtungstechnik

Eine andere Möglichkeit, mit unausgeglichenen Datensätzen umzugehen, ist die Verwendung der Stichprobengewichtstechnik. Mit dieser Methode können Stichproben verschiedener Kategorien unterschiedlich gewichtet werden, um ihre Bedeutung im Datensatz widerzuspiegeln.

Ein gängiger Ansatz besteht darin, Klassenhäufigkeiten zu verwenden, um das Gewicht von Stichproben zu berechnen. Insbesondere kann das Gewicht jeder Probe als $$

w_i=frac{1}{n_ccdot n_i}

festgelegt werden, wobei n_c die Anzahl der Proben in Kategorie c und n_i die Kategorie ist, zu der Probe i gehört gehört Anzahl der Proben. Diese Methode kann den Datensatz ausgleichen, indem den Stichproben aus Minderheitenklassen ein höheres Gewicht verliehen wird.

3. Ändern Sie die Verlustfunktion

Siamesische Netzwerke verwenden normalerweise eine kontrastive Verlustfunktion, um das Modell zu trainieren, beispielsweise eine Triplett-Verlustfunktion oder eine Kosinus-Verlustfunktion. Beim Umgang mit unausgeglichenen Datensätzen kann eine verbesserte Kontrastverlustfunktion verwendet werden, damit das Modell den Stichproben aus der Minderheitsklasse mehr Aufmerksamkeit schenkt.

Ein gängiger Ansatz ist die Verwendung einer gewichteten Kontrastverlustfunktion, bei der Stichproben aus der Minderheitsklasse höhere Gewichte haben. Konkret kann die Verlustfunktion in die folgende Form geändert werden:

L=frac{1}{N}sum_{i=1}^N w_icdot L_i

wobei N die Anzahl der Stichproben und w_i ist Probe i Das Gewicht von L_i ist der Kontrastverlust von Probe i.

4. Kombinieren Sie mehrere Methoden

Um mit unausgeglichenen Datensätzen umzugehen, können schließlich mehrere Methoden kombiniert werden, um das siamesische Netzwerk zu trainieren. Beispielsweise kann man Resampling-Techniken und Stichprobengewichtungstechniken verwenden, um den Datensatz auszugleichen, und dann eine verbesserte Kontrastverlustfunktion verwenden, um das Modell zu trainieren. Diese Methode kann die Vorteile verschiedener Techniken voll ausnutzen und eine bessere Leistung bei unausgeglichenen Datensätzen erzielen.

Für unausgeglichene Datensätze besteht eine gängige Lösung darin, eine gewichtete Verlustfunktion zu verwenden, bei der selteneren Klassen höhere Gewichtungen zugewiesen werden. Hier ist ein einfaches Beispiel, das zeigt, wie ein siamesisches Netzwerk mit einer gewichteten Verlustfunktion in Keras implementiert wird, um unausgeglichene Datensätze zu verarbeiten:

from keras.layers import Input, Conv2D, Lambda, Dense, Flatten, MaxPooling2D
from keras.models import Model
from keras import backend as K
import numpy as np

# 定义输入维度和卷积核大小
input_shape = (224, 224, 3)
kernel_size = 3

# 定义共享的卷积层
conv1 = Conv2D(64, kernel_size, activation='relu', padding='same')
pool1 = MaxPooling2D(pool_size=(2, 2))
conv2 = Conv2D(128, kernel_size, activation='relu', padding='same')
pool2 = MaxPooling2D(pool_size=(2, 2))
conv3 = Conv2D(256, kernel_size, activation='relu', padding='same')
pool3 = MaxPooling2D(pool_size=(2, 2))
conv4 = Conv2D(512, kernel_size, activation='relu', padding='same')
flatten = Flatten()

# 定义共享的全连接层
dense1 = Dense(512, activation='relu')
dense2 = Dense(512, activation='relu')

# 定义距离度量层
def euclidean_distance(vects):
    x, y = vects
    sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
    return K.sqrt(K.maximum(sum_square, K.epsilon()))

# 定义Siamese网络
input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

processed_a = conv1(input_a)
processed_a = pool1(processed_a)
processed_a = conv2(processed_a)
processed_a = pool2(processed_a)
processed_a = conv3(processed_a)
processed_a = pool3(processed_a)
processed_a = conv4(processed_a)
processed_a = flatten(processed_a)
processed_a = dense1(processed_a)
processed_a = dense2(processed_a)

processed_b = conv1(input_b)
processed_b = pool1(processed_b)
processed_b = conv2(processed_b)
processed_b = pool2(processed_b)
processed_b = conv3(processed_b)
processed_b = pool3(processed_b)
processed_b = conv4(processed_b)
processed_b = flatten(processed_b)
processed_b = dense1(processed_b)
processed_b = dense2(processed_b)

distance = Lambda(euclidean_distance)([processed_a, processed_b])

model = Model([input_a, input_b], distance)

# 定义加权损失函数
def weighted_binary_crossentropy(y_true, y_pred):
    class1_weight = K.variable(1.0)
    class2_weight = K.variable(1.0)
    class1_mask = K.cast(K.equal(y_true, 0), 'float32')
    class2_mask = K.cast(K.equal(y_true, 1), 'float32')
    class1_loss = class1_weight * K.binary_crossentropy(y_true, y_pred) * class1_mask
    class2_loss = class2_weight * K.binary_crossentropy(y_true, y_pred) * class2_mask
    return K.mean(class1_loss + class2_loss)

# 编译模型,使用加权损失函数和Adam优化器
model.compile(loss=weighted_binary_crossentropy, optimizer='adam')

# 训练模型
model.fit([X_train[:, 0], X_train[:, 1]], y_train, batch_size=32, epochs=10, validation_data=([X_val[:, 0], X_val[:, 1]], y_val))
Nach dem Login kopieren

Wobei die Funktion „weighted_binary_crossentropy“ die gewichtete Verlustfunktion definiert, sind class1_weight und class2_weight die Kategorien 1 bzw. 2. Das Gewicht der Kategorie 2, class1_mask und class2_mask sind Masken, die zur Abschirmung von Kategorie 1 und Kategorie 2 verwendet werden. Wenn Sie ein Modell trainieren, müssen Sie Trainingsdaten und Validierungsdaten an die beiden Eingänge des Modells übergeben und die Zielvariable als dritten Parameter an die Anpassungsmethode übergeben. Bitte beachten Sie, dass dies nur ein Beispiel ist und nicht garantiert werden kann, dass das Problem unausgeglichener Datensätze vollständig gelöst wird. In praktischen Anwendungen kann es notwendig sein, verschiedene Lösungen auszuprobieren und diese an die spezifische Situation anzupassen.

Das obige ist der detaillierte Inhalt vonSo verwenden Sie ein siamesisches Netzwerk, um Datensätze mit unausgeglichenen Stichproben zu verarbeiten (mit Beispielcode). Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Verwandte Etiketten:
Quelle:163.com
Erklärung dieser Website
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn
Beliebte Tutorials
Mehr>
Neueste Downloads
Mehr>
Web-Effekte
Quellcode der Website
Website-Materialien
Frontend-Vorlage