| import cv2 |
| import numpy as np |
| import torch |
| from loguru import logger |
|
|
| from iopaint.helper import download_model |
| from iopaint.plugins.base_plugin import BasePlugin |
| from iopaint.schema import RunPluginRequest, RealESRGANModel |
|
|
|
|
| class RealESRGANUpscaler(BasePlugin): |
| name = "RealESRGAN" |
| support_gen_image = True |
|
|
| def __init__(self, name, device, no_half=False): |
| super().__init__() |
| self.model_name = name |
| self.device = device |
| self.no_half = no_half |
| self._init_model(name) |
|
|
| def _init_model(self, name): |
| from basicsr.archs.rrdbnet_arch import RRDBNet |
| from realesrgan import RealESRGANer |
| from realesrgan.archs.srvgg_arch import SRVGGNetCompact |
|
|
| REAL_ESRGAN_MODELS = { |
| RealESRGANModel.realesr_general_x4v3: { |
| "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", |
| "scale": 4, |
| "model": lambda: SRVGGNetCompact( |
| num_in_ch=3, |
| num_out_ch=3, |
| num_feat=64, |
| num_conv=32, |
| upscale=4, |
| act_type="prelu", |
| ), |
| "model_md5": "91a7644643c884ee00737db24e478156", |
| }, |
| RealESRGANModel.RealESRGAN_x4plus: { |
| "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", |
| "scale": 4, |
| "model": lambda: RRDBNet( |
| num_in_ch=3, |
| num_out_ch=3, |
| num_feat=64, |
| num_block=23, |
| num_grow_ch=32, |
| scale=4, |
| ), |
| "model_md5": "99ec365d4afad750833258a1a24f44ca", |
| }, |
| RealESRGANModel.RealESRGAN_x4plus_anime_6B: { |
| "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", |
| "scale": 4, |
| "model": lambda: RRDBNet( |
| num_in_ch=3, |
| num_out_ch=3, |
| num_feat=64, |
| num_block=6, |
| num_grow_ch=32, |
| scale=4, |
| ), |
| "model_md5": "d58ce384064ec1591c2ea7b79dbf47ba", |
| }, |
| } |
| if name not in REAL_ESRGAN_MODELS: |
| raise ValueError(f"Unknown RealESRGAN model name: {name}") |
| model_info = REAL_ESRGAN_MODELS[name] |
|
|
| model_path = download_model(model_info["url"], model_info["model_md5"]) |
| logger.info(f"RealESRGAN model path: {model_path}") |
|
|
| self.model = RealESRGANer( |
| scale=model_info["scale"], |
| model_path=model_path, |
| model=model_info["model"](), |
| half=True if "cuda" in str(self.device) and not self.no_half else False, |
| tile=512, |
| tile_pad=10, |
| pre_pad=10, |
| device=self.device, |
| ) |
|
|
| def switch_model(self, new_model_name: str): |
| if self.model_name == new_model_name: |
| return |
| self._init_model(new_model_name) |
| self.model_name = new_model_name |
|
|
| def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: |
| bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) |
| logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}") |
| result = self.forward(bgr_np_img, req.scale) |
| logger.info(f"RealESRGAN output shape: {result.shape}") |
| return result |
|
|
| @torch.inference_mode() |
| def forward(self, bgr_np_img, scale: float): |
| |
| upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0] |
| return upsampled |
|
|
| def check_dep(self): |
| try: |
| import realesrgan |
| except ImportError: |
| return "RealESRGAN is not installed, please install it first. pip install realesrgan" |
|
|