Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import spaces | |
| import re | |
| import yaml | |
| import tempfile | |
| import subprocess | |
| from pathlib import Path | |
| from dataclasses import dataclass | |
| import gradio as gr | |
| from src.flux.xflux_pipeline import XFluxPipeline | |
| from huggingface_hub import login | |
| import multiprocessing | |
| # Set the multiprocessing start method globally | |
| multiprocessing.set_start_method('spawn', force=True) | |
| hf_token = os.getenv("HF_TOKEN") | |
| if hf_token: | |
| login(token=hf_token) | |
| else: | |
| print("No Hugging Face token found.") | |
| class Config: | |
| name: str = "flux-dev" | |
| device: str = "cpu" | |
| offload: bool = False | |
| share: bool = False | |
| ckpt_dir: str = "." | |
| xflux_pipeline = XFluxPipeline(Config.name, Config.device, Config.offload) | |
| xflux_pipeline.to(device='cuda' if torch.cuda.is_available() else 'cpu') | |
| def generate(**kwargs): | |
| return xflux_pipeline.gradio_generate(**kwargs) | |
| def parse_args() -> Config: | |
| parser = argparse.ArgumentParser(description="Flux") | |
| parser.add_argument("--name", type=str, default="flux-dev", help="Model name") | |
| parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use") | |
| parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use") | |
| parser.add_argument("--share", action="store_true", help="Create a public link to your demo") | |
| parser.add_argument("--ckpt_dir", type=str, default=".", help="Folder with checkpoints in safetensors format") | |
| args = parser.parse_args() | |
| return Config(**vars(args)) | |
| def list_dirs(path): | |
| if path is None or path == "None" or path == "": | |
| return | |
| if not os.path.exists(path): | |
| path = os.path.dirname(path) | |
| if not os.path.exists(path): | |
| return | |
| if not os.path.isdir(path): | |
| path = os.path.dirname(path) | |
| def natural_sort_key(s, regex=re.compile("([0-9]+)")): | |
| return [ | |
| int(text) if text.isdigit() else text.lower() for text in regex.split(s) | |
| ] | |
| subdirs = [ | |
| (item, os.path.join(path, item)) | |
| for item in os.listdir(path) | |
| if os.path.isdir(os.path.join(path, item)) | |
| ] | |
| subdirs = [ | |
| filename | |
| for item, filename in subdirs | |
| if item[0] != "." and item not in ["__pycache__"] | |
| ] | |
| subdirs = sorted(subdirs, key=natural_sort_key) | |
| if os.path.dirname(path) != "": | |
| dirs = [os.path.dirname(path), path] + subdirs | |
| else: | |
| dirs = [path] + subdirs | |
| if os.sep == "\\": | |
| dirs = [d.replace("\\", "/") for d in dirs] | |
| for d in dirs: | |
| yield d | |
| def list_train_data_dirs(): | |
| current_train_data_dir = "." | |
| return list(list_dirs(current_train_data_dir)) | |
| def update_config(d, u): | |
| for k, v in u.items(): | |
| if isinstance(v, dict): | |
| d[k] = update_config(d.get(k, {}), v) | |
| else: | |
| # convert Gradio components to strings | |
| if hasattr(v, 'value'): | |
| d[k] = str(v.value) | |
| else: | |
| try: | |
| d[k] = int(v) | |
| except (TypeError, ValueError): | |
| d[k] = str(v) | |
| return d | |
| def start_lora_training( | |
| data_dir: str, output_dir: str, lr: float, steps: int, rank: int | |
| ): | |
| inputs = { | |
| "data_config": { | |
| "img_dir": data_dir, | |
| }, | |
| "output_dir": output_dir, | |
| "learning_rate": lr, | |
| "rank": rank, | |
| "max_train_steps": steps, | |
| } | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| print(f"Creating folder {output_dir} for the output checkpoint file...") | |
| script_path = Path(__file__).resolve() | |
| config_path = script_path.parent / "train_configs" / "test_lora.yaml" | |
| with open(config_path, 'r') as file: | |
| config = yaml.safe_load(file) | |
| config = update_config(config, inputs) | |
| print("Config file is updated...", config) | |
| with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".yaml") as temp_file: | |
| yaml.dump(config, temp_file, default_flow_style=False) | |
| tmp_config_path = temp_file.name | |
| command = ["accelerate", "launch", "train_flux_lora_deepspeed.py", "--config", tmp_config_path] | |
| result = subprocess.run(command, check=True) | |
| # rRemove the temporary file after the command is run | |
| Path(tmp_config_path).unlink() | |
| return result | |
| def create_demo( | |
| model_type: str, | |
| device: str = "cuda" if torch.cuda.is_available() else "cpu", | |
| offload: bool = False, | |
| ckpt_dir: str = "", | |
| ): | |
| checkpoints = sorted(Path(ckpt_dir).glob("*.safetensors")) | |
| with gr.Blocks() as demo: | |
| gr.Markdown(f"# Flux Adapters by XLabs AI - Model: {model_type}") | |
| with gr.Tab("Inference"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(label="Prompt", value="handsome woman in the city") | |
| with gr.Accordion("Generation Options", open=False): | |
| with gr.Row(): | |
| width = gr.Slider(512, 2048, 1024, step=16, label="Width") | |
| height = gr.Slider(512, 2048, 1024, step=16, label="Height") | |
| neg_prompt = gr.Textbox(label="Negative Prompt", value="bad photo") | |
| with gr.Row(): | |
| num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps") | |
| timestep_to_start_cfg = gr.Slider(1, 50, 1, step=1, label="timestep_to_start_cfg") | |
| with gr.Row(): | |
| guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True) | |
| true_gs = gr.Slider(1.0, 5.0, 3.5, step=0.1, label="True Guidance", interactive=True) | |
| seed = gr.Textbox(-1, label="Seed (-1 for random)") | |
| with gr.Accordion("ControlNet Options", open=False): | |
| control_type = gr.Dropdown(["canny", "hed", "depth"], label="Control type") | |
| control_weight = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="Controlnet weight", interactive=True) | |
| local_path = gr.Dropdown(checkpoints, label="Controlnet Checkpoint", | |
| info="Local Path to Controlnet weights (if no, it will be downloaded from HF)" | |
| ) | |
| controlnet_image = gr.Image(label="Input Controlnet Image", visible=True, interactive=True) | |
| with gr.Accordion("LoRA Options", open=False): | |
| lora_weight = gr.Slider(0.0, 1.0, 0.9, step=0.1, label="LoRA weight", interactive=True) | |
| lora_local_path = gr.Dropdown( | |
| checkpoints, label="LoRA Checkpoint", info="Local Path to Lora weights" | |
| ) | |
| with gr.Accordion("IP Adapter Options", open=False): | |
| image_prompt = gr.Image(label="image_prompt", visible=True, interactive=True) | |
| ip_scale = gr.Slider(0.0, 1.0, 1.0, step=0.1, label="ip_scale") | |
| neg_image_prompt = gr.Image(label="neg_image_prompt", visible=True, interactive=True) | |
| neg_ip_scale = gr.Slider(0.0, 1.0, 1.0, step=0.1, label="neg_ip_scale") | |
| ip_local_path = gr.Dropdown( | |
| checkpoints, label="IP Adapter Checkpoint", | |
| info="Local Path to IP Adapter weights (if no, it will be downloaded from HF)" | |
| ) | |
| generate_btn = gr.Button("Generate") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Generated Image") | |
| download_btn = gr.File(label="Download full-resolution") | |
| inputs = [prompt, image_prompt, controlnet_image, width, height, guidance, | |
| num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt, | |
| neg_image_prompt, timestep_to_start_cfg, control_type, control_weight, | |
| lora_weight, local_path, lora_local_path, ip_local_path | |
| ] | |
| generate_btn.click( | |
| fn=generate, | |
| inputs=inputs, | |
| outputs=[output_image, download_btn], | |
| ) | |
| with gr.Tab("LoRA Finetuning"): | |
| data_dir = gr.Dropdown(list_train_data_dirs(), | |
| label="Training images (directory containing the training images)" | |
| ) | |
| output_dir = gr.Textbox(label="Output Path", value="lora_checkpoint") | |
| with gr.Accordion("Training Options", open=True): | |
| lr = gr.Textbox(label="Learning Rate", value="1e-5") | |
| steps = gr.Slider(10000, 20000, 20000, step=100, label="Train Steps") | |
| rank = gr.Slider(1, 100, 16, step=1, label="LoRa Rank") | |
| training_btn = gr.Button("Start training") | |
| training_btn.click( | |
| fn=start_lora_training, | |
| inputs=[data_dir, output_dir, lr, steps, rank], | |
| outputs=[], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| config = Config() | |
| demo = create_demo(config.name, config.device, config.offload, config.ckpt_dir) | |
| demo.launch(share=True) |