Spaces:
Running
Running
| import spaces | |
| import gradio as gr | |
| from PIL import Image | |
| import math | |
| import io | |
| import base64 | |
| import subprocess | |
| import os | |
| from concept_attention import ConceptAttentionFluxPipeline | |
| import os | |
| from huggingface_hub import login | |
| # Load token from environment | |
| hf_token = os.getenv("AccessToken") | |
| # Log in with the token (this sets it globally in the session) | |
| login(token=hf_token) | |
| IMG_SIZE = 210 | |
| COLUMNS = 5 | |
| def update_default_concepts(prompt): | |
| default_concepts = { | |
| "A dog by a tree": ["dog", "grass", "tree", "background"], | |
| "A man on the beach": ["man", "dirt", "ocean", "sky"], | |
| "A hot air balloon": ["balloon", "sky", "water", "tree"] | |
| } | |
| return gr.update(value=default_concepts.get(prompt, [])) | |
| pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell")# , offload_model=True) # , device="cuda:0") # , offload_model=True) | |
| def convert_pil_to_bytes(img): | |
| img = img.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) | |
| buffered = io.BytesIO() | |
| img.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return img_str | |
| def encode_image(image, prompt, concepts, seed, layer_start_index, noise_timestep, num_samples): | |
| try: | |
| if not prompt: | |
| prompt = "" | |
| prompt = prompt.strip() | |
| if len(concepts) == 0: | |
| raise gr.Error("Please enter at least 1 concept", duration=10) | |
| if len(concepts) > 9: | |
| raise gr.Error("Please enter at most 9 concepts", duration=10) | |
| image = image.convert("RGB") | |
| pipeline_output = pipeline.encode_image( | |
| image=image, | |
| prompt=prompt, | |
| concepts=concepts, | |
| width=1024, | |
| height=1024, | |
| seed=seed, | |
| num_samples=num_samples, | |
| noise_timestep=noise_timestep, | |
| num_steps=4, | |
| layer_indices=list(range(layer_start_index, 19)), | |
| softmax=True if len(concepts) > 1 else False | |
| ) | |
| output_image = pipeline_output.image | |
| output_space_heatmaps = pipeline_output.concept_heatmaps | |
| output_space_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in output_space_heatmaps] | |
| output_space_maps_and_labels = [(output_space_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))] | |
| cross_attention_heatmaps = pipeline_output.cross_attention_maps | |
| cross_attention_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in cross_attention_heatmaps] | |
| cross_attention_maps_and_labels = [] | |
| prompt_tokens = prompt.split() | |
| for concept_index in range(len(concepts)): | |
| concept = concepts[concept_index] | |
| if concept in prompt_tokens: | |
| cross_attention_maps_and_labels.append( | |
| (cross_attention_heatmaps[concept_index], concept) | |
| ) | |
| else: | |
| # Exclude this concept because it is only generated due to ConceptAttention's causal attention mechanism | |
| empty_image = Image.new("RGB", (IMG_SIZE, IMG_SIZE), (39, 39, 42)) | |
| cross_attention_maps_and_labels.append( | |
| (empty_image, concept) | |
| ) | |
| return output_image, \ | |
| gr.update(value=output_space_maps_and_labels, columns=len(output_space_maps_and_labels)), \ | |
| gr.update(value=cross_attention_maps_and_labels, columns=len(cross_attention_maps_and_labels)) | |
| except gr.Error as e: | |
| return None, gr.update(value=[], columns=1) # , gr.update(value=[], columns=1) | |
| def generate_image(prompt, concepts, seed, layer_start_index, timestep_start_index): | |
| try: | |
| if not prompt: | |
| raise gr.Error("Please enter a prompt", duration=10) | |
| if not prompt.strip(): | |
| raise gr.Error("Please enter a prompt", duration=10) | |
| prompt = prompt.strip() | |
| if len(concepts) == 0: | |
| raise gr.Error("Please enter at least 1 concept", duration=10) | |
| if len(concepts) > 9: | |
| raise gr.Error("Please enter at most 9 concepts", duration=10) | |
| pipeline_output = pipeline.generate_image( | |
| prompt=prompt, | |
| concepts=concepts, | |
| width=1024, | |
| height=1024, | |
| seed=seed, | |
| timesteps=list(range(timestep_start_index, 4)), | |
| num_inference_steps=4, | |
| layer_indices=list(range(layer_start_index, 19)), | |
| softmax=True if len(concepts) > 1 else False | |
| ) | |
| output_image = pipeline_output.image | |
| output_space_heatmaps = pipeline_output.concept_heatmaps | |
| output_space_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in output_space_heatmaps] | |
| output_space_maps_and_labels = [(output_space_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))] | |
| cross_attention_heatmaps = pipeline_output.cross_attention_maps | |
| cross_attention_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in cross_attention_heatmaps] | |
| cross_attention_maps_and_labels = [] | |
| prompt_tokens = prompt.split() | |
| for concept_index in range(len(concepts)): | |
| concept = concepts[concept_index] | |
| if concept in prompt_tokens: | |
| cross_attention_maps_and_labels.append( | |
| (cross_attention_heatmaps[concept_index], concept) | |
| ) | |
| else: | |
| # Exclude this concept because it is only generated due to ConceptAttention's causal attention mechanism | |
| empty_image = Image.new("RGB", (IMG_SIZE, IMG_SIZE), (39, 39, 42)) | |
| cross_attention_maps_and_labels.append( | |
| (empty_image, concept) | |
| ) | |
| return output_image, \ | |
| gr.update(value=output_space_maps_and_labels, columns=len(output_space_maps_and_labels)), \ | |
| gr.update(value=cross_attention_maps_and_labels, columns=len(cross_attention_maps_and_labels)) | |
| except gr.Error as e: | |
| return None, gr.update(value=[], columns=1), gr.update(value=[], columns=1) | |
| with gr.Blocks( | |
| css=""" | |
| .container { | |
| max-width: 1300px; | |
| margin: 0 auto; | |
| padding: 20px; | |
| } | |
| .application { | |
| max-width: 1200px; | |
| } | |
| .generated-image { | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| height: 100%; /* Ensures full height */ | |
| } | |
| .input { | |
| height: 47px; | |
| } | |
| .input-column-label {} | |
| .gallery { | |
| height: 220px; | |
| } | |
| .run-button-column { | |
| width: 100px !important; | |
| } | |
| .gallery-container { | |
| scrollbar-width: thin; | |
| scrollbar-color: grey black; | |
| } | |
| @media (min-width: 1280px) { | |
| .svg-container { | |
| min-width: 250px; | |
| display: flex; | |
| flex-direction: column; | |
| padding-top: 340px; | |
| } | |
| .callout { | |
| width: 250px; | |
| } | |
| .input-row { | |
| height: 100px; | |
| } | |
| .input-column { | |
| flex-direction: column; | |
| gap: 0px; | |
| height: 100%; | |
| } | |
| } | |
| @media (max-width: 1280px) { | |
| .svg-container { | |
| display: none !important; | |
| } | |
| .callout { | |
| display: none; | |
| } | |
| } | |
| /* | |
| @media (max-width: 1024px) { | |
| .svg-container { | |
| display: none !important; | |
| display: flex; | |
| flex-direction: column; | |
| } | |
| .callout { | |
| display: none; | |
| } | |
| } | |
| */ | |
| .header { | |
| display: flex; | |
| flex-direction: column; | |
| } | |
| #title { | |
| font-size: 4.4em; | |
| color: #F3B13E; | |
| text-align: center; | |
| margin: 5px; | |
| } | |
| #subtitle { | |
| font-size: 3.0em; | |
| color: #FAE2BA; | |
| text-align: center; | |
| margin: 5px; | |
| } | |
| #abstract { | |
| text-align: center; | |
| font-size: 2.0em; | |
| color:rgb(219, 219, 219); | |
| margin: 5px; | |
| margin-top: 10px; | |
| } | |
| #links { | |
| text-align: center; | |
| font-size: 2.0em; | |
| margin: 5px; | |
| } | |
| #links a { | |
| color: #93B7E9; | |
| text-decoration: none; | |
| } | |
| .caption-label { | |
| font-size: 1.15em; | |
| } | |
| .gallery label { | |
| font-size: 1.15em; | |
| } | |
| """ | |
| ) as demo: | |
| # with gr.Column(elem_classes="container"): | |
| with gr.Row(elem_classes="container", scale=8): | |
| with gr.Column(elem_classes="application-content", scale=10): | |
| with gr.Row(scale=3, elem_classes="header"): | |
| gr.HTML(""" | |
| <h1 id='title'> ConceptAttention </h1> | |
| <h1 id='subtitle'> Visualize Any Concepts in Your Generated Images </h1> | |
| <h1 id='abstract'> Interpret diffusion models with precise, high-quality heatmaps. </h1> | |
| <h1 id='links'> <a href='https://arxiv.org/abs/2502.04320'> Paper </a> | <a href='https://github.com/helblazer811/ConceptAttention'> Code </a> </h1> | |
| """) | |
| with gr.Tab(label="Generate Image"): | |
| with gr.Row(elem_classes="input-row", scale=2): | |
| with gr.Column(scale=4, elem_classes="input-column", min_width=250): | |
| gr.HTML( | |
| "Write a Prompt", | |
| elem_classes="input-column-label" | |
| ) | |
| prompt = gr.Dropdown( | |
| ["A dog by a tree", "A man on the beach", "A hot air balloon"], | |
| container=False, | |
| allow_custom_value=True, | |
| elem_classes="input" | |
| ) | |
| with gr.Column(scale=7, elem_classes="input-column"): | |
| gr.HTML( | |
| "Select or Write Concepts", | |
| elem_classes="input-column-label" | |
| ) | |
| concepts = gr.Dropdown( | |
| ["dog", "grass", "tree", "dragon", "sky", "rock", "cloud", "balloon", "water", "background"], | |
| value=["dog", "grass", "tree", "background"], | |
| multiselect=True, | |
| label="Concepts", | |
| container=False, | |
| allow_custom_value=True, | |
| # scale=4, | |
| elem_classes="input", | |
| max_choices=5 | |
| ) | |
| with gr.Column(scale=1, min_width=100, elem_classes="input-column run-button-column"): | |
| gr.HTML( | |
| "​", | |
| elem_classes="input-column-label" | |
| ) | |
| submit_btn = gr.Button( | |
| "Run", | |
| elem_classes="input" | |
| ) | |
| with gr.Row(elem_classes="gallery-container", scale=8): | |
| with gr.Column(scale=1, min_width=250): | |
| generated_image = gr.Image( | |
| elem_classes="generated-image", | |
| show_label=False, | |
| ) | |
| with gr.Column(scale=4): | |
| concept_attention_gallery = gr.Gallery( | |
| label="Concept Attention (Ours)", | |
| show_label=True, | |
| # columns=3, | |
| rows=1, | |
| object_fit="contain", | |
| height="200px", | |
| elem_classes="gallery", | |
| elem_id="concept-attention-gallery", | |
| # scale=4 | |
| ) | |
| cross_attention_gallery = gr.Gallery( | |
| label="Cross Attention", | |
| show_label=True, | |
| # columns=3, | |
| rows=1, | |
| object_fit="contain", | |
| height="200px", | |
| elem_classes="gallery", | |
| # scale=4 | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42) | |
| layer_start_index = gr.Slider(minimum=0, maximum=18, step=1, label="Layer Start Index", value=10) | |
| timestep_start_index = gr.Slider(minimum=0, maximum=4, step=1, label="Timestep Start Index", value=2) | |
| submit_btn.click( | |
| fn=generate_image, | |
| inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index], | |
| outputs=[generated_image, concept_attention_gallery, cross_attention_gallery] | |
| ) | |
| prompt.change(update_default_concepts, inputs=[prompt], outputs=[concepts]) | |
| # Automatically process the first example on launch | |
| demo.load( | |
| generate_image, | |
| inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index], | |
| outputs=[generated_image, concept_attention_gallery, cross_attention_gallery] | |
| ) | |
| with gr.Tab(label="Explain a Real Image"): | |
| with gr.Row(elem_classes="input-row", scale=2): | |
| with gr.Column(scale=4, elem_classes="input-column", min_width=250): | |
| gr.HTML( | |
| "Write a Prompt (Optional)", | |
| elem_classes="input-column-label" | |
| ) | |
| # prompt = gr.Dropdown( | |
| # ["A dog by a tree", "A man on the beach", "A hot air balloon"], | |
| # container=False, | |
| # allow_custom_value=True, | |
| # elem_classes="input" | |
| # ) | |
| prompt = gr.Textbox( | |
| placeholder="Write a prompt (Optional)", | |
| container=False, | |
| elem_classes="input" | |
| ) | |
| with gr.Column(scale=7, elem_classes="input-column"): | |
| gr.HTML( | |
| "Select or Write Concepts", | |
| elem_classes="input-column-label" | |
| ) | |
| concepts = gr.Dropdown( | |
| ["dog", "grass", "tree", "dragon", "sky", "rock", "cloud", "balloon", "water", "background"], | |
| value=["dog", "grass", "tree", "background"], | |
| multiselect=True, | |
| label="Concepts", | |
| container=False, | |
| allow_custom_value=True, | |
| # scale=4, | |
| elem_classes="input", | |
| max_choices=5 | |
| ) | |
| with gr.Column(scale=1, min_width=100, elem_classes="input-column run-button-column"): | |
| gr.HTML( | |
| "​", | |
| elem_classes="input-column-label" | |
| ) | |
| submit_btn = gr.Button( | |
| "Run", | |
| elem_classes="input" | |
| ) | |
| with gr.Row(elem_classes="gallery-container", scale=8, equal_height=True): | |
| with gr.Column(scale=1, min_width=250): | |
| input_image = gr.Image( | |
| elem_classes="generated-image", | |
| show_label=False, | |
| interactive=True, | |
| type="pil", | |
| image_mode="RGB", | |
| scale=1 | |
| ) | |
| with gr.Column(scale=2): | |
| concept_attention_gallery = gr.Gallery( | |
| label="Concept Attention (Ours)", | |
| show_label=True, | |
| # columns=3, | |
| rows=1, | |
| object_fit="contain", | |
| height="200px", | |
| elem_classes="gallery", | |
| elem_id="concept-attention-gallery", | |
| # scale=4 | |
| ) | |
| cross_attention_gallery = gr.Gallery( | |
| label="Cross Attention", | |
| show_label=True, | |
| # columns=3, | |
| rows=1, | |
| object_fit="contain", | |
| height="200px", | |
| elem_classes="gallery", | |
| # scale=4 | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42) | |
| num_samples = gr.Slider(minimum=1, maximum=10, step=1, label="Number of Samples", value=4) | |
| layer_start_index = gr.Slider(minimum=0, maximum=18, step=1, label="Layer Start Index", value=10) | |
| noise_timestep = gr.Slider(minimum=0, maximum=4, step=1, label="Noise Timestep", value=2) | |
| submit_btn.click( | |
| fn=encode_image, | |
| inputs=[input_image, prompt, concepts, seed, layer_start_index, noise_timestep, num_samples], | |
| outputs=[input_image, concept_attention_gallery, cross_attention_gallery] | |
| ) | |
| # # Automatically process the first example on launch | |
| # demo.load( | |
| # encode_image, | |
| # inputs=[input_image, prompt, concepts, seed, layer_start_index, noise_timestep, num_samples], | |
| # outputs=[input_image, concept_attention_gallery, cross_attention_gallery] | |
| # ) | |
| with gr.Column(scale=2, min_width=200, elem_classes="svg-column"): | |
| with gr.Row(scale=8): | |
| gr.HTML("<div></div>") | |
| with gr.Row(scale=4, elem_classes="svg-container"): | |
| concept_attention_callout_svg = gr.HTML( | |
| "<img src='/gradio_api/file=ConceptAttentionCallout.svg' class='callout'/>", | |
| # container=False, | |
| ) | |
| cross_attention_callout_svg = gr.HTML( | |
| "<img src='/gradio_api/file=CrossAttentionCallout.svg' class='callout'/>", | |
| # container=False, | |
| ) | |
| with gr.Row(scale=4): | |
| gr.HTML("<div></div>") | |
| if __name__ == "__main__": | |
| if os.path.exists("/data-nvme/zerogpu-offload"): | |
| subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True) | |
| demo.launch( | |
| allowed_paths=["."] | |
| ) | |
| # share=True, | |
| # server_name="0.0.0.0", | |
| # inbrowser=True, | |
| # # share=False, | |
| # server_port=6754, | |
| # quiet=True, | |
| # max_threads=1 | |
| # ) | |