Recently, Google has released a new text-image generation Muse model. It does not use the current popular diffusion model, but uses the classic Transformer model to achieve the most advanced image generation. Performance: Compared with diffusion or autoregressive models, the efficiency of the Muse model is also much improved.
Paper link: https://arxiv.org/pdf/2301.00704.pdf
Project link: https://muse-model.github.io/
Muse uses masked modeling tasks in discrete token space Training on: Given text embeddings extracted from a pre-trained large language model (LLM), Muse’s training process is to predict randomly masked image tokens.
Compared with pixel space diffusion models (such as Imagen and DALL-E 2), since Muse uses discrete tokens, only fewer sampling iterations are required, so the efficiency is improved Significantly improved;
Compared with autoregressive models (such as Parti), Muse is more efficient because it uses parallel decoding.
Using pre-trained LLM enables fine-grained language understanding, which translates into high-fidelity image generation and understanding of visual concepts, Such as objects, spatial relationships, postures, cardinality, etc.
In the experimental results, the Muse model with only 900M parameters achieved new SOTA performance on CC3M, with an FID score of 6.06.
The Muse 3B parametric model achieved an FID of 7.88 in the zero-shot COCO evaluation, along with a CLIP score of 0.32.
Muse can also directly implement some image editing applications without fine-tuning or inverting the model: repair (inpainting), expansion (outpainting) ) and mask-free editing.
The framework of the Muse model contains multiple components. The training pipeline consists of T5-XXL pre-trained text encoder, base model and super-resolution rate model.
1. Pre-trained text encoder
Similar to the conclusions drawn in previous studies, researchers found that using pre-trained large language models (LLM) is beneficial to improving the generation of high-quality images.
For example, the embedding extracted from the language model T5-XXL contains information about objects (nouns), actions (verbs), visual attributes (adjectives), and spatial relationships (prepositions) and rich information on other attributes such as cardability and composition.
So the researchers proposed a hypothesis: The Muse model learns to map these rich visual and semantic concepts in the LLM embedding to the generated images.
Some recent work has proven that the conceptual representation learned by LLM and the conceptual representation learned by the model trained on the visual task can roughly be "linearly mapped".
Given an input text title, passing it to the T5-XXL encoder with frozen parameters results in a 4096-dimensional language embedding vector, which is then linearly projected To the hidden size dimension of the Transformer model (base and super-resolution).
2. Use VQGAN for Semantic Tokenization
The VQGAN model consists of an encoder and a decoder, where The quantization layer maps the input image into a sequence of tokens from a learned codebook.
Then the encoder and decoder are built entirely with convolutional layers to support encoding images of different resolutions.
The encoder includes several downsampling blocks to reduce the spatial dimension of the input, while the decoder has a corresponding number of upsampling blocks to map latents back to the original image size.
The researchers trained two VQGAN models: one with a downsampling rate f=16, and the model obtained the label of the basic model on an image of 256×256 pixels, thus obtaining a spatial size of 16×16 mark; the other is the downsampling rate f=8, and the token of the super-resolution model is obtained on the 512×512 image, and the corresponding spatial size is 64×64.
The discrete token obtained after encoding can capture the high-level semantics of the image and also eliminate low-level noise. According to the discreteness of the token, cross-entropy loss can be used at the output end. Predict the masked token in the next stage
3. Base Model
Muse The model is a masked Transformer, where the input is the mapped T5 embedding and image token.
The researchers set all text embeddings to unmasked, and after randomly masking out a part of different image tokens, Use a special [MASK] tag to replace the original token.
Then the image token is linearly mapped to the required Transformer input or hidden size dimension image input embedding, and at the same time Learning 2D position embedding
is the same as the original Transformer architecture, including several transformer layers, using self-attention blocks, cross-attention blocks and MLP blocks to extract features.
In the output layer, use an MLP to convert each masked image embedding into a set of logits (corresponding to the size of the VQGAN codebook), and use cross-entropy with the ground truth token as the target loss.
In the training phase, the training goal of the basic model is to predict all msked tokens at each step; but in the inference phase, mask prediction is performed in an iterative manner, which can be extremely Greatly improve quality.
4. Super-resolution model
The researchers found that directly predicting 512× 512 resolution images will cause the model to focus on low-level details rather than high-level semantics.
Using cascade of models can improve this situation:
First use one to generate a 16×16 latent map ( A basic model corresponding to a 256×256 image); and then a super-resolution model that upsamples the basic latent map to 64×64 (corresponding to a 512×512 image). The super-resolution model is trained after the basic model training is completed.
As mentioned earlier, the researchers trained a total of two VQGAN models, one with 16×16 latent resolution and 256×256 spatial resolution rate, the other is 64×64 latent resolution and 512×512 spatial resolution.
Since the basic model outputs a token corresponding to the 16×16 latent map, the super-resolution module learns to "translate" the low-resolution latent map into a high-resolution latent map. , and then obtain the final high-resolution image through high-resolution VQGAN decoding; the translation model is also trained with text conditioning and cross-attention in a manner similar to the basic model.
5. Decoder fine-tuning
In order to further improve the model's ability to generate details, the researchers chose to increase the capacity of the VQGAN decoder by adding more residual layers and channels while keeping the capacity of the encoder unchanged.
Then fine-tune the new decoder while keeping the weights, codebook and Transformers (i.e. base model and super-resolution model) of the VQGAN encoder unchanged. This approach improves the visual quality of the generated images without the need to retrain any other model components (because the visual tokens remain fixed).
As you can see, the decoder has been fine-tuned to reconstruct more and clearer details.
6. Variable Masking Rate
Researchers use The model is trained with variable mask rates based on Csoine scheduling: for each training example, a mask rate r∈[0, 1] is drawn from the truncated arccos distribution with a density function as follows.
The expected value of the mask rate is 0.64, which means that a higher mask rate is preferred, making the prediction problem more difficult.
Random masking rates are not only crucial for parallel sampling schemes, but also enable some scattered, out-of-the-box editing capabilities.
7. Classifier Free Guidance (CFG)
Researchers use classification-free guidance (CFG) to improve image generation quality and text-image alignment.
During training, text conditions are removed from 10% of randomly selected samples, and the attention mechanism is reduced to the self-attention of the image token itself.
In the inference stage, a conditional logit lc and an unconditional logit lu are calculated for each masked token, and then a quantity t is removed from the unconditional logit as a guiding scale to form the final logit lg:
Intuitively, CFG trades diversity for fidelity, but unlike previous methods, Muse uses sampling The process linearly increases the guidance scale t to reduce the loss of diversity, allowing early tokens to be sampled more freely with low or no guidance, but also increasing the influence of conditional cues on later tokens.
The researchers also took advantage of this mechanism to promote the generation of images with features related to postive prompts by replacing the unconditional logit lu with a logit conditioned on the negative prompt.
8. Iterative parallel decoding during inference
A key part of improving the time efficiency of model inference is the use of parallelism Decoding to predict multiple output tokens in a single forward channel, one of the key assumptions is the Markov property, that is, many tokens are conditionally independent given other tokens.
The decoding is performed according to the cosine schedule, and the mask with the highest confidence in a fixed ratio is selected for prediction, where the token is set to unmasked in the remaining steps, and is appropriately reduced masked tokens.
According to the above process, only 24 decoding steps can be used to achieve reasoning for 256 tokens in the basic model, and 8 decoding steps can be used in the super-resolution model. Inference on 4096 tokens, compared to 256 or 4096 steps for the autoregressive model and hundreds of steps for the diffusion model.
Although some recent research including progressive distillation and better ODE solver have greatly reduced the sampling steps of diffusion models, these methods have not been widely verified in large-scale text-to-image generation.
The researchers trained a series of basic Transformer models based on T5-XXL with different parameter amounts (from 600M to 3B).
Quality of generated images
#The experiment tested the Muse model’s ability to handle text prompts with different attributes, including The basic understanding of cardinality is that for non-singular objects, Muse does not generate the same object pixels multiple times, but adds contextual changes to make the entire image more realistic.
For example, the size and direction of the elephant, the color of the wine bottle wrapper, the spin of the tennis ball, etc.
Quantitative comparison
The researchers conducted experimental comparisons with other research methods on the CC3M and COCO data sets, Metrics include Frechet Inception Distance (FID), which measures sample quality and diversity, and CLIP score, which measures image/text alignment.
The experimental results proved that the 632M Muse model achieved SOTA results on CC3M, improved the FID score, and also achieved the best results. Advanced CLIP scoring.
On the MS-COCO data set, the 3B model achieved an FID score of 7.88, which is slightly better than the Parti-3B model with similar parameter amounts. Achieved 8.1 points.
The above is the detailed content of Transformer beats Diffusion again! Google releases Muse, a new generation of text-image generation model: generation efficiency increased tenfold. For more information, please follow other related articles on the PHP Chinese website!