Currently, autoregressive large-scale language models using the next token prediction paradigm have become popular all over the world. At the same time, a large number of synthetic images and videos on the Internet have already shown us the power of diffusion models.
Recently, a research team at MIT CSAIL (one of whom is Chen Boyuan, a PhD student at MIT) successfully integrated the powerful capabilities of the full sequence diffusion model and the next token model, and proposed a training and sampling paradigm: Diffusion Forcing(DF).
Paper title: Diffusion Forcing: Next-token Prediction Meets Full-Sequence Diffusion
Paper address: https://arxiv.org/pdf/2407.01392
Project website: https://arxiv.org/pdf/2407.01392 /boyuan.space/diffusion-forcing
Code address: https://github.com/buoyancy99/diffusion-forcing
As shown below, diffusion forcing clearly outperforms all in terms of consistency and stability Two methods are sequence diffusion and teacher forcing.
In this framework, each token is associated with a random, independent noise level, and a shared next token prediction model or next token prediction model can be used according to an arbitrary, independent, per-token scheme Denoise the token. The research inspiration of this method comes from this observation: the process of adding noise to the token is a form of partial masking process - zero noise means that the token is not masked, while complete noise is completely Masking token. Therefore, DF forces the model to learn a mask that removes any variable set of noisy tokens (Figure 2). At the same time, by parameterizing the prediction method as a combination of multiple next token prediction models, the system can flexibly generate sequences of different lengths and generalize to new trajectories in a combinatorial manner (Figure 1 ). The team implemented the DF used for sequence generation into Causal Diffusion Forcing (CDF), in which future tokens depend on past tokens through a causal architecture. They trained the model to denoise all tokens of a sequence (where each token has an independent noise level) at once. During sampling, CDF will gradually denoise a sequence of Gaussian noise frames into clean samples, where different frames may have different noise levels at each denoising step. Similar to the next token prediction model, CDF can generate sequences of variable length; unlike the next token prediction, CDF's performance is very stable - whether it is predicting the next token, thousands of tokens in the future, or even continuously token. Additionally, similar to Full Sequence Diffusion, it can also receive guidance, allowing for high reward generation. By collaboratively leveraging causality, flexible scope, and variable noise scheduling, CDF enables a new feature: Monte Carlo Tree Guidance (MCTG). Compared with the non-causal full sequence diffusion model, MCTG can greatly improve the sampling rate of high reward generation. Figure 1 gives an overview of these capabilities. Diffusion Forcing (diffusion forcing) 1. Treat the noise adding process as a partial mask First of all, we can treat any token set (whether it is a sequence or not) as An ordered collection indexed by t. Then, using teacher forcing to train the next token prediction can be interpreted as masking out each token x_t at time t and predicting them based on the past x_{1:t−1}. For sequences, this operation can be described as: performing masking along the timeline. We can think of full-sequence forward diffusion (i.e. the process of gradually adding noise to the data ) as a kind of partial masking, which can be called "performing masking along the noise axis". In fact, after adding noise in K steps, is (probably) white noise, and there is no longer any information about the original data. As shown in Figure 2, the team established a unified perspective to look at the edge. Masks for these two axes. 2. Diffusion forcing: Different tokens have different noise levels The diffusion forcing (DF) framework can be used to train and sample noisy tokens of arbitrary sequence lengths , where The key is that the noise level k_t of each token changes with time steps.
This paper focuses on time series data, so they instantiate DF through a causal architecture, and thus get. Causal diffusion forcing (CDF). Simply put, this is a minimal implementation obtained using a basic recurrent neural network (RNN). An RNN with weight θ maintains a hidden state z_t that is informed of the influence of past tokens. It will evolve according to the dynamic through a loop layer.When an input noise observation
is obtained, the hidden state is updated in a Markovian manner.
When k_t=0, this is the posterior update in Bayesian filtering; and when k_t=K (pure noise, no information), this is equivalent to modeling Bayesian filtering. "Posterior distribution" p_θ(z_t | z_{t−1}). Given the hidden state z_t, the goal of the observation model p_θ(x_t^0 | z_t) is to predict x_t; the input-output behavior of this unit is the same as the standard conditional diffusion model: with the condition variable z_{t−1 } and noisy token as input, predict the noiseless x_t=x_t^0, and thereby indirectly predict the noise ε^{k_t} through affine reparameterization. Therefore, we can directly use the classic diffusion target to train (causal) diffusion forcing. According to the noise prediction result ε_θ, the above unit can be parameterized. Then, the parameters θ are found by minimizing the following loss: Algorithm 1 gives the pseudocode. The point is that this loss captures key elements of Bayesian filtering and conditional diffusion. The team also further re-inferred common techniques used in diffusion model training for diffusion forcing, as detailed in the appendix of the original paper. They also arrived at an informal theorem. Theorem 3.1 (informal). The diffusion-forced training procedure (Algorithm 1) is a reweighting that optimizes the evidence lower bound (ELBO) on the expected log-likelihood , where the expected value is averaged over the noise level and is noisy according to a forward process. In addition, under appropriate conditions, optimizing (3.1) can also maximize the lower likelihood limit of all noise level sequences simultaneously. Diffusion forced sampling and the resulting capabilityAlgorithm 2 describes the sampling process, which is defined as: in a two-dimensional M × T grid K ∈ [K]^{M×T } specifies the noise schedule; where the columns correspond to time steps t and the rows indexed by m determine the noise level. To generate the entire sequence of length T, token x_{1:T} is first initialized to white noise, corresponding to the noise level k = K. It then iterates row-by-row down the grid and denoises column-by-column from left to right until the noise level reaches K. By the time m = 0 in the last row, the noise of the token has been cleaned up, that is, the noise level is K_{0,t} ≡ 0. This sampling paradigm will bring the following new capabilities:
- Stable autoregressive generation
- Keep the future uncertain
- Long-term guidance capability
Use diffusion forcing for flexible sequence decisionsThe new ability of diffusion forcing also brings new possibilities. Based on this, the team designed a new framework for sequence decision-making (SDM) and successfully applied it to the fields of robots and autonomous agents. First, define a Markov decision process with dynamic p (s_{t+1}|s_t, a_t), observation p (o_t|s_t) and reward p (r_t|s_t, a_t) . The goal here is to train a policy π(a_t|o_{1:t}) to maximize the expected cumulative reward of the trajectory . Here the token x_t = [a_t, r_t, o_{t+1}] is allocated. A trajectory is a sequence x_{1:T}, whose length may be variable; the training method is as shown in Algorithm 1. At each step t of the execution process, there is a hidden state z_{t-1} summarizing the past noise-free token x_{1:t-1}.Based on this hidden state, a plan is sampled according to Algorithm 2, where contains predicted actions, rewards and observations. H is a forward observation window, similar to future predictions in model predictive control. After taking the planned action, the environment gets a reward and the next observation, and thus the next token. The hidden state can be updated according to the posterior p_θ(z_t|z_{t−1}, x_t, 0). The framework can be used as both a strategy and a planner, and its advantages include:
- with flexible planning horizons
- enables flexible reward guidance
- can be achieved Monte Carlo Tree Guidance (MCTG) to achieve future uncertainty
The team evaluated the advantages of diffusion forcing as a generative sequence model involving video and time series forecasting , planning and imitation learning and other applications. Video prediction: consistent and stable sequence generation and infinite expansion For the video generative modeling task, they trained a convolutional RNN for causal diffusion enforcement based on Minecraft game videos and DMLab navigation accomplish. Figure 3 shows the qualitative results of diffusion forcing versus baseline. It can be seen that diffusion forcing can unfold stably, even beyond its training range; while teacher forcing and full sequence diffusion benchmarks will diverge quickly. Diffusion planning: MCTG, causal uncertainty, flexible scope controlThe ability to diffuse forcing can bring unique benefits to decision-making. The team evaluated the newly proposed decision-making framework using D4RL, a standard offline reinforcement learning framework. Table 1 gives the qualitative and quantitative evaluation results. As can be seen, Diffusion Enforcement outperforms Diffuser and all baselines in all 6 environments. Controllable sequence combination generation The team found that it was possible to flexibly combine subsequences of sequences observed at training time simply by modifying the sampling scheme. They conducted experiments using a 2D trajectory dataset: on a square plane, all trajectories start from one corner and end up at the opposite corner, forming a kind of cross shape. As shown in Figure 1 above, when combination behavior is not required, DF can be allowed to maintain complete memory and replicate the distribution of the cross. When combination is required, the model can be used to generate shorter plans memorylessly using MPC, thereby stitching the sub-trajectories of this cross to obtain a V-shaped trajectory. Robots: Long-range imitation learning and robust visual motion controlDiffusion forcing also brings new opportunities for visual motion control of real robots. Imitation learning is a commonly used robot control technique that learns mappings of observed actions demonstrated by experts. However, a lack of memory often makes imitation learning difficult for long-range tasks. DF can not only alleviate this shortcoming, but also make imitation learning more robust. Use memory for imitation learning. By remotely controlling the Franka robot, the team collected a video and motion data set. As shown in Figure 4, the task is to swap the positions of apples and oranges using the third position. The initial position of the fruit is random, so there are two possible goal states. Furthermore, when there is a fruit in the third position, the desired result cannot be inferred from the current observation - the strategy must remember the initial configuration in order to decide which fruit to move.Unlike commonly used behavior cloning methods, DF can naturally integrate memories into its own hidden state. It was found that DF achieved an 80% success rate, while the diffusion strategy (currently the best memoryless imitation learning algorithm) failed. In addition, DF can also deal with noise more robustly and facilitate robot pre-training. Time Series Forecasting: Diffusion forcing is an excellent general sequence modelFor multivariate time series forecasting tasks, the team’s research shows that DF is sufficient to compete with previous diffusion models and Transformer-based Model comparable. Please refer to the original paper for more technical details and experimental results. The above is the detailed content of Unlimited video generation, planning and decision-making, diffusion forced integration of next token prediction and full sequence diffusion. For more information, please follow other related articles on the PHP Chinese website!