def _patch_asyncio_event_loop_del(): """ Patch a noisy asyncio teardown issue sometimes seen in Spaces environments. In some runtime/container combinations, Python may try to close an already invalid file descriptor when the event loop is garbage-collected. We silence only that specific harmless case. """ try: import asyncio.base_events as base_events original_del = getattr(base_events.BaseEventLoop, "__del__", None) if original_del is None: return def patched_del(self): try: original_del(self) except ValueError as e: if "Invalid file descriptor" not in str(e): raise base_events.BaseEventLoop.__del__ = patched_del except Exception: pass _patch_asyncio_event_loop_del() import spaces import os import sys import uuid import shutil import random import gradio as gr import torch from omegaconf import OmegaConf from torchvision.io import write_video from einops import rearrange from huggingface_hub import snapshot_download from pipeline import ( CausalDiffusionInferencePipeline, CausalInferencePipeline, ) from utils.dataset import TextDataset from utils.misc import set_seed from demo_utils.memory import get_cuda_free_memory_gb, DynamicSwapInstaller # ------------------------------------------------------------------- # Download checkpoints once when the Space starts # ------------------------------------------------------------------- snapshot_download( repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="./checkpoints/Wan2.1-T2V-1.3B", ) snapshot_download( repo_id="KlingTeam/VideoReward", local_dir="./checkpoints/Videoreward", ) snapshot_download( repo_id="gdhe17/Self-Forcing", local_dir="./checkpoints/ode_init.pt", ) snapshot_download( repo_id="JaydenLu666/Reward-Forcing-T2V-1.3B", local_dir="./checkpoints/Reward-Forcing-T2V-1.3B", ) # === Paths === CONFIG_PATH = "configs/reward_forcing.yaml" CHECKPOINT_PATH = "checkpoints/Reward-Forcing-T2V-1.3B/rewardforcing.pt" PROMPT_DIR = "prompts/gradio_inputs" OUTPUT_ROOT = "videos" os.makedirs(PROMPT_DIR, exist_ok=True) os.makedirs(OUTPUT_ROOT, exist_ok=True) # ------------------------------------------------------------------- # Global cached objects # ------------------------------------------------------------------- PIPELINE = None PIPELINE_DEVICE = None CHECKPOINT_STEP = None def initialize_pipeline(progress: gr.Progress | None = None): """ Load config, instantiate pipeline, and load checkpoint only once. The pipeline is kept globally and reused across requests. """ global PIPELINE, PIPELINE_DEVICE, CHECKPOINT_STEP if PIPELINE is not None: if progress is not None: progress(0.16, desc="Init: cached pipeline already available") return PIPELINE device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if progress is not None: progress(0.02, desc="Init: loading configuration") config = OmegaConf.load(CONFIG_PATH) if progress is not None: progress(0.04, desc="Init: loading default configuration") default_config = OmegaConf.load("configs/default_config.yaml") config = OmegaConf.merge(default_config, config) if progress is not None: progress(0.07, desc="Init: creating inference pipeline") if hasattr(config, "denoising_step_list"): pipeline = CausalInferencePipeline(config, device=device) else: pipeline = CausalDiffusionInferencePipeline(config, device=device) if progress is not None: progress(0.11, desc="Init: loading reward forcing checkpoint") state_dict = torch.load( CHECKPOINT_PATH, map_location="cpu", weights_only=True, ) pipeline.generator.load_state_dict(state_dict) checkpoint_step = os.path.basename(os.path.dirname(CHECKPOINT_PATH)) checkpoint_step = checkpoint_step.split("_")[-1] if progress is not None: progress(0.15, desc="Init: converting pipeline dtype") pipeline = pipeline.to(dtype=torch.bfloat16) PIPELINE = pipeline PIPELINE_DEVICE = device CHECKPOINT_STEP = checkpoint_step if progress is not None: progress(0.18, desc="Init: pipeline cached") return PIPELINE def prepare_pipeline_for_inference( device, low_memory: bool, logs: str, progress: gr.Progress | None = None ): """ Move required modules to the right device before inference. Reuses the globally initialized pipeline. """ global PIPELINE, PIPELINE_DEVICE pipeline = initialize_pipeline(progress=progress) logs += "Preparing cached pipeline for inference...\n" if low_memory: if progress is not None: progress(0.22, desc="Init: preparing text encoder (dynamic swap)") logs += "Low-memory mode enabled: installing dynamic swap for text encoder...\n" DynamicSwapInstaller.install_model(pipeline.text_encoder, device=device) else: if progress is not None: progress(0.22, desc="Init: moving text encoder to device") logs += "Moving text encoder to device...\n" pipeline.text_encoder.to(device=device) if progress is not None: progress(0.27, desc="Init: moving generator to device") logs += "Moving generator to device...\n" pipeline.generator.to(device=device) if progress is not None: progress(0.32, desc="Init: moving VAE to device") logs += "Moving VAE to device...\n" pipeline.vae.to(device=device) PIPELINE_DEVICE = device if progress is not None: progress(0.36, desc="Init: pipeline ready") return pipeline, logs def reward_forcing_inference( prompt_txt_path: str, num_output_frames: int, use_ema: bool, output_root: str, seed: int, progress: gr.Progress, ): """ Inline / simplified version of inference.py: - single GPU - text-to-video only - one .txt file = N prompts, but returns only the first generated video """ global CHECKPOINT_STEP logs = "" # --------------------- Device & randomness --------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if seed == -1: seed = random.randint(0, 2**32 - 1) set_seed(seed) logs += f"Seed: {seed}\n" free_vram = get_cuda_free_memory_gb(device) logs += f"Free VRAM {free_vram} GB\n" low_memory = free_vram < 40 torch.set_grad_enabled(False) # --------------------- Phase 1: cached init / device prep --------------------- progress(0.01, desc="Init: checking cached pipeline") logs += "Loading cached pipeline...\n" initialize_pipeline(progress=progress) progress(0.2, desc="Init: preparing pipeline for inference") logs += "Preparing pipeline for inference...\n" pipeline, logs = prepare_pipeline_for_inference( device=device, low_memory=low_memory, logs=logs, progress=progress, ) progress(0.4, desc="Preparing dataset") logs += "Preparing dataset (TextDataset)...\n" dataset = TextDataset(prompt_path=prompt_txt_path, extended_prompt_path=None) num_prompts = len(dataset) logs += f"Number of prompts: {num_prompts}\n" from torch.utils.data import DataLoader, SequentialSampler sampler = SequentialSampler(dataset) dataloader = DataLoader( dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False ) # --------------------- Clean output folder --------------------- progress(0.5, desc="Preparing output directory") output_folder = os.path.join( output_root, f"rewardforcing-{num_output_frames}f", CHECKPOINT_STEP ) shutil.rmtree(output_folder, ignore_errors=True) os.makedirs(output_folder, exist_ok=True) logs += f"Output directory: {output_folder}\n" # --------------------- Phase 2: inference loop --------------------- progress(0.55, desc="Starting video generation") for i, batch_data in progress.tqdm( enumerate(dataloader), total=num_prompts, desc="Video generation", unit="prompt", ): idx = batch_data["idx"].item() # Unpack dataset batch if isinstance(batch_data, dict): batch = batch_data elif isinstance(batch_data, list): batch = batch_data[0] else: batch = batch_data all_video = [] # TEXT-TO-VIDEO only prompt = batch["prompts"][0] extended_prompt = batch.get("extended_prompts", [None])[0] prompts = [extended_prompt] if extended_prompt else [prompt] initial_latent = None sampled_noise = torch.randn( [1, num_output_frames, 16, 60, 104], device=device, dtype=torch.bfloat16, ) logs += f"Generating for prompt: {prompt[:80]}...\n" # WAN2 inference video, latents = pipeline.inference( noise=sampled_noise, text_prompts=prompts, return_latents=True, initial_latent=initial_latent, low_memory=low_memory, ) current_video = rearrange(video, "b t c h w -> b t h w c").cpu() all_video.append(current_video) video = 255.0 * torch.cat(all_video, dim=1) pipeline.vae.model.clear_cache() if idx < num_prompts: model = "regular" if not use_ema else "ema" safe_name = prompt[:50].replace("/", "_").replace("\\", "_") output_path = os.path.join(output_folder, f"{safe_name}.mp4") write_video(output_path, video[0], fps=16) logs += f"Saved video: {output_path}\n" progress(1.0, desc="Done") return output_path, logs logs += "[WARN] No video generated.\n" return None, logs @spaces.GPU(duration=200) def gradio_generate( prompt: str, duration: str, use_ema: bool, seed: int, progress=gr.Progress(), ): """ Triggered by Gradio: - writes prompt to a .txt file - performs inference - returns video + logs """ if not prompt or not prompt.strip(): raise gr.Error("Please enter a text prompt 🙂") # Duration → number of frames num_output_frames = 21 if duration == "5s (21 frames)" else 120 os.makedirs(PROMPT_DIR, exist_ok=True) prompt_id = uuid.uuid4().hex[:8] prompt_path = os.path.join(PROMPT_DIR, f"prompt_{prompt_id}.txt") with open(prompt_path, "w", encoding="utf-8") as f: f.write(prompt.strip() + "\n") video_path, logs = reward_forcing_inference( prompt_txt_path=prompt_path, num_output_frames=num_output_frames, use_ema=use_ema, output_root=OUTPUT_ROOT, seed=int(seed), progress=progress, ) if video_path is None or not os.path.exists(video_path): raise gr.Error("No video generated. Check logs for details.") return video_path, logs # ------------------------------------------------------------------- # Gradio UI — updated title + example prompts # ------------------------------------------------------------------- examples = [ [ "A golden retriever runs across a beach, stops to pick up a red ball, then runs back toward the camera as waves crash behind it, cinematic lighting" ], [ "A small snowman slowly melting under the sun, gradually collapsing and turning into a puddle of water, realistic style" ], [ "A glass of red wine being poured, but the liquid turns into blue smoke as it fills the glass, surreal, highly detailed" ], [ "A futuristic city at sunset with flying cars, rendered in watercolor painting style, soft colors and visible brush strokes" ], [ "A slow cinematic zoom into a candle flame flickering in the dark, with subtle shadows moving on the wall" ], [ "A cat playing with a floating holographic butterfly, trying to catch it as it moves around, soft lighting" ], [ "A bustling medieval marketplace with people walking, merchants selling goods, and flags waving in the wind, detailed and lively" ], [ "A burning ice cube slowly melting while still on fire, high detail, realistic physics" ], [ "A butterfly emerging from a cocoon, slowly unfolding its wings and taking flight, macro shot" ], [ "Colorful ink swirling in water, forming and dissolving shapes continuously, slow motion" ], ] with gr.Blocks(title="Reward Forcing — Text-to-Video Demo") as demo: gr.Markdown( """ # 🎬 Reward Forcing — Text-to-Video Demo Generate short videos from text prompts using a model trained with the **Reward Forcing** method. Reward Forcing is a recent research technique that improves how well a video model follows a written description by guiding training with learned reward signals. You can learn more here: https://reward-forcing.github.io 👉 Type a prompt, click **Generate**, and the video will appear below. Longer and more detailed prompts usually produce better results. 💡 This model performs best on **detailed prompts with multiple actions or transformations**. 🎲 Set a fixed seed for reproducible results, or use **-1** for a random seed each time. > ⏳ The first run may take a little longer while the model loads — generation is faster afterwards. """ ) with gr.Row(): prompt_in = gr.Textbox( label="Prompt", placeholder="A cinematic shot of late-summer wheat fields moving in the wind...", lines=4, ) with gr.Row(): duration = gr.Radio( ["5s (21 frames)", "30s (120 frames)"], value="5s (21 frames)", label="Duration", ) use_ema = gr.Checkbox(value=True, label="Use EMA weights (--use_ema)") seed_in = gr.Number(value=-1, label="Seed (-1 = random)", precision=0) generate_btn = gr.Button("🚀 Generate Video", variant="primary") with gr.Row(): video_out = gr.Video(label="Generated Video") logs_out = gr.Textbox(label="Logs", lines=12, interactive=False) gr.Examples( examples=examples, inputs=prompt_in, label="Example prompts", ) generate_btn.click( fn=gradio_generate, inputs=[prompt_in, duration, use_ema, seed_in], outputs=[video_out, logs_out], ) demo.queue() if __name__ == "__main__": demo.launch(ssr_mode=False)