inception_512_augv1 / handler.py
Thaweewat's picture
Update handler.py
bc5a8a0 verified
from typing import Dict, List, Any
import io
import torch
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from PIL import Image
import gc
import os
import base64
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the endpoint handler with the OME detection model.
Args:
path (str): Path to the model weights (can be local or HF Hub ID)
"""
# Set device to CPU to reduce memory usage
self.device = torch.device("cpu")
# Check if we're running in the Hugging Face Endpoints environment
# In HF Endpoints, the model is loaded from the local repository directory
if os.path.isdir(path) and os.path.exists(os.path.join(path, "pytorch_model.bin")):
# Load model from local files
print(f"Loading model from local path: {path}")
self.model = timm.create_model("inception_v4", num_classes=1)
# Load state dict
state_dict_path = os.path.join(path, "pytorch_model.bin")
state_dict = torch.load(state_dict_path, map_location=self.device)
self.model.load_state_dict(state_dict)
else:
# Use the Hugging Face Hub ID
print(f"Loading model from Hugging Face Hub: Thaweewat/inception_512_augv1")
self.model = timm.create_model("hf_hub:Thaweewat/inception_512_augv1", pretrained=True)
self.model.to(self.device)
self.model.eval()
# Get model configuration for preprocessing
self.config = resolve_data_config({}, model=self.model)
# Free up memory
torch.cuda.empty_cache() if torch.cuda.is_available() else None
gc.collect()
def preprocess_image(self, image):
"""
Preprocess the image for model inference.
Args:
image (PIL.Image): Input image
Returns:
torch.Tensor: Preprocessed image tensor
"""
# First, resize and crop to 512x512
width, height = image.size
# Determine the size to crop (take the smaller dimension)
crop_size = min(width, height)
# Calculate crop coordinates to center the crop
left = (width - crop_size) // 2
top = (height - crop_size) // 2
right = left + crop_size
bottom = top + crop_size
# Crop the image to a square
image = image.crop((left, top, right, bottom))
# Resize to 512x512 if not already that size
if crop_size != 512:
image = image.resize((512, 512), Image.LANCZOS)
# Convert to RGB if not already
image = image.convert('RGB')
# Use timm's transform which is configured for the specific model
transform = create_transform(**self.config)
return transform(image)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process the input data and return predictions.
Args:
data (Dict[str, Any]): Input data containing either:
- "inputs": Base64 encoded image or URL
Returns:
List[Dict[str, Any]]: Prediction results in format required by HF Endpoints
[{"label": "OME", "score": float}]
"""
try:
# Get image data from various possible input formats
if "inputs" in data:
inputs = data["inputs"]
# Check if input is a URL
if isinstance(inputs, str) and (inputs.startswith('http://') or inputs.startswith('https://')):
import requests
response = requests.get(inputs)
image = Image.open(io.BytesIO(response.content))
elif isinstance(inputs, str):
# Assume base64 encoded image
try:
image_bytes = base64.b64decode(inputs)
image = Image.open(io.BytesIO(image_bytes))
except Exception as e:
print(f"Error decoding base64: {e}")
# Try to open as file path
try:
image = Image.open(inputs)
except Exception as e2:
print(f"Error opening as file: {e2}")
return [{"label": "OME", "score": 0.0}]
elif isinstance(inputs, bytes):
# Handle binary data directly
image = Image.open(io.BytesIO(inputs))
elif isinstance(inputs, Image.Image):
# Handle PIL Image directly
image = inputs
else:
print(f"Unsupported input type: {type(inputs)}")
return [{"label": "OME", "score": 0.0}]
else:
print("No 'inputs' found in data")
return [{"label": "OME", "score": 0.0}]
# Preprocess image
image_tensor = self.preprocess_image(image)
# Make prediction with memory optimization
with torch.no_grad(): # Disable gradient calculation to save memory
image_tensor = image_tensor.unsqueeze(0).to(self.device)
output = self.model(image_tensor)
# Handle different output formats
if isinstance(output, tuple):
# Some models return multiple outputs
output = output[0]
# Check output shape and get the first element if needed
if output.ndim > 1 and output.shape[1] > 1:
# If output has multiple classes, take the first one
output = output[:, 0]
prediction = torch.sigmoid(output).item()
# Free memory
del image_tensor
torch.cuda.empty_cache() if torch.cuda.is_available() else None
gc.collect()
# Always return "OME" as the label, but with the appropriate score
# Note the reversed logic based on the model's behavior:
# High scores (close to 1.0) indicate a normal ear (no OME) -> low OME score
# Low scores (close to 0.0) indicate presence of OME -> high OME score
# Use the reversed score (1-prediction) as the confidence for OME
# This gives high scores when OME is likely and low scores when OME is unlikely
ome_score = float(1 - prediction)
# Always return "OME" as the label with the appropriate score
return [{"label": "OME", "score": ome_score}]
except Exception as e:
print(f"Error processing image: {str(e)}")
import traceback
traceback.print_exc()
return [{"label": "OME", "score": 0.0}]