Spaces:
Runtime error
Runtime error
| import torch | |
| import torchvision.transforms as transforms | |
| import gradio as gr | |
| from PIL import Image | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def get_model_name(name, batch_size, learning_rate, epoch): | |
| """ Generate a name for the model consisting of all the hyperparameter values | |
| Args: | |
| config: Configuration object containing the hyperparameters | |
| Returns: | |
| path: A string with the hyperparameter name and value concatenated | |
| """ | |
| path = "model_{0}_bs{1}_lr{2}_epoch{3}".format(name, | |
| batch_size, | |
| learning_rate, | |
| epoch) | |
| return path | |
| class LargeNet(nn.Module): | |
| def __init__(self): | |
| super(LargeNet, self).__init__() | |
| self.name = "large" | |
| self.conv1 = nn.Conv2d(3, 5, 5) | |
| self.pool = nn.MaxPool2d(2, 2) | |
| self.conv2 = nn.Conv2d(5, 10, 5) | |
| self.fc1 = nn.Linear(10 * 29 * 29, 32) | |
| self.fc2 = nn.Linear(32, 8) | |
| def forward(self, x): | |
| x = self.pool(F.relu(self.conv1(x))) | |
| x = self.pool(F.relu(self.conv2(x))) | |
| x = x.view(-1, 10 * 29 * 29) | |
| x = F.relu(self.fc1(x)) | |
| x = self.fc2(x) | |
| x = x.squeeze(1) # Flatten to [batch_size] | |
| return x | |
| transform = transforms.Compose([ | |
| transforms.Resize((128, 128)), # Resize to 128x128 | |
| transforms.ToTensor(), # Convert to Tensor | |
| transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1] | |
| ]) | |
| def load_model(): | |
| net = LargeNet() #small or large network | |
| model_path = get_model_name(net.name, batch_size=128, learning_rate=0.001, epoch=29) | |
| state = torch.load(model_path) | |
| net.load_state_dict(state) | |
| net.eval() | |
| return net | |
| class_names = ["Gasoline_Can", "Pebbels", "pliers", "Screw_Driver", "Toolbox", "Wrench", "other"] | |
| def predict(image): | |
| model = load_model() | |
| image = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| output = model(image) | |
| _, pred = torch.max(output, 1) | |
| return class_names[pred.item()] | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs="label", | |
| title="Mechanical Tools Classifier", | |
| description="Upload an image to classify it as one of the mechanical tools." | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch() | |