| import random |
|
|
| import gradio as gr |
| import torch |
| import torchvision |
| import torchvision.transforms as transforms |
| from PIL import Image |
| from torch import nn |
| from torchvision.models import mobilenet_v2, resnet18 |
| from torchvision.transforms.functional import InterpolationMode |
|
|
| datasets_n_classes = { |
| "Imagenette": 10, |
| "Imagewoof": 10, |
| "Stanford_dogs": 120, |
| } |
|
|
| datasets_model_types = { |
| "Imagenette": [ |
| "base_200", |
| "base_200+100", |
| "synthetic_200", |
| "augment_noisy_200", |
| "augment_noisy_200+100", |
| "augment_clean_200", |
| ], |
| "Imagewoof": [ |
| "base_200", |
| "base_200+100", |
| "synthetic_200", |
| "augment_noisy_200", |
| "augment_noisy_200+100", |
| "augment_clean_200", |
| ], |
| "Stanford_dogs": [ |
| "base_200", |
| "base_200+100", |
| "synthetic_200", |
| "augment_noisy_200", |
| "augment_noisy_200+100", |
| ], |
| } |
|
|
| model_arch = ["resnet18", "mobilenet_v2"] |
|
|
| list_200 = [ |
| "Original", |
| "Synthetic", |
| "Original + Synthetic (Noisy)", |
| "Original + Synthetic (Clean)", |
| ] |
|
|
| list_200_100 = ["Base+100", "AugmentNoisy+100"] |
|
|
| methods_map = { |
| "200 Epochs": list_200, |
| "200 Epochs on Original + 100": list_200_100, |
| } |
|
|
| label_map = dict() |
| label_map["Imagenette (10 classes)"] = "Imagenette" |
| label_map["Imagewoof (10 classes)"] = "Imagewoof" |
| label_map["Stanford Dogs (120 classes)"] = "Stanford_dogs" |
| label_map["ResNet-18"] = "resnet18" |
| label_map["MobileNetV2"] = "mobilenet_v2" |
| label_map["200 Epochs"] = "200" |
| label_map["200 Epochs on Original + 100"] = "200+100" |
| label_map["Original"] = "base" |
| label_map["Synthetic"] = "synthetic" |
| label_map["Original + Synthetic (Noisy)"] = "augment_noisy" |
| label_map["Original + Synthetic (Clean)"] = "augment_clean" |
| label_map["Base+100"] = "base" |
| label_map["AugmentNoisy+100"] = "augment_noisy" |
|
|
| dataset_models = dict() |
| for dataset, n_classes in datasets_n_classes.items(): |
| models = dict() |
| for model_type in datasets_model_types[dataset]: |
| for arch in model_arch: |
| if arch == "resnet18": |
| model = resnet18(weights=None, num_classes=n_classes) |
| models[f"{arch}_{model_type}"] = ( |
| model, |
| f"./models/{arch}/{dataset}/{dataset}_{model_type}.pth", |
| ) |
| elif arch == "mobilenet_v2": |
| model = mobilenet_v2(weights=None, num_classes=n_classes) |
| models[f"{arch}_{model_type}"] = ( |
| model, |
| f"./models/{arch}/{dataset}/{dataset}_{model_type}.pth", |
| ) |
| else: |
| raise ValueError(f"Model architecture unavailable: {arch}") |
| dataset_models[dataset] = models |
|
|
|
|
| def get_random_image(dataset, label_map=label_map) -> Image: |
| dataset_root = f"./data/{label_map[dataset]}/val" |
| dataset_img = torchvision.datasets.ImageFolder( |
| dataset_root, |
| transforms.Compose([transforms.PILToTensor()]), |
| ) |
| random_idx = random.randint(0, len(dataset_img) - 1) |
| image, _ = dataset_img[random_idx] |
| image = transforms.ToPILImage()(image) |
| image = image.resize( |
| (256, 256), |
| ) |
| return image |
|
|
|
|
| def load_model(model_dict, model_name: str) -> nn.Module: |
| model_name_lower = model_name.lower() |
| if model_name_lower in model_dict: |
| model = model_dict[model_name_lower][0] |
| model_path = model_dict[model_name_lower][1] |
| if torch.cuda.is_available(): |
| checkpoint = torch.load(model_path) |
| else: |
| checkpoint = torch.load(model_path, map_location="cpu") |
| if "setup" in checkpoint: |
| if checkpoint["setup"]["distributed"]: |
| torch.nn.modules.utils.consume_prefix_in_state_dict_if_present( |
| checkpoint["model"], "module." |
| ) |
| model.load_state_dict(checkpoint["model"]) |
| else: |
| model.load_state_dict(checkpoint) |
| return model |
| else: |
| raise ValueError( |
| f"Model {model_name} is not available for image prediction. Please choose from {[name.capitalize() for name in model_dict.keys()]}." |
| ) |
|
|
|
|
| def postprocess_default(labels, output) -> dict: |
| probabilities = nn.functional.softmax(output[0], dim=0) |
| top_prob, top_catid = torch.topk(probabilities, 5) |
| confidences = { |
| labels[top_catid.tolist()[i]]: top_prob.tolist()[i] |
| for i in range(top_prob.shape[0]) |
| } |
| return confidences |
|
|
|
|
| def classify( |
| input_image: Image, |
| dataset_type: str, |
| arch_type: str, |
| methods: str, |
| training_ds: str, |
| dataset_models=dataset_models, |
| label_map=label_map, |
| ) -> dict: |
| for i in [dataset_type, arch_type, methods, training_ds]: |
| if i is None: |
| raise ValueError("Please select all options.") |
| dataset_type = label_map[dataset_type] |
| arch_type = label_map[arch_type] |
| methods = label_map[methods] |
| training_ds = label_map[training_ds] |
| preprocess_input = transforms.Compose( |
| [ |
| transforms.Resize( |
| 256, |
| interpolation=InterpolationMode.BILINEAR, |
| antialias=True, |
| ), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ] |
| ) |
| if input_image is None: |
| raise ValueError("No image was provided.") |
| input_tensor: torch.Tensor = preprocess_input(input_image) |
| input_batch = input_tensor.unsqueeze(0) |
| model = load_model( |
| dataset_models[dataset_type], f"{arch_type}_{training_ds}_{methods}" |
| ) |
|
|
| if torch.cuda.is_available(): |
| input_batch = input_batch.to("cuda") |
| model.to("cuda") |
|
|
| model.eval() |
| with torch.inference_mode(): |
| output: torch.Tensor = model(input_batch) |
| with open(f"./data/{dataset_type}.txt", "r") as f: |
| labels = {i: line.strip() for i, line in enumerate(f.readlines())} |
| return postprocess_default(labels, output) |
|
|
|
|
| def update_methods(method, ds_type): |
| if ds_type == "Stanford Dogs (120 classes)" and method == "200 Epochs": |
| methods = list_200[:-1] |
| else: |
| methods = methods_map[method] |
| return gr.update(choices=methods, value=None) |
|
|
|
|
| def downloadModel( |
| dataset_type, arch_type, methods, training_ds, dataset_models=dataset_models |
| ): |
| for i in [dataset_type, arch_type, methods, training_ds]: |
| if i is None: |
| return gr.update(label="Select Model", value=None) |
| dataset_type = label_map[dataset_type] |
| arch_type = label_map[arch_type] |
| methods = label_map[methods] |
| training_ds = label_map[training_ds] |
| if f"{arch_type}_{training_ds}_{methods}" not in dataset_models[dataset_type]: |
| return gr.update(label="Select Model", value=None) |
| model_path = dataset_models[dataset_type][f"{arch_type}_{training_ds}_{methods}"][1] |
| return gr.update( |
| label=f"Download Model: '{dataset_type}_{arch_type}_{training_ds}_{methods}'", |
| value=model_path, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| with gr.Blocks(title="Generative Augmented Image Classifiers") as demo: |
| gr.Markdown( |
| """ |
| # Generative Augmented Image Classifiers |
| Main GitHub Repo: [Generative Data Augmentation](https://github.com/zhulinchng/generative-data-augmentation) | Generative Data Augmentation Demo: [Generative Data Augmented](https://huggingface.co/spaces/czl/generative-data-augmentation-demo). |
| """ |
| ) |
| with gr.Row(): |
| with gr.Column(): |
| dataset_type = gr.Radio( |
| choices=[ |
| "Imagenette (10 classes)", |
| "Imagewoof (10 classes)", |
| "Stanford Dogs (120 classes)", |
| ], |
| label="Dataset", |
| value="Imagenette (10 classes)", |
| ) |
| arch_type = gr.Radio( |
| choices=["ResNet-18", "MobileNetV2"], |
| label="Model Architecture", |
| value="ResNet-18", |
| interactive=True, |
| ) |
| methods = gr.Radio( |
| label="Methods", |
| choices=["200 Epochs", "200 Epochs on Original + 100"], |
| interactive=True, |
| value="200 Epochs", |
| ) |
| training_ds = gr.Radio( |
| label="Training Dataset", |
| choices=methods_map["200 Epochs"], |
| interactive=True, |
| value="Original", |
| ) |
| dataset_type.change( |
| fn=update_methods, |
| inputs=[methods, dataset_type], |
| outputs=[training_ds], |
| ) |
| methods.change( |
| fn=update_methods, |
| inputs=[methods, dataset_type], |
| outputs=[training_ds], |
| ) |
| random_image_output = gr.Image(type="pil", label="Image to Classify") |
| with gr.Row(): |
| generate_button = gr.Button("Sample Random Image") |
| classify_button_random = gr.Button("Classify") |
| with gr.Column(): |
| output_label_random = gr.Label(num_top_classes=5) |
| download_model = gr.DownloadButton( |
| label=f"Download Model: '{label_map[dataset_type.value]}_{label_map[arch_type.value]}_{label_map[training_ds.value]}_{label_map[methods.value]}'", |
| value=dataset_models[label_map[dataset_type.value]][ |
| f"{label_map[arch_type.value]}_{label_map[training_ds.value]}_{label_map[methods.value]}" |
| ][1], |
| ) |
| dataset_type.change( |
| fn=downloadModel, |
| inputs=[dataset_type, arch_type, methods, training_ds], |
| outputs=[download_model], |
| ) |
| arch_type.change( |
| fn=downloadModel, |
| inputs=[dataset_type, arch_type, methods, training_ds], |
| outputs=[download_model], |
| ) |
| methods.change( |
| fn=downloadModel, |
| inputs=[dataset_type, arch_type, methods, training_ds], |
| outputs=[download_model], |
| ) |
| training_ds.change( |
| fn=downloadModel, |
| inputs=[dataset_type, arch_type, methods, training_ds], |
| outputs=[download_model], |
| ) |
| gr.Markdown( |
| """ |
| This demo showcases the performance of image classifiers trained on various datasets as part of the project 'Improving Fine-Grained Image Classification Using Diffusion-Based Generated Synthetic Images' dissertation. |
| |
| View the models and files used in this demo [here](https://huggingface.co/spaces/czl/generative-augmented-classifiers/tree/main). |
| |
| Usage Instructions & Documentation [here](https://huggingface.co/spaces/czl/generative-augmented-classifiers/blob/main/README.md). |
| """ |
| ) |
|
|
| generate_button.click( |
| get_random_image, |
| inputs=[dataset_type], |
| outputs=random_image_output, |
| ) |
| classify_button_random.click( |
| classify, |
| inputs=[random_image_output, dataset_type, arch_type, methods, training_ds], |
| outputs=output_label_random, |
| ) |
| demo.launch(show_error=True) |
|
|