from typing import Dict, List, Any import io import torch import timm from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform from PIL import Image import gc import os import base64 class EndpointHandler: def __init__(self, path=""): """ Initialize the endpoint handler with the OME detection model. Args: path (str): Path to the model weights (can be local or HF Hub ID) """ # Set device to CPU to reduce memory usage self.device = torch.device("cpu") # Check if we're running in the Hugging Face Endpoints environment # In HF Endpoints, the model is loaded from the local repository directory if os.path.isdir(path) and os.path.exists(os.path.join(path, "pytorch_model.bin")): # Load model from local files print(f"Loading model from local path: {path}") self.model = timm.create_model("inception_v4", num_classes=1) # Load state dict state_dict_path = os.path.join(path, "pytorch_model.bin") state_dict = torch.load(state_dict_path, map_location=self.device) self.model.load_state_dict(state_dict) else: # Use the Hugging Face Hub ID print(f"Loading model from Hugging Face Hub: Thaweewat/inception_512_augv1") self.model = timm.create_model("hf_hub:Thaweewat/inception_512_augv1", pretrained=True) self.model.to(self.device) self.model.eval() # Get model configuration for preprocessing self.config = resolve_data_config({}, model=self.model) # Free up memory torch.cuda.empty_cache() if torch.cuda.is_available() else None gc.collect() def preprocess_image(self, image): """ Preprocess the image for model inference. Args: image (PIL.Image): Input image Returns: torch.Tensor: Preprocessed image tensor """ # First, resize and crop to 512x512 width, height = image.size # Determine the size to crop (take the smaller dimension) crop_size = min(width, height) # Calculate crop coordinates to center the crop left = (width - crop_size) // 2 top = (height - crop_size) // 2 right = left + crop_size bottom = top + crop_size # Crop the image to a square image = image.crop((left, top, right, bottom)) # Resize to 512x512 if not already that size if crop_size != 512: image = image.resize((512, 512), Image.LANCZOS) # Convert to RGB if not already image = image.convert('RGB') # Use timm's transform which is configured for the specific model transform = create_transform(**self.config) return transform(image) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Process the input data and return predictions. Args: data (Dict[str, Any]): Input data containing either: - "inputs": Base64 encoded image or URL Returns: List[Dict[str, Any]]: Prediction results in format required by HF Endpoints [{"label": "OME", "score": float}] """ try: # Get image data from various possible input formats if "inputs" in data: inputs = data["inputs"] # Check if input is a URL if isinstance(inputs, str) and (inputs.startswith('http://') or inputs.startswith('https://')): import requests response = requests.get(inputs) image = Image.open(io.BytesIO(response.content)) elif isinstance(inputs, str): # Assume base64 encoded image try: image_bytes = base64.b64decode(inputs) image = Image.open(io.BytesIO(image_bytes)) except Exception as e: print(f"Error decoding base64: {e}") # Try to open as file path try: image = Image.open(inputs) except Exception as e2: print(f"Error opening as file: {e2}") return [{"label": "OME", "score": 0.0}] elif isinstance(inputs, bytes): # Handle binary data directly image = Image.open(io.BytesIO(inputs)) elif isinstance(inputs, Image.Image): # Handle PIL Image directly image = inputs else: print(f"Unsupported input type: {type(inputs)}") return [{"label": "OME", "score": 0.0}] else: print("No 'inputs' found in data") return [{"label": "OME", "score": 0.0}] # Preprocess image image_tensor = self.preprocess_image(image) # Make prediction with memory optimization with torch.no_grad(): # Disable gradient calculation to save memory image_tensor = image_tensor.unsqueeze(0).to(self.device) output = self.model(image_tensor) # Handle different output formats if isinstance(output, tuple): # Some models return multiple outputs output = output[0] # Check output shape and get the first element if needed if output.ndim > 1 and output.shape[1] > 1: # If output has multiple classes, take the first one output = output[:, 0] prediction = torch.sigmoid(output).item() # Free memory del image_tensor torch.cuda.empty_cache() if torch.cuda.is_available() else None gc.collect() # Always return "OME" as the label, but with the appropriate score # Note the reversed logic based on the model's behavior: # High scores (close to 1.0) indicate a normal ear (no OME) -> low OME score # Low scores (close to 0.0) indicate presence of OME -> high OME score # Use the reversed score (1-prediction) as the confidence for OME # This gives high scores when OME is likely and low scores when OME is unlikely ome_score = float(1 - prediction) # Always return "OME" as the label with the appropriate score return [{"label": "OME", "score": ome_score}] except Exception as e: print(f"Error processing image: {str(e)}") import traceback traceback.print_exc() return [{"label": "OME", "score": 0.0}]