Next-Token Prediction Meets Full-Sequence Diffusion
Diffusion forcing combines the strength of full-sequence diffusion models and next-token models, acting as either or a mix at sampling time for different applications without retraining. Watch this video to learn more.
“With Diffusion Forcing, we are taking a step to bringing video generation and robotics closer together,” says senior author Vincent Sitzmann , MIT assistant professor and member of CSAIL, where he leads the Scene Representation group.
Transcript
00:00:01 (air whooshes) (bright music) - With Diffusion Forcing, the problem we tackled is the one of trying to learn how the world works and how to accomplish certain tasks, first by just watching how other people accomplish these tasks and then afterwards, you know, being immersed in the world yourself and interacting with it. The specific problem we're focusing on
00:00:20 in Diffusion Forcing is called sequence prediction. And the goal there is to basically be able to, given a set of observations, try to predict what would have to happen next in a sequence to get to a certain goal. And specifically, Diffusion Forcing is kind of combining the strengths of two kinds of models that are around today. One is models like ChatGPT,
00:00:37 which we call next-token prediction models, and the other ones are models that generate videos like Sora, which we call video diffusion models. - Now, diffusion policy can actually make robots do very complicated things with their hands. But what we've done here is come up with a really simple example of where the traditional diffusion policy would fail. The reason for that
00:00:56 is the diffusion policies we normally use actually have very little memory. So, it turns out if the current situation, if the thing that they can currently see in their cameras doesn't completely describe what you need to do, then it will fail. So, this is a setup here. The task is very simple. They just need to swap the fruit. But in order to do that,
00:01:16 you have to move the fruit in such a way that the image that you're looking at instantaneously doesn't tell you what the actions you have to execute will be. So, it's a little puzzle, a very simple puzzle, but it shows that you need memory to accomplish this task. Before, that was very hard, and with Diffusion Forcing, it becomes very easy. So, in order to do this multi-step task, the robot is gonna end up in the middle
00:01:42 with a picture in its cameras that looks just like it could have been a different beginning for a different task. So, if it doesn't remember that it's been doing this task in sequence, then it'll be confused and start doing the task again. This is one of the simplest possible tasks we could come up with that requires memory and shows the use of memory. - Diffusion Forcing is not only just about robotics.
00:02:04 It is a generative sequence model that has wide range of applications in videos, natural language processing, and many other things. - In a nutshell, Diffusion Forcing allows you to combine the strengths of these, which is to generate sequences of flexible links. So, it can basically always do one step ahead and, like, always plan one step further, and at the same time, being able to plan a sequence
00:02:26 such that it arrives at a certain goal. The way we tackle this in Diffusion Forcing is by training a model that we give it a video sequence and then we destroy some of the frames in that video sequence to varying degrees. So, imagine some of the frames get really blurry or like really noisy specifically and some of them remain clean. And essentially what we are doing in that way is we are forcing the model to fill in any gap of any video.
00:02:48 So, you could give it the first frame and the last frame and those are very clean and then all of the ones in between are really noisy and then the model is forced to basically fill in the whole gap. And what this does is essentially it combines the strengths of both models like ChatGPT and those like Sora. Specifically, this model at test time will be able to always predict the next frame in the sequence.
00:03:07 And at the same time, you can also have it always predict the next frame in a sequence while still telling it, "At some point, you should arrive at this particular end state." Imagine the robot that you put into the kitchen with dirty dishes. You give it the first image of the kitchen with dirty dishes, you give it an image of the kitchen with clean dishes, and now you ask it, "Okay, please tell me
00:03:24 all of the in-between steps that you need to clean up the dishes." And then the robot will generate the video and that will then exactly be the plan it has to follow to, in fact, clean the dishes. - Because of the flexibility of Diffusion Forcing, we can generate videos that's much, much longer than the maximum length it is trained on. A very powerful property of it. - We're really excited for Diffusion Forcing
00:03:45 as a new building block in this big picture of building machines that can observe the world and then learn from it. The core limitation I think really is that at the time of writing the paper, we trained Diffusion Forcing on relatively small datasets and relatively small models just to really investigate the core methodological improvements over baselines. So, I think one big question is what happens
00:04:04 if you train this model on a much larger dataset and with a much larger compute budget? We're really hopeful that when we are taking Diffusion Forcing and training it on more data with more compute resources, that the model will improve to a similar degree as existing models have improved when we trained them with more data and with more compute. But we haven't showed that just yet. What are the next steps for Diffusion Forcing?
00:04:24 One next step is exactly to train it on larger datasets with more powerful architectures and see if we can get performance improvements that way. Another thing we're really excited about is really putting Diffusion Forcing on a robot and using it as a model for the robot to decide what to do next, what data to collect next, how could it interact with the world to best learn about how the world works.
00:04:44 And so this is a framework that is referred to as reinforcement learning. So, combining reinforcement learning with Diffusion Forcing I think is a really exciting concept as well.