|
|
""" |
|
|
Model Versioning and Input Caching System |
|
|
Tracks model versions, performance, and implements intelligent caching |
|
|
|
|
|
Features: |
|
|
- Model version tracking with metadata |
|
|
- Performance metrics per model version |
|
|
- A/B testing framework |
|
|
- Automated rollback capabilities |
|
|
- SHA256 input fingerprinting |
|
|
- Intelligent caching with invalidation |
|
|
- Cache performance analytics |
|
|
|
|
|
Author: MiniMax Agent |
|
|
Date: 2025-10-29 |
|
|
Version: 1.0.0 |
|
|
""" |
|
|
|
|
|
import hashlib |
|
|
import json |
|
|
import logging |
|
|
from typing import Dict, List, Any, Optional, Tuple |
|
|
from datetime import datetime, timedelta |
|
|
from dataclasses import dataclass, asdict |
|
|
from collections import defaultdict, deque |
|
|
from enum import Enum |
|
|
import os |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ModelStatus(Enum): |
|
|
"""Model deployment status""" |
|
|
ACTIVE = "active" |
|
|
TESTING = "testing" |
|
|
DEPRECATED = "deprecated" |
|
|
RETIRED = "retired" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelVersion: |
|
|
"""Model version metadata""" |
|
|
model_id: str |
|
|
version: str |
|
|
model_name: str |
|
|
model_path: str |
|
|
deployment_date: str |
|
|
status: ModelStatus |
|
|
metadata: Dict[str, Any] |
|
|
performance_metrics: Dict[str, float] |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
data = asdict(self) |
|
|
data["status"] = self.status.value |
|
|
return data |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class CacheEntry: |
|
|
"""Cache entry with metadata""" |
|
|
cache_key: str |
|
|
input_hash: str |
|
|
result_data: Dict[str, Any] |
|
|
created_at: str |
|
|
last_accessed: str |
|
|
access_count: int |
|
|
model_version: str |
|
|
size_bytes: int |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
return asdict(self) |
|
|
|
|
|
|
|
|
class ModelRegistry: |
|
|
""" |
|
|
Registry for tracking model versions and performance |
|
|
Supports version comparison and automated rollback |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.models: Dict[str, Dict[str, ModelVersion]] = defaultdict(dict) |
|
|
self.active_versions: Dict[str, str] = {} |
|
|
self.performance_history: Dict[str, deque] = defaultdict(lambda: deque(maxlen=1000)) |
|
|
|
|
|
logger.info("Model Registry initialized") |
|
|
|
|
|
def register_model( |
|
|
self, |
|
|
model_id: str, |
|
|
version: str, |
|
|
model_name: str, |
|
|
model_path: str, |
|
|
metadata: Optional[Dict[str, Any]] = None, |
|
|
set_active: bool = False |
|
|
) -> ModelVersion: |
|
|
"""Register a new model version""" |
|
|
|
|
|
model_version = ModelVersion( |
|
|
model_id=model_id, |
|
|
version=version, |
|
|
model_name=model_name, |
|
|
model_path=model_path, |
|
|
deployment_date=datetime.utcnow().isoformat(), |
|
|
status=ModelStatus.TESTING if not set_active else ModelStatus.ACTIVE, |
|
|
metadata=metadata or {}, |
|
|
performance_metrics={} |
|
|
) |
|
|
|
|
|
self.models[model_id][version] = model_version |
|
|
|
|
|
if set_active: |
|
|
self.set_active_version(model_id, version) |
|
|
|
|
|
logger.info(f"Registered model {model_id} v{version}") |
|
|
|
|
|
return model_version |
|
|
|
|
|
def set_active_version(self, model_id: str, version: str): |
|
|
"""Set active version for a model""" |
|
|
if model_id not in self.models or version not in self.models[model_id]: |
|
|
raise ValueError(f"Model {model_id} v{version} not found") |
|
|
|
|
|
|
|
|
if model_id in self.active_versions: |
|
|
prev_version = self.active_versions[model_id] |
|
|
if prev_version in self.models[model_id]: |
|
|
self.models[model_id][prev_version].status = ModelStatus.DEPRECATED |
|
|
|
|
|
|
|
|
self.active_versions[model_id] = version |
|
|
self.models[model_id][version].status = ModelStatus.ACTIVE |
|
|
|
|
|
logger.info(f"Set active version: {model_id} -> v{version}") |
|
|
|
|
|
def get_active_version(self, model_id: str) -> Optional[ModelVersion]: |
|
|
"""Get currently active model version""" |
|
|
if model_id not in self.active_versions: |
|
|
return None |
|
|
|
|
|
version = self.active_versions[model_id] |
|
|
return self.models[model_id].get(version) |
|
|
|
|
|
def record_performance( |
|
|
self, |
|
|
model_id: str, |
|
|
version: str, |
|
|
metrics: Dict[str, float] |
|
|
): |
|
|
"""Record performance metrics for a model version""" |
|
|
if model_id not in self.models or version not in self.models[model_id]: |
|
|
logger.warning(f"Cannot record performance for unknown model {model_id} v{version}") |
|
|
return |
|
|
|
|
|
performance_record = { |
|
|
"timestamp": datetime.utcnow().isoformat(), |
|
|
"model_id": model_id, |
|
|
"version": version, |
|
|
"metrics": metrics |
|
|
} |
|
|
|
|
|
self.performance_history[f"{model_id}:{version}"].append(performance_record) |
|
|
|
|
|
|
|
|
model_version = self.models[model_id][version] |
|
|
for metric_name, value in metrics.items(): |
|
|
if metric_name in model_version.performance_metrics: |
|
|
|
|
|
current = model_version.performance_metrics[metric_name] |
|
|
model_version.performance_metrics[metric_name] = (current + value) / 2 |
|
|
else: |
|
|
model_version.performance_metrics[metric_name] = value |
|
|
|
|
|
def compare_versions( |
|
|
self, |
|
|
model_id: str, |
|
|
version1: str, |
|
|
version2: str, |
|
|
metric: str = "accuracy" |
|
|
) -> Dict[str, Any]: |
|
|
"""Compare performance between two model versions""" |
|
|
if model_id not in self.models: |
|
|
return {"error": f"Model {model_id} not found"} |
|
|
|
|
|
v1 = self.models[model_id].get(version1) |
|
|
v2 = self.models[model_id].get(version2) |
|
|
|
|
|
if not v1 or not v2: |
|
|
return {"error": "One or both versions not found"} |
|
|
|
|
|
v1_metric = v1.performance_metrics.get(metric, 0.0) |
|
|
v2_metric = v2.performance_metrics.get(metric, 0.0) |
|
|
|
|
|
return { |
|
|
"model_id": model_id, |
|
|
"versions": { |
|
|
version1: v1_metric, |
|
|
version2: v2_metric |
|
|
}, |
|
|
"difference": v2_metric - v1_metric, |
|
|
"improvement_percent": ((v2_metric - v1_metric) / v1_metric * 100) if v1_metric > 0 else 0.0, |
|
|
"metric": metric |
|
|
} |
|
|
|
|
|
def rollback_to_version(self, model_id: str, version: str) -> bool: |
|
|
"""Rollback to a previous model version""" |
|
|
if model_id not in self.models or version not in self.models[model_id]: |
|
|
logger.error(f"Cannot rollback: model {model_id} v{version} not found") |
|
|
return False |
|
|
|
|
|
logger.warning(f"Rolling back {model_id} to v{version}") |
|
|
self.set_active_version(model_id, version) |
|
|
|
|
|
return True |
|
|
|
|
|
def get_model_inventory(self) -> Dict[str, Any]: |
|
|
"""Get complete model inventory""" |
|
|
inventory = {} |
|
|
|
|
|
for model_id, versions in self.models.items(): |
|
|
inventory[model_id] = { |
|
|
"active_version": self.active_versions.get(model_id, "none"), |
|
|
"total_versions": len(versions), |
|
|
"versions": { |
|
|
ver: model.to_dict() for ver, model in versions.items() |
|
|
} |
|
|
} |
|
|
|
|
|
return inventory |
|
|
|
|
|
def auto_rollback_if_degraded( |
|
|
self, |
|
|
model_id: str, |
|
|
metric: str = "accuracy", |
|
|
threshold_drop: float = 0.05 |
|
|
) -> bool: |
|
|
"""Automatically rollback if performance degraded significantly""" |
|
|
if model_id not in self.active_versions: |
|
|
return False |
|
|
|
|
|
current_version = self.active_versions[model_id] |
|
|
current_model = self.models[model_id][current_version] |
|
|
|
|
|
|
|
|
previous_versions = [ |
|
|
(ver, model) for ver, model in self.models[model_id].items() |
|
|
if model.status == ModelStatus.DEPRECATED |
|
|
] |
|
|
|
|
|
if not previous_versions: |
|
|
return False |
|
|
|
|
|
|
|
|
previous_versions.sort( |
|
|
key=lambda x: x[1].deployment_date, |
|
|
reverse=True |
|
|
) |
|
|
prev_version, prev_model = previous_versions[0] |
|
|
|
|
|
|
|
|
current_metric = current_model.performance_metrics.get(metric, 0.0) |
|
|
prev_metric = prev_model.performance_metrics.get(metric, 0.0) |
|
|
|
|
|
if prev_metric == 0.0: |
|
|
return False |
|
|
|
|
|
drop_percent = (prev_metric - current_metric) / prev_metric |
|
|
|
|
|
if drop_percent > threshold_drop: |
|
|
logger.warning( |
|
|
f"Performance degradation detected for {model_id}: " |
|
|
f"{metric} dropped {drop_percent*100:.1f}%. " |
|
|
f"Rolling back to v{prev_version}" |
|
|
) |
|
|
return self.rollback_to_version(model_id, prev_version) |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
class InputCache: |
|
|
""" |
|
|
Intelligent caching system with SHA256 fingerprinting |
|
|
Caches analysis results to avoid reprocessing identical files |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
max_cache_size_mb: int = 1000, |
|
|
ttl_hours: int = 24 |
|
|
): |
|
|
self.cache: Dict[str, CacheEntry] = {} |
|
|
self.max_cache_size_bytes = max_cache_size_mb * 1024 * 1024 |
|
|
self.current_cache_size = 0 |
|
|
self.ttl_hours = ttl_hours |
|
|
|
|
|
|
|
|
self.hits = 0 |
|
|
self.misses = 0 |
|
|
self.evictions = 0 |
|
|
|
|
|
logger.info(f"Input Cache initialized (max size: {max_cache_size_mb}MB, TTL: {ttl_hours}h)") |
|
|
|
|
|
def compute_hash(self, file_path: str) -> str: |
|
|
"""Compute SHA256 hash of file""" |
|
|
sha256_hash = hashlib.sha256() |
|
|
|
|
|
try: |
|
|
with open(file_path, "rb") as f: |
|
|
|
|
|
for byte_block in iter(lambda: f.read(4096), b""): |
|
|
sha256_hash.update(byte_block) |
|
|
|
|
|
return sha256_hash.hexdigest() |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to compute hash for {file_path}: {str(e)}") |
|
|
return "" |
|
|
|
|
|
def compute_data_hash(self, data: bytes) -> str: |
|
|
"""Compute SHA256 hash of data bytes""" |
|
|
return hashlib.sha256(data).hexdigest() |
|
|
|
|
|
def get( |
|
|
self, |
|
|
input_hash: str, |
|
|
model_version: str |
|
|
) -> Optional[Dict[str, Any]]: |
|
|
"""Retrieve cached result""" |
|
|
cache_key = f"{input_hash}:{model_version}" |
|
|
|
|
|
if cache_key not in self.cache: |
|
|
self.misses += 1 |
|
|
return None |
|
|
|
|
|
entry = self.cache[cache_key] |
|
|
|
|
|
|
|
|
created_time = datetime.fromisoformat(entry.created_at) |
|
|
if datetime.utcnow() - created_time > timedelta(hours=self.ttl_hours): |
|
|
|
|
|
self._evict(cache_key) |
|
|
self.misses += 1 |
|
|
return None |
|
|
|
|
|
|
|
|
entry.last_accessed = datetime.utcnow().isoformat() |
|
|
entry.access_count += 1 |
|
|
|
|
|
self.hits += 1 |
|
|
logger.info(f"Cache hit: {cache_key[:16]}...") |
|
|
|
|
|
return entry.result_data |
|
|
|
|
|
def put( |
|
|
self, |
|
|
input_hash: str, |
|
|
model_version: str, |
|
|
result_data: Dict[str, Any] |
|
|
): |
|
|
"""Store result in cache""" |
|
|
cache_key = f"{input_hash}:{model_version}" |
|
|
|
|
|
|
|
|
size_bytes = len(json.dumps(result_data).encode()) |
|
|
|
|
|
|
|
|
while self.current_cache_size + size_bytes > self.max_cache_size_bytes: |
|
|
self._evict_lru() |
|
|
|
|
|
entry = CacheEntry( |
|
|
cache_key=cache_key, |
|
|
input_hash=input_hash, |
|
|
result_data=result_data, |
|
|
created_at=datetime.utcnow().isoformat(), |
|
|
last_accessed=datetime.utcnow().isoformat(), |
|
|
access_count=0, |
|
|
model_version=model_version, |
|
|
size_bytes=size_bytes |
|
|
) |
|
|
|
|
|
self.cache[cache_key] = entry |
|
|
self.current_cache_size += size_bytes |
|
|
|
|
|
logger.info(f"Cache stored: {cache_key[:16]}... ({size_bytes} bytes)") |
|
|
|
|
|
def invalidate_model_version(self, model_version: str): |
|
|
"""Invalidate all cache entries for a model version""" |
|
|
keys_to_remove = [ |
|
|
key for key, entry in self.cache.items() |
|
|
if entry.model_version == model_version |
|
|
] |
|
|
|
|
|
for key in keys_to_remove: |
|
|
self._evict(key) |
|
|
|
|
|
logger.info(f"Invalidated {len(keys_to_remove)} cache entries for model v{model_version}") |
|
|
|
|
|
def _evict(self, cache_key: str): |
|
|
"""Evict a specific cache entry""" |
|
|
if cache_key in self.cache: |
|
|
entry = self.cache.pop(cache_key) |
|
|
self.current_cache_size -= entry.size_bytes |
|
|
self.evictions += 1 |
|
|
|
|
|
def _evict_lru(self): |
|
|
"""Evict least recently used entry""" |
|
|
if not self.cache: |
|
|
return |
|
|
|
|
|
|
|
|
lru_key = min( |
|
|
self.cache.keys(), |
|
|
key=lambda k: self.cache[k].last_accessed |
|
|
) |
|
|
|
|
|
self._evict(lru_key) |
|
|
logger.debug(f"LRU eviction: {lru_key[:16]}...") |
|
|
|
|
|
def get_statistics(self) -> Dict[str, Any]: |
|
|
"""Get cache performance statistics""" |
|
|
total_requests = self.hits + self.misses |
|
|
hit_rate = self.hits / total_requests if total_requests > 0 else 0.0 |
|
|
|
|
|
return { |
|
|
"total_entries": len(self.cache), |
|
|
"cache_size_mb": self.current_cache_size / (1024 * 1024), |
|
|
"max_size_mb": self.max_cache_size_bytes / (1024 * 1024), |
|
|
"utilization_percent": (self.current_cache_size / self.max_cache_size_bytes * 100), |
|
|
"total_requests": total_requests, |
|
|
"hits": self.hits, |
|
|
"misses": self.misses, |
|
|
"hit_rate_percent": hit_rate * 100, |
|
|
"evictions": self.evictions, |
|
|
"ttl_hours": self.ttl_hours |
|
|
} |
|
|
|
|
|
def clear(self): |
|
|
"""Clear all cache entries""" |
|
|
entry_count = len(self.cache) |
|
|
self.cache.clear() |
|
|
self.current_cache_size = 0 |
|
|
|
|
|
logger.info(f"Cache cleared: {entry_count} entries removed") |
|
|
|
|
|
|
|
|
class ModelVersioningSystem: |
|
|
""" |
|
|
Complete model versioning and caching system |
|
|
Integrates model registry with input caching |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
cache_size_mb: int = 1000, |
|
|
cache_ttl_hours: int = 24 |
|
|
): |
|
|
self.model_registry = ModelRegistry() |
|
|
self.input_cache = InputCache(cache_size_mb, cache_ttl_hours) |
|
|
|
|
|
|
|
|
self._initialize_default_models() |
|
|
|
|
|
logger.info("Model Versioning System initialized") |
|
|
|
|
|
def _initialize_default_models(self): |
|
|
"""Initialize default model versions""" |
|
|
default_models = [ |
|
|
("document_classifier", "1.0.0", "Bio_ClinicalBERT", "emilyalsentzer/Bio_ClinicalBERT"), |
|
|
("clinical_ner", "1.0.0", "Biomedical NER", "d4data/biomedical-ner-all"), |
|
|
("clinical_generation", "1.0.0", "BioGPT-Large", "microsoft/BioGPT-Large"), |
|
|
("medical_qa", "1.0.0", "RoBERTa-SQuAD2", "deepset/roberta-base-squad2"), |
|
|
("general_medical", "1.0.0", "PubMedBERT", "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"), |
|
|
("drug_interaction", "1.0.0", "SciBERT", "allenai/scibert_scivocab_uncased"), |
|
|
("clinical_summarization", "1.0.0", "BigBird-Pegasus", "google/bigbird-pegasus-large-pubmed") |
|
|
] |
|
|
|
|
|
for model_id, version, name, path in default_models: |
|
|
self.model_registry.register_model( |
|
|
model_id=model_id, |
|
|
version=version, |
|
|
model_name=name, |
|
|
model_path=path, |
|
|
metadata={"initialized": "2025-10-29"}, |
|
|
set_active=True |
|
|
) |
|
|
|
|
|
def process_with_cache( |
|
|
self, |
|
|
input_path: str, |
|
|
model_id: str, |
|
|
process_func: callable |
|
|
) -> Tuple[Dict[str, Any], bool]: |
|
|
""" |
|
|
Process input with caching |
|
|
Returns: (result, from_cache) |
|
|
""" |
|
|
|
|
|
active_model = self.model_registry.get_active_version(model_id) |
|
|
if not active_model: |
|
|
logger.warning(f"No active version for model {model_id}") |
|
|
return process_func(input_path), False |
|
|
|
|
|
|
|
|
input_hash = self.input_cache.compute_hash(input_path) |
|
|
if not input_hash: |
|
|
|
|
|
return process_func(input_path), False |
|
|
|
|
|
|
|
|
cached_result = self.input_cache.get(input_hash, active_model.version) |
|
|
if cached_result is not None: |
|
|
logger.info(f"Returning cached result for {model_id}") |
|
|
return cached_result, True |
|
|
|
|
|
|
|
|
result = process_func(input_path) |
|
|
self.input_cache.put(input_hash, active_model.version, result) |
|
|
|
|
|
return result, False |
|
|
|
|
|
def get_system_status(self) -> Dict[str, Any]: |
|
|
"""Get complete system status""" |
|
|
return { |
|
|
"model_registry": { |
|
|
"total_models": len(self.model_registry.models), |
|
|
"active_models": len(self.model_registry.active_versions), |
|
|
"inventory": self.model_registry.get_model_inventory() |
|
|
}, |
|
|
"cache": self.input_cache.get_statistics(), |
|
|
"timestamp": datetime.utcnow().isoformat() |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
_versioning_system = None |
|
|
|
|
|
|
|
|
def get_versioning_system() -> ModelVersioningSystem: |
|
|
"""Get singleton versioning system instance""" |
|
|
global _versioning_system |
|
|
if _versioning_system is None: |
|
|
_versioning_system = ModelVersioningSystem() |
|
|
return _versioning_system |
|
|
|