Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import gc | |
| import json | |
| import pathlib | |
| import sys | |
| import gradio as gr | |
| import PIL.Image | |
| import torch | |
| from diffusers import StableDiffusionPipeline | |
| from peft import LoraModel, LoraConfig, set_peft_model_state_dict | |
| class InferencePipeline: | |
| def __init__(self): | |
| self.pipe = None | |
| self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| self.weight_path = None | |
| def clear(self) -> None: | |
| self.weight_path = None | |
| del self.pipe | |
| self.pipe = None | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| def get_lora_weight_path(name: str) -> pathlib.Path: | |
| curr_dir = pathlib.Path(__file__).parent | |
| return curr_dir / name, curr_dir / f'{name.replace(".pt", "_config.json")}' | |
| def load_and_set_lora_ckpt(self, pipe, weight_path, config_path, dtype): | |
| with open(config_path, "r") as f: | |
| lora_config = json.load(f) | |
| lora_checkpoint_sd = torch.load(weight_path, map_location=self.device) | |
| unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k} | |
| text_encoder_lora_ds = { | |
| k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k | |
| } | |
| unet_config = LoraConfig(**lora_config["peft_config"]) | |
| pipe.unet = LoraModel(unet_config, pipe.unet) | |
| set_peft_model_state_dict(pipe.unet, unet_lora_ds) | |
| if "text_encoder_peft_config" in lora_config: | |
| text_encoder_config = LoraConfig(**lora_config["text_encoder_peft_config"]) | |
| pipe.text_encoder = LoraModel(text_encoder_config, pipe.text_encoder) | |
| set_peft_model_state_dict(pipe.text_encoder, text_encoder_lora_ds) | |
| if dtype in (torch.float16, torch.bfloat16): | |
| pipe.unet.half() | |
| pipe.text_encoder.half() | |
| pipe.to(self.device) | |
| return pipe | |
| def load_pipe(self, model_id: str, lora_filename: str) -> None: | |
| weight_path, config_path = self.get_lora_weight_path(lora_filename) | |
| if weight_path == self.weight_path: | |
| return | |
| pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(self.device) | |
| pipe = pipe.to(self.device) | |
| pipe = self.load_and_set_lora_ckpt(pipe, weight_path, config_path, torch.float16) | |
| self.pipe = pipe | |
| def run( | |
| self, | |
| base_model: str, | |
| lora_weight_name: str, | |
| prompt: str, | |
| negative_prompt: str, | |
| seed: int, | |
| n_steps: int, | |
| guidance_scale: float, | |
| ) -> PIL.Image.Image: | |
| if not torch.cuda.is_available(): | |
| raise gr.Error("CUDA is not available.") | |
| self.load_pipe(base_model, lora_weight_name) | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| out = self.pipe( | |
| prompt, | |
| num_inference_steps=n_steps, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| negative_prompt=negative_prompt if negative_prompt else None, | |
| ) # type: ignore | |
| return out.images[0] | |