SimCLR (Simple Framework for Contrastive Learning of Representations) is a self-supervised technology for learning image representations. Unlike traditional supervised learning methods, SimCLR does not rely on labeled data to learn useful representations. It utilizes a contrastive learning framework to learn a set of useful features that can capture high-level semantic information from unlabeled images.
SimCLR has been proven to outperform state-of-the-art unsupervised learning methods on various image classification benchmarks. And the representations it learns can be easily transferred to downstream tasks such as object detection, semantic segmentation and few-shot learning with minimal fine-tuning on smaller labeled datasets.
#The main idea of SimCLR is to learn a good representation of the image by comparing it with other enhanced versions of the same image through the enhancement module T. This is done by mapping the image through an encoder network f(.) and then projecting it. head g(.) maps the learned features into a low-dimensional space. A contrastive loss is then calculated between representations of two enhanced versions of the same image to encourage similar representations of the same image and different representations of different images.
In this article we will delve into the SimCLR framework and explore the key components of the algorithm, including data augmentation, contrastive loss functions, and the head architecture of the encoder and projection.
We use the garbage classification data set from Kaggle to conduct experiments
The most important thing in SimCLR is the enhancement module for converting images. The authors of the SimCLR paper suggest that powerful data augmentation is useful for unsupervised learning. Therefore, we will follow the approach recommended in the paper.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
|
The next step is to define a PyTorch Dataset.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
|
As an example, we use the smaller model ResNet18 as the backbone, so its input is a 224x224 image. We set some parameters as required and generate dataloader
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
|
We have prepared the data and started to reproduce the model. The above enhancement module provides two enhanced views of the image, which are forward passed through the encoder to obtain the corresponding representation. The goal of SimCLR is to maximize the similarity between these different learned representations by encouraging the model to learn a general representation of an object from two different augmented views.
The choice of encoder network is not restricted and can be of any architecture. As mentioned above, for simple demonstration, we use ResNet18. The representations learned by the encoder model determine the similarity coefficients, and to improve the quality of these representations, SimCLR uses a projection head to project the encoding vectors into a richer latent space. Here we project the 512-dimensional features of ResNet18 into a 256-dimensional space. It looks very complicated, but in fact it is just adding an mlp with relu.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
|
The contrast loss function, also known as the normalized temperature scaled cross-entropy loss (NT-Xent), is a key component of SimCLR , which encourages the model to learn similar representations for the same image and different representations for different images.
NT-Xent loss is computed using a pair of augmented views of an image passed through the encoder network to obtain their corresponding representations. The goal of contrastive loss is to encourage representations of two augmented views of the same image to be similar while forcing representations of different images to be dissimilar.
NT-Xent applies a softmax function to enhance pairwise similarity of view representations. The softmax function is applied to all pairs of representations within the mini-batch to obtain a similarity probability distribution for each image. The temperature parameter is used to scale the pairwise similarities before applying the softmax function, which helps to obtain better gradients during optimization.
After obtaining the probability distribution of similarities, the NT-Xent loss is calculated by maximizing the log-likelihood of matching representations of the same image and minimizing the log-likelihood of mismatching representations of different images.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
|
All preparations are completed, let’s train SimCLR and see the effect!
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
|
The above code is trained for 10 rounds. Assuming that we have completed the pre-training process, we can use the pre-trained encoder for the downstream tasks we want. This can be done with the code below.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
|
The most important part of the above code is to read the simclr model just trained, then freeze all the weights, and then create a classification head self.linear to perform downstream classification tasks
This article introduces the SimCLR framework and uses it to pre-train ResNet18 with randomly initialized weights. Pretraining is a powerful technique used in deep learning to train models on large datasets and learn useful features that can be transferred to other tasks. The SimCLR paper believes that the larger the batch size, the better the performance. Our implementation only uses a batch size of 128 and trains for only 10 epochs. So this is not the best performance of the model. If performance comparison is required, further training is required.
The following figure is the performance conclusion given by the author of the paper:
The above is the detailed content of Using Pytorch to implement contrastive learning SimCLR for self-supervised pre-training. For more information, please follow other related articles on the PHP Chinese website!