Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| """ | |
| Demo showcasing parameter-efficient fine-tuning of Stable Dissfusion via Dreambooth leveraging 🤗 PEFT (https://github.com/huggingface/peft) | |
| The code in this repo is partly adapted from the following repositories: | |
| https://huggingface.co/spaces/hysts/LoRA-SD-training | |
| https://huggingface.co/spaces/multimodalart/dreambooth-training | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import pathlib | |
| import gradio as gr | |
| import torch | |
| from typing import List | |
| from inference import InferencePipeline | |
| from trainer import Trainer | |
| from uploader import upload | |
| TITLE = "# LoRA + Dreambooth Training and Inference Demo 🎨" | |
| DESCRIPTION = "Demo showcasing parameter-efficient fine-tuning of Stable Dissfusion via Dreambooth leveraging 🤗 PEFT (https://github.com/huggingface/peft)." | |
| ORIGINAL_SPACE_ID = "smangrul/peft-lora-sd-dreambooth" | |
| SPACE_ID = os.getenv("SPACE_ID", ORIGINAL_SPACE_ID) | |
| SHARED_UI_WARNING = f"""# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU. | |
| <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center> | |
| """ | |
| if os.getenv("SYSTEM") == "spaces" and SPACE_ID != ORIGINAL_SPACE_ID: | |
| SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>' | |
| else: | |
| SETTINGS = "Settings" | |
| CUDA_NOT_AVAILABLE_WARNING = f"""# Attention - Running on CPU. | |
| <center> | |
| You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces. | |
| "T4 small" is sufficient to run this demo. | |
| </center> | |
| """ | |
| def show_warning(warning_text: str) -> gr.Blocks: | |
| with gr.Blocks() as demo: | |
| with gr.Box(): | |
| gr.Markdown(warning_text) | |
| return demo | |
| def update_output_files() -> dict: | |
| paths = sorted(pathlib.Path("results").glob("*.pt")) | |
| config_paths = sorted(pathlib.Path("results").glob("*.json")) | |
| paths = paths + config_paths | |
| paths = [path.as_posix() for path in paths] # type: ignore | |
| return gr.update(value=paths or None) | |
| def create_training_demo(trainer: Trainer, pipe: InferencePipeline) -> gr.Blocks: | |
| with gr.Blocks() as demo: | |
| base_model = gr.Dropdown( | |
| choices=[ | |
| "CompVis/stable-diffusion-v1-4", | |
| "runwayml/stable-diffusion-v1-5", | |
| "stabilityai/stable-diffusion-2-1-base", | |
| "dreamlike-art/dreamlike-photoreal-2.0" | |
| ], | |
| value="runwayml/stable-diffusion-v1-5", | |
| label="Base Model", | |
| visible=True, | |
| ) | |
| resolution = gr.Dropdown(choices=["512"], value="512", label="Resolution", visible=False) | |
| with gr.Row(): | |
| with gr.Box(): | |
| gr.Markdown("Training Data") | |
| concept_images = gr.Files(label="Images for your concept") | |
| class_images = gr.Files(label="Class images") | |
| concept_prompt = gr.Textbox(label="Concept Prompt", max_lines=1) | |
| gr.Markdown( | |
| """ | |
| - Upload images of the style you are planning on training on. | |
| - For a concept prompt, use a unique, made up word to avoid collisions. | |
| - Guidelines for getting good results: | |
| - Dreambooth for an `object` or `style`: | |
| - 5-10 images of the object from different angles | |
| - 500-800 iterations should be good enough. | |
| - Prior preservation is recommended. | |
| - `class_prompt`: | |
| - `a photo of object` | |
| - `style` | |
| - `concept_prompt`: | |
| - `<concept prompt> object` | |
| - `<concept prompt> style` | |
| - `a photo of <concept prompt> object` | |
| - `a photo of <concept prompt> style` | |
| - Dreambooth for a `Person/Face`: | |
| - 15-50 images of the person from different angles, lighting, and expressions. | |
| Have considerable photos with close up faces. | |
| - 800-1200 iterations should be good enough. | |
| - good defaults for hyperparams | |
| - Model - `runwayml/stable-diffusion-v1-5` or `stabilityai/stable-diffusion-2-1-base` | |
| - Use/check Prior preservation. | |
| - Number of class images to use - 200 | |
| - Prior Loss Weight - 1 | |
| - LoRA Rank for unet - 16 | |
| - LoRA Alpha for unet - 20 | |
| - lora dropout - 0 | |
| - LoRA Bias for unet - `all` | |
| - LoRA Rank for CLIP - 16 | |
| - LoRA Alpha for CLIP - 17 | |
| - LoRA Bias for CLIP - `all` | |
| - lora dropout for CLIP - 0 | |
| - Uncheck `FP16` and `8bit-Adam` (don't use them for faces) | |
| - `class_prompt`: Use the gender related word of the person | |
| - `man` | |
| - `woman` | |
| - `boy` | |
| - `girl` | |
| - `concept_prompt`: just the unique, made up word, e.g., `srm` | |
| - Choose `all` for `lora_bias` and `text_encode_lora_bias` | |
| - Dreambooth for a `Scene`: | |
| - 15-50 images of the scene from different angles, lighting, and expressions. | |
| - 800-1200 iterations should be good enough. | |
| - Prior preservation is recommended. | |
| - `class_prompt`: | |
| - `scene` | |
| - `landscape` | |
| - `city` | |
| - `beach` | |
| - `mountain` | |
| - `concept_prompt`: | |
| - `<concept prompt> scene` | |
| - `<concept prompt> landscape` | |
| - Experiment with various values for lora dropouts, enabling/disabling fp16 and 8bit-Adam | |
| """ | |
| ) | |
| with gr.Box(): | |
| gr.Markdown("Training Parameters") | |
| num_training_steps = gr.Number(label="Number of Training Steps", value=1000, precision=0) | |
| learning_rate = gr.Number(label="Learning Rate", value=0.0001) | |
| gradient_checkpointing = gr.Checkbox(label="Whether to use gradient checkpointing", value=True) | |
| train_text_encoder = gr.Checkbox(label="Train Text Encoder", value=True) | |
| with_prior_preservation = gr.Checkbox(label="Prior Preservation", value=True) | |
| class_prompt = gr.Textbox( | |
| label="Class Prompt", max_lines=1, placeholder='Example: "a photo of object"' | |
| ) | |
| num_class_images = gr.Number(label="Number of class images to use", value=50, precision=0) | |
| prior_loss_weight = gr.Number(label="Prior Loss Weight", value=1.0, precision=1) | |
| # use_lora = gr.Checkbox(label="Whether to use LoRA", value=True) | |
| lora_r = gr.Number(label="LoRA Rank for unet", value=4, precision=0) | |
| lora_alpha = gr.Number( | |
| label="LoRA Alpha for unet. scaling factor = lora_alpha/lora_r", value=4, precision=0 | |
| ) | |
| lora_dropout = gr.Number(label="lora dropout", value=0.00) | |
| lora_bias = gr.Dropdown( | |
| choices=["none", "all", "lora_only"], | |
| value="none", | |
| label="LoRA Bias for unet. This enables bias params to be trainable based on the bias type", | |
| visible=True, | |
| ) | |
| lora_text_encoder_r = gr.Number(label="LoRA Rank for CLIP", value=4, precision=0) | |
| lora_text_encoder_alpha = gr.Number( | |
| label="LoRA Alpha for CLIP. scaling factor = lora_alpha/lora_r", value=4, precision=0 | |
| ) | |
| lora_text_encoder_dropout = gr.Number(label="lora dropout for CLIP", value=0.00) | |
| lora_text_encoder_bias = gr.Dropdown( | |
| choices=["none", "all", "lora_only"], | |
| value="none", | |
| label="LoRA Bias for CLIP. This enables bias params to be trainable based on the bias type", | |
| visible=True, | |
| ) | |
| gradient_accumulation = gr.Number(label="Number of Gradient Accumulation", value=1, precision=0) | |
| fp16 = gr.Checkbox(label="FP16", value=True) | |
| use_8bit_adam = gr.Checkbox(label="Use 8bit Adam", value=True) | |
| gr.Markdown( | |
| """ | |
| - It will take about 20-30 minutes to train for 1000 steps with a T4 GPU. | |
| - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment. | |
| - Note that your trained models will be deleted when the second training is started. You can upload your trained model in the "Upload" tab. | |
| """ | |
| ) | |
| run_button = gr.Button("Start Training") | |
| with gr.Box(): | |
| with gr.Row(): | |
| check_status_button = gr.Button("Check Training Status") | |
| with gr.Column(): | |
| with gr.Box(): | |
| gr.Markdown("Message") | |
| training_status = gr.Markdown() | |
| output_files = gr.Files(label="Trained Weight Files and Configs") | |
| run_button.click(fn=pipe.clear) | |
| run_button.click( | |
| fn=trainer.run, | |
| inputs=[ | |
| base_model, | |
| resolution, | |
| num_training_steps, | |
| concept_images, | |
| concept_prompt, | |
| class_images, | |
| learning_rate, | |
| gradient_accumulation, | |
| fp16, | |
| use_8bit_adam, | |
| gradient_checkpointing, | |
| train_text_encoder, | |
| with_prior_preservation, | |
| prior_loss_weight, | |
| class_prompt, | |
| num_class_images, | |
| lora_r, | |
| lora_alpha, | |
| lora_bias, | |
| lora_dropout, | |
| lora_text_encoder_r, | |
| lora_text_encoder_alpha, | |
| lora_text_encoder_bias, | |
| lora_text_encoder_dropout, | |
| ], | |
| outputs=[ | |
| training_status, | |
| output_files, | |
| ], | |
| queue=False, | |
| ) | |
| check_status_button.click(fn=trainer.check_if_running, inputs=None, outputs=training_status, queue=False) | |
| check_status_button.click(fn=update_output_files, inputs=None, outputs=output_files, queue=False) | |
| return demo | |
| def find_weight_files() -> List[str]: | |
| curr_dir = pathlib.Path(__file__).parent | |
| paths = sorted(curr_dir.rglob("*.pt")) | |
| return [path.relative_to(curr_dir).as_posix() for path in paths] | |
| def reload_lora_weight_list() -> dict: | |
| return gr.update(choices=find_weight_files()) | |
| def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks: | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| base_model = gr.Dropdown( | |
| choices=[ | |
| "CompVis/stable-diffusion-v1-4", | |
| "runwayml/stable-diffusion-v1-5", | |
| "stabilityai/stable-diffusion-2-1-base", | |
| "dreamlike-art/dreamlike-photoreal-2.0" | |
| ], | |
| value="runwayml/stable-diffusion-v1-5", | |
| label="Base Model", | |
| visible=True, | |
| ) | |
| reload_button = gr.Button("Reload Weight List") | |
| lora_weight_name = gr.Dropdown( | |
| choices=find_weight_files(), value="lora/lora_disney.pt", label="LoRA Weight File" | |
| ) | |
| prompt = gr.Textbox(label="Prompt", max_lines=1, placeholder='Example: "style of sks, baby lion"') | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt", max_lines=1, placeholder='Example: "blurry, botched, low quality"' | |
| ) | |
| seed = gr.Slider(label="Seed", minimum=0, maximum=100000, step=1, value=1) | |
| with gr.Accordion("Other Parameters", open=False): | |
| num_steps = gr.Slider(label="Number of Steps", minimum=0, maximum=1000, step=1, value=50) | |
| guidance_scale = gr.Slider(label="CFG Scale", minimum=0, maximum=50, step=0.1, value=7) | |
| run_button = gr.Button("Generate") | |
| gr.Markdown( | |
| """ | |
| - After training, you can press "Reload Weight List" button to load your trained model names. | |
| - Few repos to refer for ideas: | |
| - https://huggingface.co/smangrul/smangrul | |
| - https://huggingface.co/smangrul/painting-in-the-style-of-smangrul | |
| - https://huggingface.co/smangrul/erenyeager | |
| """ | |
| ) | |
| with gr.Column(): | |
| result = gr.Image(label="Result") | |
| reload_button.click(fn=reload_lora_weight_list, inputs=None, outputs=lora_weight_name) | |
| prompt.submit( | |
| fn=pipe.run, | |
| inputs=[ | |
| base_model, | |
| lora_weight_name, | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| num_steps, | |
| guidance_scale, | |
| ], | |
| outputs=result, | |
| queue=False, | |
| ) | |
| run_button.click( | |
| fn=pipe.run, | |
| inputs=[ | |
| base_model, | |
| lora_weight_name, | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| num_steps, | |
| guidance_scale, | |
| ], | |
| outputs=result, | |
| queue=False, | |
| ) | |
| seed.change( | |
| fn=pipe.run, | |
| inputs=[ | |
| base_model, | |
| lora_weight_name, | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| num_steps, | |
| guidance_scale, | |
| ], | |
| outputs=result, | |
| queue=False, | |
| ) | |
| return demo | |
| def create_upload_demo() -> gr.Blocks: | |
| with gr.Blocks() as demo: | |
| model_name = gr.Textbox(label="Model Name") | |
| hf_token = gr.Textbox(label="Hugging Face Token (with write permission)") | |
| upload_button = gr.Button("Upload") | |
| with gr.Box(): | |
| gr.Markdown("Message") | |
| result = gr.Markdown() | |
| gr.Markdown( | |
| """ | |
| - You can upload your trained model to your private Model repo (i.e. https://huggingface.co/{your_username}/{model_name}). | |
| - You can find your Hugging Face token [here](https://huggingface.co/settings/tokens). | |
| """ | |
| ) | |
| upload_button.click(fn=upload, inputs=[model_name, hf_token], outputs=result) | |
| return demo | |
| pipe = InferencePipeline() | |
| trainer = Trainer() | |
| with gr.Blocks(css="style.css") as demo: | |
| if os.getenv("IS_SHARED_UI"): | |
| show_warning(SHARED_UI_WARNING) | |
| if not torch.cuda.is_available(): | |
| show_warning(CUDA_NOT_AVAILABLE_WARNING) | |
| gr.Markdown(TITLE) | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Tabs(): | |
| with gr.TabItem("Train"): | |
| create_training_demo(trainer, pipe) | |
| with gr.TabItem("Test"): | |
| create_inference_demo(pipe) | |
| with gr.TabItem("Upload"): | |
| create_upload_demo() | |
| demo.queue(default_enabled=False).launch(share=False) | |