Spaces:
Runtime error
Runtime error
Update pipeline.py
Browse files- pipeline.py +45 -59
pipeline.py
CHANGED
|
@@ -292,11 +292,9 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
| 292 |
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
| 293 |
|
| 294 |
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
| 295 |
-
text_ids = torch.zeros(
|
| 296 |
-
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
|
| 297 |
-
negative_text_ids = torch.zeros(batch_size, negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
| 298 |
|
| 299 |
-
return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, negative_pooled_prompt_embeds
|
| 300 |
|
| 301 |
def check_inputs(
|
| 302 |
self,
|
|
@@ -485,13 +483,11 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
| 485 |
self,
|
| 486 |
prompt: Union[str, List[str]] = None,
|
| 487 |
prompt_2: Optional[Union[str, List[str]]] = None,
|
|
|
|
| 488 |
height: Optional[int] = None,
|
| 489 |
width: Optional[int] = None,
|
| 490 |
-
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 491 |
-
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 492 |
num_inference_steps: int = 8,
|
| 493 |
timesteps: List[int] = None,
|
| 494 |
-
eta: Optional[float] = 0.0,
|
| 495 |
guidance_scale: float = 3.5,
|
| 496 |
device: Optional[int] = None,
|
| 497 |
num_images_per_prompt: Optional[int] = 1,
|
|
@@ -499,14 +495,13 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
| 499 |
latents: Optional[torch.FloatTensor] = None,
|
| 500 |
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 501 |
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 502 |
-
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 503 |
-
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 504 |
output_type: Optional[str] = "pil",
|
| 505 |
cfg: Optional[bool] = True,
|
| 506 |
return_dict: bool = True,
|
| 507 |
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
|
|
|
|
|
| 508 |
max_sequence_length: int = 512,
|
| 509 |
-
**kwargs,
|
| 510 |
):
|
| 511 |
height = height or self.default_sample_size * self.vae_scale_factor
|
| 512 |
width = width or self.default_sample_size * self.vae_scale_factor
|
|
@@ -518,9 +513,8 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
| 518 |
height,
|
| 519 |
width,
|
| 520 |
prompt_embeds=prompt_embeds,
|
| 521 |
-
negative_prompt_embeds=negative_prompt_embeds,
|
| 522 |
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 523 |
-
|
| 524 |
max_sequence_length=max_sequence_length,
|
| 525 |
)
|
| 526 |
|
|
@@ -546,21 +540,16 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
| 546 |
pooled_prompt_embeds,
|
| 547 |
text_ids,
|
| 548 |
negative_prompt_embeds,
|
| 549 |
-
negative_pooled_prompt_embeds
|
| 550 |
-
negative_text_ids,
|
| 551 |
) = self.encode_prompt(
|
| 552 |
prompt=prompt,
|
| 553 |
prompt_2=prompt_2,
|
| 554 |
num_images_per_prompt=num_images_per_prompt,
|
| 555 |
max_sequence_length=max_sequence_length,
|
| 556 |
-
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 557 |
device=device,
|
| 558 |
negative_prompt=negative_prompt,
|
| 559 |
-
negative_prompt_2=negative_prompt_2,
|
| 560 |
prompt_embeds=prompt_embeds,
|
| 561 |
-
negative_prompt_embeds=negative_prompt_embeds,
|
| 562 |
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 563 |
-
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 564 |
lora_scale=lora_scale,
|
| 565 |
)
|
| 566 |
|
|
@@ -607,67 +596,64 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
| 607 |
for i, t in enumerate(timesteps):
|
| 608 |
if self.interrupt:
|
| 609 |
continue
|
| 610 |
-
|
|
|
|
| 611 |
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 612 |
-
|
| 613 |
-
timestep = t.expand(latent_model_input.shape[0])
|
| 614 |
|
| 615 |
-
|
| 616 |
-
guidance = torch.tensor([guidance_scale], device=device)
|
| 617 |
-
guidance = guidance.expand(latents.shape[0])
|
| 618 |
-
else:
|
| 619 |
-
guidance = None
|
| 620 |
-
|
| 621 |
-
noise_pred_text = self.transformer(
|
| 622 |
hidden_states=latent_model_input,
|
| 623 |
timestep=timestep / 1000,
|
| 624 |
-
|
| 625 |
-
pooled_projections=pooled_prompt_embeds.shape[1],
|
| 626 |
encoder_hidden_states=prompt_embeds,
|
| 627 |
txt_ids=text_ids,
|
| 628 |
img_ids=latent_image_ids,
|
| 629 |
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 630 |
return_dict=False,
|
| 631 |
)[0]
|
| 632 |
-
|
| 633 |
-
hidden_states=latents,
|
| 634 |
-
timestep=timestep / 1000,
|
| 635 |
-
guidance=guidance,
|
| 636 |
-
pooled_projections=negative_pooled_prompt_embeds.shape[1],
|
| 637 |
-
encoder_hidden_states=negative_prompt_embeds,
|
| 638 |
-
txt_ids=negative_text_ids,
|
| 639 |
-
img_ids=latent_image_ids,
|
| 640 |
-
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 641 |
-
return_dict=False,
|
| 642 |
-
)[0]
|
| 643 |
-
|
| 644 |
if self.do_classifier_free_guidance:
|
| 645 |
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(
|
| 646 |
-
noise_pred = noise_pred_uncond + self.
|
| 647 |
-
|
| 648 |
-
|
| 649 |
# compute the previous noisy sample x_t -> x_t-1
|
| 650 |
latents_dtype = latents.dtype
|
| 651 |
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 652 |
-
|
| 653 |
if latents.dtype != latents_dtype:
|
| 654 |
if torch.backends.mps.is_available():
|
| 655 |
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 656 |
latents = latents.to(latents_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 657 |
|
| 658 |
-
# call the callback, if provided
|
| 659 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 660 |
progress_bar.update()
|
|
|
|
|
|
|
|
|
|
| 661 |
|
| 662 |
-
|
| 663 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 664 |
self.maybe_free_model_hooks()
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 671 |
-
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
|
| 672 |
-
image = vae.decode(latents, return_dict=False)[0]
|
| 673 |
-
return self.image_processor.postprocess(image, output_type=output_type)[0]
|
|
|
|
| 292 |
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
| 293 |
|
| 294 |
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
| 295 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
|
|
|
|
|
|
| 296 |
|
| 297 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, negative_pooled_prompt_embeds
|
| 298 |
|
| 299 |
def check_inputs(
|
| 300 |
self,
|
|
|
|
| 483 |
self,
|
| 484 |
prompt: Union[str, List[str]] = None,
|
| 485 |
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 486 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 487 |
height: Optional[int] = None,
|
| 488 |
width: Optional[int] = None,
|
|
|
|
|
|
|
| 489 |
num_inference_steps: int = 8,
|
| 490 |
timesteps: List[int] = None,
|
|
|
|
| 491 |
guidance_scale: float = 3.5,
|
| 492 |
device: Optional[int] = None,
|
| 493 |
num_images_per_prompt: Optional[int] = 1,
|
|
|
|
| 495 |
latents: Optional[torch.FloatTensor] = None,
|
| 496 |
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 497 |
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
|
|
|
| 498 |
output_type: Optional[str] = "pil",
|
| 499 |
cfg: Optional[bool] = True,
|
| 500 |
return_dict: bool = True,
|
| 501 |
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 502 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 503 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 504 |
max_sequence_length: int = 512,
|
|
|
|
| 505 |
):
|
| 506 |
height = height or self.default_sample_size * self.vae_scale_factor
|
| 507 |
width = width or self.default_sample_size * self.vae_scale_factor
|
|
|
|
| 513 |
height,
|
| 514 |
width,
|
| 515 |
prompt_embeds=prompt_embeds,
|
|
|
|
| 516 |
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 517 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 518 |
max_sequence_length=max_sequence_length,
|
| 519 |
)
|
| 520 |
|
|
|
|
| 540 |
pooled_prompt_embeds,
|
| 541 |
text_ids,
|
| 542 |
negative_prompt_embeds,
|
| 543 |
+
negative_pooled_prompt_embeds
|
|
|
|
| 544 |
) = self.encode_prompt(
|
| 545 |
prompt=prompt,
|
| 546 |
prompt_2=prompt_2,
|
| 547 |
num_images_per_prompt=num_images_per_prompt,
|
| 548 |
max_sequence_length=max_sequence_length,
|
|
|
|
| 549 |
device=device,
|
| 550 |
negative_prompt=negative_prompt,
|
|
|
|
| 551 |
prompt_embeds=prompt_embeds,
|
|
|
|
| 552 |
pooled_prompt_embeds=pooled_prompt_embeds,
|
|
|
|
| 553 |
lora_scale=lora_scale,
|
| 554 |
)
|
| 555 |
|
|
|
|
| 596 |
for i, t in enumerate(timesteps):
|
| 597 |
if self.interrupt:
|
| 598 |
continue
|
| 599 |
+
|
| 600 |
+
# expand the latents if we are doing classifier free guidance
|
| 601 |
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 602 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 603 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 604 |
|
| 605 |
+
noise_pred = self.transformer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 606 |
hidden_states=latent_model_input,
|
| 607 |
timestep=timestep / 1000,
|
| 608 |
+
pooled_projections=pooled_prompt_embeds,
|
|
|
|
| 609 |
encoder_hidden_states=prompt_embeds,
|
| 610 |
txt_ids=text_ids,
|
| 611 |
img_ids=latent_image_ids,
|
| 612 |
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 613 |
return_dict=False,
|
| 614 |
)[0]
|
| 615 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 616 |
if self.do_classifier_free_guidance:
|
| 617 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 618 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 619 |
+
|
|
|
|
| 620 |
# compute the previous noisy sample x_t -> x_t-1
|
| 621 |
latents_dtype = latents.dtype
|
| 622 |
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 623 |
+
|
| 624 |
if latents.dtype != latents_dtype:
|
| 625 |
if torch.backends.mps.is_available():
|
| 626 |
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 627 |
latents = latents.to(latents_dtype)
|
| 628 |
+
|
| 629 |
+
if callback_on_step_end is not None:
|
| 630 |
+
callback_kwargs = {}
|
| 631 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 632 |
+
callback_kwargs[k] = locals()[k]
|
| 633 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 634 |
+
|
| 635 |
+
latents = callback_outputs.pop("latents", latents)
|
| 636 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 637 |
|
|
|
|
| 638 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 639 |
progress_bar.update()
|
| 640 |
+
|
| 641 |
+
if XLA_AVAILABLE:
|
| 642 |
+
xm.mark_step()
|
| 643 |
|
| 644 |
+
if output_type == "latent":
|
| 645 |
+
image = latents
|
| 646 |
+
|
| 647 |
+
else:
|
| 648 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 649 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 650 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 651 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 652 |
+
|
| 653 |
+
# Offload all models
|
| 654 |
self.maybe_free_model_hooks()
|
| 655 |
+
|
| 656 |
+
if not return_dict:
|
| 657 |
+
return (image,)
|
| 658 |
+
|
| 659 |
+
return FluxPipelineOutput(images=image)
|
|
|
|
|
|
|
|
|
|
|
|