Why does transformer perform so well? Where does the In-Context Learning capability it brings to many large language models come from? In the field of artificial intelligence, transformer has become the dominant model in deep learning, but the theoretical basis for its excellent performance has been insufficiently studied.
Recently, researchers from Google AI, ETH Zurich, and Google DeepMind conducted a new study in an attempt to uncover the secrets of some optimization algorithms in Google AI. In this study, they reverse-engineered the transformer and found some optimization methods. This paper is called "Revealing the Mesa Optimization Algorithm in Transformer"
Paper link: https://arxiv.org/abs/2309.05858
The authors demonstrate that minimizing the universal autoregressive loss results in an auxiliary gradient-based optimization algorithm operating in the forward pass of the Transformer. This phenomenon has recently been called "mesa-optimization." Furthermore, the researchers found that the resulting mesa optimization algorithm exhibited contextual small-shot learning capabilities, independent of model size. The new results therefore complement the principles of small-shot learning that have emerged previously in large language models.
The researchers believe that the success of Transformers is based on architectural biases in its implementation of the Mesa optimization algorithm in the forward pass: (i) defining internal learning goals, and (ii) Optimizing
Figure 1: Illustration of the new hypothesis: optimizing the weights θ of the autoregressive Transformer fθ will produce the mesa implemented in the forward propagation of the model optimization. As input sequence s_1, . . . , s_t is processed to time step t, Transformer (i) creates an internal training set consisting of input-target association pairs, (ii) defines an internal objective function through the result dataset, which is used to measure the performance of the internal model using weights W, (iii) Optimize this objective and use the learned model to generate future predictions.
The contributions of this study include the following:
#Builds on recent work showing that transformers explicitly trained to solve small-shot tasks in context can implement gradient descent (GD) algorithms. Here, the authors show that these results generalize to autoregressive sequence modeling—a typical approach to training LLMs.
First, analyze the Transformer trained on simple linear dynamics. In this case, each sequence is generated by a different W* to prevent cross-sequence memorization. In this simple setup, the researchers show how Transformer creates a mesa dataset and uses preprocessed GD to optimize the mesa objective
The rewritten content is: we can aggregate the token structure of adjacent sequence elements by training a deep transformer. Interestingly, this simple preprocessing method results in a very sparse weight matrix (less than 1% of the weights are non-zero), resulting in a reverse engineering algorithm
For single-layer linear self-attention, the weight corresponds to one gradient descent step. For deep Transformers, interpretability becomes difficult. The study relies on linear detection and examines whether hidden activations can predict autoregressive targets or preprocessed inputs
Interestingly, the predictability of both detection methods scales with network depth. increase gradually. This finding suggests that preprocessed GD is hidden in the model.
Figure 2: Reverse engineering of a trained linear self-attention layer.
The study found that the training layer can be perfectly fitted when all degrees of freedom are used in the construction, including not only the learned learning rate eta, but also a set of learned initial weights W_0 . Importantly, as shown in Figure 2, the learned one-step algorithm still performs far better than a single mesa layer.
With simple weight settings, we can notice that it is easy to find through basic optimization that this layer can optimally solve this research task. This result proves that hard-coded inductive bias is beneficial for mesa optimization
With theoretical insights into the multi-layer case, first analyze deep linear and softmax and pay attention only to Transformer. The authors format the input according to a 4-channel structure, , which corresponds to the choice of W_0 = 0.
As with the single-layer model, the authors see clear structure in the weights of the trained model. As a first reverse engineering analysis, this study exploits this structure and builds an algorithm (RevAlg-d, where d represents the number of layers) containing 16 parameters per layer header (instead of 3200). The authors found that this compressed but complex expression can describe the trained model. In particular, it allows interpolation between actual Transformer and RevAlg-d weights in an almost lossless manner
While the RevAlg-d expression interprets the trained Multi-layer Transformer, but it's hard to interpret it as a mesa optimization algorithm. Therefore, the authors employed linear regression probing analysis (Alain & Bengio, 2017; Akyürek et al., 2023) to find the characteristics of the hypothesized mesa optimization algorithm.
On the deep linear self-attention Transformer shown in Figure 3, we can observe that both probes are capable of linear decoding, and as the sequence length and network depth increase , decoding performance is also increased. Therefore, we discovered a basic optimization algorithm that descends layer by layer based on the original mesa-objective Lt (W) while improving the condition number of the mesa optimization problem. This results in a rapid decline in mesa-objective Lt (W). In addition, we can also observe that as the depth increases, the performance improves significantly
With better preprocessing of the data, the autoregressive objective function Lt ( W), so it can be considered that the rapid descent is achieved by this optimization
Figure 3: Multiple layers of reverse engineering the constructed token input Transformer training.
This shows that if the transformer is trained on the built tokens, it will predict with mesa optimization. Interestingly, when sequence elements are given directly, the transformer will construct the token by itself by grouping the elements, which the research team calls "creating the mesa dataset".
The finding of this study is that when trained using the Transformer model for sequence prediction tasks under standard autoregressive objectives, gradient-based Inference algorithms. Therefore, the latest multi-task and meta-learning results can also be applied to traditional self-supervised LLM training settings
In addition, the study also found that the learned autoregressive inference algorithm can be used in different Re-adapt usage in cases where retraining is required to solve supervised contextual learning tasks and thus interpret results within a unified framework
Then , what is the relationship between these and context learning? According to the study, after training the transformer model, on the autoregressive sequence task, it achieves appropriate mesa optimization and therefore can perform few-shot context learning without any fine-tuning
This study assumes that mesa optimization also exists in LLM, thereby improving its context learning ability. Interestingly, the study also observed that effectively adapting prompts for LLM can also lead to substantial improvements in contextual learning capabilities.
##Interested readers can read the original text of the paper to learn more Research more content.
The above is the detailed content of What is the source of Transformer's contextual learning capabilities?. For more information, please follow other related articles on the PHP Chinese website!