| import os |
|
|
| import torch |
| from facexlib.utils.face_restoration_helper import FaceRestoreHelper |
| from gfpgan import GFPGANv1Clean, GFPGANer |
| from torch.hub import get_dir |
|
|
|
|
| class MyGFPGANer(GFPGANer): |
| """Helper for restoration with GFPGAN. |
| |
| It will detect and crop faces, and then resize the faces to 512x512. |
| GFPGAN is used to restored the resized faces. |
| The background is upsampled with the bg_upsampler. |
| Finally, the faces will be pasted back to the upsample background image. |
| |
| Args: |
| model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically). |
| upscale (float): The upscale of the final output. Default: 2. |
| arch (str): The GFPGAN architecture. Option: clean | original. Default: clean. |
| channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. |
| bg_upsampler (nn.Module): The upsampler for the background. Default: None. |
| """ |
|
|
| def __init__( |
| self, |
| model_path, |
| upscale=2, |
| arch="clean", |
| channel_multiplier=2, |
| bg_upsampler=None, |
| device=None, |
| ): |
| self.upscale = upscale |
| self.bg_upsampler = bg_upsampler |
|
|
| |
| self.device = ( |
| torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| if device is None |
| else device |
| ) |
| |
| if arch == "clean": |
| self.gfpgan = GFPGANv1Clean( |
| out_size=512, |
| num_style_feat=512, |
| channel_multiplier=channel_multiplier, |
| decoder_load_path=None, |
| fix_decoder=False, |
| num_mlp=8, |
| input_is_latent=True, |
| different_w=True, |
| narrow=1, |
| sft_half=True, |
| ) |
| elif arch == "RestoreFormer": |
| from gfpgan.archs.restoreformer_arch import RestoreFormer |
|
|
| self.gfpgan = RestoreFormer() |
|
|
| hub_dir = get_dir() |
| model_dir = os.path.join(hub_dir, "checkpoints") |
|
|
| |
| self.face_helper = FaceRestoreHelper( |
| upscale, |
| face_size=512, |
| crop_ratio=(1, 1), |
| det_model="retinaface_resnet50", |
| save_ext="png", |
| use_parse=True, |
| device=self.device, |
| model_rootpath=model_dir, |
| ) |
|
|
| loadnet = torch.load(model_path) |
| if "params_ema" in loadnet: |
| keyname = "params_ema" |
| else: |
| keyname = "params" |
| self.gfpgan.load_state_dict(loadnet[keyname], strict=True) |
| self.gfpgan.eval() |
| self.gfpgan = self.gfpgan.to(self.device) |
|
|