| | from typing import Dict, List, Any |
| | import io |
| | import torch |
| | import timm |
| | from timm.data import resolve_data_config |
| | from timm.data.transforms_factory import create_transform |
| | from PIL import Image |
| | import gc |
| | import os |
| | import base64 |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | """ |
| | Initialize the endpoint handler with the OME detection model. |
| | |
| | Args: |
| | path (str): Path to the model weights (can be local or HF Hub ID) |
| | """ |
| | |
| | self.device = torch.device("cpu") |
| | |
| | |
| | |
| | if os.path.isdir(path) and os.path.exists(os.path.join(path, "pytorch_model.bin")): |
| | |
| | print(f"Loading model from local path: {path}") |
| | self.model = timm.create_model("inception_v4", num_classes=1) |
| | |
| | |
| | state_dict_path = os.path.join(path, "pytorch_model.bin") |
| | state_dict = torch.load(state_dict_path, map_location=self.device) |
| | self.model.load_state_dict(state_dict) |
| | else: |
| | |
| | print(f"Loading model from Hugging Face Hub: Thaweewat/inception_512_augv1") |
| | self.model = timm.create_model("hf_hub:Thaweewat/inception_512_augv1", pretrained=True) |
| | |
| | self.model.to(self.device) |
| | self.model.eval() |
| | |
| | |
| | self.config = resolve_data_config({}, model=self.model) |
| | |
| | |
| | torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| | gc.collect() |
| | |
| | def preprocess_image(self, image): |
| | """ |
| | Preprocess the image for model inference. |
| | |
| | Args: |
| | image (PIL.Image): Input image |
| | |
| | Returns: |
| | torch.Tensor: Preprocessed image tensor |
| | """ |
| | |
| | width, height = image.size |
| | |
| | |
| | crop_size = min(width, height) |
| | |
| | |
| | left = (width - crop_size) // 2 |
| | top = (height - crop_size) // 2 |
| | right = left + crop_size |
| | bottom = top + crop_size |
| | |
| | |
| | image = image.crop((left, top, right, bottom)) |
| | |
| | |
| | if crop_size != 512: |
| | image = image.resize((512, 512), Image.LANCZOS) |
| | |
| | |
| | image = image.convert('RGB') |
| | |
| | |
| | transform = create_transform(**self.config) |
| | |
| | return transform(image) |
| | |
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | """ |
| | Process the input data and return predictions. |
| | |
| | Args: |
| | data (Dict[str, Any]): Input data containing either: |
| | - "inputs": Base64 encoded image or URL |
| | |
| | Returns: |
| | List[Dict[str, Any]]: Prediction results in format required by HF Endpoints |
| | [{"label": "OME", "score": float}] |
| | """ |
| | try: |
| | |
| | if "inputs" in data: |
| | inputs = data["inputs"] |
| | |
| | |
| | if isinstance(inputs, str) and (inputs.startswith('http://') or inputs.startswith('https://')): |
| | import requests |
| | response = requests.get(inputs) |
| | image = Image.open(io.BytesIO(response.content)) |
| | elif isinstance(inputs, str): |
| | |
| | try: |
| | image_bytes = base64.b64decode(inputs) |
| | image = Image.open(io.BytesIO(image_bytes)) |
| | except Exception as e: |
| | print(f"Error decoding base64: {e}") |
| | |
| | try: |
| | image = Image.open(inputs) |
| | except Exception as e2: |
| | print(f"Error opening as file: {e2}") |
| | return [{"label": "OME", "score": 0.0}] |
| | elif isinstance(inputs, bytes): |
| | |
| | image = Image.open(io.BytesIO(inputs)) |
| | elif isinstance(inputs, Image.Image): |
| | |
| | image = inputs |
| | else: |
| | print(f"Unsupported input type: {type(inputs)}") |
| | return [{"label": "OME", "score": 0.0}] |
| | else: |
| | print("No 'inputs' found in data") |
| | return [{"label": "OME", "score": 0.0}] |
| | |
| | |
| | image_tensor = self.preprocess_image(image) |
| | |
| | |
| | with torch.no_grad(): |
| | image_tensor = image_tensor.unsqueeze(0).to(self.device) |
| | output = self.model(image_tensor) |
| | |
| | |
| | if isinstance(output, tuple): |
| | |
| | output = output[0] |
| | |
| | |
| | if output.ndim > 1 and output.shape[1] > 1: |
| | |
| | output = output[:, 0] |
| | |
| | prediction = torch.sigmoid(output).item() |
| | |
| | |
| | del image_tensor |
| | torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| | gc.collect() |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | ome_score = float(1 - prediction) |
| | |
| | |
| | return [{"label": "OME", "score": ome_score}] |
| | |
| | except Exception as e: |
| | print(f"Error processing image: {str(e)}") |
| | import traceback |
| | traceback.print_exc() |
| | return [{"label": "OME", "score": 0.0}] |
| |
|