|
|
""" |
|
|
Real Model Loader for Hugging Face Models |
|
|
Manages model loading, caching, and inference |
|
|
Works with public HuggingFace models without requiring authentication |
|
|
""" |
|
|
|
|
|
import os |
|
|
import logging |
|
|
from typing import Dict, Any, Optional, List |
|
|
from functools import lru_cache |
|
|
|
|
|
|
|
|
import torch |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModel, |
|
|
AutoModelForSequenceClassification, |
|
|
AutoModelForTokenClassification, |
|
|
pipeline |
|
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN", None) |
|
|
|
|
|
if HF_TOKEN: |
|
|
logger.info("HF_TOKEN found - will use for gated models if needed") |
|
|
else: |
|
|
logger.info("HF_TOKEN not found - using public models only (this is normal)") |
|
|
|
|
|
|
|
|
class ModelLoader: |
|
|
""" |
|
|
Manages loading and caching of Hugging Face models |
|
|
Implements lazy loading and GPU optimization |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
"""Initialize the model loader with GPU support if available""" |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.loaded_models = {} |
|
|
self.model_configs = self._get_model_configs() |
|
|
|
|
|
|
|
|
logger.info(f"Model Loader initialized on device: {self.device}") |
|
|
logger.info(f"PyTorch version: {torch.__version__}") |
|
|
logger.info(f"CUDA available: {torch.cuda.is_available()}") |
|
|
|
|
|
|
|
|
logger.info(f"Model configurations loaded: {len(self.model_configs)} models") |
|
|
for key in self.model_configs: |
|
|
logger.info(f" - {key}: {self.model_configs[key]['model_id']}") |
|
|
|
|
|
def _get_model_configs(self) -> Dict[str, Dict[str, Any]]: |
|
|
""" |
|
|
Configuration for real Hugging Face models |
|
|
Maps tasks to actual model names on Hugging Face Hub |
|
|
""" |
|
|
return { |
|
|
|
|
|
"document_classifier": { |
|
|
"model_id": "emilyalsentzer/Bio_ClinicalBERT", |
|
|
"task": "text-classification", |
|
|
"description": "Clinical document type classification" |
|
|
}, |
|
|
|
|
|
|
|
|
"clinical_ner": { |
|
|
"model_id": "d4data/biomedical-ner-all", |
|
|
"task": "ner", |
|
|
"description": "Biomedical named entity recognition" |
|
|
}, |
|
|
|
|
|
|
|
|
"clinical_generation": { |
|
|
"model_id": "microsoft/BioGPT-Large", |
|
|
"task": "text-generation", |
|
|
"description": "Clinical text generation and summarization" |
|
|
}, |
|
|
|
|
|
|
|
|
"medical_qa": { |
|
|
"model_id": "deepset/roberta-base-squad2", |
|
|
"task": "question-answering", |
|
|
"description": "Medical question answering" |
|
|
}, |
|
|
|
|
|
|
|
|
"general_medical": { |
|
|
"model_id": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", |
|
|
"task": "feature-extraction", |
|
|
"description": "General medical text understanding" |
|
|
}, |
|
|
|
|
|
|
|
|
"drug_interaction": { |
|
|
"model_id": "allenai/scibert_scivocab_uncased", |
|
|
"task": "feature-extraction", |
|
|
"description": "Drug interaction detection" |
|
|
}, |
|
|
|
|
|
|
|
|
"radiology_generation": { |
|
|
"model_id": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", |
|
|
"task": "feature-extraction", |
|
|
"description": "Radiology report analysis" |
|
|
}, |
|
|
|
|
|
|
|
|
"clinical_summarization": { |
|
|
"model_id": "google/bigbird-pegasus-large-pubmed", |
|
|
"task": "summarization", |
|
|
"description": "Clinical document summarization" |
|
|
} |
|
|
} |
|
|
|
|
|
def load_model(self, model_key: str) -> Optional[Any]: |
|
|
""" |
|
|
Load a model by key, with caching |
|
|
|
|
|
Most HuggingFace models are public and don't require authentication. |
|
|
HF_TOKEN is only needed for private/gated models. |
|
|
""" |
|
|
try: |
|
|
|
|
|
if model_key in self.loaded_models: |
|
|
logger.info(f"Using cached model: {model_key}") |
|
|
return self.loaded_models[model_key] |
|
|
|
|
|
|
|
|
if model_key not in self.model_configs: |
|
|
logger.warning(f"Unknown model key: {model_key}, using fallback") |
|
|
model_key = "general_medical" |
|
|
|
|
|
config = self.model_configs[model_key] |
|
|
model_id = config["model_id"] |
|
|
task = config["task"] |
|
|
|
|
|
logger.info(f"Loading model: {model_id} for task: {task}") |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
pipeline_kwargs = { |
|
|
"task": task, |
|
|
"model": model_id, |
|
|
"device": 0 if self.device == "cuda" else -1, |
|
|
"trust_remote_code": True |
|
|
} |
|
|
|
|
|
|
|
|
if HF_TOKEN: |
|
|
pipeline_kwargs["token"] = HF_TOKEN |
|
|
|
|
|
model_pipeline = pipeline(**pipeline_kwargs) |
|
|
|
|
|
self.loaded_models[model_key] = model_pipeline |
|
|
logger.info(f"Successfully loaded model: {model_id}") |
|
|
return model_pipeline |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = str(e).lower() |
|
|
|
|
|
|
|
|
if "401" in error_msg or "unauthorized" in error_msg or "authentication" in error_msg: |
|
|
if not HF_TOKEN: |
|
|
logger.error(f"Model {model_id} requires authentication but HF_TOKEN not available") |
|
|
logger.error("This model is gated/private. Using public alternative or fallback.") |
|
|
else: |
|
|
logger.error(f"Model {model_id} authentication failed even with HF_TOKEN") |
|
|
else: |
|
|
logger.error(f"Failed to load model {model_id}: {str(e)}") |
|
|
|
|
|
|
|
|
try: |
|
|
logger.info(f"Trying alternative loading method for {model_id}...") |
|
|
|
|
|
tokenizer_kwargs = {"model_id": model_id, "trust_remote_code": True} |
|
|
model_kwargs = {"pretrained_model_name_or_path": model_id, "trust_remote_code": True} |
|
|
|
|
|
if HF_TOKEN: |
|
|
tokenizer_kwargs["token"] = HF_TOKEN |
|
|
model_kwargs["token"] = HF_TOKEN |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(**tokenizer_kwargs) |
|
|
model = AutoModel.from_pretrained(**model_kwargs).to(self.device) |
|
|
|
|
|
self.loaded_models[model_key] = { |
|
|
"tokenizer": tokenizer, |
|
|
"model": model, |
|
|
"type": "custom" |
|
|
} |
|
|
logger.info(f"Successfully loaded {model_id} with alternative method") |
|
|
return self.loaded_models[model_key] |
|
|
|
|
|
except Exception as inner_e: |
|
|
logger.error(f"Alternative loading also failed for {model_id}: {str(inner_e)}") |
|
|
logger.info(f"Model {model_key} unavailable - will use fallback analysis") |
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Model loading failed for {model_key}: {str(e)}") |
|
|
return None |
|
|
|
|
|
def run_inference( |
|
|
self, |
|
|
model_key: str, |
|
|
input_text: str, |
|
|
task_params: Optional[Dict[str, Any]] = None |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Run inference on loaded model |
|
|
""" |
|
|
try: |
|
|
model = self.load_model(model_key) |
|
|
|
|
|
if model is None: |
|
|
return { |
|
|
"error": "Model not available", |
|
|
"model_key": model_key |
|
|
} |
|
|
|
|
|
task_params = task_params or {} |
|
|
|
|
|
|
|
|
if hasattr(model, '__call__') and not isinstance(model, dict): |
|
|
|
|
|
max_length = task_params.get("max_length", 512) |
|
|
|
|
|
result = model( |
|
|
input_text[:4000], |
|
|
max_length=max_length, |
|
|
truncation=True, |
|
|
**task_params |
|
|
) |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"result": result, |
|
|
"model_key": model_key |
|
|
} |
|
|
|
|
|
|
|
|
elif isinstance(model, dict) and model.get("type") == "custom": |
|
|
tokenizer = model["tokenizer"] |
|
|
model_obj = model["model"] |
|
|
|
|
|
inputs = tokenizer( |
|
|
input_text[:512], |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=512 |
|
|
).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model_obj(**inputs) |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"result": { |
|
|
"embeddings": outputs.last_hidden_state.mean(dim=1).cpu().tolist(), |
|
|
"pooled": outputs.pooler_output.cpu().tolist() if hasattr(outputs, 'pooler_output') else None |
|
|
}, |
|
|
"model_key": model_key |
|
|
} |
|
|
|
|
|
else: |
|
|
return { |
|
|
"error": "Unknown model type", |
|
|
"model_key": model_key |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Inference failed for {model_key}: {str(e)}") |
|
|
return { |
|
|
"error": str(e), |
|
|
"model_key": model_key |
|
|
} |
|
|
|
|
|
def clear_cache(self, model_key: Optional[str] = None): |
|
|
"""Clear model cache to free memory""" |
|
|
if model_key: |
|
|
if model_key in self.loaded_models: |
|
|
del self.loaded_models[model_key] |
|
|
logger.info(f"Cleared cache for model: {model_key}") |
|
|
else: |
|
|
self.loaded_models.clear() |
|
|
logger.info("Cleared all model caches") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
def test_model_loading(self) -> Dict[str, Any]: |
|
|
"""Test loading all configured models to verify AI functionality""" |
|
|
results = { |
|
|
"total_models": len(self.model_configs), |
|
|
"models_loaded": 0, |
|
|
"models_failed": 0, |
|
|
"errors": [], |
|
|
"device": self.device, |
|
|
"pytorch_version": torch.__version__ |
|
|
} |
|
|
|
|
|
for model_key, config in self.model_configs.items(): |
|
|
try: |
|
|
logger.info(f"Testing model: {model_key} ({config['model_id']})") |
|
|
|
|
|
|
|
|
test_input = "Test ECG analysis request" |
|
|
result = self.run_inference(model_key, test_input, {"max_new_tokens": 50}) |
|
|
|
|
|
if result.get("success"): |
|
|
results["models_loaded"] += 1 |
|
|
logger.info(f"✅ {model_key}: Loaded successfully") |
|
|
else: |
|
|
results["models_failed"] += 1 |
|
|
error_msg = result.get("error", "Unknown error") |
|
|
results["errors"].append(f"{model_key}: {error_msg}") |
|
|
logger.warning(f"⚠️ {model_key}: {error_msg}") |
|
|
|
|
|
except Exception as e: |
|
|
results["models_failed"] += 1 |
|
|
error_msg = f"Exception during loading: {str(e)}" |
|
|
results["errors"].append(f"{model_key}: {error_msg}") |
|
|
logger.error(f"❌ {model_key}: {error_msg}") |
|
|
|
|
|
logger.info(f"Model loading test complete: {results['models_loaded']}/{results['total_models']} successful") |
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
_model_loader = None |
|
|
|
|
|
|
|
|
def get_model_loader() -> ModelLoader: |
|
|
"""Get singleton model loader instance""" |
|
|
global _model_loader |
|
|
if _model_loader is None: |
|
|
_model_loader = ModelLoader() |
|
|
return _model_loader |
|
|
|