| | import gc
|
| |
|
| | import numpy as np
|
| | import PIL.Image
|
| | import torch
|
| | import torchvision
|
| | from controlnet_aux import (
|
| | CannyDetector,
|
| | ContentShuffleDetector,
|
| | HEDdetector,
|
| | LineartAnimeDetector,
|
| | LineartDetector,
|
| | MidasDetector,
|
| | MLSDdetector,
|
| | NormalBaeDetector,
|
| | OpenposeDetector,
|
| | PidiNetDetector,
|
| | )
|
| | from controlnet_aux.util import HWC3
|
| |
|
| | from cv_utils import resize_image
|
| | from depth_estimator import DepthEstimator
|
| | from image_segmentor import ImageSegmentor
|
| |
|
| | from kornia.core import Tensor
|
| |
|
| |
|
| |
|
| |
|
| | Midas = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
| | MLSD = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
| | Canny = CannyDetector()
|
| | OPENPOSE = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
| |
|
| |
|
| | class Preprocessor:
|
| | MODEL_ID = "lllyasviel/Annotators"
|
| |
|
| | def __init__(self):
|
| | self.model = None
|
| | self.name = ""
|
| |
|
| | def load(self, name: str) -> None:
|
| | if name == self.name:
|
| | return
|
| |
|
| | if name == "Midas":
|
| | self.model = Midas
|
| | elif name == "MLSD":
|
| | self.model =MLSD
|
| | elif name == "Openpose":
|
| | self.model = OPENPOSE
|
| | elif name == "Canny":
|
| | self.model = Canny
|
| | else:
|
| | raise ValueError
|
| | torch.cuda.empty_cache()
|
| | gc.collect()
|
| | self.name = name
|
| |
|
| | def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
|
| | if self.name == "Canny" or self.name == "MLSD":
|
| | detect_resolution = kwargs.pop("detect_resolution")
|
| | image_resolution = kwargs.pop("image_resolution", 512)
|
| | image = np.array(image)
|
| | image = HWC3(image)
|
| | image = resize_image(image, resolution=detect_resolution)
|
| | image = self.model(image, **kwargs)
|
| | image = np.array(image)
|
| | image = HWC3(image)
|
| | image = resize_image(image, resolution=image_resolution)
|
| | return PIL.Image.fromarray(image).convert('RGB')
|
| |
|
| | else:
|
| | detect_resolution = kwargs.pop("detect_resolution", 512)
|
| | image_resolution = kwargs.pop("image_resolution", 512)
|
| | image = np.array(image)
|
| | image = HWC3(image)
|
| | image = resize_image(image, resolution=detect_resolution)
|
| | image = self.model(image, **kwargs)
|
| | image = np.array(image)
|
| | image = HWC3(image)
|
| | image = resize_image(image, resolution=image_resolution)
|
| | return PIL.Image.fromarray(image)
|
| |
|