\sidecaptionvpos

figurec

Diffusion for World Modeling:
Visual Details Matter in Atarithanks: To prevent confusion, this is the final version of (Alonso et al.,, 2023) and is not related to (Ding et al.,, 2024).

Eloi Alonso
University of Geneva &Adam Jelley
University of Edinburgh &Vincent Micheli
University of Geneva &Anssi Kanervisto
Microsoft Research &Amos Storkey
University of Edinburgh &Tim Pearce
 Microsoft Research &François Fleuret
University of Geneva
Equal contribution. Equal supervision. Contact: eloi.alonso@unige.ch and adam.jelley@ed.ac.uk
Abstract

World models constitute a promising approach for training reinforcement learning agents in a safe and sample-efficient manner. Recent world models predominantly operate on sequences of discrete latent variables to model environment dynamics. However, this compression into a compact discrete representation may ignore visual details that are important for reinforcement learning. Concurrently, diffusion models have become a dominant approach for image generation, challenging well-established methods modeling discrete latents. Motivated by this paradigm shift, we introduce diamond (DIffusion As a Model Of eNvironment Dreams), a reinforcement learning agent trained in a diffusion world model. We analyze the key design choices that are required to make diffusion suitable for world modeling, and demonstrate how improved visual details can lead to improved agent performance. diamond achieves a mean human normalized score of 1.46 on the competitive Atari 100k benchmark; a new best for agents trained entirely within a world model. We further demonstrate that diamond’s diffusion world model can stand alone as an interactive neural game engine by training on static Counter-Strike: Global Offensive gameplay. To foster future research on diffusion for world modeling, we release our code, agents, videos and playable world models at https://diamond-wm.github.io.

1 Introduction

Generative models of environments, or “world models" (Ha and Schmidhuber,, 2018), are becoming increasingly important as a component for generalist agents to plan and reason about their environment (LeCun,, 2022). Reinforcement Learning (RL) has demonstrated a wide variety of successes in recent years (Silver et al.,, 2016; Degrave et al.,, 2022; Ouyang et al.,, 2022), but is well-known to be sample inefficient, which limits real-world applications. World models have shown promise for training reinforcement learning agents across diverse environments (Hafner et al.,, 2023; Schrittwieser et al.,, 2020), with greatly improved sample-efficiency (Ye et al.,, 2021), which can enable learning from experience in the real world (Wu et al.,, 2023).

AAAAat1subscript𝑎𝑡1a_{t-1}italic_a start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPTAAAAatsubscript𝑎𝑡a_{t}italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPTAAAAaT1subscript𝑎𝑇1a_{T-1}italic_a start_POSTSUBSCRIPT italic_T - 1 end_POSTSUBSCRIPT𝐱t10subscriptsuperscript𝐱0𝑡1\mathbf{x}^{0}_{t-1}bold_x start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPTRefer to caption\vdotsRefer to caption\vdotsatsubscript𝑎𝑡a_{t}italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT𝐱t𝒯subscriptsuperscript𝐱𝒯𝑡\mathbf{x}^{\mathcal{T}}_{t}bold_x start_POSTSUPERSCRIPT caligraphic_T end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT𝐱tτsubscriptsuperscript𝐱𝜏𝑡\mathbf{x}^{\tau}_{t}bold_x start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT𝐱t0subscriptsuperscript𝐱0𝑡\mathbf{x}^{0}_{t}bold_x start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT𝐱t0subscriptsuperscript𝐱0𝑡\mathbf{x}^{0}_{t}bold_x start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPTRefer to caption\vdotsRefer to caption\vdotsat+1subscript𝑎𝑡1a_{t+1}italic_a start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT𝐱t+1𝒯subscriptsuperscript𝐱𝒯𝑡1\mathbf{x}^{\mathcal{T}}_{t+1}bold_x start_POSTSUPERSCRIPT caligraphic_T end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT𝐱t+1τsubscriptsuperscript𝐱𝜏𝑡1\mathbf{x}^{\tau}_{t+1}bold_x start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT𝐱t+10subscriptsuperscript𝐱0𝑡1\mathbf{x}^{0}_{t+1}bold_x start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT𝐱T10subscriptsuperscript𝐱0𝑇1\mathbf{x}^{0}_{T-1}bold_x start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T - 1 end_POSTSUBSCRIPTRefer to caption\vdotsRefer to caption\vdots𝐱T𝒯subscriptsuperscript𝐱𝒯𝑇\mathbf{x}^{\mathcal{T}}_{T}bold_x start_POSTSUPERSCRIPT caligraphic_T end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT𝐱Tτsubscriptsuperscript𝐱𝜏𝑇\mathbf{x}^{\tau}_{T}bold_x start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT𝐱T0subscriptsuperscript𝐱0𝑇\mathbf{x}^{0}_{T}bold_x start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPTπϕsubscript𝜋italic-ϕ\pi_{\phi}italic_π start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPTπϕsubscript𝜋italic-ϕ\pi_{\phi}italic_π start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT𝐃θsubscript𝐃𝜃\mathbf{D}_{\theta}bold_D start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT𝐃θsubscript𝐃𝜃\mathbf{D}_{\theta}bold_D start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT𝐃θsubscript𝐃𝜃\mathbf{D}_{\theta}bold_D start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT𝐃θsubscript𝐃𝜃\mathbf{D}_{\theta}bold_D start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT𝐃θsubscript𝐃𝜃\mathbf{D}_{\theta}bold_D start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT𝐃θsubscript𝐃𝜃\mathbf{D}_{\theta}bold_D start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT\dotsEnvironment time (t𝑡titalic_t)Denoising time (τ𝜏\tauitalic_τ)Conditioning
Figure 1: Unrolling imagination of diamond over time. The top row depicts a policy πϕsubscript𝜋italic-ϕ\pi_{\phi}italic_π start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT taking a sequence of actions in the imagination of our learned diffusion world model 𝐃θsubscript𝐃𝜃\mathbf{D}_{\theta}bold_D start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT. The environment time t𝑡titalic_t flows along the horizontal axis, while the vertical axis represents the denoising time τ𝜏\tauitalic_τ flowing backward from 𝒯𝒯\mathcal{T}caligraphic_T to 00. Concretely, given (clean) past observations 𝐱<t0subscriptsuperscript𝐱0absent𝑡\mathbf{x}^{0}_{<t}bold_x start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT < italic_t end_POSTSUBSCRIPT, actions a<tsubscript𝑎absent𝑡a_{<t}italic_a start_POSTSUBSCRIPT < italic_t end_POSTSUBSCRIPT, and starting from an initial noisy sample 𝐱t𝒯superscriptsubscript𝐱𝑡𝒯\mathbf{x}_{t}^{\mathcal{T}}bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_T end_POSTSUPERSCRIPT, we simulate a reverse noising process {𝐱tτ}τ=𝒯τ=0superscriptsubscriptsuperscriptsubscript𝐱𝑡𝜏𝜏𝒯𝜏0\{\mathbf{x}_{t}^{\tau}\}_{\tau=\mathcal{T}}^{\tau=0}{ bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_τ = caligraphic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ = 0 end_POSTSUPERSCRIPT by repeatedly calling 𝐃θsubscript𝐃𝜃\mathbf{D}_{\theta}bold_D start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT, and obtain the (clean) next observation 𝐱t0superscriptsubscript𝐱𝑡0\mathbf{x}_{t}^{0}bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT. The imagination procedure is autoregressive in that the predicted observation 𝐱t0superscriptsubscript𝐱𝑡0\mathbf{x}_{t}^{0}bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT and the action atsubscript𝑎𝑡a_{t}italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT taken by the policy become part of the conditioning for the next time step. Animated visualizations of this procedure can be found at https://diamond-wm.github.io.

Recent world modeling methods (Hafner et al.,, 2021; Micheli et al.,, 2023; Robine et al.,, 2023; Hafner et al.,, 2023; Zhang et al.,, 2023) often model environment dynamics as a sequence of discrete latent variables. Discretization of the latent space helps to avoid compounding error over multi-step time horizons. However, this encoding may lose information, resulting in a loss of generality and reconstruction quality. This may be problematic for more real-world scenarios where the information required for the task is less well-defined, such as training autonomous vehicles (Hu et al.,, 2023). In this case, small details in the visual input, such as a traffic light or a pedestrian in the distance, may change the policy of an agent. Increasing the number of discrete latents can mitigate this lossy compression, but comes with an increased computational cost (Micheli et al.,, 2023).

In the meantime, diffusion models (Sohl-Dickstein et al.,, 2015; Ho et al.,, 2020; Song et al.,, 2020) have become a dominant paradigm for high-resolution image generation (Rombach et al.,, 2022; Podell et al.,, 2023). This class of methods, in which the model learns to reverse a noising process, challenges well-established approaches modeling discrete tokens (Esser et al.,, 2021; Ramesh et al.,, 2021; Chang et al.,, 2023), and thereby offers a promising alternative to alleviate the need for discretization in world modeling. Additionally, diffusion models are known to be easily conditionable and to flexibly model complex multi-modal distributions without mode collapse. These properties are instrumental to world modeling, since adherence to conditioning should allow a world model to reflect an agent’s actions more closely, resulting in more reliable credit assignment, and modeling multi-modal distributions should provide greater diversity of training scenarios for an agent.

Motivated by these characteristics, we propose diamond (DIffusion As a Model Of eNvironment Dreams), a reinforcement learning agent trained in a diffusion world model. Careful design choices are necessary to ensure our diffusion world model is efficient and stable over long-time horizons, and we provide a qualitative analysis to illustrate their importance. diamond achieves a mean human normalized score of 1.46 on the well-established Atari 100k benchmark; a new state of the art for agents trained entirely within a world model. Additionally, operating in image space has the benefit of enabling our diffusion world model to be a drop-in substitute for the environment, which provides greater insight into world model and agent behaviors. In particular, we find the improved performance in some games follows from better modeling of critical visual details. To further demonstrate the effectiveness of our world model in isolation, we train diamond’s diffusion world model on 87878787 hours of static Counter-Strike: Global Offensive (CSGO) gameplay (Pearce and Zhu,, 2022) to produce an interactive neural game engine for the popular in-game map, Dust II. We release our code, agents and playable world models at https://diamond-wm.github.io.

2 Preliminaries

2.1 Reinforcement learning and world models

We model the environment as a standard Partially Observable Markov Decision Process (pomdp) (Sutton and Barto,, 2018), (𝒮,𝒜,𝒪,T,R,O,γ)𝒮𝒜𝒪𝑇𝑅𝑂𝛾(\mathcal{S},\mathcal{A},\mathcal{O},T,R,O,\gamma)( caligraphic_S , caligraphic_A , caligraphic_O , italic_T , italic_R , italic_O , italic_γ ), where 𝒮𝒮\mathcal{S}caligraphic_S is a set of states, 𝒜𝒜\mathcal{A}caligraphic_A is a set of discrete actions, and 𝒪𝒪\mathcal{O}caligraphic_O is a set of image observations. The transition function T:𝒮×𝒜×𝒮[0,1]:𝑇𝒮𝒜𝒮01T:\mathcal{S}\times\mathcal{A}\times\mathcal{S}\to[0,1]italic_T : caligraphic_S × caligraphic_A × caligraphic_S → [ 0 , 1 ] describes the environment dynamics p(𝐬t+1𝐬t,𝐚t)𝑝conditionalsubscript𝐬𝑡1subscript𝐬𝑡subscript𝐚𝑡p(\mathbf{s}_{t+1}\mid\mathbf{s}_{t},\mathbf{a}_{t})italic_p ( bold_s start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ∣ bold_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ), and the reward function R:𝒮×𝒜×𝒮:𝑅𝒮𝒜𝒮R:\mathcal{S}\times\mathcal{A}\times\mathcal{S}\to\mathbb{R}italic_R : caligraphic_S × caligraphic_A × caligraphic_S → blackboard_R maps transitions to scalar rewards. Agents cannot directly access states stsubscript𝑠𝑡s_{t}italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and only see the environment through image observations xt𝒪subscript𝑥𝑡𝒪x_{t}\in\mathcal{O}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ caligraphic_O, emitted according to observation probabilities p(𝐱t𝐬t)𝑝conditionalsubscript𝐱𝑡subscript𝐬𝑡p(\mathbf{x}_{t}\mid\mathbf{s}_{t})italic_p ( bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∣ bold_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ), described by the observation function O:𝒮×𝒪[0,1]:𝑂𝒮𝒪01O:\mathcal{S}\times\mathcal{O}\to[0,1]italic_O : caligraphic_S × caligraphic_O → [ 0 , 1 ]. The goal is to obtain a policy π𝜋\piitalic_π that maps observations to actions in order to maximize the expected discounted return 𝔼π[t0γtrt]subscript𝔼𝜋delimited-[]subscript𝑡0superscript𝛾𝑡subscript𝑟𝑡\mathbb{E}_{\pi}[\sum_{t\geq 0}\gamma^{t}r_{t}]blackboard_E start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT [ ∑ start_POSTSUBSCRIPT italic_t ≥ 0 end_POSTSUBSCRIPT italic_γ start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ], where γ[0,1]𝛾01\gamma\in[0,1]italic_γ ∈ [ 0 , 1 ] is a discount factor. World models (Ha and Schmidhuber,, 2018) are generative models of environments, i.e. models of p(st+1,rtst,at)𝑝subscript𝑠𝑡1conditionalsubscript𝑟𝑡subscript𝑠𝑡subscript𝑎𝑡p(s_{t+1},r_{t}\mid s_{t},a_{t})italic_p ( italic_s start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT , italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∣ italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ). These models can be used as simulated environments to train RL agents (Sutton,, 1991) in a sample-efficient manner (Wu et al.,, 2023). In this paradigm, the training procedure typically consists of cycling through the three following steps: collect data with the RL agent in the real environment; train the world model on all the collected data; train the RL agent in the world model environment (commonly referred to as "in imagination").

2.2 Score-based diffusion models

Diffusion models (Sohl-Dickstein et al.,, 2015) are a class of generative models inspired by non-equilibrium thermodynamics that generate samples by reversing a noising process.

We consider a diffusion process {𝐱τ}τ[0,𝒯]subscriptsuperscript𝐱𝜏𝜏0𝒯\{\mathbf{x}^{\tau}\}_{\tau\in[0,\mathcal{T}]}{ bold_x start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_τ ∈ [ 0 , caligraphic_T ] end_POSTSUBSCRIPT indexed by a continuous time variable τ[0,𝒯]𝜏0𝒯\tau\in[0,\mathcal{T}]italic_τ ∈ [ 0 , caligraphic_T ], with corresponding marginals {pτ}τ[0,𝒯]subscriptsuperscript𝑝𝜏𝜏0𝒯\{p^{\tau}\}_{\tau\in[0,\mathcal{T}]}{ italic_p start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_τ ∈ [ 0 , caligraphic_T ] end_POSTSUBSCRIPT, and boundary conditions p0=pdatasuperscript𝑝0superscript𝑝𝑑𝑎𝑡𝑎p^{0}=p^{data}italic_p start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT = italic_p start_POSTSUPERSCRIPT italic_d italic_a italic_t italic_a end_POSTSUPERSCRIPT and p𝒯=ppriorsuperscript𝑝𝒯superscript𝑝𝑝𝑟𝑖𝑜𝑟p^{\mathcal{T}}=p^{prior}italic_p start_POSTSUPERSCRIPT caligraphic_T end_POSTSUPERSCRIPT = italic_p start_POSTSUPERSCRIPT italic_p italic_r italic_i italic_o italic_r end_POSTSUPERSCRIPT, where ppriorsuperscript𝑝𝑝𝑟𝑖𝑜𝑟p^{prior}italic_p start_POSTSUPERSCRIPT italic_p italic_r italic_i italic_o italic_r end_POSTSUPERSCRIPT is a tractable unstructured prior distribution, such as a Gaussian. Note that we use τ𝜏\tauitalic_τ and superscript for the diffusion process time, in order to keep t𝑡titalic_t and subscript for the environment time.

This diffusion process can be described as the solution to a standard stochastic differential equation (SDE) (Song et al.,, 2020),

d𝐱=𝐟(𝐱,τ)dτ+g(τ)d𝐰,𝑑𝐱𝐟𝐱𝜏𝑑𝜏𝑔𝜏𝑑𝐰d\mathbf{x}=\mathbf{f}(\mathbf{x},\tau)d\tau+g(\tau)d\mathbf{w},italic_d bold_x = bold_f ( bold_x , italic_τ ) italic_d italic_τ + italic_g ( italic_τ ) italic_d bold_w , (1)

where 𝐰𝐰\mathbf{w}bold_w is the Wiener process (Brownian motion), 𝐟𝐟\mathbf{f}bold_f a vector-valued function acting as a drift coefficient, and g𝑔gitalic_g a scalar-valued function known as the diffusion coefficient of the process.

To obtain a generative model, which maps from noise to data, we must reverse this process. Remarkably, Anderson, (1982) shows that the reverse process is also a diffusion process, running backwards in time, and described by the following SDE,

d𝐱=[𝐟(𝐱,τ)g(τ)2𝐱logpτ(𝐱)]dτ+g(τ)d𝐰¯,𝑑𝐱delimited-[]𝐟𝐱𝜏𝑔superscript𝜏2subscript𝐱superscript𝑝𝜏𝐱𝑑𝜏𝑔𝜏𝑑¯𝐰d\mathbf{x}=[\mathbf{f}(\mathbf{x},\tau)-g(\tau)^{2}\nabla_{\mathbf{x}}\log p^% {\tau}(\mathbf{x})]d\tau+g(\tau)d\bar{\mathbf{w}},italic_d bold_x = [ bold_f ( bold_x , italic_τ ) - italic_g ( italic_τ ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_x end_POSTSUBSCRIPT roman_log italic_p start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT ( bold_x ) ] italic_d italic_τ + italic_g ( italic_τ ) italic_d over¯ start_ARG bold_w end_ARG , (2)

where 𝐰¯¯𝐰\bar{\mathbf{w}}over¯ start_ARG bold_w end_ARG is the reverse-time Wiener process, and 𝐱logpτ(𝐱)subscript𝐱superscript𝑝𝜏𝐱\nabla_{\mathbf{x}}\log p^{\tau}(\mathbf{x})∇ start_POSTSUBSCRIPT bold_x end_POSTSUBSCRIPT roman_log italic_p start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT ( bold_x ) is the (Stein) score function, the gradient of the log-marginals with respect to the support. Therefore, to reverse the forward noising process, we are left to define the functions f𝑓fitalic_f and g𝑔gitalic_g (in Section 3.1), and to estimate the unknown score functions 𝐱logpτ(𝐱)subscript𝐱superscript𝑝𝜏𝐱\nabla_{\mathbf{x}}\log p^{\tau}(\mathbf{x})∇ start_POSTSUBSCRIPT bold_x end_POSTSUBSCRIPT roman_log italic_p start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT ( bold_x ), associated with marginals {pτ}τ[0,𝒯]subscriptsuperscript𝑝𝜏𝜏0𝒯\{p^{\tau}\}_{\tau\in[0,\mathcal{T}]}{ italic_p start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_τ ∈ [ 0 , caligraphic_T ] end_POSTSUBSCRIPT along the process. In practice, it is possible to use a single time-dependent score model 𝐒θ(𝐱,τ)subscript𝐒𝜃𝐱𝜏\mathbf{S}_{\theta}(\mathbf{x},\tau)bold_S start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_x , italic_τ ) to estimate these score functions (Song et al.,, 2020).

At any point in time, estimating the score function is not trivial since we do not have access to the true score function. Fortunately, Hyvärinen, (2005) introduces the score matching objective, which surprisingly enables training a score model from data samples without the knowledge of the underlying score function. To access samples from the marginal pτsuperscript𝑝𝜏p^{\tau}italic_p start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT, we need to simulate the forward process from time 00 to time τ𝜏\tauitalic_τ, as we only have clean data samples. This is costly in general, but if f𝑓fitalic_f is affine, we can analytically reach any time τ𝜏\tauitalic_τ in the forward process in a single step, by applying a Gaussian perturbation kernel p0τsuperscript𝑝0𝜏p^{0\tau}italic_p start_POSTSUPERSCRIPT 0 italic_τ end_POSTSUPERSCRIPT to clean data samples (Song et al.,, 2020). Since the kernel is differentiable, score matching simplifies to a denoising score matching objective (Vincent,, 2011),

(θ)=𝔼[𝐒θ(𝐱τ,τ)𝐱τlogp0τ(𝐱τ𝐱0)2],\mathcal{L}(\theta)=\mathbb{E}\left[\|\mathbf{S}_{\theta}(\mathbf{x}^{\tau},% \tau)-\nabla_{\mathbf{x}^{\tau}}\log p^{0\tau}(\mathbf{x}^{\tau}\mid\mathbf{x}% ^{0})\|^{2}\right],caligraphic_L ( italic_θ ) = blackboard_E [ ∥ bold_S start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_x start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT , italic_τ ) - ∇ start_POSTSUBSCRIPT bold_x start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUPERSCRIPT 0 italic_τ end_POSTSUPERSCRIPT ( bold_x start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT ∣ bold_x start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] , (3)

where the expectation is over diffusion time τ𝜏\tauitalic_τ, noised sample 𝐱τp0τ(𝐱τ𝐱0)similar-tosuperscript𝐱𝜏superscript𝑝0𝜏conditionalsuperscript𝐱𝜏superscript𝐱0\mathbf{x}^{\tau}\sim p^{0\tau}(\mathbf{x}^{\tau}\mid\mathbf{x}^{0})bold_x start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT ∼ italic_p start_POSTSUPERSCRIPT 0 italic_τ end_POSTSUPERSCRIPT ( bold_x start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT ∣ bold_x start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ), obtained by applying the τ𝜏\tauitalic_τ-level perturbation kernel to a clean sample 𝐱0pdata(𝐱0)similar-tosuperscript𝐱0superscript𝑝𝑑𝑎𝑡𝑎superscript𝐱0\mathbf{x}^{0}\sim p^{data}(\mathbf{x}^{0})bold_x start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ∼ italic_p start_POSTSUPERSCRIPT italic_d italic_a italic_t italic_a end_POSTSUPERSCRIPT ( bold_x start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ). Importantly, as the kernel p0τsuperscript𝑝0𝜏p^{0\tau}italic_p start_POSTSUPERSCRIPT 0 italic_τ end_POSTSUPERSCRIPT is a known Gaussian distribution, this objective becomes a simple L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT reconstruction loss,

(θ)=𝔼[𝐃θ(𝐱τ,τ)𝐱02],𝜃𝔼delimited-[]superscriptnormsubscript𝐃𝜃superscript𝐱𝜏𝜏superscript𝐱02\mathcal{L}(\theta)=\mathbb{E}\left[\|\mathbf{D}_{\theta}(\mathbf{x}^{\tau},% \tau)-\mathbf{x}^{0}\|^{2}\right],caligraphic_L ( italic_θ ) = blackboard_E [ ∥ bold_D start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_x start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT , italic_τ ) - bold_x start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] , (4)

with reparameterization 𝐃θ(𝐱τ,τ)=𝐒θ(𝐱τ,τ)σ2(τ)+𝐱τsubscript𝐃𝜃superscript𝐱𝜏𝜏subscript𝐒𝜃superscript𝐱𝜏𝜏superscript𝜎2𝜏superscript𝐱𝜏\mathbf{D}_{\theta}(\mathbf{x}^{\tau},\tau)=\mathbf{S}_{\theta}(\mathbf{x}^{% \tau},\tau)\sigma^{2}(\tau)+\mathbf{x}^{\tau}bold_D start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_x start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT , italic_τ ) = bold_S start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_x start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT , italic_τ ) italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_τ ) + bold_x start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT, where σ(τ)𝜎𝜏\sigma(\tau)italic_σ ( italic_τ ) is the variance of the τ𝜏\tauitalic_τ-level perturbation kernel.

2.3 Diffusion for world modeling

The score-based diffusion model described in Section 2.2 provides an unconditional generative model of pdatasubscript𝑝𝑑𝑎𝑡𝑎p_{data}italic_p start_POSTSUBSCRIPT italic_d italic_a italic_t italic_a end_POSTSUBSCRIPT. To serve as a world model, we need a conditional generative model of the environment dynamics, p(𝐱t+1𝐱t,at)𝑝conditionalsubscript𝐱𝑡1subscript𝐱absent𝑡subscript𝑎absent𝑡p(\mathbf{x}_{t+1}\mid\mathbf{x}_{\leq t},a_{\leq t})italic_p ( bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ∣ bold_x start_POSTSUBSCRIPT ≤ italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT ≤ italic_t end_POSTSUBSCRIPT ), where we consider the general case of a pomdp, in which the Markovian state stsubscript𝑠𝑡s_{t}italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is unknown and can be approximated from past observations and actions. We can condition a diffusion model on this history, to estimate and generate the next observation directly, as shown in Figure 1. This modifies Equation 4 as follows,

(θ)=𝔼[𝐃θ(𝐱t+1τ,τ,𝐱t0,at)𝐱t+102].𝜃𝔼delimited-[]superscriptnormsubscript𝐃𝜃superscriptsubscript𝐱𝑡1𝜏𝜏superscriptsubscript𝐱absent𝑡0subscript𝑎absent𝑡superscriptsubscript𝐱𝑡102\mathcal{L}(\theta)=\mathbb{E}\left[\|\mathbf{D}_{\theta}(\mathbf{x}_{t+1}^{% \tau},\tau,\mathbf{x}_{\leq t}^{0},a_{\leq t})-\mathbf{x}_{t+1}^{0}\|^{2}% \right].caligraphic_L ( italic_θ ) = blackboard_E [ ∥ bold_D start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT , italic_τ , bold_x start_POSTSUBSCRIPT ≤ italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT , italic_a start_POSTSUBSCRIPT ≤ italic_t end_POSTSUBSCRIPT ) - bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] . (5)

During training, we sample a trajectory segment 𝐱t0,at,𝐱t+10superscriptsubscript𝐱absent𝑡0subscript𝑎absent𝑡superscriptsubscript𝐱𝑡10\mathbf{x}_{\leq t}^{0},a_{\leq t},\mathbf{x}_{t+1}^{0}bold_x start_POSTSUBSCRIPT ≤ italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT , italic_a start_POSTSUBSCRIPT ≤ italic_t end_POSTSUBSCRIPT , bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT from the agent’s replay dataset, and we obtain the noised next observation 𝐱t+1τp0τ(𝐱t+1τ𝐱t+10)similar-tosuperscriptsubscript𝐱𝑡1𝜏superscript𝑝0𝜏conditionalsuperscriptsubscript𝐱𝑡1𝜏superscriptsubscript𝐱𝑡10\mathbf{x}_{t+1}^{\tau}\sim p^{0\tau}(\mathbf{x}_{t+1}^{\tau}\mid\mathbf{x}_{t% +1}^{0})bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT ∼ italic_p start_POSTSUPERSCRIPT 0 italic_τ end_POSTSUPERSCRIPT ( bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT ∣ bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ) by applying the τ𝜏\tauitalic_τ-level perturbation kernel. In summary, this diffusion process for world modeling resembles the standard diffusion process described in Section 2.2, with a score model conditioned on past observations and actions.

To sample the next observation, we iteratively solve the reverse SDE in Equation 2, as illustrated in Figure 1. While we can in principle resort to any ODE or SDE solver, there is an inherent trade-off between sampling quality and Number of Function Evaluations (NFE), that directly determines the inference cost of the diffusion world model (see Appendix A for more details).

3 Method

3.1 Practical choice of diffusion paradigm

Building on the background provided in Section 2, we now introduce diamond as a practical realization of a diffusion-based world model. In particular, we now define the drift and diffusion coefficients 𝐟𝐟\mathbf{f}bold_f and g𝑔gitalic_g introduced in Section 2.2, corresponding to a particular choice of diffusion paradigm. While ddpm (Ho et al.,, 2020) is an example of one such choice (as described in Appendix B) and would historically be the natural candidate, we instead build upon the edm formulation proposed in Karras et al., (2022). The practical implications of this choice are discussed in Section 5.1. In what follows, we describe how we adapt edm to build our diffusion-based world model.

We consider the perturbation kernel p0τ(𝐱t+1τ𝐱t+10)=𝒩(𝐱t+1τ;𝐱t+10,σ2(τ)𝐈)superscript𝑝0𝜏conditionalsuperscriptsubscript𝐱𝑡1𝜏superscriptsubscript𝐱𝑡10𝒩superscriptsubscript𝐱𝑡1𝜏superscriptsubscript𝐱𝑡10superscript𝜎2𝜏𝐈p^{0\tau}(\mathbf{x}_{t+1}^{\tau}\mid\mathbf{x}_{t+1}^{0})=\mathcal{N}(\mathbf% {x}_{t+1}^{\tau};\mathbf{x}_{t+1}^{0},\sigma^{2}(\tau)\mathbf{I})italic_p start_POSTSUPERSCRIPT 0 italic_τ end_POSTSUPERSCRIPT ( bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT ∣ bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ) = caligraphic_N ( bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT ; bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_τ ) bold_I ), where σ(τ)𝜎𝜏\sigma(\tau)italic_σ ( italic_τ ) is a real-valued function of diffusion time called the noise schedule. This corresponds to setting the drift and diffusion coefficients to 𝐟(𝐱,τ)=𝟎𝐟𝐱𝜏0\mathbf{f}(\mathbf{x},\tau)=\mathbf{0}bold_f ( bold_x , italic_τ ) = bold_0 (affine) and g(τ)=2σ˙(τ)σ(τ)𝑔𝜏2˙𝜎𝜏𝜎𝜏g(\tau)=\sqrt{2\dot{\sigma}(\tau)\sigma(\tau)}italic_g ( italic_τ ) = square-root start_ARG 2 over˙ start_ARG italic_σ end_ARG ( italic_τ ) italic_σ ( italic_τ ) end_ARG.

We use the network preconditioning introduced by Karras et al., (2022) and so parameterize 𝐃θsubscript𝐃𝜃\mathbf{D}_{\theta}bold_D start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT in Equation 5 as the weighted sum of the noised observation and the prediction of a neural network 𝐅θsubscript𝐅𝜃\mathbf{F}_{\theta}bold_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT,

𝐃θ(𝐱t+1τ,ytτ)=cskipτ𝐱t+1τ+coutτ𝐅θ(cinτ𝐱t+1τ,ytτ),subscript𝐃𝜃superscriptsubscript𝐱𝑡1𝜏superscriptsubscript𝑦𝑡𝜏superscriptsubscript𝑐skip𝜏superscriptsubscript𝐱𝑡1𝜏superscriptsubscript𝑐out𝜏subscript𝐅𝜃superscriptsubscript𝑐in𝜏superscriptsubscript𝐱𝑡1𝜏superscriptsubscript𝑦𝑡𝜏\mathbf{D}_{\theta}(\mathbf{x}_{t+1}^{\tau},y_{t}^{\tau})=c_{\text{skip}}^{% \tau}\;\mathbf{x}_{t+1}^{\tau}+c_{\text{out}}^{\tau}\;\mathbf{F}_{\theta}\big{% (}c_{\text{in}}^{\tau}\;\mathbf{x}_{t+1}^{\tau},y_{t}^{\tau}\big{)},bold_D start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT , italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT ) = italic_c start_POSTSUBSCRIPT skip end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT + italic_c start_POSTSUBSCRIPT out end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT bold_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_c start_POSTSUBSCRIPT in end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT , italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT ) , (6)

where for brevity we define ytτ(cnoiseτ,𝐱t0,at)superscriptsubscript𝑦𝑡𝜏superscriptsubscript𝑐noise𝜏subscriptsuperscript𝐱0absent𝑡subscript𝑎absent𝑡y_{t}^{\tau}\coloneqq(c_{\text{noise}}^{\tau},\mathbf{x}^{0}_{\leq t},a_{\leq t})italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT ≔ ( italic_c start_POSTSUBSCRIPT noise end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT , bold_x start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ≤ italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT ≤ italic_t end_POSTSUBSCRIPT ) to include all conditioning variables.

The preconditioners cinτsuperscriptsubscript𝑐in𝜏c_{\text{in}}^{\tau}italic_c start_POSTSUBSCRIPT in end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT and coutτsuperscriptsubscript𝑐out𝜏c_{\text{out}}^{\tau}italic_c start_POSTSUBSCRIPT out end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT are selected to keep the network’s input and output at unit variance for any noise level σ(τ)𝜎𝜏\sigma(\tau)italic_σ ( italic_τ ), cnoiseτsuperscriptsubscript𝑐noise𝜏c_{\text{noise}}^{\tau}italic_c start_POSTSUBSCRIPT noise end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT is an empirical transformation of the noise level, and cskipτsuperscriptsubscript𝑐skip𝜏c_{\text{skip}}^{\tau}italic_c start_POSTSUBSCRIPT skip end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT is given in terms of σ(τ)𝜎𝜏\sigma(\tau)italic_σ ( italic_τ ) and the standard deviation of the data distribution σdatasubscript𝜎data\sigma_{\text{data}}italic_σ start_POSTSUBSCRIPT data end_POSTSUBSCRIPT, as cskipτ=σdata2/(σdata2+σ2(τ))superscriptsubscript𝑐𝑠𝑘𝑖𝑝𝜏superscriptsubscript𝜎𝑑𝑎𝑡𝑎2superscriptsubscript𝜎𝑑𝑎𝑡𝑎2superscript𝜎2𝜏c_{skip}^{\tau}=\sigma_{data}^{2}/(\sigma_{data}^{2}+\sigma^{2}(\tau))italic_c start_POSTSUBSCRIPT italic_s italic_k italic_i italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT = italic_σ start_POSTSUBSCRIPT italic_d italic_a italic_t italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / ( italic_σ start_POSTSUBSCRIPT italic_d italic_a italic_t italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_τ ) ). These preconditioners are fully described in Appendix C.

Combining Equations 5 and 6 provides insight into the training objective of 𝐅θsubscript𝐅𝜃\mathbf{F}_{\theta}bold_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT,

(θ)=𝔼[𝐅θ(cinτ𝐱t+1τ,ytτ)Network prediction1coutτ(𝐱t+10cskipτ𝐱t+1τ)Network training target2].𝜃𝔼delimited-[]superscriptnormsubscriptsubscript𝐅𝜃superscriptsubscript𝑐in𝜏superscriptsubscript𝐱𝑡1𝜏superscriptsubscript𝑦𝑡𝜏Network predictionsubscript1superscriptsubscript𝑐out𝜏superscriptsubscript𝐱𝑡10superscriptsubscript𝑐skip𝜏superscriptsubscript𝐱𝑡1𝜏Network training target2\displaystyle\mathcal{L}(\theta)=\mathbb{E}\Big{[}\|\underbrace{\mathbf{F}_{% \theta}\big{(}c_{\text{in}}^{\tau}\mathbf{x}_{t+1}^{\tau},y_{t}^{\tau}\big{)}}% _{\text{Network prediction}}-\underbrace{\frac{1}{c_{\text{out}}^{\tau}}\big{(% }\mathbf{x}_{t+1}^{0}-c_{\text{skip}}^{\tau}\mathbf{x}_{t+1}^{\tau}\big{)}}_{% \text{Network training target}}\|^{2}\Big{]}.caligraphic_L ( italic_θ ) = blackboard_E [ ∥ under⏟ start_ARG bold_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_c start_POSTSUBSCRIPT in end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT , italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT ) end_ARG start_POSTSUBSCRIPT Network prediction end_POSTSUBSCRIPT - under⏟ start_ARG divide start_ARG 1 end_ARG start_ARG italic_c start_POSTSUBSCRIPT out end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT end_ARG ( bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT - italic_c start_POSTSUBSCRIPT skip end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT ) end_ARG start_POSTSUBSCRIPT Network training target end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] . (7)

The network training target adaptively mixes signal and noise depending on the degradation level σ(τ)𝜎𝜏\sigma(\tau)italic_σ ( italic_τ ). When σ(τ)σdatamuch-greater-than𝜎𝜏subscript𝜎data\sigma(\tau)\gg\sigma_{\text{data}}italic_σ ( italic_τ ) ≫ italic_σ start_POSTSUBSCRIPT data end_POSTSUBSCRIPT, we have cskipτ0superscriptsubscript𝑐skip𝜏0c_{\text{skip}}^{\tau}\to 0italic_c start_POSTSUBSCRIPT skip end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT → 0, and the training target for 𝐅θsubscript𝐅𝜃\mathbf{F}_{\theta}bold_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT is dominated by the clean signal 𝐱t+10superscriptsubscript𝐱𝑡10\mathbf{x}_{t+1}^{0}bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT. Conversely, when the noise level is low, σ(τ)0𝜎𝜏0\sigma(\tau)\to 0italic_σ ( italic_τ ) → 0, we have cskipτ1superscriptsubscript𝑐skip𝜏1c_{\text{skip}}^{\tau}\to 1italic_c start_POSTSUBSCRIPT skip end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT → 1, and the target becomes the difference between the clean and the perturbed signal, i.e. the added Gaussian noise. Intuitively, this prevents the training objective to become trivial in the low-noise regime. In practice, this objective is high variance at the extremes of the noise schedule, so Karras et al., (2022) sample the noise level σ(τ)𝜎𝜏\sigma(\tau)italic_σ ( italic_τ ) from an empirically chosen log-normal distribution in order to concentrate the training around medium-noise regions, as described in Appendix C.

We use a standard U-Net 2D for the vector field 𝐅θsubscript𝐅𝜃\mathbf{F}_{\theta}bold_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT (Ronneberger et al.,, 2015), and we keep a buffer of L𝐿Litalic_L past observations and actions that we use to condition the model. We concatenate these past observations to the next noisy observation channel-wise, and we input actions through adaptive group normalization layers (Zheng et al.,, 2020) in the residual blocks (He et al.,, 2015) of the U-Net.

As discussed in Section 2.3 and Appendix A, there are many possible sampling methods to generate the next observation from the trained diffusion model. While our codebase supports a variety of sampling schemes, we found Euler’s method to be effective without incurring the cost of additional NFE required by higher order samplers, or the unnecessary complexity of stochastic sampling.

3.2 Reinforcement learning in imagination

Given the diffusion model from Section 3.1, we now complete our world model with a reward and termination model, required for training an RL agent in imagination. Since estimating the reward and termination are scalar prediction problems, we use a separate model Rψsubscript𝑅𝜓R_{\psi}italic_R start_POSTSUBSCRIPT italic_ψ end_POSTSUBSCRIPT consisting of standard cnn (LeCun et al.,, 1989; He et al.,, 2015) and lstm (Hochreiter and Schmidhuber,, 1997; Gers et al.,, 2000) layers to handle partial observability. The RL agent involves an actor-critic network parameterized by a shared cnn-lstm with policy and value heads. The policy πϕsubscript𝜋italic-ϕ\pi_{\phi}italic_π start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT is trained with reinforce with a value baseline, and we use a Bellman error with λ𝜆\lambdaitalic_λ-returns to train the value network Vϕsubscript𝑉italic-ϕV_{\phi}italic_V start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT, similar to Micheli et al., (2023). We train the agent entirely in imagination as described in Section 2.1. The agent only interacts with the real environment for data collection. After each collection stage, the current world model is updated by training on all data collected so far. Then, the agent is trained with RL in the updated world model environment, and these steps are repeated. This procedure is detailed in Algorithm 1, and is similar to Kaiser et al., (2019); Hafner et al., (2020); Micheli et al., (2023). We provide architecture details, hyperparameters, and RL objectives in Appendices D, E, F, respectively.

4 Experiments

4.1 Atari 100k benchmark

For comprehensive evaluation of diamond we use the established Atari 100k benchmark (Kaiser et al.,, 2019), consisting of 26 games that test a wide range of agent capabilities. For each game, an agent is only allowed to take 100k actions in the environment, which is roughly equivalent to 2 hours of human gameplay, to learn to play the game before evaluation. As a reference, unconstrained Atari agents are usually trained for 50 million steps, a 500 fold increase in experience. We trained diamond from scratch for 5 random seeds on each game. Each run utilized around 12GB of VRAM and took approximately 2.9 days on a single Nvidia RTX 4090 (1.03 GPU years in total).

Table 1: Returns on the 26 games of the Atari 100k benchmark after 2 hours of real-time experience, and human-normalized aggregate metrics. Bold numbers indicate the best performing methods. diamond notably outperforms other world model baselines in terms of mean score over 5 seeds.
Game Random Human SimPLe TWM IRIS DreamerV3 STORM diamond (ours)
Alien 227.8 7127.7 616.9 674.6 420.0 959.0 983.6 744.1
Amidar 5.8 1719.5 74.3 121.8 143.0 139.0 204.8 225.8
Assault 222.4 742.0 527.2 682.6 1524.4 706.0 801.0 1526.4
Asterix 210.0 8503.3 1128.3 1116.6 853.6 932.0 1028.0 3698.5
BankHeist 14.2 753.1 34.2 466.7 53.1 649.0 641.2 19.7
BattleZone 2360.0 37187.5 4031.2 5068.0 13074.0 12250.0 13540.0 4702.0
Boxing 0.1 12.1 7.8 77.5 70.1 78.0 79.7 86.9
Breakout 1.7 30.5 16.4 20.0 83.7 31.0 15.9 132.5
ChopperCommand 811.0 7387.8 979.4 1697.4 1565.0 420.0 1888.0 1369.8
CrazyClimber 10780.5 35829.4 62583.6 71820.4 59324.2 97190.0 66776.0 99167.8
DemonAttack 152.1 1971.0 208.1 350.2 2034.4 303.0 164.6 288.1
Freeway 0.0 29.6 16.7 24.3 31.1 0.0 33.5 33.3
Frostbite 65.2 4334.7 236.9 1475.6 259.1 909.0 1316.0 274.1
Gopher 257.6 2412.5 596.8 1674.8 2236.1 3730.0 8239.6 5897.9
Hero 1027.0 30826.4 2656.6 7254.0 7037.4 11161.0 11044.3 5621.8
Jamesbond 29.0 302.8 100.5 362.4 462.7 445.0 509.0 427.4
Kangaroo 52.0 3035.0 51.2 1240.0 838.2 4098.0 4208.0 5382.2
Krull 1598.0 2665.5 2204.8 6349.2 6616.4 7782.0 8412.6 8610.1
KungFuMaster 258.5 22736.3 14862.5 24554.6 21759.8 21420.0 26182.0 18713.6
MsPacman 307.3 6951.6 1480.0 1588.4 999.1 1327.0 2673.5 1958.2
Pong -20.7 14.6 12.8 18.8 14.6 18.0 11.3 20.4
PrivateEye 24.9 69571.3 35.0 86.6 100.0 882.0 7781.0 114.3
Qbert 163.9 13455.0 1288.8 3330.8 745.7 3405.0 4522.5 4499.3
RoadRunner 11.5 7845.0 5640.6 9109.0 9614.6 15565.0 17564.0 20673.2
Seaquest 68.4 42054.7 683.3 774.4 661.3 618.0 525.2 551.2
UpNDown 533.4 11693.2 3350.3 15981.7 3546.2 9234.0 7985.0 3856.3
#Superhuman (↑) 0 N/A 1 8 10 9 10 11
Mean (↑) 0.000 1.000 0.332 0.956 1.046 1.097 1.266 1.459
IQM (↑) 0.000 1.000 0.130 0.459 0.501 0.497 0.636 0.641

We compare with other recent methods training an agent entirely within a world model in Table 1, including storm (Zhang et al.,, 2023), DreamerV3 (Hafner et al.,, 2023), iris (Micheli et al.,, 2023), twm (Robine et al.,, 2023), and SimPle (Kaiser et al.,, 2019). A broader comparison to model-free and search-based methods, including bbf (Schwarzer et al.,, 2023) and EfficientZero (Ye et al.,, 2021), the current best performing methods on this benchmark, is provided in Appendix J. bbf and EfficientZero use techniques that are orthogonal and not directly comparable to our approach, such as using periodic network resets in combination with hyperparameter scheduling for bbf, and computationally expensive lookahead Monte-Carlo tree search for EfficientZero. Combining these additional components with our world model would be an interesting direction for future work.

4.2 Results on the Atari 100k benchmark

Refer to caption

Figure 2: Mean and interquartile mean human normalized scores. diamond, in blue, obtains a mean HNS of 1.46 and an IQM of 0.64.

Table 1 provides scores for all games, and the mean and interquartile mean (IQM) of human-normalized scores (HNS) (Wang et al.,, 2016). Following the recommendations of Agarwal et al., (2021) on the limitations of point estimates, we provide stratified bootstrap confidence intervals for the mean and IQM in Figure 2, as well as performance profiles and additional metrics in Appendix H.

Our results demonstrate that diamond performs strongly across the benchmark, outperforming human players on 11 games, and achieving a superhuman mean HNS of 1.46, a new best among agents trained entirely within a world model. diamond also achieves an IQM on par with storm, and greater than all other baselines. We find that diamond performs particularly well on environments where capturing small details is important, such as Asterix, Breakout and Road Runner. We provide further qualitative analysis of the visual quality of the world model in Section 5.3.

5 Analysis

5.1 Choice of diffusion framework

As explained in Section 2, we could in principle use any diffusion model variant in our world model. While diamond utilizes edm (Karras et al.,, 2022) as described in Section 3, ddpm (Ho et al.,, 2020) would also be a natural candidate, having been used in many image generation applications (Rombach et al.,, 2022; Nichol and Dhariwal,, 2021). We justify this design decision in this section.

To provide a fair comparison of ddpm with our edm implementation, we train both variants with the same network architecture, on a shared static dataset of 100k frames collected with an expert policy on the game Breakout. As discussed in Section 2.3, the number of denoising steps is directly related to the inference cost of the world model, and so fewer steps will reduce the cost of training an agent on imagined trajectories. Ho et al., (2020) use a thousand denoising steps, and Rombach et al., (2022) employ hundreds steps for Stable Diffusion. However, for our world model to be computationally comparable with other world model baselines (such as iris which requires 16 NFE for each timestep), we need at most tens of denoising steps, and preferably fewer. Unfortunately, if the number of denoising steps is set too low, the visual quality will degrade, leading to compounding error.

To investigate the stability of the diffusion variants, we display imagined trajectories generated autoregressively up to t=1000𝑡1000t=1000italic_t = 1000 timesteps in Figure 3, for different numbers of denoising steps n10𝑛10n\leq 10italic_n ≤ 10. We see that using ddpm (Figure 3(a)) in this regime leads to severe compounding error, causing the world model to quickly drift out of distribution. In contrast, the edm-based diffusion world model (Figure 3(b)) appears much more stable over long time horizons, even for a single denoising step. A quantitative analysis of this compounding error is provided in Appendix K.

Refer to caption
(a) ddpm-based world model trajectories.
Refer to caption
(b) edm-based world model trajectories.
Figure 3: Imagined trajectories with diffusion world models based on ddpm (left) and edm (right). The initial observation at t=0𝑡0t=0italic_t = 0 is common, and each row corresponds to a decreasing number of denoising steps n𝑛nitalic_n. We observe that ddpm-based generation suffers from compounding error, and that the smaller the number of denoising steps, the faster the error accumulates. In contrast, our edm-based world model appears much more stable, even for n=1𝑛1n=1italic_n = 1.

This surprising result is a consequence of the improved training objective described in Equation 7, compared to the simpler noise prediction objective employed by ddpm. While predicting the noise works well for intermediate noise levels, this objective causes the model to learn the identity function when the noise is dominant (σnoiseσdataξθ(𝐱t+1τ,ytτ)𝐱t+1τmuch-greater-thansubscript𝜎𝑛𝑜𝑖𝑠𝑒subscript𝜎𝑑𝑎𝑡𝑎subscript𝜉𝜃subscriptsuperscript𝐱𝜏𝑡1superscriptsubscript𝑦𝑡𝜏subscriptsuperscript𝐱𝜏𝑡1\sigma_{noise}\gg\sigma_{data}\implies\xi_{\theta}(\mathbf{x}^{\tau}_{t+1},y_{% t}^{\tau})\to\mathbf{x}^{\tau}_{t+1}italic_σ start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT ≫ italic_σ start_POSTSUBSCRIPT italic_d italic_a italic_t italic_a end_POSTSUBSCRIPT ⟹ italic_ξ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_x start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT ) → bold_x start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT), where ξθsubscript𝜉𝜃\xi_{\theta}italic_ξ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT is the noise prediction network of ddpm. This gives a poor estimate of the score function at the beginning of the sampling procedure, which degrades the generation quality and leads to compounding error.

In contrast, the adaptive mixing of signal and noise employed by edm, described in Section 3.1, means that the model is trained to predict the clean image when the noise is dominant (σnoiseσdata𝐅θ(𝐱t+1τ,ytτ)𝐱t+10much-greater-thansubscript𝜎𝑛𝑜𝑖𝑠𝑒subscript𝜎𝑑𝑎𝑡𝑎subscript𝐅𝜃subscriptsuperscript𝐱𝜏𝑡1superscriptsubscript𝑦𝑡𝜏subscriptsuperscript𝐱0𝑡1\sigma_{noise}\gg\sigma_{data}\implies\mathbf{F_{\theta}}(\mathbf{x}^{\tau}_{t% +1},y_{t}^{\tau})\to\mathbf{x}^{0}_{t+1}italic_σ start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT ≫ italic_σ start_POSTSUBSCRIPT italic_d italic_a italic_t italic_a end_POSTSUBSCRIPT ⟹ bold_F start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_x start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT ) → bold_x start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT). This gives a better estimate of the score function in the absence of signal, so the model is able to produce higher quality generations with fewer denoising steps, as illustrated in Figure 3(b).

5.2 Choice of the number of denoising steps

While we found that our edm-based world model was very stable with just a single denoising step, as shown for Breakout in the last row of Figure 3(b), we discuss here how this choice would limit the visual quality of the model in some cases. We provide more a quantitative analysis in Appendix L.

As discussed in Section 2.2, our score model is equivalent to a denoising autoencoder (Vincent et al.,, 2008) trained with an L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT reconstruction loss. The optimal single-step prediction is thus the expectation over possible reconstructions for a given noisy input, which can be out of distribution if this posterior distribution is multimodal. While some games like Breakout have deterministic transitions that can be accurately modeled with a single denoising step (see Figure 3(b)), in some other games partial observability gives rise to multimodal observation distributions. In this case, an iterative solver is necessary to drive the sampling procedure towards a particular mode, as illustrated in the game Boxing in Figure 4. As a result, we therefore set n=3𝑛3n=3italic_n = 3 in all of our experiments.

Refer to caption
Figure 4: Single-step (top row) versus multi-step (bottom row) sampling in Boxing. Movements of the black player are unpredictable, so that single-step denoising interpolates between possible outcomes and results in blurry predictions. In contrast, multi-step sampling produces a crisp image by driving the generation towards a particular mode. Interestingly, the policy controls the white player, so his actions are known to the world model. This information removes any ambiguity, and so we observe that both single-step and multi-step sampling correctly predict the white player’s position.

5.3 Qualitative visual comparison with iris

We now compare to iris (Micheli et al.,, 2023), a well-established world model that uses a discrete autoencoder (Van Den Oord et al.,, 2017) to convert images to discrete tokens, and composes these tokens over time with an autoregressive transformer (Radford et al.,, 2019). For fair comparison, we train both world models on the same static datasets of 100k frames collected with expert policies. This comparison is displayed in Figure 2 below.

Refer to caption
(a) iris
Refer to caption
(b) diamond
Figure 5: Consecutive frames imagined with iris (left) and diamond (right). The white boxes highlight inconsistencies between frames, which we see only arise in trajectories generated with iris. In Asterix (top row), an enemy (orange) becomes a reward (red) in the second frame, before reverting to an enemy in the third, and again to a reward in the fourth. In Breakout (middle row), the bricks and score are inconsistent between frames. In Road Runner (bottom row), the rewards (small blue dots on the road) are inconsistently rendered between frames. None of these inconsistencies occur with diamond. In Breakout, the score is even reliably updated by +7 when a red brick is broken222https://en.wikipedia.org/wiki/Breakout_(video_game)#Gameplay.

We see in Figure 2 that the trajectories imagined by diamond are generally of higher visual quality and more faithful to the true environment compared to the trajectories imagined by iris. In particular, the trajectories generated by iris contain visual inconsistencies between frames (highlighted by white boxes), such as enemies being displayed as rewards and vice-versa. These inconsistencies may only represent a few pixels in the generated images, but can have significant consequences for reinforcement learning. For example, since an agent should generally target rewards and avoid enemies, these small visual discrepancies can make it more challenging to learn an optimal policy.

These improvements in the consistency of visual details are generally reflected by greater agent performance on these games, as shown in Table 1. Since the agent component of these methods is similar, this improvement can likely be attributed to the world model.

Finally, we note that this improvement is not simply the result of increased computation. Both world models are rendering frames at the same resolution (64×64646464\times 6464 × 64), and diamond requires only 3 NFE per frame compared to 16 NFE per frame for iris. This is further reflected by the fact that diamond has significantly fewer parameters and takes less time to train than iris, as provided in Appendix H.

6 Scaling the diffusion world model to Counter-Strike: Global Offensive333This section was added after NeurIPS acceptance, following community interest in later CS:GO experiments.

To investigate the ability of diamond’s diffusion world model to learn to model more complex 3D environments, we train the world model in isolation on static data from the popular video game Counter-Strike: Global Offensive (CS:GO). We use the Online dataset of 5.5M frames (95 hours) of online human gameplay captured at 16Hz from the map Dust II by Pearce and Zhu, (2022). We randomly hold out 0.5M frames (corresponding to 500 episodes, or 8 hours) for testing, and use the remaining 5M frames (87 hours) for training. There is no reinforcement learning agent or online data collection involved in these experiments.

To reduce the computational cost, we reduce the resolution from (280×150)280150(280\times 150)( 280 × 150 ) to (56×30)5630(56\times 30)( 56 × 30 ) for world modeling. We then introduce a second, smaller diffusion model as an upsampler to improve the generated images at the original resolution (Saharia et al., 2022b, ). We scale the channels of the U-Net to increase the number of parameters from 4M for our Atari models to 381M for our CS:GO model (including 51M for the upsampler). The combined model was trained for 12 days on an RTX 4090.

Finally, we introduce stochastic sampling and increase the number of denoising steps for the upsampler to 10, which we found to improve the resulting visual quality of the generations, while keeping the dynamics model the same (in particular, still using only 3 denoising steps). This enables a reasonable tradeoff between visual quality and inference cost, with the model running at 10Hz on an RTX 3090. Typical generations of the model are provided in Figure 6 below.

Refer to caption
Figure 6: Images captured from people playing with keyboard and mouse inside diamond’s diffusion world model. This model was trained on 87878787 hours of static Counter-Strike: Global Offensive (CS:GO) gameplay (Pearce and Zhu,, 2022) to produce an interactive neural game engine for the popular in-game map, Dust II. Best viewed as videos at https://diamond-wm.github.io.

We find the model is able to generate stable trajectories over hundreds of timesteps, although is more likely to drift out-of-distribution in less frequently visited areas of the map. Due to the limited memory of the model, approaching walls or losing visibility may cause the model to forget the current state and instead generate a new weapon or area of map. Interestingly, we find the model wrongly enables successive jumps by generalizing the effect of a jump on the geometry of the scene, since multiple jumps do not appear often enough in the training gameplay for the model to learn that mid-air jumps should be ignored. We expect scaling the model and data to address many of these limitations, with the exception of the memory of the model. Quantitative measurements of the capabilities of the CS:GO world model and attempts to address these limitations are left to future work.

7 Related work

World models. The idea of reinforcement learning (RL) in the imagination of a neural network world model was introduced by Ha and Schmidhuber, (2018). SimPLe (Kaiser et al.,, 2019) applied world models to Atari, and introduced the Atari 100k benchmark to focus on sample efficiency. Dreamer (Hafner et al.,, 2020) introduced RL from the latent space of a recurrent state space model (RSSM). DreamerV2 (Hafner et al.,, 2021) demonstrated that using discrete latents could help to reduce compounding error, and DreamerV3 (Hafner et al.,, 2023) was able to achieve human-level performance on a wide range of domains with fixed hyperparameters. TWM (Robine et al.,, 2023) adapts DreamerV2’s RSSM to use a transformer architecture, while STORM (Zhang et al.,, 2023) adapts DreamerV3 in a similar way but with a different tokenization approach. Alternatively, IRIS (Micheli et al.,, 2023) builds a language of image tokens with a discrete autoencoder, and composes these tokens over time with an autoregressive transformer.

Generative vision models. There are parallels between these world models and image generation models which suggests that developments in generative vision models could provide benefits to world modeling. Following the rise of transformers in natural language processing (Vaswani et al.,, 2017; Devlin et al.,, 2018; Radford et al.,, 2019), VQGAN (Esser et al.,, 2021) and DALL·E (Ramesh et al.,, 2021) convert images to discrete tokens with discrete autoencoders (Van Den Oord et al.,, 2017), and leverage the sequence modeling abilities of autoregressive transformers to build powerful text-to-image generative models. Concurrently, diffusion models (Sohl-Dickstein et al.,, 2015; Ho et al.,, 2020; Song et al.,, 2020) gained traction (Dhariwal and Nichol,, 2021; Rombach et al.,, 2022), and have become a dominant paradigm for high-resolution image generation (Saharia et al., 2022a, ; Ramesh et al.,, 2022; Podell et al.,, 2023).

The same trends have taken place in the recent developments of video generation methods. VideoGPT (Yan et al.,, 2021) provides a minimal video generation architecture by combining a discrete autoencoder with an autoregressive transformer. Godiva (Wu et al.,, 2021) enables text conditioning with promising generalization. Phenaki (Villegas et al.,, 2023) allows arbitrary length video generation with sequential prompt conditioning. TECO (Yan et al.,, 2023) improves upon autoregressive modeling by using MaskGit (Chang et al.,, 2022), and enables longer temporal dependencies by compressing input sequence embeddings. Diffusion models have also seen a resurgence in video generation using 3D U-Nets to provide high quality but short-duration video (Singer et al.,, 2023; Bar-Tal et al.,, 2024). Recently, transformer-based diffusion models such as DiT (Peebles and Xie,, 2023) and Sora (Brooks et al.,, 2024) have shown improved scalability for both image and video generation, respectively.

Diffusion for reinforcement learning. There has also been much interest in combining diffusion models with reinforcement learning. This includes taking advantage of the flexibility of diffusion models as a policy (Wang et al.,, 2022; Ajay et al.,, 2022; Pearce et al.,, 2023), as planners (Janner et al.,, 2022; Liang et al.,, 2023), as reward models (Nuti et al.,, 2023), and trajectory modeling for data augmentation in offline RL (Lu et al.,, 2023; Ding et al.,, 2024; Jackson et al.,, 2024). diamond represents the first use of diffusion models as world models for learning online in imagination.

Generative game engines. Playable games running entirely on neural networks have recently been growing in scope. GameGAN (Kim et al.,, 2020) learns generative models of games using a GAN (Goodfellow et al.,, 2014) while Bamford and Lucas, (2020) use a Neural GPU (Kaiser and Sutskever,, 2015). Concurrent work includes Genie (Bruce et al.,, 2024), which generates playable platformer environments from image prompts, and GameNGen (Valevski et al.,, 2024), which similarly leverages a diffusion model to obtain a high resolution simulator of the game DOOM, but at a larger scale.

8 Limitations

We identify three main limitations of our work for future research. First, our main evaluation is focused on discrete control environments, and applying diamond to the continuous domain may provide additional insights. Second, the use of frame stacking for conditioning is a minimal mechanism to provide a memory of past observations. Integrating an autoregressive transformer over environment time, using an approach such as Peebles and Xie, (2023), would enable longer-term memory and better scalability. We include an initial investigation into a potential cross-attention architecture in Appendix M, but found frame-stacking more effective in our early experiments. Third, we leave potential integration of the reward/termination prediction into the diffusion model for future work, since combining these objectives and extracting representations from a diffusion model is not trivial (Luo et al.,, 2023; Xu et al.,, 2023) and would make our world model unnecessarily complex.

9 Conclusion and Broader Impact

We have introduced diamond, a reinforcement learning agent trained in a diffusion world model. We explained the key design choices we made to adapt diffusion for world modeling and to make our world model stable over long time horizons with a low number of denoising steps. diamond achieves a mean human normalized score of 1.461.461.461.46 on the well-established Atari 100k benchmark; a new best among agents trained entirely within a world model. We analyzed our improved performance in some games and found that it likely follows from better modeling of critical visual details. We further demonstrated diamond’s diffusion world model can successfully model 3D environments and serve as a real-time neural game engine by training on static Counter-Strike: Global Offensive gameplay.

World models constitute a promising direction to address sample efficiency and safety concerns associated with training agents in the real world. However, imperfections in the world model may lead to suboptimal or unexpected agent behaviors. We hope that the development of more faithful and interactive world models will contribute to broader efforts to further reduce these risks.

Acknowledgments and Disclosure of Funding

We would like to thank Andrew Foong, Bálint Máté, Clément Vignac, Maxim Peter, Pedro Sanchez, Rich Turner, Stéphane Nguyen, Tom Lee, Trevor McInroe and Weipu Zhang for insightful discussions and comments. Adam and Eloi met during an internship at Microsoft Research Cambridge, and would like to thank the Game Intelligence team, including Anssi Kanervisto, Dave Bignell, Gunshi Gupta, Katja Hofmann, Lukas Schäfer, Raluca Georgescu, Sam Devlin, Sergio Valcarcel Macua, Shanzheng Tan, Tabish Rashid, Tarun Gupta, Tim Pearce, and Yuhan Cao, for their support in the early stages of this project, and a great summer.

References

  • Agarwal et al., (2021) Agarwal, R., Schwarzer, M., Castro, P. S., Courville, A. C., and Bellemare, M. (2021). Deep reinforcement learning at the edge of the statistical precipice. Advances in Neural Information Processing Systems, 34:29304–29320.
  • Ajay et al., (2022) Ajay, A., Du, Y., Gupta, A., Tenenbaum, J., Jaakkola, T., and Agrawal, P. (2022). Is conditional generative modeling all you need for decision-making? International Conference on Learning Representations.
  • Alonso et al., (2023) Alonso, E., Jelley, A., Kanervisto, A., and Pearce, T. (2023). Diffusion world models. OpenReview.
  • Anderson, (1982) Anderson, B. D. (1982). Reverse-time diffusion equation models. Stochastic Processes and their Applications, 12(3):313–326.
  • Ascher and Petzold, (1998) Ascher, U. M. and Petzold, L. R. (1998). Computer methods for ordinary differential equations and differential-algebraic equations. Society for Industrial and Applied Mathematics.
  • Bamford and Lucas, (2020) Bamford, C. and Lucas, S. M. (2020). Neural game engine: Accurate learning of generalizable forward models from pixels. In 2020 IEEE Conference on Games (CoG), pages 81–88. IEEE.
  • Bar-Tal et al., (2024) Bar-Tal, O., Chefer, H., Tov, O., Herrmann, C., Paiss, R., Zada, S., Ephrat, A., Hur, J., Li, Y., Michaeli, T., et al. (2024). Lumiere: A space-time diffusion model for video generation. arXiv preprint arXiv:2401.12945.
  • Brooks et al., (2024) Brooks, T., Peebles, B., Holmes, C., DePue, W., Guo, Y., Jing, L., Schnurr, D., Taylor, J., Luhman, T., Luhman, E., Ng, C., Wang, R., and Ramesh, A. (2024). Video generation models as world simulators.
  • Bruce et al., (2024) Bruce, J., Dennis, M. D., Edwards, A., Parker-Holder, J., Shi, Y., Hughes, E., Lai, M., Mavalankar, A., Steigerwald, R., Apps, C., et al. (2024). Genie: Generative interactive environments. In Forty-first International Conference on Machine Learning.
  • Chang et al., (2023) Chang, H., Zhang, H., Barber, J., Maschinot, A., Lezama, J., Jiang, L., Yang, M.-H., Murphy, K. P., Freeman, W. T., Rubinstein, M., Li, Y., and Krishnan, D. (2023). Muse: Text-to-image generation via masked generative transformers. In International Conference on Machine Learning, pages 4055–4075. PMLR.
  • Chang et al., (2022) Chang, H., Zhang, H., Jiang, L., Liu, C., and Freeman, W. T. (2022). Maskgit: Masked generative image transformer. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 11315–11325.
  • Çiçek et al., (2016) Çiçek, Ö., Abdulkadir, A., Lienkamp, S. S., Brox, T., and Ronneberger, O. (2016). 3d u-net: learning dense volumetric segmentation from sparse annotation. In Medical Image Computing and Computer-Assisted Intervention–MICCAI 2016: 19th International Conference, Athens, Greece, October 17-21, 2016, Proceedings, Part II 19, pages 424–432. Springer.
  • Degrave et al., (2022) Degrave, J., Felici, F., Buchli, J., Neunert, M., Tracey, B., Carpanese, F., Ewalds, T., Hafner, R., Abdolmaleki, A., de Las Casas, D., et al. (2022). Magnetic control of tokamak plasmas through deep reinforcement learning. Nature, 602:414–419.
  • Devlin et al., (2018) Devlin, J., Chang, M.-W., Lee, K., and Toutanova, K. (2018). Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805.
  • Dhariwal and Nichol, (2021) Dhariwal, P. and Nichol, A. (2021). Diffusion models beat gans on image synthesis. Advances in Neural Information Processing Systems, 34:8780–8794.
  • Ding et al., (2024) Ding, Z., Zhang, A., Tian, Y., and Zheng, Q. (2024). Diffusion world model. arXiv preprint arXiv:2402.03570.
  • Elfwing et al., (2018) Elfwing, S., Uchibe, E., and Doya, K. (2018). Sigmoid-weighted linear units for neural network function approximation in reinforcement learning. Neural networks, 107:3–11.
  • Esser et al., (2021) Esser, P., Rombach, R., and Ommer, B. (2021). Taming transformers for high-resolution image synthesis. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 12873–12883.
  • Gers et al., (2000) Gers, F. A., Schmidhuber, J., and Cummins, F. (2000). Learning to forget: Continual prediction with LSTM. Neural Computation, 12(10):2451–2471.
  • Goodfellow et al., (2014) Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., Courville, A., and Bengio, Y. (2014). Generative adversarial nets. Advances in Neural Information Processing Systems, 27.
  • Ha and Schmidhuber, (2018) Ha, D. and Schmidhuber, J. (2018). Recurrent world models facilitate policy evolution. Advances in Neural Information Processing Systems, 31:2451–2463.
  • Hafner et al., (2020) Hafner, D., Lillicrap, T., Ba, J., and Norouzi, M. (2020). Dream to control: Learning behaviors by latent imagination. In International Conference on Learning Representations.
  • Hafner et al., (2021) Hafner, D., Lillicrap, T. P., Norouzi, M., and Ba, J. (2021). Mastering atari with discrete world models. In International Conference on Learning Representations.
  • Hafner et al., (2023) Hafner, D., Pasukonis, J., Ba, J., and Lillicrap, T. (2023). Mastering diverse domains through world models. arXiv preprint arXiv:2301.04104.
  • He et al., (2015) He, K., Zhang, X., Ren, S., and Sun, J. (2015). Deep residual learning for image recognition. arXiv preprint arXiv:1512.03385.
  • Heusel et al., (2017) Heusel, M., Ramsauer, H., Unterthiner, T., Nessler, B., and Hochreiter, S. (2017). Gans trained by a two time-scale update rule converge to a local nash equilibrium. Advances in Neural Information Processing Systems, 30.
  • Ho et al., (2020) Ho, J., Jain, A., and Abbeel, P. (2020). Denoising diffusion probabilistic models. Advances in Neural Information Processing Systems, 33:6840–6851.
  • Ho et al., (2022) Ho, J., Salimans, T., Gritsenko, A., Chan, W., Norouzi, M., and Fleet, D. J. (2022). Video diffusion models. URL https://arxiv. org/abs/2204.03458.
  • Hochreiter and Schmidhuber, (1997) Hochreiter, S. and Schmidhuber, J. (1997). Long short-term memory. Neural Computation, 9(8):1735–1780.
  • Hu et al., (2023) Hu, A., Russell, L., Yeo, H., Murez, Z., Fedoseev, G., Kendall, A., Shotton, J., and Corrado, G. (2023). Gaia-1: A generative world model for autonomous driving. arXiv preprint arXiv:2309.17080.
  • Hyvärinen, (2005) Hyvärinen, A. (2005). Estimation of non-normalized statistical models by score matching. Journal of Machine Learning Research, 6:695–709.
  • Jackson et al., (2024) Jackson, M. T., Matthews, M. T., Lu, C., Ellis, B., Whiteson, S., and Foerster, J. (2024). Policy-guided diffusion. arXiv preprint arXiv:2404.06356.
  • Janner et al., (2022) Janner, M., Du, Y., Tenenbaum, J., and Levine, S. (2022). Planning with diffusion for flexible behavior synthesis. In International Conference on Machine Learning, pages 9902–9915. PMLR.
  • Kaiser et al., (2019) Kaiser, L., Babaeizadeh, M., Milos, P., Osinski, B., Campbell, R. H., Czechowski, K., Erhan, D., Finn, C., Kozakowski, P., Levine, S., et al. (2019). Model-based reinforcement learning for atari. arXiv preprint arXiv:1903.00374.
  • Kaiser and Sutskever, (2015) Kaiser, Ł. and Sutskever, I. (2015). Neural gpus learn algorithms. arXiv preprint arXiv:1511.08228.
  • Kapturowski et al., (2018) Kapturowski, S., Ostrovski, G., Quan, J., Munos, R., and Dabney, W. (2018). Recurrent experience replay in distributed reinforcement learning. International Conference on Learning Representations.
  • Karras et al., (2022) Karras, T., Aittala, M., Aila, T., and Laine, S. (2022). Elucidating the design space of diffusion-based generative models. Advances in Neural Information Processing Systems, 35:26565–26577.
  • Kim et al., (2020) Kim, S. W., Zhou, Y., Philion, J., Torralba, A., and Fidler, S. (2020). Learning to simulate dynamic environments with gamegan. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 1231–1240.
  • LeCun, (2022) LeCun, Y. (2022). A path towards autonomous machine intelligence version 0.9. 2, 2022-06-27. OpenReview.
  • LeCun et al., (1989) LeCun, Y., Boser, B., Denker, J. S., Henderson, D., Howard, R. E., Hubbard, W., and Jackel, L. D. (1989). Backpropagation applied to handwritten zip code recognition. Neural computation, 1(4):541–551.
  • Liang et al., (2023) Liang, Z., Mu, Y., Ding, M., Ni, F., Tomizuka, M., and Luo, P. (2023). Adaptdiffuser: Diffusion models as adaptive self-evolving planners. International Conference on Machine Learning.
  • Lu et al., (2023) Lu, C., Ball, P. J., and Parker-Holder, J. (2023). Synthetic experience replay. arXiv preprint arXiv:2303.06614.
  • Luo et al., (2023) Luo, G., Dunlap, L., Park, D. H., Holynski, A., and Darrell, T. (2023). Diffusion hyperfeatures: Searching through time and space for semantic correspondence. In Advances in Neural Information Processing Systems.
  • Micheli et al., (2023) Micheli, V., Alonso, E., and Fleuret, F. (2023). Transformers are sample-efficient world models. International Conference on Learning Representations.
  • Mnih et al., (2016) Mnih, V., Badia, A. P., Mirza, M., Graves, A., Lillicrap, T., Harley, T., Silver, D., and Kavukcuoglu, K. (2016). Asynchronous methods for deep reinforcement learning. In Balcan, M. F. and Weinberger, K. Q., editors, Proceedings of The 33rd International Conference on Machine Learning, volume 48 of Proceedings of Machine Learning Research, pages 1928–1937, New York, New York, USA.
  • Mnih et al., (2015) Mnih, V., Kavukcuoglu, K., Silver, D., Rusu, A. A., Veness, J., Bellemare, M. G., Graves, A., Riedmiller, M., Fidjeland, A. K., Ostrovski, G., et al. (2015). Human-level control through deep reinforcement learning. Nature, 518(7540):529–533.
  • Nichol and Dhariwal, (2021) Nichol, A. Q. and Dhariwal, P. (2021). Improved denoising diffusion probabilistic models. International Conference on Machine Learning.
  • Nuti et al., (2023) Nuti, F., Franzmeyer, T., and Henriques, J. F. (2023). Extracting reward functions from diffusion models. arXiv preprint arXiv:2306.01804.
  • Ouyang et al., (2022) Ouyang, L., Wu, J., Jiang, X., Almeida, D., Wainwright, C., Mishkin, P., Zhang, C., Agarwal, S., Slama, K., Ray, A., et al. (2022). Training language models to follow instructions with human feedback. Advances in Neural Information Processing Systems, 35:27730–27744.
  • Pearce et al., (2023) Pearce, T., Rashid, T., Kanervisto, A., Bignell, D., Sun, M., Georgescu, R., Macua, S. V., Tan, S. Z., Momennejad, I., Hofmann, K., and Devlin, S. (2023). Imitating human behaviour with diffusion models. The Eleventh International Conference on Learning Representations.
  • Pearce and Zhu, (2022) Pearce, T. and Zhu, J. (2022). Counter-strike deathmatch with large-scale behavioural cloning. In 2022 IEEE Conference on Games (CoG), pages 104–111. IEEE.
  • Peebles and Xie, (2023) Peebles, W. and Xie, S. (2023). Scalable diffusion models with transformers. In Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), pages 4195–4205.
  • Podell et al., (2023) Podell, D., English, Z., Lacey, K., Blattmann, A., Dockhorn, T., Müller, J., Penna, J., and Rombach, R. (2023). Sdxl: Improving latent diffusion models for high-resolution image synthesis. arXiv preprint arXiv:2307.01952.
  • Radford et al., (2019) Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., and Sutskever, I. (2019). Language models are unsupervised multitask learners.
  • Ramesh et al., (2022) Ramesh, A., Dhariwal, P., Nichol, A., Chu, C., and Chen, M. (2022). Hierarchical text-conditional image generation with clip latents. arXiv preprint arXiv:2204.06125, 1(2):3.
  • Ramesh et al., (2021) Ramesh, A., Pavlov, M., Goh, G., Gray, S., Voss, C., Radford, A., Chen, M., and Sutskever, I. (2021). Zero-shot text-to-image generation. International Conference on Machine Learning.
  • Robine et al., (2023) Robine, J., Höftmann, M., Uelwer, T., and Harmeling, S. (2023). Transformer-based world models are happy with 100k interactions. International Conference on Learning Representations.
  • Rombach et al., (2022) Rombach, R., Blattmann, A., Lorenz, D., Esser, P., and Ommer, B. (2022). High-resolution image synthesis with latent diffusion models. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pages 10684–10695.
  • Ronneberger et al., (2015) Ronneberger, O., Fischer, P., and Brox, T. (2015). U-net: Convolutional networks for biomedical image segmentation. In Medical Image Computing and Computer-Assisted Intervention–MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18, pages 234–241. Springer.
  • (60) Saharia, C., Chan, W., Saxena, S., Li, L., Whang, J., Denton, E. L., Ghasemipour, K., Gontijo Lopes, R., Karagol Ayan, B., Salimans, T., et al. (2022a). Photorealistic text-to-image diffusion models with deep language understanding. Advances in Neural Information Processing Systems, 35:36479–36494.
  • (61) Saharia, C., Ho, J., Chan, W., Salimans, T., Fleet, D. J., and Norouzi, M. (2022b). Image super-resolution via iterative refinement. IEEE transactions on pattern analysis and machine intelligence.
  • Santana and Hotz, (2016) Santana, E. and Hotz, G. (2016). Learning a driving simulator. arXiv preprint arXiv:1608.01230.
  • Schrittwieser et al., (2020) Schrittwieser, J., Antonoglou, I., Hubert, T., Simonyan, K., Sifre, L., Schmitt, S., Guez, A., Lockhart, E., Hassabis, D., Graepel, T., et al. (2020). Mastering atari, go, chess and shogi by planning with a learned model. Nature, 588(7839):604–609.
  • Schwarzer et al., (2023) Schwarzer, M., Ceron, J. S. O., Courville, A., Bellemare, M. G., Agarwal, R., and Castro, P. S. (2023). Bigger, better, faster: Human-level atari with human-level efficiency. International Conference on Machine Learning.
  • Silver et al., (2016) Silver, D., Huang, A., Maddison, C. J., Guez, A., Sifre, L., Van Den Driessche, G., Schrittwieser, J., Antonoglou, I., Panneershelvam, V., Lanctot, M., et al. (2016). Mastering the game of go with deep neural networks and tree search. Nature, 529:484–489.
  • Singer et al., (2023) Singer, U., Polyak, A., Hayes, T., Yin, X., An, J., Zhang, S., Hu, Q., Yang, H., Ashual, O., Gafni, O., et al. (2023). Make-a-video: Text-to-video generation without text-video data. International Conference on Learning Representations.
  • Skorokhodov et al., (2022) Skorokhodov, I., Tulyakov, S., and Elhoseiny, M. (2022). Stylegan-v: A continuous video generator with the price, image quality and perks of stylegan2. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 3626–3636.
  • Sohl-Dickstein et al., (2015) Sohl-Dickstein, J., Weiss, E., Maheswaranathan, N., and Ganguli, S. (2015). Deep unsupervised learning using nonequilibrium thermodynamics. International Conference on Machine Learning.
  • Song et al., (2020) Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., and Poole, B. (2020). Score-based generative modeling through stochastic differential equations. International Conference on Learning Representations.
  • Sutton, (1991) Sutton, R. S. (1991). Dyna, an integrated architecture for learning, planning, and reacting. ACM Sigart Bulletin, 2(4):160–163.
  • Sutton and Barto, (2018) Sutton, R. S. and Barto, A. G. (2018). Reinforcement learning: An introduction. MIT press.
  • Unterthiner et al., (2018) Unterthiner, T., Van Steenkiste, S., Kurach, K., Marinier, R., Michalski, M., and Gelly, S. (2018). Towards accurate generative models of video: A new metric & challenges. arXiv preprint arXiv:1812.01717.
  • Valevski et al., (2024) Valevski, D., Leviathan, Y., Arar, M., and Fruchter, S. (2024). Diffusion models are real-time game engines.
  • Van Den Oord et al., (2017) Van Den Oord, A., Vinyals, O., et al. (2017). Neural discrete representation learning. Advances in Neural Information Processing Systems, 30.
  • Vaswani et al., (2017) Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., and Polosukhin, I. (2017). Attention is all you need. Advances in Neural Information Processing Systems, 30.
  • Villegas et al., (2023) Villegas, R., Babaeizadeh, M., Kindermans, P.-J., Moraldo, H., Zhang, H., Saffar, M. T., Castro, S., Kunze, J., and Erhan, D. (2023). Phenaki: Variable length video generation from open domain textual descriptions. International Conference on Learning Representations.
  • Vincent, (2011) Vincent, P. (2011). A connection between score matching and denoising autoencoders. Neural computation, 23(7):1661–1674.
  • Vincent et al., (2008) Vincent, P., Larochelle, H., Bengio, Y., and Manzagol, P.-A. (2008). Extracting and composing robust features with denoising autoencoders. In International Conference on Machine learning.
  • Wang et al., (2022) Wang, Z., Hunt, J. J., and Zhou, M. (2022). Diffusion policies as an expressive policy class for offline reinforcement learning. International Conference on Learning Representations.
  • Wang et al., (2016) Wang, Z., Schaul, T., Hessel, M., Hasselt, H., Lanctot, M., and Freitas, N. (2016). Dueling network architectures for deep reinforcement learning. International Conference on Machine Learning.
  • Wu et al., (2021) Wu, C., Huang, L., Zhang, Q., Li, B., Ji, L., Yang, F., Sapiro, G., and Duan, N. (2021). Godiva: Generating open-domain videos from natural descriptions. arXiv preprint arXiv:2104.14806.
  • Wu et al., (2023) Wu, P., Escontrela, A., Hafner, D., Abbeel, P., and Goldberg, K. (2023). Daydreamer: World models for physical robot learning. In Conference on Robot Learning, pages 2226–2240. PMLR.
  • Wu and He, (2018) Wu, Y. and He, K. (2018). Group normalization. In Proceedings of the European Conference on Computer Vision (ECCV).
  • Xu et al., (2023) Xu, J., Liu, S., Vahdat, A., Byeon, W., Wang, X., and De Mello, S. (2023). Open-vocabulary panoptic segmentation with text-to-image diffusion models. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 2955–2966.
  • Yan et al., (2023) Yan, W., Hafner, D., James, S., and Abbeel, P. (2023). Temporally consistent transformers for video generation. International Conference on Machine Learning.
  • Yan et al., (2021) Yan, W., Zhang, Y., Abbeel, P., and Srinivas, A. (2021). Videogpt: Video generation using vq-vae and transformers. arXiv preprint arXiv:2104.10157.
  • Ye et al., (2021) Ye, W., Liu, S., Kurutach, T., Abbeel, P., and Gao, Y. (2021). Mastering atari games with limited data. Advances in Neural Information Processing Systems, 34.
  • Zhang et al., (2018) Zhang, R., Isola, P., Efros, A. A., Shechtman, E., and Wang, O. (2018). The unreasonable effectiveness of deep features as a perceptual metric. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 586–595.
  • Zhang et al., (2023) Zhang, W., Wang, G., Sun, J., Yuan, Y., and Huang, G. (2023). Storm: Efficient stochastic transformer based world models for reinforcement learning. In Thirty-seventh Conference on Neural Information Processing Systems.
  • Zheng et al., (2020) Zheng, H., Fu, J., Zeng, Y., Luo, J., and Zha, Z.-J. (2020). Learning semantic-aware normalization for generative adversarial networks. In Advances in Neural Information Processing Systems.

Appendix A Sampling observations in diamond

We describe here how we sample an observation 𝐱t0superscriptsubscript𝐱𝑡0\mathbf{x}_{t}^{0}bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT from our diffusion world model. We initialize the procedure with a noisy observation 𝐱t𝒯ppriorsimilar-tosuperscriptsubscript𝐱𝑡𝒯superscript𝑝𝑝𝑟𝑖𝑜𝑟\mathbf{x}_{t}^{\mathcal{T}}\sim p^{prior}bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_T end_POSTSUPERSCRIPT ∼ italic_p start_POSTSUPERSCRIPT italic_p italic_r italic_i italic_o italic_r end_POSTSUPERSCRIPT, and iteratively solve the reverse SDE in Equation 2 from τ=𝒯𝜏𝒯\tau=\mathcal{T}italic_τ = caligraphic_T to τ=0𝜏0\tau=0italic_τ = 0, using the learned score model 𝐒θ(𝐱tτ,τ,𝐱<t0,a<t)subscript𝐒𝜃superscriptsubscript𝐱𝑡𝜏𝜏superscriptsubscript𝐱absent𝑡0subscript𝑎absent𝑡\mathbf{S}_{\theta}(\mathbf{x}_{t}^{\tau},\tau,\mathbf{x}_{<t}^{0},a_{<t})bold_S start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT , italic_τ , bold_x start_POSTSUBSCRIPT < italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT , italic_a start_POSTSUBSCRIPT < italic_t end_POSTSUBSCRIPT ) conditioned on past observations 𝐱<t0superscriptsubscript𝐱absent𝑡0\mathbf{x}_{<t}^{0}bold_x start_POSTSUBSCRIPT < italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT and actions a<tsubscript𝑎absent𝑡a_{<t}italic_a start_POSTSUBSCRIPT < italic_t end_POSTSUBSCRIPT. This procedure is illustrated in Figure 1.

In fact, there are many possible sampling methods for a given learned score model 𝐒θsubscript𝐒𝜃\mathbf{S}_{\theta}bold_S start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT (Karras et al.,, 2022). Notably, Song et al., (2020) introduce a corresponding “probability flow" ordinary differential equation (ODE), with marginals equivalent to the stochastic process described in Section 2.2. In that case, the solving procedure is deterministic, and the only randomness comes from sampling the initial condition. In practice, this means that for a given score model, we can resort to any ODE or SDE solver, from simple first order methods like Euler (deterministic) and Euler–Maruyama (stochastic) schemes, to higher-order methods like Heun’s method (Ascher and Petzold,, 1998).

Regardless of the choice of solver, each step introduces truncation errors, resulting from the local score approximation and the discretization of the continuous process. Higher order samplers may reduce this truncation error, but come at the cost of additional Number of Function Evaluations (NFE) – how many forward passes of the network are required to generate a sample. This local error generally scales superlinearly with respect to the step size (for instance Euler’s method is 𝒪(h2)𝒪superscript2\mathcal{O}(h^{2})caligraphic_O ( italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) for step size hhitalic_h), so increasing the number of denoising steps improves the visual quality of the generated next frame. Therefore, there is a trade-off between visual quality and NFE that directly determines the inference cost of the diffusion world model.

Appendix B Link between DDPM and continuous-time score-based diffusion models

Denoising Diffusion Probabilistic Models (ddpm, Ho et al., (2020)) can be described as a discrete version of the diffusion process introduced in Section 2.2, as described in Song et al., (2020). The discrete forward process is a Markov chain characterized by a discrete noise schedule 0<β1,,βi,βN<1formulae-sequence0subscript𝛽1subscript𝛽𝑖subscript𝛽𝑁10<\beta_{1},\dots,\beta_{i},\dots\beta_{N}<10 < italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , … italic_β start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT < 1, and a variance-preserving Gaussian transition kernel,

p(𝐱i|𝐱i1)=𝒩(𝐱i;1βi𝐱i1,βi𝐈).𝑝conditionalsuperscript𝐱𝑖superscript𝐱𝑖1𝒩superscript𝐱𝑖1subscript𝛽𝑖superscript𝐱𝑖1subscript𝛽𝑖𝐈p(\mathbf{x}^{i}|\mathbf{x}^{i-1})=\mathcal{N}(\mathbf{x}^{i};\sqrt{1-\beta_{i% }}\mathbf{x}^{i-1},\beta_{i}\mathbf{I}).italic_p ( bold_x start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT | bold_x start_POSTSUPERSCRIPT italic_i - 1 end_POSTSUPERSCRIPT ) = caligraphic_N ( bold_x start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ; square-root start_ARG 1 - italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG bold_x start_POSTSUPERSCRIPT italic_i - 1 end_POSTSUPERSCRIPT , italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_I ) . (8)

In the continuous time limit N𝑁N\to\inftyitalic_N → ∞, the Markov chain becomes a diffusion process, and the discrete noise schedule becomes a time-dependent function β(τ)𝛽𝜏\beta(\tau)italic_β ( italic_τ ). This diffusion process can be described by an SDE with drift coefficient 𝐟(𝐱,τ)=12β(τ)𝐱𝐟𝐱𝜏12𝛽𝜏𝐱\mathbf{f}(\mathbf{x},\tau)=-\frac{1}{2}\beta(\tau)\mathbf{x}bold_f ( bold_x , italic_τ ) = - divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_β ( italic_τ ) bold_x and diffusion coefficient g(τ)=β(τ)𝑔𝜏𝛽𝜏g(\tau)=\sqrt{\beta(\tau)}italic_g ( italic_τ ) = square-root start_ARG italic_β ( italic_τ ) end_ARG (Song et al.,, 2020).

Appendix C EDM network preconditioners and training

Karras et al., (2022) use the following preconditioners for normalization and rescaling purposes (as mentioned in Section 3.1) to improve network training:

cinτ=1σ(τ)2+σdata2superscriptsubscript𝑐𝑖𝑛𝜏1𝜎superscript𝜏2superscriptsubscript𝜎𝑑𝑎𝑡𝑎2c_{in}^{\tau}=\frac{1}{\sqrt{\sigma(\tau)^{2}+\sigma_{data}^{2}}}italic_c start_POSTSUBSCRIPT italic_i italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_σ ( italic_τ ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_σ start_POSTSUBSCRIPT italic_d italic_a italic_t italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG (9)
coutτ=σ(τ)σdataσ(τ)2+σdata2superscriptsubscript𝑐𝑜𝑢𝑡𝜏𝜎𝜏subscript𝜎𝑑𝑎𝑡𝑎𝜎superscript𝜏2superscriptsubscript𝜎𝑑𝑎𝑡𝑎2c_{out}^{\tau}=\frac{\sigma(\tau)\sigma_{data}}{\sqrt{\sigma(\tau)^{2}+\sigma_% {data}^{2}}}italic_c start_POSTSUBSCRIPT italic_o italic_u italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT = divide start_ARG italic_σ ( italic_τ ) italic_σ start_POSTSUBSCRIPT italic_d italic_a italic_t italic_a end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG italic_σ ( italic_τ ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_σ start_POSTSUBSCRIPT italic_d italic_a italic_t italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG (10)
cnoiseτ=14log(σ(τ))superscriptsubscript𝑐𝑛𝑜𝑖𝑠𝑒𝜏14𝜎𝜏c_{noise}^{\tau}=\frac{1}{4}\log(\sigma(\tau))italic_c start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG 4 end_ARG roman_log ( italic_σ ( italic_τ ) ) (11)
cskipτ=σdata2σdata2+σ2(τ),superscriptsubscript𝑐𝑠𝑘𝑖𝑝𝜏superscriptsubscript𝜎𝑑𝑎𝑡𝑎2superscriptsubscript𝜎𝑑𝑎𝑡𝑎2superscript𝜎2𝜏c_{skip}^{\tau}=\frac{\sigma_{data}^{2}}{\sigma_{data}^{2}+\sigma^{2}(\tau)},italic_c start_POSTSUBSCRIPT italic_s italic_k italic_i italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT = divide start_ARG italic_σ start_POSTSUBSCRIPT italic_d italic_a italic_t italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_d italic_a italic_t italic_a end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_τ ) end_ARG , (12)

where σdata=0.5subscript𝜎𝑑𝑎𝑡𝑎0.5\sigma_{data}=0.5italic_σ start_POSTSUBSCRIPT italic_d italic_a italic_t italic_a end_POSTSUBSCRIPT = 0.5.

The noise parameter σ(τ)𝜎𝜏\sigma(\tau)italic_σ ( italic_τ ) is sampled to maximize the effectiveness of training as follows:

log(σ(τ))𝒩(Pmean,Pstd2),similar-to𝜎𝜏𝒩subscript𝑃𝑚𝑒𝑎𝑛superscriptsubscript𝑃𝑠𝑡𝑑2\log(\sigma(\tau))\sim\mathcal{N}(P_{mean},P_{std}^{2}),roman_log ( italic_σ ( italic_τ ) ) ∼ caligraphic_N ( italic_P start_POSTSUBSCRIPT italic_m italic_e italic_a italic_n end_POSTSUBSCRIPT , italic_P start_POSTSUBSCRIPT italic_s italic_t italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) , (13)

where Pmean=0.4,Pstd=1.2formulae-sequencesubscript𝑃𝑚𝑒𝑎𝑛0.4subscript𝑃𝑠𝑡𝑑1.2P_{mean}=-0.4,P_{std}=1.2italic_P start_POSTSUBSCRIPT italic_m italic_e italic_a italic_n end_POSTSUBSCRIPT = - 0.4 , italic_P start_POSTSUBSCRIPT italic_s italic_t italic_d end_POSTSUBSCRIPT = 1.2. Refer to Karras et al., (2022) for an in-depth analysis.

Appendix D Model architectures

The diffusion model 𝐃θsubscript𝐃𝜃\mathbf{D}_{\theta}bold_D start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT is a standard U-Net 2D (Ronneberger et al.,, 2015), conditioned on the last 4 frames and actions, as well as the diffusion time τ𝜏\tauitalic_τ. We use frame stacking for observation conditioning, and adaptive group normalization (Zheng et al.,, 2020) for action and diffusion time conditioning.

The reward/termination model Rψsubscript𝑅𝜓R_{\psi}italic_R start_POSTSUBSCRIPT italic_ψ end_POSTSUBSCRIPT layers are shared except for the final prediction heads. The model takes as input a sequence of frames and actions, and forwards it through convolutional residual blocks (He et al.,, 2015) followed by an LSTM cell (Mnih et al.,, 2016; Hochreiter and Schmidhuber,, 1997; Gers et al.,, 2000). Before starting the imagination procedure, we burn-in (Kapturowski et al.,, 2018) the conditioning frames and actions to initialize the hidden and cell states of the LSTM.

The weights of the policy πϕsubscript𝜋italic-ϕ\pi_{\phi}italic_π start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT and value network Vϕsubscript𝑉italic-ϕV_{\phi}italic_V start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT are shared except for the last layer. In the following, we refer to (π,V)ϕsubscript𝜋𝑉italic-ϕ(\pi,V)_{\phi}( italic_π , italic_V ) start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT as the "actor-critic" network, even though V𝑉Vitalic_V is technically a state-value network, not a critic. This network takes as input a frame, and forwards it through convolutional trunk followed by an LSTM cell. The convolutional trunk consists of four residual blocks and 2x2 max-pooling with stride 2. The main path of the residual blocks consists of a group normalization (Wu and He,, 2018) layer, a SiLU activation (Elfwing et al.,, 2018), and a 3x3 convolution with stride 1 and padding 1. Before starting the imagination procedure, we burn-in the conditioning frames to initialize the hidden and cell states of the LSTM.

Please refer to Table 2 below for hyperparameter values, and to Algorithm 1 for a detailed summary of the training procedure.

Table 2: Architecture details for diamond.
Hyperparameter Value
Diffusion Model (𝐃θsubscript𝐃𝜃\mathbf{D}_{\theta}bold_D start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT)
Observation conditioning mechanism Frame stacking
Action conditioning mechanism Adaptive Group Normalization
Diffusion time conditioning mechanism Adaptive Group Normalization
Residual blocks layers [2, 2, 2, 2]
Residual blocks channels [64, 64, 64, 64]
Residual blocks conditioning dimension 256
Reward/Termination Model (Rψsubscript𝑅𝜓R_{\psi}italic_R start_POSTSUBSCRIPT italic_ψ end_POSTSUBSCRIPT)
Action conditioning mechanisms Adaptive Group Normalization
Residual blocks layers [2, 2, 2, 2]
Residual blocks channels [32, 32, 32, 32]
Residual blocks conditioning dimension 128
LSTM dimension 512
Actor-Critic Model (πϕsubscript𝜋italic-ϕ\pi_{\phi}italic_π start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT and Vϕsubscript𝑉italic-ϕV_{\phi}italic_V start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT)
Residual blocks layers [1, 1, 1, 1]
Residual blocks channels [32, 32, 64, 64]
LSTM dimension 512

Appendix E Training hyperparameters

Table 3: Hyperparameters for diamond.
Hyperparameter Value
Training loop
Number of epochs 1000
Training steps per epoch 400
Batch size 32
Environment steps per epoch 100
Epsilon (greedy) for collection 0.01
RL hyperparameters
Imagination horizon (H𝐻Hitalic_H) 15
Discount factor (γ𝛾\gammaitalic_γ) 0.985
Entropy weight (η𝜂\etaitalic_η) 0.001
λ𝜆\lambdaitalic_λ-returns coefficient (λ𝜆\lambdaitalic_λ) 0.95
Sequence construction during training
For 𝐃θsubscript𝐃𝜃\mathbf{D}_{\theta}bold_D start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT, number of conditioning observations and actions (L𝐿Litalic_L) 4
For Rψsubscript𝑅𝜓R_{\psi}italic_R start_POSTSUBSCRIPT italic_ψ end_POSTSUBSCRIPT, burn-in length (BRsubscript𝐵𝑅B_{R}italic_B start_POSTSUBSCRIPT italic_R end_POSTSUBSCRIPT), set to L𝐿Litalic_L in practice 4
For Rψsubscript𝑅𝜓R_{\psi}italic_R start_POSTSUBSCRIPT italic_ψ end_POSTSUBSCRIPT, training sequence length (BR+Hsubscript𝐵𝑅𝐻B_{R}+Hitalic_B start_POSTSUBSCRIPT italic_R end_POSTSUBSCRIPT + italic_H) 19
For πϕsubscript𝜋italic-ϕ\pi_{\phi}italic_π start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT and Vϕsubscript𝑉italic-ϕV_{\phi}italic_V start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT, burn-in length (Bπ,Vsubscript𝐵𝜋𝑉B_{\pi,V}italic_B start_POSTSUBSCRIPT italic_π , italic_V end_POSTSUBSCRIPT), set to L𝐿Litalic_L in practice 4
Optimization
Optimizer AdamW
Learning rate 1e-4
Epsilon 1e-8
Weight decay (𝐃θsubscript𝐃𝜃\mathbf{D}_{\theta}bold_D start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT) 1e-2
Weight decay (Rψsubscript𝑅𝜓R_{\psi}italic_R start_POSTSUBSCRIPT italic_ψ end_POSTSUBSCRIPT) 1e-2
Weight decay (πϕsubscript𝜋italic-ϕ\pi_{\phi}italic_π start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT and Vϕsubscript𝑉italic-ϕV_{\phi}italic_V start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT) 0
Diffusion Sampling
Method Euler
Number of steps 3
Environment
Image observation dimensions 64×\times×64×\times×3
Action space Discrete (up to 18 actions)
Frameskip 4
Max noop 30
Termination on life loss True
Reward clipping {1,0,1}101\{-1,0,1\}{ - 1 , 0 , 1 }

Appendix F Reinforcement learning objectives

In what follows, we note 𝐱tsubscript𝐱𝑡\mathbf{x}_{t}bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, rtsubscript𝑟𝑡r_{t}italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and dtsubscript𝑑𝑡d_{t}italic_d start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT the observations, rewards, and boolean episode terminations predicted by our world model. We note H𝐻Hitalic_H the imagination horizon, Vϕsubscript𝑉italic-ϕV_{\phi}italic_V start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT the value network, πϕsubscript𝜋italic-ϕ\pi_{\phi}italic_π start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT the policy network, and atsubscript𝑎𝑡a_{t}italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT the actions taken by the policy within the world model.

We use λ𝜆\lambdaitalic_λ-returns to balance bias and variance as the regression target for the value network. Given an imagined trajectory of length H𝐻Hitalic_H, we can define the λ𝜆\lambdaitalic_λ-return recursively as follows,

Λt={rt+γ(1dt)[(1λ)Vϕ(𝐱t+1)+λΛt+1]ift<HVϕ(𝐱H)ift=H.subscriptΛ𝑡casessubscript𝑟𝑡𝛾1subscript𝑑𝑡delimited-[]1𝜆subscript𝑉italic-ϕsubscript𝐱𝑡1𝜆subscriptΛ𝑡1if𝑡𝐻subscript𝑉italic-ϕsubscript𝐱𝐻if𝑡𝐻\Lambda_{t}=\begin{cases}r_{t}+\gamma(1-d_{t})\Big{[}(1-\lambda)V_{\phi}(% \mathbf{x}_{t+1})+\lambda\Lambda_{t+1}\Big{]}&\text{if}\quad t<H\\ V_{\phi}(\mathbf{x}_{H})&\text{if}\quad t=H.\\ \end{cases}roman_Λ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = { start_ROW start_CELL italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_γ ( 1 - italic_d start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) [ ( 1 - italic_λ ) italic_V start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ) + italic_λ roman_Λ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ] end_CELL start_CELL if italic_t < italic_H end_CELL end_ROW start_ROW start_CELL italic_V start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT ) end_CELL start_CELL if italic_t = italic_H . end_CELL end_ROW (14)

The value network Vϕsubscript𝑉italic-ϕV_{\phi}italic_V start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT is trained to minimize V(ϕ)subscript𝑉italic-ϕ\mathcal{L}_{V}(\phi)caligraphic_L start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ( italic_ϕ ), the expected squared difference with λ𝜆\lambdaitalic_λ-returns over imagined trajectories,

V(ϕ)=𝔼πϕ[t=0H1(Vϕ(𝐱t)sg(Λt))2],subscript𝑉italic-ϕsubscript𝔼subscript𝜋italic-ϕdelimited-[]superscriptsubscript𝑡0𝐻1superscriptsubscript𝑉italic-ϕsubscript𝐱𝑡sgsubscriptΛ𝑡2\mathcal{L}_{V}(\phi)=\mathbb{E}_{\pi_{\phi}}\left[\sum_{t=0}^{H-1}\big{(}V_{% \phi}(\mathbf{x}_{t})-\mathrm{sg}(\Lambda_{t})\big{)}^{2}\right],caligraphic_L start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ( italic_ϕ ) = blackboard_E start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∑ start_POSTSUBSCRIPT italic_t = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_H - 1 end_POSTSUPERSCRIPT ( italic_V start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - roman_sg ( roman_Λ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] , (15)

where sg()sg\operatorname{sg}(\cdot)roman_sg ( ⋅ ) denotes the gradient stopping operation, meaning that the target is a constant in the gradient-based optimization, as classically established in the literature (Mnih et al.,, 2015; Hafner et al.,, 2021; Micheli et al.,, 2023).

As we can generate large amounts of on-policy trajectories in imagination, we use a simple reinforce objective to train the policy, with the value Vϕ(𝐱t)subscript𝑉italic-ϕsubscript𝐱𝑡V_{\phi}(\mathbf{x}_{t})italic_V start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) as a baseline to reduce the variance of the gradients (Sutton and Barto,, 2018). The policy is trained to minimize the following objective, combining reinforce and a weighted entropy maximization objective to maintain sufficient exploration,

π(ϕ)=𝔼πϕ[t=0H1log(πϕ(at𝐱t))sg(ΛtVϕ(𝐱t))+η(πϕ(at𝐱t))].subscript𝜋italic-ϕsubscript𝔼subscript𝜋italic-ϕdelimited-[]superscriptsubscript𝑡0𝐻1subscript𝜋italic-ϕconditionalsubscript𝑎𝑡subscript𝐱absent𝑡sgsubscriptΛ𝑡subscript𝑉italic-ϕsubscript𝐱𝑡𝜂subscript𝜋italic-ϕconditionalsubscript𝑎𝑡subscript𝐱absent𝑡\mathcal{L}_{\pi}(\phi)=-\mathbb{E}_{\pi_{\phi}}\left[\sum_{t=0}^{H-1}\log% \left(\pi_{\phi}\left(a_{t}\mid\mathbf{x}_{\leq t}\right)\right)\operatorname{% sg}\left(\Lambda_{t}-V_{\phi}\left(\mathbf{x}_{t}\right)\right)+\eta% \operatorname{\mathcal{H}}\left(\pi_{\phi}\left(a_{t}\mid\mathbf{x}_{\leq t}% \right)\right)\right].caligraphic_L start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT ( italic_ϕ ) = - blackboard_E start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∑ start_POSTSUBSCRIPT italic_t = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_H - 1 end_POSTSUPERSCRIPT roman_log ( italic_π start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∣ bold_x start_POSTSUBSCRIPT ≤ italic_t end_POSTSUBSCRIPT ) ) roman_sg ( roman_Λ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_V start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) + italic_η caligraphic_H ( italic_π start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∣ bold_x start_POSTSUBSCRIPT ≤ italic_t end_POSTSUBSCRIPT ) ) ] . (16)

Appendix G diamond algorithm

We summarize the overall training procedure of diamond in Algorithm 1 below. We denote as 𝒟𝒟\mathcal{D}caligraphic_D the replay dataset where the agent stores data collected from the real environment, and other notations are introduced in previous sections or are self-explanatory.

Procedure training_loop():
       for epochs do
             collect_experience(steps_collect)
             for steps_diffusion_model do
                   update_diffusion_model()
                  
            for steps_reward_end_model do
                   update_reward_end_model()
                  
            for steps_actor_critic do
                   update_actor_critic()
                  
            
      
Procedure collect_experience(n𝑛nitalic_n):
       𝐱00env.reset()superscriptsubscript𝐱00env.reset()\mathbf{x}_{0}^{0}\leftarrow\texttt{env.reset()}bold_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ← env.reset()
       for t=0𝑡0t=0italic_t = 0 to n1𝑛1n-1italic_n - 1 do
             Sample atπϕ(at𝐱t0)similar-tosubscript𝑎𝑡subscript𝜋italic-ϕconditionalsubscript𝑎𝑡superscriptsubscript𝐱𝑡0a_{t}\sim\pi_{\phi}(a_{t}\mid\mathbf{x}_{t}^{0})italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∼ italic_π start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∣ bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT )
             𝐱t+10,rt,dtenv.step(at)superscriptsubscript𝐱𝑡10subscript𝑟𝑡subscript𝑑𝑡env.step(subscript𝑎𝑡)\mathbf{x}_{t+1}^{0},r_{t},d_{t}\leftarrow\texttt{env.step(}a_{t}\texttt{)}bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT , italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← env.step( italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )
             𝒟𝒟{𝐱t0,at,rt,dt}𝒟𝒟superscriptsubscript𝐱𝑡0subscript𝑎𝑡subscript𝑟𝑡subscript𝑑𝑡\mathcal{D}\leftarrow\mathcal{D}\cup\{\mathbf{x}_{t}^{0},a_{t},r_{t},d_{t}\}caligraphic_D ← caligraphic_D ∪ { bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT }
             if dt=1subscript𝑑𝑡1d_{t}=1italic_d start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = 1 then
                   𝐱t+10env.reset()superscriptsubscript𝐱𝑡10env.reset()\mathbf{x}_{t+1}^{0}\leftarrow\texttt{env.reset()}bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ← env.reset()
                  
            
      
Procedure update_diffusion_model():
       Sample sequence (𝐱tL+10,atL+1,,𝐱t0,at,𝐱t+10)𝒟similar-tosuperscriptsubscript𝐱𝑡𝐿10subscript𝑎𝑡𝐿1superscriptsubscript𝐱𝑡0subscript𝑎𝑡superscriptsubscript𝐱𝑡10𝒟(\mathbf{x}_{t-L+1}^{0},a_{t-L+1},\dots,\mathbf{x}_{t}^{0},a_{t},\mathbf{x}_{t% +1}^{0})\sim\mathcal{D}( bold_x start_POSTSUBSCRIPT italic_t - italic_L + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT , italic_a start_POSTSUBSCRIPT italic_t - italic_L + 1 end_POSTSUBSCRIPT , … , bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ) ∼ caligraphic_D
       Sample log(σ)𝒩(Pmean,Pstd2)similar-to𝜎𝒩subscript𝑃𝑚𝑒𝑎𝑛superscriptsubscript𝑃𝑠𝑡𝑑2\log(\sigma)\sim\mathcal{N}(P_{mean},P_{std}^{2})roman_log ( italic_σ ) ∼ caligraphic_N ( italic_P start_POSTSUBSCRIPT italic_m italic_e italic_a italic_n end_POSTSUBSCRIPT , italic_P start_POSTSUBSCRIPT italic_s italic_t italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) // log-normal sigma distribution from EDM
       Define τ:=σassign𝜏𝜎\tau:=\sigmaitalic_τ := italic_σ // default identity schedule from EDM
       Sample 𝐱t+1τ𝒩(𝐱t+10,σ2𝐈)similar-tosuperscriptsubscript𝐱𝑡1𝜏𝒩superscriptsubscript𝐱𝑡10superscript𝜎2𝐈\mathbf{x}_{t+1}^{\tau}\sim\mathcal{N}(\mathbf{x}_{t+1}^{0},\sigma^{2}\mathbf{% I})bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT ∼ caligraphic_N ( bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_I ) // Add independent Gaussian noise
       Compute 𝐱^t+10=𝐃θ(𝐱t+1τ,τ,𝐱tL+10,atL+1,,𝐱t0,at)superscriptsubscript^𝐱𝑡10subscript𝐃𝜃superscriptsubscript𝐱𝑡1𝜏𝜏superscriptsubscript𝐱𝑡𝐿10subscript𝑎𝑡𝐿1superscriptsubscript𝐱𝑡0subscript𝑎𝑡\hat{\mathbf{x}}_{t+1}^{0}=\mathbf{D}_{\theta}(\mathbf{x}_{t+1}^{\tau},\tau,% \mathbf{x}_{t-L+1}^{0},a_{t-L+1},\dots,\mathbf{x}_{t}^{0},a_{t})over^ start_ARG bold_x end_ARG start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT = bold_D start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT , italic_τ , bold_x start_POSTSUBSCRIPT italic_t - italic_L + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT , italic_a start_POSTSUBSCRIPT italic_t - italic_L + 1 end_POSTSUBSCRIPT , … , bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )
       Compute reconstruction loss (θ)=𝐱^t+10𝐱t+102𝜃superscriptnormsuperscriptsubscript^𝐱𝑡10superscriptsubscript𝐱𝑡102\mathcal{L}(\theta)=\|\hat{\mathbf{x}}_{t+1}^{0}-\mathbf{x}_{t+1}^{0}\|^{2}caligraphic_L ( italic_θ ) = ∥ over^ start_ARG bold_x end_ARG start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT - bold_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
       Update 𝐃θsubscript𝐃𝜃\mathbf{D}_{\theta}bold_D start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT
      
Procedure update_reward_end_model():
       Sample indexes {t,,t+L+H1}𝑡𝑡𝐿𝐻1\mathcal{I}\coloneqq\{t,\dots,t+L+H-1\}caligraphic_I ≔ { italic_t , … , italic_t + italic_L + italic_H - 1 } // burn-in + imagination horizon
       Sample sequence (𝐱i0,ai,ri,di)i𝒟similar-tosubscriptsuperscriptsubscript𝐱𝑖0subscript𝑎𝑖subscript𝑟𝑖subscript𝑑𝑖𝑖𝒟(\mathbf{x}_{i}^{0},a_{i},r_{i},d_{i})_{i\in\mathcal{I}}\sim\mathcal{D}( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT , italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_r start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i ∈ caligraphic_I end_POSTSUBSCRIPT ∼ caligraphic_D
       Initialize h=c=0𝑐0h=c=0italic_h = italic_c = 0 // LSTM hidden and cell states
       for i𝑖i\in\mathcal{I}italic_i ∈ caligraphic_I do
             Compute r^i,d^i,h,c=Rψ(𝐱i,ai,h,c)subscript^𝑟𝑖subscript^𝑑𝑖𝑐subscript𝑅𝜓subscript𝐱𝑖subscript𝑎𝑖𝑐\hat{r}_{i},\hat{d}_{i},h,c=R_{\psi}(\mathbf{x}_{i},a_{i},h,c)over^ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , over^ start_ARG italic_d end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_h , italic_c = italic_R start_POSTSUBSCRIPT italic_ψ end_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_h , italic_c )
            
      Compute (ψ)=iCE(r^i,sign(ri))+CE(d^i,di)𝜓subscript𝑖CEsubscript^𝑟𝑖signsubscript𝑟𝑖CEsubscript^𝑑𝑖subscript𝑑𝑖\mathcal{L}(\psi)=\sum_{i\in\mathcal{I}}\mathrm{CE}(\hat{r}_{i},\mathrm{sign}(% r_{i}))+\mathrm{CE}(\hat{d}_{i},d_{i})caligraphic_L ( italic_ψ ) = ∑ start_POSTSUBSCRIPT italic_i ∈ caligraphic_I end_POSTSUBSCRIPT roman_CE ( over^ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , roman_sign ( italic_r start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) + roman_CE ( over^ start_ARG italic_d end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) // CE: cross-entropy loss
       Update Rψsubscript𝑅𝜓R_{\psi}italic_R start_POSTSUBSCRIPT italic_ψ end_POSTSUBSCRIPT
      
Procedure update_actor_critic():
       Sample initial buffer (𝐱tL+10,atL+1,,𝐱t0)𝒟similar-tosuperscriptsubscript𝐱𝑡𝐿10subscript𝑎𝑡𝐿1superscriptsubscript𝐱𝑡0𝒟(\mathbf{x}_{t-L+1}^{0},a_{t-L+1},\dots,\mathbf{x}_{t}^{0})\sim\mathcal{D}( bold_x start_POSTSUBSCRIPT italic_t - italic_L + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT , italic_a start_POSTSUBSCRIPT italic_t - italic_L + 1 end_POSTSUBSCRIPT , … , bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ) ∼ caligraphic_D
       Burn-in buffer with Rψsubscript𝑅𝜓R_{\psi}italic_R start_POSTSUBSCRIPT italic_ψ end_POSTSUBSCRIPT, πϕsubscript𝜋italic-ϕ\pi_{\phi}italic_π start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT and Vϕsubscript𝑉italic-ϕV_{\phi}italic_V start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT to initialize LSTM states
       for i=t𝑖𝑡i=titalic_i = italic_t to t+H1𝑡𝐻1t+H-1italic_t + italic_H - 1 do
             Sample aiπϕ(ai𝐱i0)similar-tosubscript𝑎𝑖subscript𝜋italic-ϕconditionalsubscript𝑎𝑖superscriptsubscript𝐱𝑖0a_{i}\sim\pi_{\phi}(a_{i}\mid\mathbf{x}_{i}^{0})italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∼ italic_π start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∣ bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT )
             Sample reward risubscript𝑟𝑖r_{i}italic_r start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and termination disubscript𝑑𝑖d_{i}italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT with Rψsubscript𝑅𝜓R_{\psi}italic_R start_POSTSUBSCRIPT italic_ψ end_POSTSUBSCRIPT
             Sample next observation 𝐱i+10superscriptsubscript𝐱𝑖10\mathbf{x}_{i+1}^{0}bold_x start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT by simulating reverse diffusion process with 𝐃θsubscript𝐃𝜃\mathbf{D}_{\theta}bold_D start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT
            
      Compute Vϕ(𝐱i)subscript𝑉italic-ϕsubscript𝐱𝑖V_{\phi}(\mathbf{x}_{i})italic_V start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) for i=t,,t+H𝑖𝑡𝑡𝐻i=t,\dots,t+Hitalic_i = italic_t , … , italic_t + italic_H
       Compute RL losses V(ϕ)subscript𝑉italic-ϕ\mathcal{L}_{V}(\phi)caligraphic_L start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ( italic_ϕ ) and π(ϕ)subscript𝜋italic-ϕ\mathcal{L}_{\pi}(\phi)caligraphic_L start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT ( italic_ϕ )
       Update πϕsubscript𝜋italic-ϕ\pi_{\phi}italic_π start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT and Vϕsubscript𝑉italic-ϕV_{\phi}italic_V start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT
      
Algorithm 1 diamond

Appendix H Additional performance comparisons

We provide performance profiles (Agarwal et al.,, 2021) for diamond and baselines below.


Refer to caption

Figure 7: Performance profiles, i.e. fraction of runs above a given human normalized score.

As additional angles of comparison, we also provide parameter counts and approximate training times for iris, DreamerV3 and diamond in Table 4 below. We see that diamond has the highest mean HNS, with fewer parameters than both iris and DreamerV3. diamond also trains faster than iris, although is slower than DreamerV3.

Table 4: Number of parameters, training time, and mean human-normalized score (HNS).
iris DreamerV3 diamond (ours)
#parameters (↓) 30M 18M 13M
Training days (↓) 4.1 <1 2.9
Mean HNS (↑) 1.046 1.097 1.459

A full training time profile for diamond is provided in Appendix I.

Appendix I Training time profile

Table 5 provides a full training time profile for diamond.

Table 5: Detailed breakdown of training time. Profiling performed using a Nvidia RTX 4090 with the default hyperparameters specified in Appendices D and E These profiling measures are representative, since exact durations will depend on the machine, the environment, and the training stage.
Single update Time (ms) Detail (ms)
Total 543543543543 88+115+3408811534088+115+34088 + 115 + 340
      Diffusion model update 88888888 -
      Reward/Termination model update 115115115115 -
      Actor-Critic model update 340340340340 15×20.4+341520.43415\times 20.4+3415 × 20.4 + 34
          Imagination step (x 15) 20.420.420.420.4 12.7+7.0+0.712.77.00.712.7+7.0+0.712.7 + 7.0 + 0.7
              Next observation prediction 12.712.712.712.7 3×4.234.23\times 4.23 × 4.2
                  Denoising step (x 3) 4.24.24.24.2 -
              Reward/Termination prediction 7.07.07.07.0 -
              Action prediction 0.70.70.70.7 -
          Loss computation and backward 34343434 -
Epoch Time (s) Detail (s)
Total 217217217217 35+46+136354613635+46+13635 + 46 + 136
      Diffusion model 35353535 400×88×10340088superscript103400\times 88\times 10^{-3}400 × 88 × 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT
      Reward/Termination model 46464646 400×115×103400115superscript103400\times 115\times 10^{-3}400 × 115 × 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT
      Actor-Critic model 136136136136 400×340×103400340superscript103400\times 340\times 10^{-3}400 × 340 × 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT
Run Time (days) Detail (days)
Total 2.92.92.92.9 2.5+0.42.50.42.5+0.42.5 + 0.4
      Training time 2.52.52.52.5 1000×217/(24×3600)10002172436001000\times 217/(24\times 3600)1000 × 217 / ( 24 × 3600 )
      Other (collection, evaluation, checkpointing) 0.40.40.40.4 -

Appendix J Broader comparison to model-free and search-based methods

Table 6 provides scores for model-free and search-based methods, including the current best performing methods on the Atari 100k benchmark, EfficientZero (Ye et al.,, 2021) and bbf (Schwarzer et al.,, 2023). Both of these methods use approaches that are out of scope of our approach, such as computationally expensive lookahead Monte-Carlo tree search for EfficientZero, and using periodic network resets in combination with hyperparameter scheduling for bbf. We see that while the use of lookahead search and more advanced reinforcement learning techniques (for EfficientZero (Ye et al.,, 2021) and bbf (Schwarzer et al.,, 2023) respectively) can still provide greater performance overall, diamond promisingly still outperforms these methods on some games.

Table 6: Raw scores and human-normalized metrics for search-based and model-free methods.
Search-based Model-free
Game Human MuZero EfficientZero CURL SPR SR-SPR BBF diamond (ours)
Alien 7127.7 530.0 808.5 711.0 841.9 1107.8 1173.2 744.1
Amidar 1719.5 38.8 148.6 113.7 179.7 203.4 244.6 225.8
Assault 742.0 500.1 1263.1 500.9 565.6 1088.9 2098.5 1526.4
Asterix 8503.3 1734.0 25557.8 567.2 962.5 903.1 3946.1 3698.5
BankHeist 753.1 192.5 351.0 65.3 345.4 531.7 732.9 19.7
BattleZone 37187.5 7687.5 13871.2 8997.8 14834.1 17671.0 24459.8 4702.0
Boxing 12.1 15.1 52.7 0.9 35.7 45.8 85.8 86.9
Breakout 30.5 48.0 414.1 2.6 19.6 25.5 370.6 132.5
ChopperCommand 7387.8 1350.0 1117.3 783.5 946.3 2362.1 7549.3 1369.8
CrazyClimber 35829.4 56937.0 83940.2 9154.4 36700.5 45544.1 58431.8 99167.8
DemonAttack 1971.0 3527.0 13003.9 646.5 517.6 2814.4 13341.4 288.1
Freeway 29.6 21.8 21.8 28.3 19.3 25.4 25.5 33.3
Frostbite 4334.7 255.0 296.3 1226.5 1170.7 2584.8 2384.8 274.1
Gopher 2412.5 1256.0 3260.3 400.9 660.6 712.4 1331.2 5897.9
Hero 30826.4 3095.0 9315.9 4987.7 5858.6 8524.0 7818.6 5621.8
Jamesbond 302.8 87.5 517.0 331.0 366.5 389.1 1129.6 427.4
Kangaroo 3035.0 62.5 724.1 740.2 3617.4 3631.7 6614.7 5382.2
Krull 2665.5 4890.8 5663.3 3049.2 3681.6 5911.8 8223.4 8610.1
KungFuMaster 22736.3 18813.0 30944.8 8155.6 14783.2 18649.4 18991.7 18713.6
MsPacman 6951.6 1265.6 1281.2 1064.0 1318.4 1574.1 2008.3 1958.2
Pong 14.6 -6.7 20.1 -18.5 -5.4 2.9 16.7 20.4
PrivateEye 69571.3 56.3 96.7 81.9 86.0 97.9 40.5 114.3
Qbert 13455.0 3952.0 13781.9 727.0 866.3 4044.1 4447.1 4499.3
RoadRunner 7845.0 2500.0 17751.3 5006.1 12213.1 13463.4 33426.8 20673.2
Seaquest 42054.7 208.0 1100.2 315.2 558.1 819.0 1232.5 551.2
UpNDown 11693.2 2896.9 17264.2 2646.4 10859.2 112450.3 12101.7 3856.3
#Superhuman (↑) N/A 5 14 2 6 9 12 11
Mean (↑) 1.000 0.562 1.943 0.261 0.616 1.271 2.247 1.459
IQM (↑) 1.000 0.288 1.047 0.113 0.337 0.700 1.139 0.641

Appendix K Quantitative analysis of autoregressive model drift

Figure 8 provides a quantitative measure of the compounding error demonstrated qualitatively in Figure 3 for DDPM and EDM based world models.

Refer to caption

Figure 8: Average pixel drift between an imagined trajectory and the corresponding reference trajectory collected with an expert in Breakout. The trajectories are each 1000 timesteps, starting from the same frame and following the same sequence of actions. Each line displays the average and shaded standard deviation of 400 reference trajectories held out from training data. DDPM becomes more stable with increasing number of denoising steps, but is less stable than 1-step EDM, even with 10 denoising steps. The drift we observe for EDM corresponds to differences in the imagined trajectory rather than a pathological color shift as we see in Figure 3a.

Appendix L Quantitative ablation on reducing the number of denoising steps

Table 7 provides a quantitative ablation of the effect of reducing the number of denoising steps used for our EDM diffusion world model from 3 (used for Table 1) to 1, for diamond’s 10 highest performing games. Note that the 1-step results correspond to a single seed only so will have higher variance. Nonetheless, these results provide some signal that agents trained with 1 denoising step perform worse than our default choice of 3, particularly for the game Boxing, despite the apparent similarity in Figure 8. This additional evidence supports our qualitative analysis in Section 5.2.

Table 7: Quantitative ablation on reducing the number of denoising steps from 3 (default) to 1.
Game Random Human diamond (n=3𝑛3n=3italic_n = 3) diamond (n=1𝑛1n=1italic_n = 1)
Amidar 5.8 1719.5 225.8 191.8
Assault 222.4 742.0 1526.4 782.5
Asterix 210.0 8503.3 3698.5 6687.0
Boxing 0.1 12.1 86.9 41.9
Breakout 1.7 30.5 132.5 50.8
CrazyClimber 10780.5 35829.4 99167.8 87233.0
Kangaroo 52.0 3035.0 5382.2 1710.0
Krull 1598.0 2665.5 8610.1 9105.1
Pong -20.7 14.6 20.4 20.9
RoadRunner 11.5 7845.0 20673.2 5084.0
Mean HNS (↑) 0.000 1.000 3.052 1.962

Appendix M Early investigations on visual quality in more complex environments

In the main body of the paper, we evaluated the utility of diamond for the purpose of training RL agents in a world model on the well-established Atari 100k benchmark (Kaiser et al.,, 2019), and demonstrated diamond’s diffusion world model could be applied to model a more complex 3D environment from the game Counter-Strike: Global Offensive. In this section, we provide early experiments investigating the effectiveness of diamond’s diffusion world model by directly evaluating the visual quality of the trajectories they generate. The two environments we consider are presented in Section M.1 below.

M.1 Environments

CS:GO. We use the Counter-Strike: Global Offensive dataset introduced by Pearce and Zhu, (2022). Here we use the Clean dataset containing 190k frames (3.3 hours) of high-skill human gameplay, captured on the Dust II map. This contains observations and actions (mouse and keyboard) captured at 16Hz. We use 150k frames (2.6 hours) for training and 40k frames (0.7 hours) for evaluation. We resize observations to 64×\times×64 pixels, and use no augmentation.

Motorway driving. We use the dataset from Santana and Hotz, (2016)555https://github.com/commaai/research, which contains camera and metadata captured from human drivers on US motorways. We select only trajectories captured in daylight, and exclude the first and last 5 minutes of each trajectory (typically traveling to/from a motorway), leaving 4.4 hours of data. We use five trajectories for training (3.6 hours) and two for testing (0.8 hours). We downsample the dataset to 10Hz, resize observations to 64×\times×64, and for actions use the (normalized) steering angle and acceleration. During training, we apply data augmentation of shift & scale, contrast, brightness, and saturation, and mirroring.

We note that the purpose of our investigation is to train and evaluate diamond’s diffusion model on these static datasets, and that we do not perform reinforcement learning, since there is no standard reinforcement learning protocol for these environments.

M.2 Diffusion Model Architectures

We consider two potential diffusion model architectures, summarized in Figure 9.

Refer to caption
Figure 9: We tested two architectures for diamond’s diffusion model which condition on previous image observations in different ways. To illustrate differences with typical video generation models, we also visualize a U-Net 3D (Çiçek et al.,, 2016) which diffuses a block of frames simultaneously.

Frame-stacking. The simplest way to condition on previous observations is by concatenating the previous L𝐿Litalic_L frames together with the next noised frame, concat[𝐱tτ,𝐱t10,,𝐱tL0]concatsuperscriptsubscript𝐱𝑡𝜏superscriptsubscript𝐱𝑡10superscriptsubscript𝐱𝑡𝐿0\operatorname{concat}[\mathbf{x}_{t}^{\tau},\mathbf{x}_{t-1}^{0},\dots,\mathbf% {x}_{t-L}^{0}]roman_concat [ bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_τ end_POSTSUPERSCRIPT , bold_x start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT , … , bold_x start_POSTSUBSCRIPT italic_t - italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ], which is compatible with a standard U-Net 2D (Ronneberger et al.,, 2015). This architecture is particularly attractive due to its lightweight construction, requiring minimal additional parameters and compute compared to typical image diffusion. This is the architecture we used for the main body of the paper.

Cross-attention. The U-Net 3D (Çiçek et al.,, 2016), also displayed for comparison in Figure 9, is a leading architecture in video diffusion (Ho et al.,, 2022). We adapted this design to have an autoregressive cross-attention architecture, formed of a core U-Net 2D, that only receives a single noised frame as direct input, but which cross-attends to the activations of a separate history encoder network. This encoder is a lightweight version of the U-Net 2D architecture. Parameters are shared for all L𝐿Litalic_L encoders, and each receives the relative environment timestep embedding as input. The final design differs from the U-Net 3D which diffuses all frames jointly, shares parameters across networks, and uses self-, rather than cross-, attention.

M.3 Metrics, Baselines and Compute

Metrics. To evaluate the visual quality of generated trajectories, we use the standard Fréchet Video Distance (FVD) (Unterthiner et al.,, 2018) as implemented by Skorokhodov et al., (2022). This is computed between 1024 real videos (taken from the test set), and 1024 generated videos, each 16 frames long (1-2 seconds). Models condition on L=6𝐿6L=6italic_L = 6 previous real frames, and the real action sequence. On this same data, we also report the Fréchet Inception Distance (FID) (Heusel et al.,, 2017), which measures the visual quality of individual observations, ignoring the temporal dimension. For these same sets of videos, we also compute the LPIPS loss (Zhang et al.,, 2018) between each pair of real/generated observations (Yan et al.,, 2023). Sampling rate describes the number of observations that can be generated, in sequence, by a single Nvidia RTX A6000 GPU, per second.

Baselines. We compare against two well-established world model methods; DreamerV3 (Hafner et al.,, 2023) and iris (Micheli et al.,, 2023), adapting the original implementations to train on a static dataset. We ensured baselines used a similar number of parameters to diamond. Two variants of iris are reported; image observations are discretized into K=16𝐾16K=16italic_K = 16 tokens (as used in the original work), or into K=64𝐾64K=64italic_K = 64 tokens (achieved with one less down/up-sampling layer in the autoencoder, see Appendix E of Micheli et al., (2023)), which provide the potential for modeling higher-fidelity visuals.

Compute. All models (baselines and diamond) were trained for 120k updates with a batch size of 64, on up to 4×\times×A6000 GPUs. Each training run took between 1-2 days.

M.4 Analysis

Table 8: Results for 3D environments. These metrics compare observations from real trajectories and generated trajectories. The generated trajectories are conditioned on an initial set of L=6𝐿6L=6italic_L = 6 observations and a real sequence of actions.
———— CS:GO ———— ———– Driving ———– Sample rate Parameters
Method FID \downarrow FVD \downarrow LPIPS \downarrow FID \downarrow FVD \downarrow LPIPS \downarrow (Hz) \uparrow (#)
DreamerV3 106.8 509.1 0.173 167.5 733.7 0.160 266.7 181M
IRIS (K=16𝐾16K=16italic_K = 16) 24.5 110.1 0.129 51.4 368.7 0.188 4.2 123M
IRIS (K=64𝐾64K=64italic_K = 64) 22.8 85.7 0.116 44.3 276.9 0.148 1.5 111M
diamond frame-stack (ours) 9.6 34.8 0.107 16.7 80.3 0.058 7.4 122M
diamond cross-attention (ours) 11.6 81.4 0.125 35.2 299.9 0.119 2.5 184M

Table 8 reports metrics on the visual quality of generated trajectories, along with sampling rates and number of parameters, for the frame-stack and cross-attention diamond architectures, compared to baseline methods. diamond outperforms the baselines across all visual quality metrics. This validates the results seen in the wider video generation literature, where diffusion models currently lead, as discussed in Section 7. The simpler frame-stacking architecture performs better than cross-attention, something surprising given the prevalence of cross-attention in the video generation literature. We believe the inductive bias provided by directly feeding in the input, frame-wise, may be well suited to autoregressive generation. Overall, these results indicate diamond frame-stack >>> diamond cross-attention \approx IRIS 64 >>> IRIS 16 >>> DreamerV3, which we found corresponds to our intuition from visual inspection.

In terms of sampling rate, diamond frame-stack (with 20 denoising steps) is faster than iris (K=16𝐾16K=16italic_K = 16). iris suffers from a further 2.8×\times× slow down for the K=64𝐾64K=64italic_K = 64 version, verifying its sample time is bottlenecked by the number of tokens K𝐾Kitalic_K. On the other hand, DreamerV3 is an order of magnitude faster – this derives from its independent, rather than joint, sampling procedure, and the flip-side of this is the low visual quality of its trajectories.

Figure 10 below shows selected examples of the trajectories produced by diamond in CS:GO and motorway driving. The trajectories are plausible, often even at time horizons of reasonable length. In CS:GO, the model accurately generates the correct geometry of the level as it passes through the doorway into a new area of the map. In motorway driving, a car is plausibly imagined overtaking on the left.

Refer to caption
Figure 10: Example trajectories sampled every 25 timesteps from diamond (frame stack) for the modern 3D first-person shooter CS:GO (top row), and real-world motorway driving (bottom row).

While the above experiments use real sequences of actions from the dataset, we also investigated how robust diamond (frame stack) was to novel, user-input actions. Figure 11 shows the effect of the actions in motorway driving – conditioned on the same L=6𝐿6L=6italic_L = 6 real frames, we generate trajectories conditioned on five different action sequences. In general the effects are as intended, e.g. steer straight/left/right moves the camera as expected. Interestingly, when ‘slow down’ is input, the distance to the car in front decreases since the model predicts that the traffic ahead has come to a standstill. Figure 12 shows similar sequences for CS:GO. For the common actions (mouse movements and fire), the effects are as expected, though they are unstable beyond a few frames, since such a sequence of actions is unlikely to have been seen in the demonstration dataset. We note that these issues – the causal confusion and instabilities – are a symptom of training world models on offline data, rather than being an inherent weakness of diamond.

Refer to caption
Figure 11: Effect of fixed actions on sampled trajectories in motorway driving. Conditioned on the same initial observations, we rollout the model applying differing actions. Interestingly, the model has learnt to associate ’Slow down’ and ’Speed up’ actions to the whole traffic slowing down and speeding up.
Refer to caption
Figure 12: Effect of fixed actions on sampled trajectories in CS:GO. Conditioned on the same initial observation, we rollout the model applying differing actions. Whilst in immediate frames these have the intended effect, for longer roll-outs the observations can degenerate. For instance, it would have been very unlikely for the human demonstrator to look directly into ground in this game state, so the world model is unable to generate a plausible trajectory here, and instead snaps onto another area of the map when looking down does make sense.