coralLight commited on
Commit
b083acf
·
1 Parent(s): 215a1c6
Files changed (6) hide show
  1. NoiseTransformer.py +26 -0
  2. README.md +17 -3
  3. SVDNoiseUnet.py +430 -0
  4. app.py +755 -54
  5. free_lunch_utils.py +304 -0
  6. requirements.txt +10 -6
NoiseTransformer.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from torch.nn import functional as F
4
+ from timm import create_model
5
+
6
+
7
+ __all__ = ['NoiseTransformer']
8
+
9
+ class NoiseTransformer(nn.Module):
10
+ def __init__(self, resolution=(128,96)):
11
+ super().__init__()
12
+ self.upsample = lambda x: F.interpolate(x, [224,224])
13
+ self.downsample = lambda x: F.interpolate(x, [resolution[0],resolution[1]])
14
+ self.upconv = nn.Conv2d(7,4,(1,1),(1,1),(0,0))
15
+ self.downconv = nn.Conv2d(4,3,(1,1),(1,1),(0,0))
16
+ # self.upconv = nn.Conv2d(7,4,(1,1),(1,1),(0,0))
17
+ self.swin = create_model("swin_tiny_patch4_window7_224",pretrained=True)
18
+
19
+
20
+ def forward(self, x, residual=False):
21
+ if residual:
22
+ x = self.upconv(self.downsample(self.swin.forward_features(self.downconv(self.upsample(x))))) + x
23
+ else:
24
+ x = self.upconv(self.downsample(self.swin.forward_features(self.downconv(self.upsample(x)))))
25
+
26
+ return x
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Hyperparameters Are All You Need Sd Version
3
  emoji: 🖼
4
  colorFrom: purple
5
  colorTo: red
@@ -8,7 +8,21 @@ sdk_version: 5.44.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: a few steps training-free diffusion ODE solver
 
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Hyperparameters-are-all-you-need
3
  emoji: 🖼
4
  colorFrom: purple
5
  colorTo: red
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ short_description: training-free few step diffusion solver
12
+ header: default
13
  ---
14
 
15
+ **Abstract:** The diffusion model is a state-of-the-art generative model that generates an image by applying a neural network iteratively. Moreover, this generation process is regarded as an algorithm solving an ordinary differential equation or a stochastic differential equation. Based on the analysis of the truncation error of the diffusion ODE and SDE, our study proposes a training-free algorithm that generates high-quality 512 x 512 and 1024 x 1024 images in eight steps, with flexible guidance scales. To the best of my knowledge, our algorithm is the first one that samples a 1024 x 1024 resolution image in 8 steps with an FID performance comparable to that of the latest distillation model, but without additional training. Meanwhile, our algorithm can also generate a 512 x 512 image in 8 steps, and its FID performance is better than the inference result using state-of-the-art ODE solver DPM++ 2m in 20 steps. We validate our eight-step image generation algorithm using the COCO 2014, COCO 2017, and LAION datasets. And our best FID performance is 15.7, 22.35, and 17.52. While the FID performance of DPM++2m is 17.3, 23.75, and 17.33. Further, it also outperforms the state-of-the-art AMED-plugin solver, whose FID performance is 19.07, 25.50, and 18.06. We also apply the algorithm in five-step inference without additional training, for which the best FID performance in the datasets mentioned above is 19.18, 23.24, and 19.61, respectively, and is comparable to the performance of the state-of-the-art AMED Pulgin solver in eight steps, SDXL-turbo in four steps, and the state-of-the-art diffusion distillation model Flash Diffusion in five steps. We also validate our algorithm in synthesizing 1024 * 1024 images within 6 steps, whose FID performance only has a limited distance to the latest distillation algorithm.
16
+
17
+ This is a demo is a simplified version of the approach described in the paper, ["Hyperparameters are all you need: Using five-step inference for an original diffusion model to generate images comparable to the latest distillation model"](https://arxiv.org/abs/2510.02390)
18
+
19
+ ```
20
+ @misc{hyper,
21
+ title={Hyperparameters are all you need: Using five-step inference for an original diffusion model to generate images comparable to the latest distillation model},
22
+ author={Zilai Li},
23
+ year={2025},
24
+ eprint={2510.02390},
25
+ archivePrefix={arXiv},
26
+ primaryClass={eess.IV}
27
+ }
28
+ ```
SVDNoiseUnet.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import einops
4
+
5
+ from torch.nn import functional as F
6
+ from torch.jit import Final
7
+ from timm.layers import use_fused_attn
8
+ from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, get_act_layer
9
+ from abc import abstractmethod
10
+ from NoiseTransformer import NoiseTransformer
11
+ from einops import rearrange
12
+ __all__ = ['SVDNoiseUnet', 'SVDNoiseUnet_Concise']
13
+
14
+ class Attention(nn.Module):
15
+ fused_attn: Final[bool]
16
+
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ num_heads: int = 8,
21
+ qkv_bias: bool = False,
22
+ qk_norm: bool = False,
23
+ attn_drop: float = 0.,
24
+ proj_drop: float = 0.,
25
+ norm_layer: nn.Module = nn.LayerNorm,
26
+ ) -> None:
27
+ super().__init__()
28
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
29
+ self.num_heads = num_heads
30
+ self.head_dim = dim // num_heads
31
+ self.scale = self.head_dim ** -0.5
32
+ self.fused_attn = use_fused_attn()
33
+
34
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
35
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
36
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
37
+ self.attn_drop = nn.Dropout(attn_drop)
38
+ self.proj = nn.Linear(dim, dim)
39
+ self.proj_drop = nn.Dropout(proj_drop)
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ B, N, C = x.shape
43
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
44
+ q, k, v = qkv.unbind(0)
45
+ q, k = self.q_norm(q), self.k_norm(k)
46
+
47
+ if self.fused_attn:
48
+ x = F.scaled_dot_product_attention(
49
+ q, k, v,
50
+ dropout_p=self.attn_drop.p if self.training else 0.,
51
+ )
52
+ else:
53
+ q = q * self.scale
54
+ attn = q @ k.transpose(-2, -1)
55
+ attn = attn.softmax(dim=-1)
56
+ attn = self.attn_drop(attn)
57
+ x = attn @ v
58
+
59
+ x = x.transpose(1, 2).reshape(B, N, C)
60
+ x = self.proj(x)
61
+ x = self.proj_drop(x)
62
+ return x
63
+
64
+
65
+ class SVDNoiseUnet(nn.Module):
66
+ def __init__(self, in_channels=4, out_channels=4, resolution=(128,96)): # resolution = size // 8
67
+ super(SVDNoiseUnet, self).__init__()
68
+
69
+ _in_1 = int(resolution[0] * in_channels // 2)
70
+ _out_1 = int(resolution[0] * out_channels // 2)
71
+
72
+ _in_2 = int(resolution[1] * in_channels // 2)
73
+ _out_2 = int(resolution[1] * out_channels // 2)
74
+ self.mlp1 = nn.Sequential(
75
+ nn.Linear(_in_1, 64),
76
+ nn.ReLU(inplace=True),
77
+ nn.Linear(64, _out_1),
78
+ )
79
+ self.mlp2 = nn.Sequential(
80
+ nn.Linear(_in_2, 64),
81
+ nn.ReLU(inplace=True),
82
+ nn.Linear(64, _out_2),
83
+ )
84
+
85
+ self.mlp3 = nn.Sequential(
86
+ nn.Linear(_in_2, _out_2),
87
+ )
88
+
89
+ self.attention = Attention(_out_2)
90
+
91
+ self.bn = nn.BatchNorm1d(256)
92
+ self.bn2 = nn.BatchNorm1d(192)
93
+
94
+ self.mlp4 = nn.Sequential(
95
+ nn.Linear(_out_2, 1024),
96
+ nn.ReLU(inplace=True),
97
+ nn.Linear(1024, _out_2),
98
+ )
99
+ self.ffn = nn.Sequential(
100
+ nn.Linear(256, 384), # Expand
101
+ nn.ReLU(inplace=True),
102
+ nn.Linear(384, 192) # Reduce to target size
103
+ )
104
+ self.ffn2 = nn.Sequential(
105
+ nn.Linear(256, 384), # Expand
106
+ nn.ReLU(inplace=True),
107
+ nn.Linear(384, 192) # Reduce to target size
108
+ )
109
+ # self.adaptive_pool = nn.AdaptiveAvgPool2d((256, 192))
110
+
111
+ def forward(self, x, residual=False):
112
+ b, c, h, w = x.shape
113
+ x = einops.rearrange(x, "b (a c)h w ->b (a h)(c w)", a=2,c=2) # x -> [1, 256, 256]
114
+ U, s, V = torch.linalg.svd(x) # U->[b 256 256], s-> [b 256], V->[b 256 256]
115
+ U_T = U.permute(0, 2, 1)
116
+ U_out = self.ffn(self.mlp1(U_T))
117
+ U_out = self.bn(U_out)
118
+ U_out = U_out.transpose(1, 2)
119
+ U_out = self.ffn2(U_out) # [b, 256, 256] -> [b, 256, 192]
120
+ U_out = self.bn2(U_out)
121
+ U_out = U_out.transpose(1, 2)
122
+ # U_out = self.bn(U_out)
123
+ V_out = self.mlp2(V)
124
+ s_out = self.mlp3(s).unsqueeze(1) # s -> [b, 1, 256] => [b, 256, 256]
125
+ out = U_out + V_out + s_out
126
+ # print(out.size())
127
+ out = out.squeeze(1)
128
+ out = self.attention(out).mean(1)
129
+ out = self.mlp4(out) + s
130
+ diagonal_out = torch.diag_embed(out)
131
+ padded_diag = F.pad(diagonal_out, (0, 0, 0, 64), mode='constant', value=0) # Shape: [b, 1, 256, 192]
132
+ pred = U @ padded_diag @ V
133
+ return einops.rearrange(pred, "b (a h)(c w) -> b (a c) h w", a=2,c=2)
134
+
135
+ class SVDNoiseUnet64(nn.Module):
136
+ def __init__(self, in_channels=4, out_channels=4, resolution=64): # resolution = size // 8
137
+ super(SVDNoiseUnet64, self).__init__()
138
+
139
+ _in = int(resolution * in_channels // 2)
140
+ _out = int(resolution * out_channels // 2)
141
+ self.mlp1 = nn.Sequential(
142
+ nn.Linear(_in, 64),
143
+ nn.ReLU(inplace=True),
144
+ nn.Linear(64, _out),
145
+ )
146
+ self.mlp2 = nn.Sequential(
147
+ nn.Linear(_in, 64),
148
+ nn.ReLU(inplace=True),
149
+ nn.Linear(64, _out),
150
+ )
151
+
152
+ self.mlp3 = nn.Sequential(
153
+ nn.Linear(_in, _out),
154
+ )
155
+
156
+ self.attention = Attention(_out)
157
+
158
+ self.bn = nn.BatchNorm2d(_out)
159
+
160
+ self.mlp4 = nn.Sequential(
161
+ nn.Linear(_out, 1024),
162
+ nn.ReLU(inplace=True),
163
+ nn.Linear(1024, _out),
164
+ )
165
+
166
+ def forward(self, x, residual=False):
167
+ b, c, h, w = x.shape
168
+ x = einops.rearrange(x, "b (a c)h w ->b (a h)(c w)", a=2,c=2) # x -> [1, 256, 256]
169
+ U, s, V = torch.linalg.svd(x) # U->[b 256 256], s-> [b 256], V->[b 256 256]
170
+ U_T = U.permute(0, 2, 1)
171
+ out = self.mlp1(U_T) + self.mlp2(V) + self.mlp3(s).unsqueeze(1) # s -> [b, 1, 256] => [b, 256, 256]
172
+ out = self.attention(out).mean(1)
173
+ out = self.mlp4(out) + s
174
+ pred = U @ torch.diag_embed(out) @ V
175
+ return einops.rearrange(pred, "b (a h)(c w) -> b (a c) h w", a=2,c=2)
176
+
177
+
178
+
179
+ class SVDNoiseUnet128(nn.Module):
180
+ def __init__(self, in_channels=4, out_channels=4, resolution=128): # resolution = size // 8
181
+ super(SVDNoiseUnet128, self).__init__()
182
+
183
+ _in = int(resolution * in_channels // 2)
184
+ _out = int(resolution * out_channels // 2)
185
+ self.mlp1 = nn.Sequential(
186
+ nn.Linear(_in, 64),
187
+ nn.ReLU(inplace=True),
188
+ nn.Linear(64, _out),
189
+ )
190
+ self.mlp2 = nn.Sequential(
191
+ nn.Linear(_in, 64),
192
+ nn.ReLU(inplace=True),
193
+ nn.Linear(64, _out),
194
+ )
195
+
196
+ self.mlp3 = nn.Sequential(
197
+ nn.Linear(_in, _out),
198
+ )
199
+
200
+ self.attention = Attention(_out)
201
+
202
+ self.bn = nn.BatchNorm2d(_out)
203
+
204
+ self.mlp4 = nn.Sequential(
205
+ nn.Linear(_out, 1024),
206
+ nn.ReLU(inplace=True),
207
+ nn.Linear(1024, _out),
208
+ )
209
+
210
+ def forward(self, x, residual=False):
211
+ b, c, h, w = x.shape
212
+ x = einops.rearrange(x, "b (a c)h w ->b (a h)(c w)", a=2,c=2) # x -> [1, 256, 256]
213
+ U, s, V = torch.linalg.svd(x) # U->[b 256 256], s-> [b 256], V->[b 256 256]
214
+ U_T = U.permute(0, 2, 1)
215
+ out = self.mlp1(U_T) + self.mlp2(V) + self.mlp3(s).unsqueeze(1) # s -> [b, 1, 256] => [b, 256, 256]
216
+ out = self.attention(out).mean(1)
217
+ out = self.mlp4(out) + s
218
+ pred = U @ torch.diag_embed(out) @ V
219
+ return einops.rearrange(pred, "b (a h)(c w) -> b (a c) h w", a=2,c=2)
220
+
221
+
222
+
223
+ class SVDNoiseUnet_Concise(nn.Module):
224
+ def __init__(self, in_channels=4, out_channels=4, resolution=64):
225
+ super(SVDNoiseUnet_Concise, self).__init__()
226
+
227
+
228
+ from diffusers.models.normalization import AdaGroupNorm
229
+
230
+ class NPNet(nn.Module):
231
+ def __init__(self, model_id, pretrained_path=' ', device='cuda') -> None:
232
+ super(NPNet, self).__init__()
233
+
234
+ assert model_id in ['SD1.5', 'DreamShaper', 'DiT']
235
+
236
+ self.model_id = model_id
237
+ self.device = device
238
+ self.pretrained_path = pretrained_path
239
+
240
+ (
241
+ self.unet_svd,
242
+ self.unet_embedding,
243
+ self.text_embedding,
244
+ self._alpha,
245
+ self._beta
246
+ ) = self.get_model()
247
+ def save_model(self, save_path: str):
248
+ """
249
+ Save this NPNet so that get_model() can later reload it.
250
+ """
251
+ torch.save({
252
+ "unet_svd": self.unet_svd.state_dict(),
253
+ "unet_embedding": self.unet_embedding.state_dict(),
254
+ "embeeding": self.text_embedding.state_dict(), # matches get_model’s key
255
+ "alpha": self._alpha,
256
+ "beta": self._beta,
257
+ }, save_path)
258
+ print(f"NPNet saved to {save_path}")
259
+ def get_model(self):
260
+
261
+ unet_embedding = NoiseTransformer(resolution=(128,96)).to(self.device).to(torch.float32)
262
+ unet_svd = SVDNoiseUnet(resolution=(128,96)).to(self.device).to(torch.float32)
263
+
264
+ if self.model_id == 'DiT':
265
+ text_embedding = AdaGroupNorm(768 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
266
+ else:
267
+ text_embedding = AdaGroupNorm(768 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
268
+
269
+ # initialize random _alpha and _beta when no checkpoint is provided
270
+ _alpha = torch.randn(1, device=self.device)
271
+ _beta = torch.randn(1, device=self.device)
272
+
273
+ if '.pth' in self.pretrained_path:
274
+ gloden_unet = torch.load(self.pretrained_path)
275
+ unet_svd.load_state_dict(gloden_unet["unet_svd"],strict=True)
276
+ unet_embedding.load_state_dict(gloden_unet["unet_embedding"],strict=True)
277
+ text_embedding.load_state_dict(gloden_unet["embeeding"],strict=True)
278
+ _alpha = gloden_unet["alpha"]
279
+ _beta = gloden_unet["beta"]
280
+
281
+ print("Load Successfully!")
282
+
283
+ return unet_svd, unet_embedding, text_embedding, _alpha, _beta
284
+
285
+ else:
286
+ return unet_svd, unet_embedding, text_embedding, _alpha, _beta
287
+
288
+
289
+ def forward(self, initial_noise, prompt_embeds):
290
+
291
+ prompt_embeds = prompt_embeds.float().view(prompt_embeds.shape[0], -1)
292
+ text_emb = self.text_embedding(initial_noise.float(), prompt_embeds)
293
+
294
+ encoder_hidden_states_svd = initial_noise
295
+ encoder_hidden_states_embedding = initial_noise + text_emb
296
+
297
+ golden_embedding = self.unet_embedding(encoder_hidden_states_embedding.float())
298
+
299
+ golden_noise = self.unet_svd(encoder_hidden_states_svd.float()) + (
300
+ 2 * torch.sigmoid(self._alpha) - 1) * text_emb + self._beta * golden_embedding
301
+
302
+ return golden_noise
303
+
304
+
305
+ class NPNet64(nn.Module):
306
+ def __init__(self, model_id, pretrained_path=' ', device='cuda') -> None:
307
+ super(NPNet64, self).__init__()
308
+ self.model_id = model_id
309
+ self.device = device
310
+ self.pretrained_path = pretrained_path
311
+
312
+ (
313
+ self.unet_svd,
314
+ self.unet_embedding,
315
+ self.text_embedding,
316
+ self._alpha,
317
+ self._beta
318
+ ) = self.get_model()
319
+
320
+ def save_model(self, save_path: str):
321
+ """
322
+ Save this NPNet so that get_model() can later reload it.
323
+ """
324
+ torch.save({
325
+ "unet_svd": self.unet_svd.state_dict(),
326
+ "unet_embedding": self.unet_embedding.state_dict(),
327
+ "embeeding": self.text_embedding.state_dict(), # matches get_model’s key
328
+ "alpha": self._alpha,
329
+ "beta": self._beta,
330
+ }, save_path)
331
+ print(f"NPNet saved to {save_path}")
332
+
333
+ def get_model(self):
334
+
335
+ unet_embedding = NoiseTransformer(resolution=(64,64)).to(self.device).to(torch.float32)
336
+ unet_svd = SVDNoiseUnet64(resolution=64).to(self.device).to(torch.float32)
337
+ _alpha = torch.randn(1, device=self.device)
338
+ _beta = torch.randn(1, device=self.device)
339
+
340
+ text_embedding = AdaGroupNorm(768 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
341
+
342
+
343
+ if '.pth' in self.pretrained_path:
344
+ gloden_unet = torch.load(self.pretrained_path)
345
+ unet_svd.load_state_dict(gloden_unet["unet_svd"])
346
+ unet_embedding.load_state_dict(gloden_unet["unet_embedding"])
347
+ text_embedding.load_state_dict(gloden_unet["embeeding"])
348
+ _alpha = gloden_unet["alpha"]
349
+ _beta = gloden_unet["beta"]
350
+
351
+ print("Load Successfully!")
352
+
353
+ return unet_svd, unet_embedding, text_embedding, _alpha, _beta
354
+
355
+
356
+ def forward(self, initial_noise, prompt_embeds):
357
+
358
+ prompt_embeds = prompt_embeds.float().view(prompt_embeds.shape[0], -1)
359
+ text_emb = self.text_embedding(initial_noise.float(), prompt_embeds)
360
+
361
+ encoder_hidden_states_svd = initial_noise
362
+ encoder_hidden_states_embedding = initial_noise + text_emb
363
+
364
+ golden_embedding = self.unet_embedding(encoder_hidden_states_embedding.float())
365
+
366
+ golden_noise = self.unet_svd(encoder_hidden_states_svd.float()) + (
367
+ 2 * torch.sigmoid(self._alpha) - 1) * text_emb + self._beta * golden_embedding
368
+
369
+ return golden_noise
370
+
371
+ class NPNet128(nn.Module):
372
+ def __init__(self, model_id, pretrained_path=True, device='cuda') -> None:
373
+ super(NPNet128, self).__init__()
374
+
375
+ assert model_id in ['SDXL', 'DreamShaper', 'DiT']
376
+
377
+ self.model_id = model_id
378
+ self.device = device
379
+ self.pretrained_path = pretrained_path
380
+
381
+ (
382
+ self.unet_svd,
383
+ self.unet_embedding,
384
+ self.text_embedding,
385
+ self._alpha,
386
+ self._beta
387
+ ) = self.get_model()
388
+
389
+ def get_model(self):
390
+
391
+ unet_embedding = NoiseTransformer(resolution=(128,128)).to(self.device).to(torch.float32)
392
+ unet_svd = SVDNoiseUnet128(resolution=128).to(self.device).to(torch.float32)
393
+
394
+ if self.model_id == 'DiT':
395
+ text_embedding = AdaGroupNorm(1024 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
396
+ else:
397
+ text_embedding = AdaGroupNorm(2048 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
398
+
399
+
400
+ if '.pth' in self.pretrained_path:
401
+ gloden_unet = torch.load(self.pretrained_path)
402
+ unet_svd.load_state_dict(gloden_unet["unet_svd"])
403
+ unet_embedding.load_state_dict(gloden_unet["unet_embedding"])
404
+ text_embedding.load_state_dict(gloden_unet["embeeding"])
405
+ _alpha = gloden_unet["alpha"]
406
+ _beta = gloden_unet["beta"]
407
+
408
+ print("Load Successfully!")
409
+
410
+ return unet_svd, unet_embedding, text_embedding, _alpha, _beta
411
+
412
+ else:
413
+ assert ("No Pretrained Weights Found!")
414
+
415
+
416
+ def forward(self, initial_noise, prompt_embeds):
417
+
418
+ prompt_embeds = prompt_embeds.float().view(prompt_embeds.shape[0], -1)
419
+ text_emb = self.text_embedding(initial_noise.float(), prompt_embeds)
420
+
421
+ encoder_hidden_states_svd = initial_noise
422
+ encoder_hidden_states_embedding = initial_noise + text_emb
423
+
424
+ golden_embedding = self.unet_embedding(encoder_hidden_states_embedding.float())
425
+
426
+ golden_noise = self.unet_svd(encoder_hidden_states_svd.float()) + (
427
+ 2 * torch.sigmoid(self._alpha) - 1) * text_emb + self._beta * golden_embedding
428
+
429
+ return golden_noise
430
+
app.py CHANGED
@@ -2,58 +2,763 @@ import gradio as gr
2
  import numpy as np
3
  import random
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
 
 
 
 
 
 
 
 
 
7
  import torch
8
-
 
 
 
 
 
 
 
 
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
 
 
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
 
23
 
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
  def infer(
26
  prompt,
27
  negative_prompt,
28
  seed,
29
  randomize_seed,
30
- width,
31
- height,
32
  guidance_scale,
33
  num_inference_steps,
34
  progress=gr.Progress(track_tqdm=True),
35
  ):
36
  if randomize_seed:
37
  seed = random.randint(0, MAX_SEED)
 
 
 
 
 
 
 
 
 
38
 
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
 
53
 
54
  examples = [
55
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
  "A delicious ceviche cheesecake slice",
58
  ]
59
 
@@ -66,7 +771,7 @@ css = """
66
 
67
  with gr.Blocks(css=css) as demo:
68
  with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
 
71
  with gr.Row():
72
  prompt = gr.Text(
@@ -79,7 +784,13 @@ with gr.Blocks(css=css) as demo:
79
 
80
  run_button = gr.Button("Run", scale=0, variant="primary")
81
 
82
- result = gr.Image(label="Result", show_label=False)
 
 
 
 
 
 
83
 
84
  with gr.Accordion("Advanced Settings", open=False):
85
  negative_prompt = gr.Text(
@@ -99,22 +810,15 @@ with gr.Blocks(css=css) as demo:
99
 
100
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
 
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
 
119
  with gr.Row():
120
  guidance_scale = gr.Slider(
@@ -122,15 +826,13 @@ with gr.Blocks(css=css) as demo:
122
  minimum=0.0,
123
  maximum=10.0,
124
  step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
  )
127
 
128
- num_inference_steps = gr.Slider(
 
 
129
  label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
  )
135
 
136
  gr.Examples(examples=examples, inputs=[prompt])
@@ -142,12 +844,11 @@ with gr.Blocks(css=css) as demo:
142
  negative_prompt,
143
  seed,
144
  randomize_seed,
145
- width,
146
- height,
147
  guidance_scale,
148
  num_inference_steps,
149
  ],
150
- outputs=[result, seed],
151
  )
152
 
153
  if __name__ == "__main__":
 
2
  import numpy as np
3
  import random
4
 
5
+ import spaces #[uncomment to use ZeroGPU]
6
+ from diffusers import (
7
+ StableDiffusionPipeline,
8
+ DPMSolverMultistepScheduler
9
+ )
10
+ from PIL import Image
11
+ # from huggingface_hub import login
12
+ from SVDNoiseUnet import NPNet64
13
+ import functools
14
+ import random
15
+ from free_lunch_utils import register_free_upblock2d, register_free_crossattn_upblock2d
16
  import torch
17
+ import torch.nn as nn
18
+ from einops import rearrange
19
+ from torchvision.utils import make_grid
20
+ import time
21
+ from pytorch_lightning import seed_everything
22
+ from torch import autocast
23
+ from contextlib import contextmanager, nullcontext
24
+ import accelerate
25
+ import torchsde
26
+ from tqdm import tqdm, trange
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ model_repo_id = "sinkinai/Counterfeit-V3.0" # Replace to the model you would like to use
29
 
30
+ precision_scope = autocast
31
+
32
+ def extract_into_tensor(a, t, x_shape):
33
+ b, *_ = t.shape
34
+ out = a.gather(-1, t)
35
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
36
 
 
 
37
 
38
+ def append_zero(x):
39
+ return torch.cat([x, x.new_zeros([1])])
40
+
41
+ # New helper to load a list-of-dicts preference JSON
42
+ # JSON schema: [ { 'human_preference': [int], 'prompt': str, 'file_path': [str] }, ... ]
43
+ def load_preference_json(json_path: str) -> list[dict]:
44
+ """Load records from a JSON file formatted as a list of preference dicts."""
45
+ with open(json_path, 'r') as f:
46
+ data = json.load(f)
47
+ return data
48
+
49
+ # New helper to extract just the prompts from the preference JSON
50
+ # Returns a flat list of all 'prompt' values
51
+
52
+ def extract_prompts_from_pref_json(json_path: str) -> list[str]:
53
+ """Load a JSON of preference records and return only the prompts."""
54
+ records = load_preference_json(json_path)
55
+ return [rec['prompt'] for rec in records]
56
+
57
+ # Example usage:
58
+ # prompts = extract_prompts_from_pref_json("path/to/preference.json")
59
+ # print(prompts)
60
+
61
+ def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu',need_append_zero = True):
62
+ """Constructs the noise schedule of Karras et al. (2022)."""
63
+ ramp = torch.linspace(0, 1, n)
64
+ min_inv_rho = sigma_min ** (1 / rho)
65
+ max_inv_rho = sigma_max ** (1 / rho)
66
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
67
+ return append_zero(sigmas).to(device) if need_append_zero else sigmas.to(device)
68
+
69
+ def extract_into_tensor(a, t, x_shape):
70
+ b, *_ = t.shape
71
+ out = a.gather(-1, t)
72
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
73
+
74
+ def append_zero(x):
75
+ return torch.cat([x, x.new_zeros([1])])
76
+
77
+ def append_dims(x, target_dims):
78
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
79
+ dims_to_append = target_dims - x.ndim
80
+ if dims_to_append < 0:
81
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
82
+ return x[(...,) + (None,) * dims_to_append]
83
+
84
+ class CFGDenoiser(nn.Module):
85
+ def __init__(self, model):
86
+ super().__init__()
87
+ self.inner_model = model
88
+
89
+ def get_golden_noised(self, x, sigma,sigma_nxt, uncond, cond, cond_scale,need_distill_uncond=False,tmp_list = [], uncond_list = [],noise_training_list={}):
90
+ x_in = torch.cat([x] * 2)
91
+ sigma_in = torch.cat([sigma] * 2)
92
+ sigma_nxt = torch.cat([sigma_nxt] * 2)
93
+ cond_in = torch.cat([uncond, cond])
94
+ _, ret = self.inner_model.get_customed_golden_noise(x_in
95
+ , 1.0
96
+ , sigma_in, sigma_nxt
97
+ , True
98
+ , encoder_hidden_states=cond_in.to(device=x.device, dtype=x.dtype)
99
+ , noise_training_list=noise_training_list).chunk(2)
100
+
101
+ return ret
102
+
103
+ def forward(self, x, sigma, uncond, cond, cond_scale,need_distill_uncond=False,tmp_list = [], uncond_list = []):
104
+
105
+ x_in = torch.cat([x] * 2)
106
+ sigma_in = torch.cat([sigma] * 2)
107
+ cond_in = torch.cat([uncond, cond])
108
+ uncond, cond = self.inner_model(x_in, sigma_in, tmp_list, encoder_hidden_states=cond_in.to(device=x.device, dtype=x.dtype)).chunk(2)
109
+ if need_distill_uncond:
110
+ uncond_list.append(uncond)
111
+ return uncond + (cond - uncond) * cond_scale
112
+
113
+
114
+ class DiscreteSchedule(nn.Module):
115
+ """A mapping between continuous noise levels (sigmas) and a list of discrete noise
116
+ levels."""
117
+
118
+ def __init__(self, sigmas, quantize):
119
+ super().__init__()
120
+ self.register_buffer('sigmas', sigmas)
121
+ self.register_buffer('log_sigmas', sigmas.log())
122
+ self.quantize = quantize
123
+
124
+ @property
125
+ def sigma_min(self):
126
+ return self.sigmas[0]
127
+
128
+ @property
129
+ def sigma_max(self):
130
+ return self.sigmas[-1]
131
+
132
+ def get_sigmas(self, n=None):
133
+ if n is None:
134
+ return append_zero(self.sigmas.flip(0))
135
+ t_max = len(self.sigmas) - 1
136
+ t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
137
+ return append_zero(self.t_to_sigma(t))
138
+
139
+ def sigma_to_t(self, sigma, quantize=None):
140
+ quantize = self.quantize if quantize is None else quantize
141
+ log_sigma = sigma.log()
142
+ dists = log_sigma - self.log_sigmas[:, None]
143
+ if quantize:
144
+ return dists.abs().argmin(dim=0).view(sigma.shape)
145
+ low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
146
+ high_idx = low_idx + 1
147
+ low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
148
+ w = (low - log_sigma) / (low - high)
149
+ w = w.clamp(0, 1)
150
+ t = (1 - w) * low_idx + w * high_idx
151
+ return t.view(sigma.shape)
152
+
153
+ def t_to_sigma(self, t):
154
+ t = t.float()
155
+ low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
156
+ log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
157
+ return log_sigma.exp()
158
+
159
+ class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
160
+ """A wrapper for discrete schedule DDPM models that output eps (the predicted
161
+ noise)."""
162
+
163
+ def __init__(self, model, alphas_cumprod, quantize = False):
164
+ super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
165
+ self.inner_model = model
166
+ # self.alphas_cumprod = alphas_cumprod.flip(0)
167
+ # Prepare a reversed version of alphas_cumprod for backward scheduling
168
+ self.register_buffer('alphas_cumprod', alphas_cumprod)
169
+ # self.register_buffer('alphas_cumprod_prev', append_zero(alphas_cumprod[:-1]))
170
+ self.sigma_data = 1.
171
+
172
+ def get_scalings(self, sigma):
173
+ c_out = -sigma
174
+ c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
175
+ return c_out, c_in
176
+
177
+ def get_eps(self, *args, **kwargs):
178
+ return self.inner_model(*args, **kwargs)
179
+
180
+ def get_alphact_and_sigma(self, timesteps, x_0, noise):
181
+ high_idx = torch.ceil(timesteps).int()
182
+ low_idx = torch.floor(timesteps).int()
183
+
184
+ nxt_ts = timesteps - timesteps.new_ones(timesteps.shape[0])
185
+
186
+ w = (timesteps - low_idx) / (high_idx - low_idx)
187
+
188
+ beta_1 = torch.tensor([1e-4],dtype=torch.float32)
189
+ beta_T = torch.tensor([0.02],dtype=torch.float32)
190
+ ddpm_max_step = torch.tensor([1000.0],dtype=torch.float32)
191
+
192
+ beta_t: torch.Tensor = (beta_T - beta_1) / ddpm_max_step * timesteps + beta_1
193
+ beta_t_prev: torch.Tensor = (beta_T - beta_1) / ddpm_max_step * nxt_ts + beta_1
194
+
195
+ alpha_t = beta_t.new_ones(beta_t.shape[0]) - beta_t
196
+ alpha_t_prev = beta_t.new_ones(beta_t.shape[0]) - beta_t_prev
197
+
198
+ dir_xt = (1. - alpha_t_prev).sqrt() * noise
199
+ x_prev = alpha_t_prev.sqrt() * x_0 + dir_xt + noise
200
+
201
+ alpha_cumprod_t_floor = self.alpha_cumprods[low_idx]
202
+ alpha_cumprod_t = (alpha_cumprod_t_floor * alpha_t) #.unsqueeze(1)
203
+ sqrt_alpha_cumprod_t = torch.sqrt(alpha_cumprod_t)
204
+ sigmas = torch.sqrt(alpha_cumprod_t.new_ones(alpha_cumprod_t.shape[0]) - alpha_cumprod_t)
205
+
206
+ # Fix broadcasting
207
+ sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t[:, None, None]
208
+ sigmas = sigmas[:, None, None]
209
+ return alpha_cumprod_t, sigmas
210
+
211
+ def get_c_ins(self,sigmas): # use to adjust loss
212
+ ret = list()
213
+ for sigma in sigmas:
214
+ _, c_in = self.get_scalings(sigma=sigma)
215
+ ret.append(c_in)
216
+ return ret
217
+
218
+
219
+ def get_customed_golden_noise(self, input, unconditional_guidance_scale:float, sigma, sigma_nxt, need_cond = True,noise_training_list = {}, **kwargs):
220
+ """User should ensure the input is a pure noise.
221
+ It's a customed golden noise, not the one purposed in the paper.
222
+ Maybe the one purposed in the paper should be implemented in the future."""
223
+ c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
224
+
225
+ sigma_fn = lambda t: t.neg().exp()
226
+ t_fn = lambda sigma: sigma.log().neg()
227
+ if need_cond:
228
+ _, tmp_img = (input * c_in).chunk(2)
229
+ else :
230
+ tmp_img = input * c_in
231
+
232
+ eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs).sample
233
+ x_0 = input + eps * c_out
234
+
235
+ x_0_uncond, x_0 = x_0.chunk(2)
236
+ x_0 = x_0_uncond + unconditional_guidance_scale * (x_0 - x_0_uncond)
237
+ x_0 = torch.cat([x_0] * 2)
238
+
239
+
240
+ t, t_next = t_fn(sigma), t_fn(sigma_nxt)
241
+ h = t_next - t
242
+
243
+ x = (append_dims(sigma_fn(t_next) / sigma_fn(t),input.ndim)) * input - append_dims((-h).expm1(),input.ndim) * x_0
244
+
245
+ c_out_2, c_in_2 = [append_dims(x, input.ndim) for x in self.get_scalings(sigma_nxt)]
246
+
247
+ eps_ret = self.get_eps(x * c_in_2, self.sigma_to_t(sigma_nxt), **kwargs).sample
248
+ org_golden_noise = True
249
+ x_1 = x + eps_ret * c_out_2
250
+ if org_golden_noise:
251
+ ret = (x + append_dims((-h).expm1(),input.ndim) * x_1) / (append_dims(sigma_fn(t_next) / sigma_fn(t),input.ndim))
252
+ else :
253
+ e_t_uncond_ret, e_t_ret = eps_ret.chunk(2)
254
+ e_t_ret = e_t_uncond_ret + 1.0 * (e_t_ret - e_t_uncond_ret)
255
+ e_t_ret = torch.cat([e_t_ret] * 2)
256
+ ret = x_0 + e_t_ret * append_dims(sigma,input.ndim)
257
+
258
+ noise_training_list['org_noise'] = input * c_in
259
+ noise_training_list['golden_noise'] = ret * c_in
260
+ # noise_training_list.append(tmp_dict)
261
+ return ret
262
+
263
+
264
+ def forward(self, input, sigma, tmp_list=[], need_cond = True, **kwargs):
265
+
266
+ c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
267
+ if need_cond:
268
+ _, tmp_img = (input * c_in).chunk(2)
269
+ else :
270
+ tmp_img = input * c_in
271
+ # print(tmp_img.max())
272
+ eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs).sample
273
+ tmp_x0 = input + eps * c_out
274
+ tmp_dict = {'tmp_z': tmp_img, 'tmp_x0': tmp_x0}
275
+ tmp_list.append(tmp_dict)
276
+ return tmp_x0 #input + eps * c_out
277
+
278
+ def get_special_sigmas_with_timesteps(self,timesteps):
279
+ low_idx, high_idx, w = np.minimum(np.floor(timesteps),999), np.minimum(np.ceil(timesteps),999), torch.from_numpy( timesteps - np.floor(timesteps))
280
+ self.alphas_cumprod = self.alphas_cumprod.to('cpu')
281
+ alphas = (1 - w) * self.alphas_cumprod[low_idx] + w * self.alphas_cumprod[high_idx]
282
+ return ((1 - alphas) / alphas) ** 0.5
283
+
284
+ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
285
+ """Calculates the noise level (sigma_down) to step down to and the amount
286
+ of noise to add (sigma_up) when doing an ancestral sampling step."""
287
+ if not eta:
288
+ return sigma_to, 0.
289
+ sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
290
+ sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
291
+ return sigma_down, sigma_up
292
+
293
+ def to_d(x, sigma, denoised):
294
+ """Converts a denoiser output to a Karras ODE derivative."""
295
+ return (x - denoised) / append_dims(sigma, x.ndim)
296
+
297
+ class BatchedBrownianTree:
298
+ """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
299
+ def __init__(self, x, t0, t1, seed=None, **kwargs):
300
+ t0, t1, self.sign = self.sort(t0, t1)
301
+ w0 = kwargs.get('w0', torch.zeros_like(x))
302
+ if seed is None:
303
+ seed = torch.randint(0, 2 ** 63 - 1, []).item()
304
+ self.batched = True
305
+ try:
306
+ assert len(seed) == x.shape[0]
307
+ w0 = w0[0]
308
+ except TypeError:
309
+ seed = [seed]
310
+ self.batched = False
311
+ self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
312
+
313
+ @staticmethod
314
+ def sort(a, b):
315
+ return (a, b, 1) if a < b else (b, a, -1)
316
+
317
+ def __call__(self, t0, t1):
318
+ t0, t1, sign = self.sort(t0, t1)
319
+ w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
320
+ return w if self.batched else w[0]
321
+
322
+ class BrownianTreeNoiseSampler:
323
+ """A noise sampler backed by a torchsde.BrownianTree.
324
+
325
+ Args:
326
+ x (Tensor): The tensor whose shape, device and dtype to use to generate
327
+ random samples.
328
+ sigma_min (float): The low end of the valid interval.
329
+ sigma_max (float): The high end of the valid interval.
330
+ seed (int or List[int]): The random seed. If a list of seeds is
331
+ supplied instead of a single integer, then the noise sampler will
332
+ use one BrownianTree per batch item, each with its own seed.
333
+ transform (callable): A function that maps sigma to the sampler's
334
+ internal timestep.
335
+ """
336
+ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
337
+ self.transform = transform
338
+ t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
339
+ self.tree = BatchedBrownianTree(x, t0, t1, seed)
340
+
341
+ def __call__(self, sigma, sigma_next):
342
+ t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
343
+ return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
344
+
345
+ @torch.no_grad()
346
+ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
347
+ """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
348
+ extra_args = {} if extra_args is None else extra_args
349
+ s_in = x.new_ones([x.shape[0]])
350
+ intermediates = {'x_inter': [x],'pred_x0': []}
351
+ for i in trange(len(sigmas) - 1, disable=disable):
352
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
353
+ eps = torch.randn_like(x) * s_noise
354
+ sigma_hat = sigmas[i] * (gamma + 1)
355
+ if gamma > 0:
356
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
357
+ denoised = model(x, sigma_hat * s_in, **extra_args)
358
+ d = to_d(x, sigma_hat, denoised)
359
+ if callback is not None:
360
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
361
+ dt = sigmas[i + 1] - sigma_hat
362
+ # Euler method
363
+ x = x + d * dt
364
+ intermediates['pred_x0'].append(denoised)
365
+ intermediates['x_inter'].append(x)
366
+ return intermediates, x
367
+
368
+ @torch.no_grad()
369
+ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
370
+ """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
371
+ extra_args = {} if extra_args is None else extra_args
372
+ s_in = x.new_ones([x.shape[0]])
373
+ intermediates = {'x_inter': [x],'pred_x0': []}
374
+ for i in trange(len(sigmas) - 1, disable=disable):
375
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
376
+ eps = torch.randn_like(x) * s_noise
377
+ sigma_hat = sigmas[i] * (gamma + 1)
378
+ if gamma > 0:
379
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
380
+ denoised = model(x, sigma_hat * s_in, **extra_args)
381
+ d = to_d(x, sigma_hat, denoised)
382
+ if callback is not None:
383
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
384
+ dt = sigmas[i + 1] - sigma_hat
385
+ if sigmas[i + 1] == 0:
386
+ # Euler method
387
+ x = x + d * dt
388
+ else:
389
+ # Heun's method
390
+ x_2 = x + d * dt
391
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
392
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
393
+ d_prime = (d + d_2) / 2
394
+ x = x + d_prime * dt
395
+ intermediates['pred_x0'].append(denoised_2)
396
+ intermediates['x_inter'].append(x)
397
+ return intermediates, x
398
+
399
+ @torch.no_grad()
400
+ def sample_dpmpp_ode(model
401
+ , x
402
+ , sigmas
403
+ , need_golden_noise = False
404
+ , extra_args=None
405
+ , callback=None
406
+ , disable=None
407
+ , start_free_step = 1
408
+ , pipe = None
409
+ , tmp_list=[]
410
+ , need_distill_uncond=False
411
+ , uncond_list=[]
412
+ , noise_training_list={}):
413
+ """DPM-Solver++."""
414
+ extra_args = {} if extra_args is None else extra_args
415
+ s_in = x.new_ones([x.shape[0]])
416
+ sigma_fn = lambda t: t.neg().exp()
417
+ t_fn = lambda sigma: sigma.log().neg()
418
+ old_denoised = None
419
+
420
+ if False: #need_golden_noise:
421
+ x = model.get_golden_noised(x=x,sigma=sigmas[0] * s_in, sigma_nxt=(sigmas[0] - 0.28) * s_in, noise_training_list=noise_training_list,**extra_args)
422
+ intermediates = {'x_inter': [x],'pred_x0': []}
423
+
424
+ for i in trange(len(sigmas) - 1, disable=disable):
425
+
426
+ # macs, params = profile(model, inputs=(x, sigmas[i] * s_in,*extra_args.values(),need_distill_uncond,tmp_list,uncond_list, ))
427
+ denoised = model(x, sigmas[i] * s_in, tmp_list=tmp_list,need_distill_uncond=need_distill_uncond,uncond_list=uncond_list, **extra_args)
428
+ if callback is not None:
429
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
430
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
431
+ h = t_next - t
432
+
433
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
434
+ intermediates['pred_x0'].append(denoised)
435
+ intermediates['x_inter'].append(x)
436
+
437
+ # print(denoised_d.max())
438
+
439
+ # intermediates['noise'].append(denoised_d)
440
+ return intermediates,x
441
+
442
+ @torch.no_grad()
443
+ def sample_dpmpp_sde(model
444
+ , need_golden_noise
445
+ , x
446
+ , sigmas
447
+ , extra_args=None
448
+ , callback=None
449
+ , disable=None
450
+ , eta=1.
451
+ , s_noise=1.
452
+ , noise_sampler=None
453
+ , r=1 / 2):
454
+ """DPM-Solver++ (stochastic)."""
455
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
456
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
457
+ extra_args = {} if extra_args is None else extra_args
458
+ s_in = x.new_ones([x.shape[0]])
459
+ sigma_fn = lambda t: t.neg().exp()
460
+ t_fn = lambda sigma: sigma.log().neg()
461
+ if need_golden_noise:
462
+ x = model.get_golden_noised(x=x,sigma=sigmas[0] * s_in, sigma_nxt=sigmas[1] * s_in,**extra_args)
463
+
464
+ intermediates = {'x_inter': [x],'pred_x0': []}
465
+
466
+ for i in trange(len(sigmas) - 1, disable=disable):
467
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
468
+ if callback is not None:
469
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
470
+ if sigmas[i + 1] == 0:
471
+ # Euler method
472
+ d = to_d(x, sigmas[i], denoised)
473
+ dt = sigmas[i + 1] - sigmas[i]
474
+ x = x + d * dt
475
+ intermediates['pred_x0'].append(denoised)
476
+ intermediates['x_inter'].append(x)
477
+ else:
478
+ # DPM-Solver++
479
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
480
+ h = t_next - t
481
+ s = t + h * r
482
+ fac = 1 / (2 * r)
483
+
484
+ # Step 1
485
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
486
+ s_ = t_fn(sd)
487
+ x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
488
+ x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
489
+ denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
490
+
491
+ # Step 2
492
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
493
+ t_next_ = t_fn(sd)
494
+ denoised_d = (1 - fac) * denoised + fac * denoised_2
495
+ x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
496
+ intermediates['pred_x0'].append(x)
497
+ x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
498
+ intermediates['x_inter'].append(x)
499
+ return intermediates, x
500
+
501
+ @torch.no_grad()
502
+ def sample_dpmpp_2m(model
503
+ , x
504
+ , sigmas
505
+ , need_golden_noise = True
506
+ , extra_args=None
507
+ , callback=None
508
+ , disable=None
509
+ , tmp_list=[]
510
+ , need_distill_uncond=False
511
+ , uncond_list=[]
512
+ , stop_t = None):
513
+ """DPM-Solver++(2M)."""
514
+ extra_args = {} if extra_args is None else extra_args
515
+ s_in = x.new_ones([x.shape[0]])
516
+ sigma_fn = lambda t: t.neg().exp()
517
+ t_fn = lambda sigma: sigma.log().neg()
518
+ old_denoised = None
519
+ if need_golden_noise:
520
+ x = model.get_golden_noised(x=x,sigma=sigmas[0] * s_in, sigma_nxt=sigmas[1] * s_in,**extra_args)
521
+
522
+ intermediates = {'x_inter': [x],'pred_x0': []}
523
+
524
+ for i in trange(len(sigmas) - 1, disable=disable):
525
+
526
+ # macs, params = profile(model, inputs=(x, sigmas[i] * s_in,*extra_args.values(),need_distill_uncond,tmp_list,uncond_list, ))
527
+ denoised = model(x, sigmas[i] * s_in, tmp_list=tmp_list,need_distill_uncond=need_distill_uncond,uncond_list=uncond_list, **extra_args)
528
+ if callback is not None:
529
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
530
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
531
+ h = t_next - t
532
+ if old_denoised is None or sigmas[i + 1] == 0:
533
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
534
+ intermediates['pred_x0'].append(denoised)
535
+ intermediates['x_inter'].append(x)
536
+ else:
537
+ h_last = t - t_fn(sigmas[i - 1])
538
+ r = h_last / h
539
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
540
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
541
+ intermediates['x_inter'].append(x)
542
+ intermediates['pred_x0'].append(denoised)
543
+ # print(denoised_d.max())
544
+ old_denoised = denoised
545
+ if i is not None and i == stop_t:
546
+ return intermediates, x
547
+ # intermediates['noise'].append(denoised_d)
548
+ return intermediates,x
549
+
550
+
551
+ # Adapted from pipelines.StableDiffusionPipeline.encode_prompt
552
+ def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True):
553
+ captions = []
554
+ for caption in prompt_batch:
555
+ if random.random() < proportion_empty_prompts:
556
+ captions.append("")
557
+ elif isinstance(caption, str):
558
+ captions.append(caption)
559
+ elif isinstance(caption, (list, np.ndarray)):
560
+ # take a random caption if there are multiple
561
+ captions.append(random.choice(caption) if is_train else caption[0])
562
+
563
+ with torch.no_grad():
564
+ text_inputs = tokenizer(
565
+ captions,
566
+ padding="max_length",
567
+ max_length=tokenizer.model_max_length,
568
+ truncation=True,
569
+ return_tensors="pt",
570
+ )
571
+ text_input_ids = text_inputs.input_ids
572
+ prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device))[0]
573
+
574
+ return prompt_embeds
575
+
576
+ def chunk(it, size):
577
+ it = iter(it)
578
+ return iter(lambda: tuple(islice(it, size)), ())
579
+
580
+ def convert_caption_json_to_str(json):
581
+ caption = json["caption"]
582
+ return caption
583
+
584
+ torch_dtype = torch.float32
585
+ pipe = StableDiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
586
+ pipe.to(device=device, torch_dtype=torch_dtype)
587
+ # neg_emb = pipe.load_textual_inversion("./EasyNegative.safetensors", device=device, dtype=torch_dtype)
588
+
589
+ pipe = pipe.to(device)
590
+ register_free_upblock2d(pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
591
+ register_free_crossattn_upblock2d(pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
592
  MAX_SEED = np.iinfo(np.int32).max
593
  MAX_IMAGE_SIZE = 1024
594
+ noise_scheduler = pipe.scheduler
595
+ alpha_schedule = noise_scheduler.alphas_cumprod.to(device=device, dtype=torch_dtype)
596
+ model_wrap = DiscreteEpsDDPMDenoiser(pipe.unet, alpha_schedule, quantize=False)
597
+ accelerator = accelerate.Accelerator()
598
+
599
+ def generate_image_with_steps(prompt, negative_prompt, seed, width, height, guidance_scale, num_inference_steps):
600
+ """Helper function to generate image with specific number of steps"""
601
+ prompts = [prompt]
602
+ if num_inference_steps <= 10:
603
+ register_free_upblock2d(pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
604
+ register_free_crossattn_upblock2d(pipe, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
605
+ else:
606
+ register_free_upblock2d(pipe, b1=1, b2=1, s1=1, s2=1)
607
+ register_free_crossattn_upblock2d(pipe, b1=1, b2=1, s1=1, s2=1)
608
+ if randomize_seed:
609
+ seed = random.randint(0, MAX_SEED)
610
+
611
+ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
612
+ prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
613
+ return {"prompt_embeds": prompt_embeds}
614
+
615
+ compute_embeddings_fn = functools.partial(
616
+ compute_embeddings,
617
+ proportion_empty_prompts=0,
618
+ text_encoder=pipe.text_encoder,
619
+ tokenizer=pipe.tokenizer,
620
+ )
621
+ generator = torch.Generator().manual_seed(seed)
622
+
623
+ intermediate_photos = list()
624
+ # prompts = prompts[0]
625
+
626
+ # if isinstance(prompts, tuple) or isinstance(prompts, str):
627
+ # prompts = list(prompts)
628
+ if isinstance(prompts, str):
629
+ prompts = prompts #+ 'high quality, best quality, masterpiece, 4K, highres, extremely detailed, ultra-detailed'
630
+ prompts = (prompts,)
631
+ if isinstance(prompts, tuple) or isinstance(prompts, str):
632
+ prompts = list(prompts)
633
+
634
+ encoded_text = compute_embeddings_fn(prompts)
635
+ uc = compute_embeddings_fn(1 * [""])
636
+ uc = uc.pop("prompt_embeds") if uc is not None else None
637
+ c = encoded_text.pop("prompt_embeds")
638
+ shape = [4, height // 8, width // 8]
639
+ start_free_step = num_inference_steps
640
+ fir_stage_sigmas_ct = None
641
+ sec_stage_sigmas_ct = None
642
+ # sigmas = model_wrap.get_sigmas(opt.ddim_steps).to(device=device)
643
+ if num_inference_steps == 5:
644
+ sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()
645
+ sigmas = get_sigmas_karras(8, sigma_min, sigma_max, rho=5.0, device=device)# 6.0 if 5 else 10.0
646
+
647
+ ct_start, ct_end = model_wrap.sigma_to_t(sigmas[0]), model_wrap.sigma_to_t(sigmas[6])
648
+ # sigma_kct_start, sigma_kct_end = sigmas[0].item(), sigmas[5].item()
649
+ ct = get_sigmas_karras(6, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
650
+ sigmas_ct = model_wrap.get_special_sigmas_with_timesteps(ct).to(device=device)
651
+ elif num_inference_steps == 6:
652
+ sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()
653
+ sigmas = get_sigmas_karras(8, sigma_min, sigma_max,rho=5.0, device=device)# 6.0 if 5 else 10.0
654
+
655
+ ct_start, ct_end = model_wrap.sigma_to_t(sigmas[0]), model_wrap.sigma_to_t(sigmas[6])
656
+ # sigma_kct_start, sigma_kct_end = sigmas[0].item(), sigmas[5].item()
657
+ ct = get_sigmas_karras(7, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
658
+ sigmas_ct = model_wrap.get_special_sigmas_with_timesteps(ct).to(device=device)
659
+
660
+ elif num_inference_steps == 8:
661
+ sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()
662
+
663
+ sigmas = get_sigmas_karras(12, sigma_min, sigma_max,rho=7.0, device=device)# 6.0 if 5 else 10.0
664
+ ct_start, ct_end = model_wrap.sigma_to_t(sigmas[0]), model_wrap.sigma_to_t(sigmas[9])
665
+
666
+ ct = get_sigmas_karras(num_inference_steps +1, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
667
+ sigmas_ct = model_wrap.get_special_sigmas_with_timesteps(ct).to(device=device)
668
+ else:
669
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(
670
+ pipe.scheduler.config
671
+ )
672
+ image = pipe(prompt=prompts
673
+ ,num_inference_steps=num_inference_steps
674
+ ,guidance_scale=guidance_scale
675
+ ,height=height
676
+ ,width=width).images[0]
677
+ return image
678
+ ts = []
679
+ for sigma in sigmas_ct:
680
+ t = model_wrap.sigma_to_t(sigma)
681
+ ts.append(t)
682
+
683
+ c_in = model_wrap.get_c_ins(sigmas=sigmas_ct)
684
+ x = torch.randn([1, *shape], device=device) * sigmas_ct[0]
685
+ # if opt.is_acgn:
686
+ # x = x.half()
687
+ model_wrap_cfg = CFGDenoiser(model_wrap)
688
+ extra_args = {'cond': c, 'uncond': uc, 'cond_scale': guidance_scale}
689
+ noise_training_list = {}
690
+ with torch.no_grad():
691
+ # with precision_scope("cuda" if torch.cuda.is_available() else "cpu"):
692
+ if not (num_inference_steps == 8 or num_inference_steps == 7):
693
+ guide_distill, samples_ddim = sample_dpmpp_ode(model_wrap_cfg
694
+ , x
695
+ , sigmas_ct
696
+ , need_golden_noise = False
697
+ , extra_args=extra_args
698
+ , disable=not accelerator.is_main_process
699
+ , tmp_list=intermediate_photos
700
+ , noise_training_list=noise_training_list)
701
+ # , stop_t=4)
702
+ else:
703
+ guide_distill, samples_ddim = sample_dpmpp_2m(model_wrap_cfg
704
+ , x
705
+ , sigmas_ct
706
+ , need_golden_noise = False
707
+ , extra_args=extra_args
708
+ , disable=not accelerator.is_main_process
709
+ , tmp_list=intermediate_photos)
710
+ # print('2m')
711
+
712
+ x_samples_ddim = pipe.vae.decode(samples_ddim / pipe.vae.config.scaling_factor).sample
713
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
714
+
715
+ if True: # not opt.skip_save:
716
+ for x_sample in x_samples_ddim:
717
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
718
+ image = Image.fromarray(x_sample.astype(np.uint8))
719
+ # base_count += 1
720
+
721
+ # image = pipe(
722
+ # prompt=prompt,
723
+ # negative_prompt=negative_prompt,
724
+ # guidance_scale=guidance_scale,
725
+ # num_inference_steps=num_inference_steps,
726
+ # width=width,
727
+ # height=height,
728
+ # generator=generator,
729
+ # ).images[0]
730
 
731
+ return image
732
 
733
+ @spaces.GPU #[uncomment to use ZeroGPU]
734
  def infer(
735
  prompt,
736
  negative_prompt,
737
  seed,
738
  randomize_seed,
739
+ resolution,
 
740
  guidance_scale,
741
  num_inference_steps,
742
  progress=gr.Progress(track_tqdm=True),
743
  ):
744
  if randomize_seed:
745
  seed = random.randint(0, MAX_SEED)
746
+
747
+ # Parse resolution string into width and height
748
+ width, height = map(int, resolution.split('x'))
749
+
750
+ # Generate image with selected steps
751
+ image_quick = generate_image_with_steps(prompt, negative_prompt, seed, width, height, guidance_scale, num_inference_steps)
752
+
753
+ # Generate image with 50 steps for high quality
754
+ image_50_steps = generate_image_with_steps(prompt, negative_prompt, seed, width, height, guidance_scale, 50)
755
 
756
+ return image_quick, image_50_steps, seed
 
 
 
 
 
 
 
 
 
 
 
 
757
 
758
 
759
  examples = [
760
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
761
+ "((masterpiece,best quality)) , 1girl, ((school uniform)),brown blazer, black skirt,small breasts,necktie,red plaid skirt,looking at viewer",
762
  "A delicious ceviche cheesecake slice",
763
  ]
764
 
 
771
 
772
  with gr.Blocks(css=css) as demo:
773
  with gr.Column(elem_id="col-container"):
774
+ gr.Markdown(" # Hyperparameters are all you need")
775
 
776
  with gr.Row():
777
  prompt = gr.Text(
 
784
 
785
  run_button = gr.Button("Run", scale=0, variant="primary")
786
 
787
+ with gr.Row():
788
+ with gr.Column():
789
+ gr.Markdown("### Our fast inference Result")
790
+ result = gr.Image(label="Quick Result", show_label=False)
791
+ with gr.Column():
792
+ gr.Markdown("### Original 50 steps Result")
793
+ result_50_steps = gr.Image(label="50 Steps Result", show_label=False)
794
 
795
  with gr.Accordion("Advanced Settings", open=False):
796
  negative_prompt = gr.Text(
 
810
 
811
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
812
 
813
+ resolution = gr.Dropdown(
814
+ choices=[
815
+ "512x512",
816
+ "1024x768",
817
+ "768x1024"
818
+ ],
819
+ value="1024x768",
820
+ label="Resolution",
821
+ )
 
 
 
 
 
 
 
822
 
823
  with gr.Row():
824
  guidance_scale = gr.Slider(
 
826
  minimum=0.0,
827
  maximum=10.0,
828
  step=0.1,
829
+ value=7.5, # Replace with defaults that work for your model
830
  )
831
 
832
+ num_inference_steps = gr.Dropdown(
833
+ choices=[5, 6, 8],
834
+ value=8,
835
  label="Number of inference steps",
 
 
 
 
836
  )
837
 
838
  gr.Examples(examples=examples, inputs=[prompt])
 
844
  negative_prompt,
845
  seed,
846
  randomize_seed,
847
+ resolution,
 
848
  guidance_scale,
849
  num_inference_steps,
850
  ],
851
+ outputs=[result, result_50_steps, seed],
852
  )
853
 
854
  if __name__ == "__main__":
free_lunch_utils.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.fft as fft
3
+ from diffusers.utils import is_torch_version
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+
7
+ def isinstance_str(x: object, cls_name: str):
8
+ """
9
+ Checks whether x has any class *named* cls_name in its ancestry.
10
+ Doesn't require access to the class's implementation.
11
+
12
+ Useful for patching!
13
+ """
14
+
15
+ for _cls in x.__class__.__mro__:
16
+ if _cls.__name__ == cls_name:
17
+ return True
18
+
19
+ return False
20
+
21
+
22
+ def Fourier_filter(x, threshold, scale):
23
+ dtype = x.dtype
24
+ x = x.type(torch.float32)
25
+ # FFT
26
+ x_freq = fft.fftn(x, dim=(-2, -1))
27
+ x_freq = fft.fftshift(x_freq, dim=(-2, -1))
28
+
29
+ B, C, H, W = x_freq.shape
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ mask = torch.ones((B, C, H, W)).to(device=device)
32
+
33
+ crow, ccol = H // 2, W //2
34
+ mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
35
+ x_freq = x_freq * mask
36
+
37
+ # IFFT
38
+ x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
39
+ x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
40
+
41
+ x_filtered = x_filtered.type(dtype)
42
+ return x_filtered
43
+
44
+
45
+ def register_upblock2d(model):
46
+ def up_forward(self):
47
+ def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
48
+ for resnet in self.resnets:
49
+ # pop res hidden states
50
+ res_hidden_states = res_hidden_states_tuple[-1]
51
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
52
+ #print(f"in upblock2d, hidden states shape: {hidden_states.shape}")
53
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
54
+
55
+ if self.training and self.gradient_checkpointing:
56
+
57
+ def create_custom_forward(module):
58
+ def custom_forward(*inputs):
59
+ return module(*inputs)
60
+
61
+ return custom_forward
62
+
63
+ if is_torch_version(">=", "1.11.0"):
64
+ hidden_states = torch.utils.checkpoint.checkpoint(
65
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
66
+ )
67
+ else:
68
+ hidden_states = torch.utils.checkpoint.checkpoint(
69
+ create_custom_forward(resnet), hidden_states, temb
70
+ )
71
+ else:
72
+ hidden_states = resnet(hidden_states, temb)
73
+
74
+ if self.upsamplers is not None:
75
+ for upsampler in self.upsamplers:
76
+ hidden_states = upsampler(hidden_states, upsample_size)
77
+
78
+ return hidden_states
79
+
80
+ return forward
81
+
82
+ for i, upsample_block in enumerate(model.unet.up_blocks):
83
+ if isinstance_str(upsample_block, "UpBlock2D"):
84
+ upsample_block.forward = up_forward(upsample_block)
85
+
86
+
87
+ def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
88
+ def up_forward(self):
89
+ def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
90
+ for resnet in self.resnets:
91
+ # pop res hidden states
92
+ res_hidden_states = res_hidden_states_tuple[-1]
93
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
94
+ #print(f"in free upblock2d, hidden states shape: {hidden_states.shape}")
95
+
96
+ # --------------- FreeU code -----------------------
97
+ # Only operate on the first two stages
98
+ if hidden_states.shape[1] == 1280:
99
+ hidden_states[:,:640] = hidden_states[:,:640] * self.b1
100
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
101
+ if hidden_states.shape[1] == 640:
102
+ hidden_states[:,:320] = hidden_states[:,:320] * self.b2
103
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
104
+ # ---------------------------------------------------------
105
+
106
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
107
+
108
+ if self.training and self.gradient_checkpointing:
109
+
110
+ def create_custom_forward(module):
111
+ def custom_forward(*inputs):
112
+ return module(*inputs)
113
+
114
+ return custom_forward
115
+
116
+ if is_torch_version(">=", "1.11.0"):
117
+ hidden_states = torch.utils.checkpoint.checkpoint(
118
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
119
+ )
120
+ else:
121
+ hidden_states = torch.utils.checkpoint.checkpoint(
122
+ create_custom_forward(resnet), hidden_states, temb
123
+ )
124
+ else:
125
+ hidden_states = resnet(hidden_states, temb)
126
+
127
+ if self.upsamplers is not None:
128
+ for upsampler in self.upsamplers:
129
+ hidden_states = upsampler(hidden_states, upsample_size)
130
+
131
+ return hidden_states
132
+
133
+ return forward
134
+
135
+ for i, upsample_block in enumerate(model.unet.up_blocks):
136
+ if isinstance_str(upsample_block, "UpBlock2D"):
137
+ upsample_block.forward = up_forward(upsample_block)
138
+ setattr(upsample_block, 'b1', b1)
139
+ setattr(upsample_block, 'b2', b2)
140
+ setattr(upsample_block, 's1', s1)
141
+ setattr(upsample_block, 's2', s2)
142
+
143
+
144
+ def register_crossattn_upblock2d(model):
145
+ def up_forward(self):
146
+ def forward(
147
+ hidden_states: torch.FloatTensor,
148
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
149
+ temb: Optional[torch.FloatTensor] = None,
150
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
151
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
152
+ upsample_size: Optional[int] = None,
153
+ attention_mask: Optional[torch.FloatTensor] = None,
154
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
155
+ ):
156
+ for resnet, attn in zip(self.resnets, self.attentions):
157
+ # pop res hidden states
158
+ #print(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}")
159
+ res_hidden_states = res_hidden_states_tuple[-1]
160
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
161
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
162
+
163
+ if self.training and self.gradient_checkpointing:
164
+
165
+ def create_custom_forward(module, return_dict=None):
166
+ def custom_forward(*inputs):
167
+ if return_dict is not None:
168
+ return module(*inputs, return_dict=return_dict)
169
+ else:
170
+ return module(*inputs)
171
+
172
+ return custom_forward
173
+
174
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
175
+ hidden_states = torch.utils.checkpoint.checkpoint(
176
+ create_custom_forward(resnet),
177
+ hidden_states,
178
+ temb,
179
+ **ckpt_kwargs,
180
+ )
181
+ hidden_states = torch.utils.checkpoint.checkpoint(
182
+ create_custom_forward(attn, return_dict=False),
183
+ hidden_states,
184
+ encoder_hidden_states,
185
+ None, # timestep
186
+ None, # class_labels
187
+ cross_attention_kwargs,
188
+ attention_mask,
189
+ encoder_attention_mask,
190
+ **ckpt_kwargs,
191
+ )[0]
192
+ else:
193
+ hidden_states = resnet(hidden_states, temb)
194
+ hidden_states = attn(
195
+ hidden_states,
196
+ encoder_hidden_states=encoder_hidden_states,
197
+ cross_attention_kwargs=cross_attention_kwargs,
198
+ attention_mask=attention_mask,
199
+ encoder_attention_mask=encoder_attention_mask,
200
+ return_dict=False,
201
+ )[0]
202
+
203
+ if self.upsamplers is not None:
204
+ for upsampler in self.upsamplers:
205
+ hidden_states = upsampler(hidden_states, upsample_size)
206
+
207
+ return hidden_states
208
+
209
+ return forward
210
+
211
+ for i, upsample_block in enumerate(model.unet.up_blocks):
212
+ if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
213
+ upsample_block.forward = up_forward(upsample_block)
214
+
215
+
216
+ def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
217
+ def up_forward(self):
218
+ def forward(
219
+ hidden_states: torch.FloatTensor,
220
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
221
+ temb: Optional[torch.FloatTensor] = None,
222
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
223
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
224
+ upsample_size: Optional[int] = None,
225
+ attention_mask: Optional[torch.FloatTensor] = None,
226
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
227
+ ):
228
+ for resnet, attn in zip(self.resnets, self.attentions):
229
+ # pop res hidden states
230
+ #print(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}")
231
+ res_hidden_states = res_hidden_states_tuple[-1]
232
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
233
+
234
+ # --------------- FreeU code -----------------------
235
+ # Only operate on the first two stages
236
+ if hidden_states.shape[1] == 1280:
237
+ hidden_states[:,:640] = hidden_states[:,:640] * self.b1
238
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
239
+ if hidden_states.shape[1] == 640:
240
+ hidden_states[:,:320] = hidden_states[:,:320] * self.b2
241
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
242
+ # ---------------------------------------------------------
243
+
244
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
245
+
246
+ if self.training and self.gradient_checkpointing:
247
+
248
+ def create_custom_forward(module, return_dict=None):
249
+ def custom_forward(*inputs):
250
+ if return_dict is not None:
251
+ return module(*inputs, return_dict=return_dict)
252
+ else:
253
+ return module(*inputs)
254
+
255
+ return custom_forward
256
+
257
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
258
+ hidden_states = torch.utils.checkpoint.checkpoint(
259
+ create_custom_forward(resnet),
260
+ hidden_states,
261
+ temb,
262
+ **ckpt_kwargs,
263
+ )
264
+ hidden_states = torch.utils.checkpoint.checkpoint(
265
+ create_custom_forward(attn, return_dict=False),
266
+ hidden_states,
267
+ encoder_hidden_states,
268
+ None, # timestep
269
+ None, # class_labels
270
+ cross_attention_kwargs,
271
+ attention_mask,
272
+ encoder_attention_mask,
273
+ **ckpt_kwargs,
274
+ )[0]
275
+ else:
276
+ hidden_states = resnet(hidden_states, temb)
277
+ # hidden_states = attn(
278
+ # hidden_states,
279
+ # encoder_hidden_states=encoder_hidden_states,
280
+ # cross_attention_kwargs=cross_attention_kwargs,
281
+ # encoder_attention_mask=encoder_attention_mask,
282
+ # return_dict=False,
283
+ # )[0]
284
+ hidden_states = attn(
285
+ hidden_states,
286
+ encoder_hidden_states=encoder_hidden_states,
287
+ cross_attention_kwargs=cross_attention_kwargs,
288
+ )[0]
289
+
290
+ if self.upsamplers is not None:
291
+ for upsampler in self.upsamplers:
292
+ hidden_states = upsampler(hidden_states, upsample_size)
293
+
294
+ return hidden_states
295
+
296
+ return forward
297
+
298
+ for i, upsample_block in enumerate(model.unet.up_blocks):
299
+ if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
300
+ upsample_block.forward = up_forward(upsample_block)
301
+ setattr(upsample_block, 'b1', b1)
302
+ setattr(upsample_block, 'b2', b2)
303
+ setattr(upsample_block, 's1', s1)
304
+ setattr(upsample_block, 's2', s2)
requirements.txt CHANGED
@@ -1,6 +1,10 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
- torch
5
- transformers
6
- xformers
 
 
 
 
 
1
+ tqdm
2
+ einops
3
+ pytorch_lightning
4
+ accelerate>=0.20.0
5
+ torchsde
6
+ pycocotools
7
+ diffusers== 0.32.2
8
+ timm
9
+ transformers==4.49
10
+ torch>=2.0.0