File size: 4,459 Bytes
d74182a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
from transformers import CLIPVisionModelWithProjection,CLIPImageProcessor
from diffusers.utils import load_image
import os,sys

from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import StableDiffusionXLControlNetImg2ImgPipeline
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from kolors.models.controlnet import ControlNetModel

from diffusers import  AutoencoderKL
from kolors.models.unet_2d_condition import UNet2DConditionModel

from diffusers import EulerDiscreteScheduler
from PIL import Image
import numpy as np
import cv2

from annotator.midas import MidasDetector
from annotator.dwpose import DWposeDetector
from annotator.util import resize_image,HWC3

root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

def process_canny_condition( image, canny_threods=[100,200] ):
    np_image = image.copy()
    np_image = cv2.Canny(np_image, canny_threods[0], canny_threods[1])
    np_image = np_image[:, :, None]
    np_image = np.concatenate([np_image, np_image, np_image], axis=2)
    np_image = HWC3(np_image)
    return Image.fromarray(np_image)


model_midas = None
def process_depth_condition_midas(img, res = 1024):
    h,w,_ = img.shape
    img = resize_image(HWC3(img), res)
    global model_midas
    if model_midas is None:
        model_midas = MidasDetector()
    
    result = HWC3( model_midas(img) )
    result = cv2.resize( result, (w,h) )
    return Image.fromarray(result)


model_dwpose = None
def process_dwpose_condition( image, res=1024 ):
    h,w,_ = image.shape
    img = resize_image(HWC3(image), res)
    global model_dwpose
    if model_dwpose is None:
        model_dwpose = DWposeDetector()
    out_res, out_img = model_dwpose(image) 
    result = HWC3( out_img )
    result = cv2.resize( result, (w,h) )
    return Image.fromarray(result)


def infer( image_path , prompt, model_type = 'Canny' ):

    ckpt_dir = f'{root_dir}/weights/Kolors'
    text_encoder = ChatGLMModel.from_pretrained(
        f'{ckpt_dir}/text_encoder',
        torch_dtype=torch.float16).half()
    tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
    vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half()
    scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
    unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half()

    control_path = f'{root_dir}/weights/Kolors-ControlNet-{model_type}'
    controlnet = ControlNetModel.from_pretrained( control_path , revision=None).half()

    pipe = StableDiffusionXLControlNetImg2ImgPipeline(
            vae=vae,
            controlnet = controlnet,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            force_zeros_for_empty_prompt=False
            )

    pipe = pipe.to("cuda")
    pipe.enable_model_cpu_offload()
    
    negative_prompt = 'nsfw,脸部阴影,低分辨率,jpeg伪影、模糊、糟糕,黑脸,霓虹灯'
    
    MAX_IMG_SIZE=1024
    controlnet_conditioning_scale = 0.7
    control_guidance_end = 0.9
    strength = 1.0

    basename = image_path.rsplit('/',1)[-1].rsplit('.',1)[0]

    init_image = Image.open( image_path )
    
    init_image = resize_image( init_image,  MAX_IMG_SIZE)
    if model_type == 'Canny':
        condi_img = process_canny_condition( np.array(init_image) )
    elif model_type == 'Depth':
        condi_img = process_depth_condition_midas( np.array(init_image), MAX_IMG_SIZE )
    elif model_type == 'Pose':
        condi_img = process_dwpose_condition( np.array(init_image), MAX_IMG_SIZE)

    generator = torch.Generator(device="cpu").manual_seed(66)
    image = pipe(
        prompt= prompt ,
        image = init_image,
        controlnet_conditioning_scale = controlnet_conditioning_scale,
        control_guidance_end = control_guidance_end, 
        strength= strength , 
        control_image = condi_img,
        negative_prompt= negative_prompt , 
        num_inference_steps= 50 , 
        guidance_scale= 6.0,
        num_images_per_prompt=1,
        generator=generator,
    ).images[0]
    
    condi_img.save( f'{root_dir}/controlnet/outputs/{model_type}_{basename}_condition.jpg' )
    image.save(f'{root_dir}/controlnet/outputs/{model_type}_{basename}.jpg')


if __name__ == '__main__':
    import fire
    fire.Fire(infer)