| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """WorldModel transformer for frame generation.""" |
|
|
| from typing import Optional, List |
| import math |
|
|
| import einops as eo |
| import torch |
| from torch import nn, Tensor |
| import torch.nn.functional as F |
| from tensordict import TensorDict |
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.models.modeling_utils import ModelMixin |
|
|
| from .attn import Attn, MergedQKVAttn, CrossAttention |
| from .nn import AdaLN, MLP, NoiseConditioner, ada_gate, ada_rmsnorm, rms_norm |
| from .quantize import quantize_model |
| from .cache import CachedDenoiseStepEmb, CachedCondHead |
|
|
|
|
| def patch_cached_noise_conditioning(model) -> None: |
| |
| cached_denoise_step_emb = CachedDenoiseStepEmb( |
| model.denoise_step_emb, model.config.scheduler_sigmas |
| ) |
| model.denoise_step_emb = cached_denoise_step_emb |
| for blk in model.transformer.blocks: |
| blk.cond_head = CachedCondHead(blk.cond_head, cached_denoise_step_emb) |
|
|
|
|
| def patch_Attn_merge_qkv(model) -> None: |
| for name, mod in list(model.named_modules()): |
| if isinstance(mod, Attn) and not isinstance(mod, MergedQKVAttn): |
| model.set_submodule(name, MergedQKVAttn(mod, model.config)) |
|
|
|
|
| def patch_MLPFusion_split(model) -> None: |
| for name, mod in list(model.named_modules()): |
| if isinstance(mod, MLPFusion) and not isinstance(mod, SplitMLPFusion): |
| model.set_submodule(name, SplitMLPFusion(mod)) |
|
|
|
|
| def _apply_inference_patches(model) -> None: |
| patch_cached_noise_conditioning(model) |
| patch_Attn_merge_qkv(model) |
| patch_MLPFusion_split(model) |
|
|
|
|
| class CFG(nn.Module): |
| def __init__(self, d_model: int, dropout: float): |
| super().__init__() |
| self.dropout = dropout |
| self.null_emb = nn.Parameter(torch.zeros(1, 1, d_model)) |
|
|
| def forward( |
| self, x: torch.Tensor, is_conditioned: Optional[bool] = None |
| ) -> torch.Tensor: |
| """ |
| x: [B, L, D] |
| is_conditioned: |
| - None: training-style random dropout |
| - bool: whole batch conditioned / unconditioned at sampling |
| """ |
| B, L, _ = x.shape |
| null = self.null_emb.expand(B, L, -1) |
|
|
| |
| if self.training or is_conditioned is None: |
| if self.dropout == 0.0: |
| return x |
| drop = torch.rand(B, 1, 1, device=x.device) < self.dropout |
| return torch.where(drop, null, x) |
|
|
| |
| return x if is_conditioned else null |
|
|
|
|
| class ControllerInputEmbedding(nn.Module): |
| """Embeds controller inputs (mouse + buttons) into model dimension.""" |
|
|
| def __init__(self, n_buttons: int, d_model: int, mlp_ratio: int = 4): |
| super().__init__() |
| self.mlp = MLP(n_buttons + 3, d_model * mlp_ratio, d_model) |
|
|
| def forward(self, mouse: Tensor, button: Tensor, scroll: Tensor): |
| assert len(mouse.shape) == 3 |
| x = torch.cat((mouse, button, scroll), dim=-1) |
| return self.mlp(x) |
|
|
|
|
| class MLPFusion(nn.Module): |
| """Fuses per-group conditioning into tokens by applying an MLP to cat([x, cond]).""" |
|
|
| def __init__(self, d_model: int): |
| super().__init__() |
| self.mlp = MLP(2 * d_model, d_model, d_model) |
|
|
| def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: |
| B, _, D = x.shape |
| L = cond.shape[1] |
|
|
| Wx, Wc = self.mlp.fc1.weight.chunk(2, dim=1) |
|
|
| x = x.view(B, L, -1, D) |
| h = F.linear(x, Wx) + F.linear(cond, Wc).unsqueeze( |
| 2 |
| ) |
| h = F.silu(h) |
| y = F.linear(h, self.mlp.fc2.weight) |
| return y.flatten(1, 2) |
|
|
|
|
| class SplitMLPFusion(nn.Module): |
| """Packed MLPFusion -> split linears (no cat, quant-friendly).""" |
|
|
| def __init__(self, src: MLPFusion): |
| super().__init__() |
| D = src.mlp.fc2.in_features |
| dev, dt = src.mlp.fc2.weight.device, src.mlp.fc2.weight.dtype |
|
|
| self.fc1_x = nn.Linear(D, D, bias=False, device=dev, dtype=dt) |
| self.fc1_c = nn.Linear(D, D, bias=False, device=dev, dtype=dt) |
| self.fc2 = nn.Linear(D, D, bias=False, device=dev, dtype=dt) |
|
|
| with torch.no_grad(): |
| Wx, Wc = src.mlp.fc1.weight.chunk(2, dim=1) |
| self.fc1_x.weight.copy_(Wx) |
| self.fc1_c.weight.copy_(Wc) |
| self.fc2.weight.copy_(src.mlp.fc2.weight) |
|
|
| self.train(src.training) |
|
|
| def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: |
| B, _, D = x.shape |
| L = cond.shape[1] |
| x = x.reshape(B, L, -1, D) |
| return self.fc2(F.silu(self.fc1_x(x) + self.fc1_c(cond).unsqueeze(2))).flatten( |
| 1, 2 |
| ) |
|
|
|
|
| class CondHead(nn.Module): |
| """Per-layer conditioning head: bias_in -> SiLU -> Linear -> chunk(n_cond).""" |
|
|
| n_cond = 6 |
|
|
| def __init__(self, d_model: int, noise_conditioning: str = "wan"): |
| super().__init__() |
| self.bias_in = ( |
| nn.Parameter(torch.zeros(d_model)) if noise_conditioning == "wan" else None |
| ) |
| self.cond_proj = nn.ModuleList( |
| [nn.Linear(d_model, d_model, bias=False) for _ in range(self.n_cond)] |
| ) |
|
|
| def forward(self, cond): |
| cond = cond + self.bias_in if self.bias_in is not None else cond |
| h = F.silu(cond) |
| return tuple(p(h) for p in self.cond_proj) |
|
|
|
|
| class WorldDiTBlock(nn.Module): |
| """Single transformer block with self-attention, optional cross-attention, and MLP.""" |
|
|
| def __init__( |
| self, |
| d_model: int, |
| n_heads: int, |
| mlp_ratio: int, |
| layer_idx: int, |
| prompt_conditioning: Optional[str], |
| prompt_conditioning_period: int, |
| prompt_embedding_dim: int, |
| ctrl_conditioning_period: int, |
| noise_conditioning: str, |
| config, |
| ): |
| super().__init__() |
| self.config = config |
| self.attn = Attn(config, layer_idx) |
| self.mlp = MLP(d_model, d_model * mlp_ratio, d_model) |
| self.cond_head = CondHead(d_model, noise_conditioning) |
|
|
| do_prompt_cond = ( |
| prompt_conditioning is not None |
| and layer_idx % prompt_conditioning_period == 0 |
| ) |
| self.prompt_cross_attn = ( |
| CrossAttention(config, prompt_embedding_dim) if do_prompt_cond else None |
| ) |
| do_ctrl_cond = layer_idx % ctrl_conditioning_period == 0 |
| self.ctrl_mlpfusion = MLPFusion(d_model) if do_ctrl_cond else None |
|
|
| def forward(self, x, pos_ids, cond, ctx, v, kv_cache=None): |
| """ |
| 0) Causal Frame Attention |
| 1) Frame->CTX Cross Attention |
| 2) MLP |
| """ |
| s0, b0, g0, s1, b1, g1 = self.cond_head(cond) |
|
|
| |
| residual = x |
| x = ada_rmsnorm(x, s0, b0) |
| x, v = self.attn(x, pos_ids, v, kv_cache=kv_cache) |
| x = ada_gate(x, g0) + residual |
|
|
| |
| if self.prompt_cross_attn is not None: |
| x = ( |
| self.prompt_cross_attn( |
| rms_norm(x), |
| context=rms_norm(ctx["prompt_emb"]), |
| context_pad_mask=ctx["prompt_pad_mask"], |
| ) |
| + x |
| ) |
|
|
| |
| if self.ctrl_mlpfusion is not None: |
| x = self.ctrl_mlpfusion(rms_norm(x), rms_norm(ctx["ctrl_emb"])) + x |
|
|
| |
| x = ada_gate(self.mlp(ada_rmsnorm(x, s1, b1)), g1) + x |
|
|
| return x, v |
|
|
|
|
| class WorldDiT(nn.Module): |
| """Stack of WorldDiTBlocks with shared parameters.""" |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.blocks = nn.ModuleList( |
| [ |
| WorldDiTBlock( |
| d_model=config.d_model, |
| n_heads=config.n_heads, |
| mlp_ratio=config.mlp_ratio, |
| layer_idx=idx, |
| prompt_conditioning=config.prompt_conditioning, |
| prompt_conditioning_period=config.prompt_conditioning_period, |
| prompt_embedding_dim=config.prompt_embedding_dim, |
| ctrl_conditioning_period=config.ctrl_conditioning_period, |
| noise_conditioning=config.noise_conditioning, |
| config=config, |
| ) |
| for idx in range(config.n_layers) |
| ] |
| ) |
|
|
| if config.noise_conditioning in ("dit_air", "wan"): |
| ref_proj = self.blocks[0].cond_head.cond_proj |
| for blk in self.blocks[1:]: |
| for blk_mod, ref_mod in zip(blk.cond_head.cond_proj, ref_proj): |
| blk_mod.weight = ref_mod.weight |
|
|
| |
| ref_rope = self.blocks[0].attn.rope |
| for blk in self.blocks[1:]: |
| blk.attn.rope = ref_rope |
|
|
| def forward(self, x, pos_ids, cond, ctx, kv_cache=None): |
| v = None |
| for i, block in enumerate(self.blocks): |
| x, v = block(x, pos_ids, cond, ctx, v, kv_cache=kv_cache) |
| return x |
|
|
|
|
| class WorldModel(ModelMixin, ConfigMixin): |
| """ |
| WORLD: Wayfarer Operator-driven Rectified-flow Long-context Diffuser. |
| |
| Denoises a frame given: |
| - All previous frames (via KV cache) |
| - The prompt embedding |
| - The controller input embedding |
| - The current noise level |
| """ |
|
|
| _supports_gradient_checkpointing = False |
| _keep_in_fp32_modules = ["denoise_step_emb", "rope"] |
|
|
| @register_to_config |
| def __init__( |
| self, |
| |
| d_model: int = 2560, |
| n_heads: int = 40, |
| n_kv_heads: Optional[int] = 20, |
| n_layers: int = 22, |
| mlp_ratio: int = 5, |
| channels: int = 16, |
| height: int = 16, |
| width: int = 16, |
| patch: tuple = (2, 2), |
| tokens_per_frame: int = 256, |
| n_frames: int = 512, |
| local_window: int = 16, |
| global_window: int = 128, |
| global_attn_period: int = 4, |
| global_pinned_dilation: int = 8, |
| global_attn_offset: int = -1, |
| value_residual: bool = False, |
| gated_attn: bool = True, |
| n_buttons: int = 256, |
| ctrl_conditioning: Optional[str] = "mlp_fusion", |
| ctrl_conditioning_period: int = 3, |
| ctrl_cond_dropout: float = 0.0, |
| prompt_conditioning: Optional[str] = "cross_attention", |
| prompt_conditioning_period: int = 3, |
| prompt_embedding_dim: int = 2048, |
| prompt_cond_dropout: float = 0.0, |
| noise_conditioning: str = "wan", |
| scheduler_sigmas: Optional[List[float]] = [ |
| 1.0, |
| 0.9483006596565247, |
| 0.8379597067832947, |
| 0.0, |
| ], |
| base_fps: int = 60, |
| causal: bool = True, |
| mlp_gradient_checkpointing: bool = True, |
| block_gradient_checkpointing: bool = True, |
| rope_impl: str = "ortho", |
| ): |
| super().__init__() |
|
|
| self.denoise_step_emb = NoiseConditioner(d_model) |
| self.ctrl_emb = ControllerInputEmbedding(n_buttons, d_model, mlp_ratio) |
|
|
| if self.config.ctrl_conditioning is not None: |
| self.ctrl_cfg = CFG(self.config.d_model, self.config.ctrl_cond_dropout) |
| if self.config.prompt_conditioning is not None: |
| self.prompt_cfg = CFG( |
| self.config.prompt_embedding_dim, self.config.prompt_cond_dropout |
| ) |
|
|
| self.transformer = WorldDiT(self.config) |
| self.patch = tuple(patch) |
|
|
| C, D = channels, d_model |
| self.patchify = nn.Conv2d( |
| C, D, kernel_size=self.patch, stride=self.patch, bias=False |
| ) |
| self.unpatchify = nn.Linear(D, C * math.prod(self.patch), bias=True) |
| self.out_norm = AdaLN(d_model) |
|
|
| |
| T = tokens_per_frame |
| idx = torch.arange(T, dtype=torch.long) |
| self.register_buffer( |
| "_t_pos_1f", torch.empty(T, dtype=torch.long), persistent=False |
| ) |
| self.register_buffer( |
| "_y_pos_1f", idx.div(width, rounding_mode="floor"), persistent=False |
| ) |
| self.register_buffer("_x_pos_1f", idx.remainder(width), persistent=False) |
|
|
| def forward( |
| self, |
| x: Tensor, |
| sigma: Tensor, |
| frame_timestamp: Tensor, |
| prompt_emb: Optional[Tensor] = None, |
| prompt_pad_mask: Optional[Tensor] = None, |
| mouse: Optional[Tensor] = None, |
| button: Optional[Tensor] = None, |
| scroll: Optional[Tensor] = None, |
| kv_cache=None, |
| ): |
| """ |
| Args: |
| x: [B, N, C, H, W] - latent frames |
| sigma: [B, N] - noise levels |
| frame_timestamp: [B, N] - frame indices |
| prompt_emb: [B, P, D] - prompt embeddings |
| prompt_pad_mask: [B, P] - padding mask for prompts |
| mouse: [B, N, 2] - mouse velocity |
| button: [B, N, n_buttons] - button states |
| scroll: [B, N, 1] - scroll wheel sign (-1, 0, 1) |
| kv_cache: StaticKVCache instance |
| ctrl_cond: whether to apply controller conditioning (inference only) |
| prompt_cond: whether to apply prompt conditioning (inference only) |
| """ |
| B, N, C, H, W = x.shape |
| ph, pw = self.patch |
| assert (H % ph == 0) and (W % pw == 0), "H, W must be divisible by patch" |
| Hp, Wp = H // ph, W // pw |
| torch._assert( |
| Hp * Wp == self.config.tokens_per_frame, |
| f"{Hp} * {Wp} != {self.config.tokens_per_frame}", |
| ) |
|
|
| torch._assert( |
| B == 1 and N == 1, "WorldModel.forward currently supports B==1, N==1" |
| ) |
| self._t_pos_1f.copy_(frame_timestamp[0, 0].expand_as(self._t_pos_1f)) |
| pos_ids = TensorDict( |
| { |
| "t_pos": self._t_pos_1f[None], |
| "y_pos": self._y_pos_1f[None], |
| "x_pos": self._x_pos_1f[None], |
| }, |
| batch_size=[1, self._t_pos_1f.numel()], |
| ) |
| cond = self.denoise_step_emb(sigma) |
|
|
| assert button is not None |
| ctx = { |
| "ctrl_emb": self.ctrl_emb(mouse, button, scroll), |
| "prompt_emb": prompt_emb, |
| "prompt_pad_mask": prompt_pad_mask, |
| } |
|
|
| D = self.unpatchify.in_features |
| x = self.patchify(x.reshape(B * N, C, H, W)) |
| x = eo.rearrange(x.view(B, N, D, Hp, Wp), "b n d hp wp -> b (n hp wp) d") |
| x = self.transformer(x, pos_ids, cond, ctx, kv_cache) |
| x = F.silu(self.out_norm(x, cond)) |
| x = eo.rearrange( |
| self.unpatchify(x), |
| "b (n hp wp) (c ph pw) -> b n c (hp ph) (wp pw)", |
| n=N, |
| hp=Hp, |
| wp=Wp, |
| ph=ph, |
| pw=pw, |
| ) |
|
|
| return x |
|
|
| def quantize(self, quant_type: str): |
| quantize_model(self, quant_type) |
|
|
| def apply_inference_patches(self): |
| _apply_inference_patches(self) |
|
|