import numpy as np
import time
import torch
import threading
from typing import Dict, Any, Optional, Tuple, Union, List
from enum import Enum
from tensor_core import TensorCoreArray
from multithread_storage import MultithreadStorage
from config import DB_URL

class VectorOperation(Enum):
    """Enumeration of supported vector operations."""
    ADD = "add"
    SUBTRACT = "subtract"
    MULTIPLY = "multiply"
    DIVIDE = "divide"
    DOT_PRODUCT = "dot_product"
    CROSS_PRODUCT = "cross_product"
    NORMALIZE = "normalize"
    MAGNITUDE = "magnitude"


class AIAccelerator:
    """
    AI Accelerator that simulates GPU-based AI computations using HTTP storage.
    
    This class leverages NumPy's optimized operations to simulate the parallel
    processing capabilities of the vGPU for AI workloads.
    """
    
    def __init__(self, vram=None, num_sms: int = 800, cuda_cores_per_sm: int = 128, tensor_cores_per_sm: int = 3000, storage=None):
        """Initialize AI Accelerator with electron-speed awareness and hybrid core utilization."""
        from electron_speed import TARGET_SWITCHES_PER_SEC, TRANSISTORS_ON_CHIP, drift_velocity
        from gpu_parallel_distributor import GPUParallelDistributor
        import logging
        
        self.storage = storage  # Use the shared storage instance
        if self.storage is None:
            self.storage = MultithreadStorage(db_url=DB_URL)
            logging.info(f"Initialized MultithreadStorage with database URL: {DB_URL}")
                
        # Initialize GPU parallel distributor for multi-GPU operations
        self.gpu_distributor = GPUParallelDistributor(num_gpus=8)
            
        self.vram = vram
        self.num_sms = num_sms
        
        # Core configuration
        self.cuda_cores_per_sm = cuda_cores_per_sm
        self.tensor_cores_per_sm = tensor_cores_per_sm
        self.total_cuda_cores = num_sms * cuda_cores_per_sm
        self.total_tensor_cores = num_sms * tensor_cores_per_sm
        
        # Workload distribution thresholds
        self.cuda_threshold = 0.3  # 30% CUDA core utilization threshold
        self.tensor_threshold = 0.7  # 70% tensor core utilization threshold
        
        logging.info(f"Initialized AI Accelerator with {self.total_cuda_cores} CUDA cores and {self.total_tensor_cores} tensor cores")
        
        # Initialize registries and monitors
        self.model_registry: Dict[str, Dict[str, Any]] = {}  # Track loaded models
        self.tensor_registry: Dict[str, Dict[str, Any]] = {}  # Track tensor metadata
        self.core_utilization = {
            'cuda': 0.0,
            'tensor': 0.0,
            'last_update': time.time()
        }
        self.tokenizer_registry: Dict[str, Any] = {}  # Track tokenizers
        self.resource_monitor = {
            'vram_used': 0,
            'active_tensors': 0,
            'loaded_models': set()
        }
        
        # Configure for maximum parallel processing at electron speed
        self.tensor_core_array = TensorCoreArray(
            num_tensor_cores=self.total_tensor_cores,
            bits=32,
            bandwidth_tbps=drift_velocity / 1e-12  # Bandwidth scaled to electron drift speed
        )
        self.tensor_cores_initialized = False
        self._vram_allocated = 0
        
        # Initialize operation tracking
        self.operations_performed = 0
        self.total_compute_time = 0.0
        self.flops_performed = 0
        
        # Initialize caches
        self.activation_cache: Dict[str, str] = {}  # Cache activation outputs
        self.weight_cache: Dict[str, Any] = {}  # Cache preprocessed weights
        
    def pre_allocate_vram(self, size_bytes: int) -> bool:
        """Pre-allocate VRAM for model loading"""
        if not self.vram:
            return True  # No VRAM restrictions
            
        # Check vram_state for unlimited allocation
        if hasattr(self.vram, 'vram_state') and self.vram.vram_state.get('is_unlimited', False):
            self._vram_allocated += size_bytes
            return True
            
        # If there's a specific size limit in vram_state
        total_size = float('inf')
        if hasattr(self.vram, 'vram_state'):
            total_size = self.vram.vram_state.get('total_size', float('inf'))
            
        if self._vram_allocated + size_bytes > total_size:
            return False
            
        self._vram_allocated += size_bytes
        return True
        
    def has_model(self, model_id: str) -> bool:
        """Check if a model is loaded"""
        if not model_id:
            return False
        return model_id in self.model_registry and self.storage.is_model_loaded(model_id)
        
    async def load_model(self, model_id: str, model: Dict[str, Any], 
                   processor: Any = None, model_config: Dict[str, Any] = None) -> bool:
        """Load a model into the virtual GPU accelerator
        
        Args:
            model_id: Unique identifier for the model
            model: Model dictionary containing layer weights and architecture
            processor: Optional preprocessing/postprocessing functions
            model_config: Optional model configuration
        """
        try:
            if not self.storage:
                raise RuntimeError("No storage available")
                
            # Extract and store model weights in virtual VRAM
            weights = {}
            for layer_name, layer_data in model.get("layers", {}).items():
                # Store weights and biases in virtual VRAM with thread awareness
                weight_id = f"{model_id}/{layer_name}/weight"
                # Use the new async store_tensor method
                if not await self.storage.store_tensor(
                    tensor_id=weight_id, 
                    data=layer_data["weight"],
                    metadata={"model_id": model_id, "layer": layer_name, "type": "weight"},
                    thread_id=threading.get_ident()
                ):
                    raise RuntimeError(f"Failed to store weights for layer {layer_name}")
                weights[layer_name] = {"weight": weight_id}
                
                # Store bias if present
                if "bias" in layer_data:
                    bias_id = f"{model_id}/{layer_name}/bias"
                    if not await self.storage.store_tensor(
                        tensor_id=bias_id, 
                        data=layer_data["bias"],
                        metadata={"model_id": model_id, "layer": layer_name, "type": "bias"},
                        thread_id=threading.get_ident()
                    ):
                        raise RuntimeError(f"Failed to store bias for layer {layer_name}")
                    weights[layer_name]["bias"] = bias_id
            
            # Update model registry with weight references and config
            self.model_registry[model_id] = {
                'weights': weights,
                'config': model_config or {},
                'architecture': model.get("architecture", {}),
                'loaded_at': time.time(),
                'processor': processor
            }
            
            # Pre-allocate VRAM if using size limits
            if hasattr(self.vram, 'pre_allocate_vram'):
                total_size = sum(
                    np.prod(layer["weight"].shape) * 4  # Assuming float32
                    for layer in model.get("layers", {}).values()
                )
                if not self.vram.pre_allocate_vram(total_size):
                    raise RuntimeError("Insufficient VRAM for model weights")
            
            # Update resource monitoring
            self.resource_monitor['loaded_models'].add(model_id)
            if hasattr(self.storage, 'resource_monitor'):
                self.storage.resource_monitor['loaded_models'].add(model_id)
                
            return True
                
        except Exception as e:
            print(f"Error loading model {model_id}: {str(e)}")
            return False
        
        # # Model registries
        # self.model_registry: Dict[str, Any] = {}
        # self.tokenizer_registry: Dict[str, Any] = {}
        # self.model_configs: Dict[str, Any] = {}  # Store model architectures
        # self.model_loaded = False
        
        # # Batch processing configuration
        # self.max_batch_size = 64
        # self.min_batch_size = 4
        # self.dynamic_batching = True  # Enable automatic batch size adjustment
        
    def _serialize_model_config(self, config: Any) -> dict:
        """Convert model config to a serializable format."""
        # Handle None case first
        if config is None:
            return None
            
        # Handle Florence2LanguageConfig specifically
        if config.__class__.__name__ == "Florence2LanguageConfig":
            try:
                return {
                    "type": "Florence2LanguageConfig",
                    "model_type": getattr(config, "model_type", ""),
                    "architectures": getattr(config, "architectures", []),
                    "hidden_size": getattr(config, "hidden_size", 0),
                    "num_attention_heads": getattr(config, "num_attention_heads", 0),
                    "num_hidden_layers": getattr(config, "num_hidden_layers", 0),
                    "intermediate_size": getattr(config, "intermediate_size", 0),
                    "max_position_embeddings": getattr(config, "max_position_embeddings", 0),
                    "layer_norm_eps": getattr(config, "layer_norm_eps", 1e-12),
                    "vocab_size": getattr(config, "vocab_size", 0)
                }
            except Exception as e:
                print(f"Warning: Error serializing Florence2LanguageConfig: {e}")
                return {"type": "Florence2LanguageConfig", "error": str(e)}

        # Handle standard types
        if isinstance(config, (int, float, str, bool)):
            return config
            
        # Handle lists and tuples
        if isinstance(config, (list, tuple)):
            return [self._serialize_model_config(item) for item in config]
            
        # Handle dictionaries
        if isinstance(config, dict):
            return {k: self._serialize_model_config(v) for k, v in config.items()}
            
        # Handle objects with __dict__
        if hasattr(config, '__dict__'):
            config_dict = {}
            for key, value in config.__dict__.items():
                try:
                    # Skip private attributes
                    if key.startswith('_'):
                        continue
                    config_dict[key] = self._serialize_model_config(value)
                except Exception as e:
                    print(f"Warning: Error serializing attribute {key}: {e}")
                    config_dict[key] = str(value)
            return config_dict
            
        # Fallback: convert to string representation
        try:
            return str(config)
        except Exception as e:
            return f"<Unserializable object of type {type(config).__name__}: {str(e)}>"
            
    def store_model_state(self, model_name: str, model_info: Dict[str, Any]) -> bool:
        """Store model state in HTTP storage with proper serialization."""
        try:
            # Convert any non-serializable parts of model_info
            serializable_info = self._serialize_model_config(model_info)
            
            # Store in model registry
            self.model_registry[model_name] = serializable_info
            
            # Save to storage
            if self.storage:
                # Store model info
                info_success = self.storage.store_state(
                    "models",
                    f"{model_name}/info",
                    serializable_info
                )
                
                # Store model state
                state_success = self.storage.store_state(
                    "models",
                    f"{model_name}/state",
                    {"loaded": True, "timestamp": time.time()}
                )
                
                if info_success and state_success:
                    self.resource_monitor['loaded_models'].add(model_name)
                    return True
                    
            return False
        except Exception as e:
            print(f"Error storing model state: {str(e)}")
            return False
        
    def initialize_tensor_cores(self):
        """Initialize tensor cores and verify they're ready for computation"""
        if self.tensor_cores_initialized:
            return True
            
        try:
            # Verify tensor core array is properly initialized
            if not hasattr(self, 'tensor_core_array') or self.tensor_core_array is None:
                raise RuntimeError("Tensor core array not properly initialized")
                
            # Initialize tensor cores if needed
            if hasattr(self.tensor_core_array, 'initialize'):
                self.tensor_core_array.initialize()
                
            # Verify VRAM access
            if self.vram is None:
                raise RuntimeError("VRAM not properly configured")
                
            # Test tensor core functionality with a small computation
            test_input = np.array([[1.0, 2.0], [3.0, 4.0]])
            try:
                test_result = self.tensor_core_array.matmul(test_input, test_input)
                if test_result is not None:
                    self.tensor_cores_initialized = True
                    return True
            except Exception as e:
                print(f"Failed to perform tensor core test: {str(e)}")
                self.tensor_cores_initialized = False
            return False
            
        except Exception as e:
            print(f"Failed to initialize tensor cores: {str(e)}")
            self.tensor_cores_initialized = False
            return False
            
    def _combine_results(self, results: List[Dict[str, Any]], operation_type: str) -> Dict[str, Any]:
        """Combine results from CUDA and tensor core operations."""
        if not results:
            return {'data': [], 'status': 'error', 'message': 'No results to combine'}
            
        if len(results) == 1:
            return results[0]
            
        # Extract data arrays from results
        data_arrays = [result.get('data', []) for result in results]
        
        if operation_type == 'matmul':
            # For matrix multiplication, we need to concatenate along the row axis
            combined_data = np.concatenate(data_arrays, axis=0)
        else:
            # For element-wise operations, simple concatenation is sufficient
            combined_data = np.concatenate(data_arrays)
            
        return {
            'data': combined_data,
            'status': 'success',
            'operation': operation_type
        }
            
    async def process_tensor_operation(self, tensor_data: Dict[str, Any]) -> Dict[str, Any]:
        """Process tensor operation using both CUDA and tensor cores for optimal performance."""
        # Calculate total available processing power
        total_cores = {
            'cuda': self.total_cuda_cores,
            'tensor': self.total_tensor_cores
        }
        
        # Analyze operation complexity and data size
        data_size = len(tensor_data['data'])
        operation_type = tensor_data.get('operation', 'matmul')  # Default to matrix multiplication
        
        # Determine optimal core distribution based on operation type and current utilization
        if operation_type in ['matmul', 'conv2d']:
            # Matrix operations benefit more from tensor cores
            cuda_ratio = min(self.cuda_threshold, 1 - self.core_utilization['tensor'])
            tensor_ratio = 1 - cuda_ratio
        else:
            # General compute operations use more CUDA cores
            tensor_ratio = min(self.tensor_threshold, 1 - self.core_utilization['cuda'])
            cuda_ratio = 1 - tensor_ratio
            
        # Calculate workload distribution
        cuda_workload = int(data_size * cuda_ratio)
        tensor_workload = data_size - cuda_workload
        
        # Distribute operations across both core types
        cuda_task = None
        if cuda_workload > 0:
            cuda_task = self.gpu_distributor.distribute_cuda_ops(
                {**tensor_data, 'data': tensor_data['data'][:cuda_workload]},
                cuda_workload / total_cores['cuda'],
                total_cores['cuda']
            )
            
        tensor_task = None
        if tensor_workload > 0:
            tensor_task = self.gpu_distributor.distribute_tensor_ops(
                {**tensor_data, 'data': tensor_data['data'][cuda_workload:]},
                tensor_workload / total_cores['tensor'],
                total_cores['tensor']
            )
            
        # Execute operations in parallel
        results = []
        if cuda_task:
            results.append(await cuda_task)
        if tensor_task:
            results.append(await tensor_task)
            
        # Combine and process results
        combined_results = self._combine_results(results, operation_type)
        
        # Update core utilization metrics
        now = time.time()
        self.core_utilization.update({
            'cuda': cuda_ratio,
            'tensor': tensor_ratio,
            'last_update': now
        })
        
        return combined_results

    def _combine_results(self, results: List[Dict[str, Any]], operation_type: str) -> Dict[str, Any]:
        """Combine results from CUDA and tensor core operations."""
        if not results:
            return {'data': [], 'status': 'error', 'message': 'No results to combine'}
            
        if len(results) == 1:
            return results[0]
            
        # Extract data arrays from results
        data_arrays = [result.get('data', []) for result in results]
        
        if operation_type == 'matmul':
            # For matrix multiplication, we need to concatenate along the row axis
            combined_data = np.concatenate(data_arrays, axis=0)
        else:
            # For element-wise operations, simple concatenation is sufficient
            combined_data = np.concatenate(data_arrays)
            
        return {
            'data': combined_data,
            'status': 'success',
            'operation': operation_type
        }
        
    def set_vram(self, vram):
        """Set the VRAM reference."""
        self.vram = vram
        
    def allocate_matrix(self, shape: Tuple[int, ...], dtype=np.float32, 
                       name: Optional[str] = None) -> str:
        """Allocate a matrix in VRAM and return its ID."""
        if not self.vram:
            raise RuntimeError("VRAM not available")
            
        if name is None:
            name = f"matrix_{self.matrix_counter}"
            self.matrix_counter += 1
            
        # Create matrix data
        matrix_data = np.zeros(shape, dtype=dtype)
        
        # Store in VRAM using HTTP storage
        if self.storage.store_tensor(name, matrix_data):
            self.matrix_registry[name] = name
            return name
        else:
            raise RuntimeError(f"Failed to allocate matrix {name}")
        
    def load_matrix(self, matrix_data: np.ndarray, name: Optional[str] = None) -> str:
        """Load matrix data into VRAM and return its ID."""
        if name is None:
            name = f"matrix_{self.matrix_counter}"
            self.matrix_counter += 1
            
        # Store in VRAM using HTTP storage
        if self.storage.store_tensor(name, matrix_data):
            self.matrix_registry[name] = name
            return name
        else:
            raise RuntimeError(f"Failed to load matrix {name}")
        
    def get_matrix(self, matrix_id: str) -> Optional[np.ndarray]:
        """Retrieve matrix data from VRAM."""
        if matrix_id not in self.matrix_registry:
            return None
            
        return self.storage.load_tensor(matrix_id)
        
    def matrix_multiply(self, matrix_a_id: str, matrix_b_id: str, 
                       result_id: Optional[str] = None) -> Optional[str]:
        """Perform matrix multiplication using simulated GPU parallelism."""
        start_time = time.time()
        
        # Retrieve matrices from VRAM via HTTP storage
        matrix_a = self.get_matrix(matrix_a_id)
        matrix_b = self.get_matrix(matrix_b_id)
        
        if matrix_a is None or matrix_b is None:
            print(f"Error: Could not retrieve matrices {matrix_a_id} or {matrix_b_id}")
            return None
            
        try:
            # Check if matrices can be multiplied
            if matrix_a.shape[-1] != matrix_b.shape[0]:
                print(f"Error: Matrix dimensions incompatible for multiplication: "
                      f"{matrix_a.shape} x {matrix_b.shape}")
                return None
                
            # Distribute matrix multiplication across GPUs
            operation = {
                "type": "matmul",
                "inputs": {
                    "A": matrix_a,
                    "B": matrix_b
                }
            }
            
            # Use GPU distributor to split the operation
            distributed_ops = self.gpu_distributor.distribute_operation(operation)
            
            # Process each chunk on its assigned GPU
            partial_results = []
            for chunk_op in distributed_ops:
                gpu_id = chunk_op["gpu_id"]
                start_row = chunk_op["start_row"]
                end_row = chunk_op["end_row"]
                
                # Process chunk using tensor cores on assigned GPU
                chunk_result = self.tensor_core_array.matmul(
                    chunk_op["inputs"]["A"],
                    chunk_op["inputs"]["B"]
                )
                partial_results.append((start_row, end_row, chunk_result))
            
            # Combine results in correct order
            result_array = np.zeros((matrix_a.shape[0], matrix_b.shape[1]))
            for start_row, end_row, chunk_result in partial_results:
                result_array[start_row:end_row] = chunk_result
            
            # Store result in VRAM
            if result_id is None:
                result_id = f"result_{self.matrix_counter}"
                self.matrix_counter += 1
                
            result_matrix_id = self.load_matrix(result_array, result_id)
            
            # Update statistics
            compute_time = time.time() - start_time
            self.total_compute_time += compute_time
            self.operations_performed += 1
            
            # Calculate FLOPs (2 * M * N * K for matrix multiplication)
            m, k = matrix_a.shape
            k2, n = matrix_b.shape
            flops = 2 * m * n * k
            self.flops_performed += flops
            
            print(f"Matrix multiplication completed: {matrix_a.shape} x {matrix_b.shape} "
                  f"= {result_array.shape} in {compute_time:.4f}s")
            print(f"Simulated {flops:,} FLOPs across {self.total_cores} cores")
            
            return result_matrix_id
            
        except Exception as e:
            print(f"Error in matrix multiplication: {e}")
            return None
            
    def vector_operation(self, operation: VectorOperation, vector_a_id: str,
                        vector_b_id: Optional[str] = None, 
                        result_id: Optional[str] = None) -> Optional[str]:
        """Perform vector operations using simulated GPU parallelism."""
        start_time = time.time()
        
        # Retrieve vectors from VRAM via HTTP storage
        vector_a = self.get_matrix(vector_a_id)
        if vector_a is None:
            print(f"Error: Could not retrieve vector {vector_a_id}")
            return None
            
        vector_b = None
        if vector_b_id:
            vector_b = self.get_matrix(vector_b_id)
            if vector_b is None:
                print(f"Error: Could not retrieve vector {vector_b_id}")
                return None
                
        try:
            result = None
            flops = 0
            
            if operation == VectorOperation.ADD:
                if vector_b is None:
                    raise ValueError("Vector B required for addition")
                result = vector_a + vector_b
                flops = vector_a.size
                
            elif operation == VectorOperation.SUBTRACT:
                if vector_b is None:
                    raise ValueError("Vector B required for subtraction")
                result = vector_a - vector_b
                flops = vector_a.size
                
            elif operation == VectorOperation.MULTIPLY:
                if vector_b is None:
                    raise ValueError("Vector B required for multiplication")
                result = vector_a * vector_b
                flops = vector_a.size
                
            elif operation == VectorOperation.DIVIDE:
                if vector_b is None:
                    raise ValueError("Vector B required for division")
                result = vector_a / vector_b
                flops = vector_a.size
                
            elif operation == VectorOperation.DOT_PRODUCT:
                if vector_b is None:
                    raise ValueError("Vector B required for dot product")
                result = np.dot(vector_a.flatten(), vector_b.flatten())
                flops = 2 * vector_a.size
                
            elif operation == VectorOperation.CROSS_PRODUCT:
                if vector_b is None:
                    raise ValueError("Vector B required for cross product")
                if vector_a.size != 3 or vector_b.size != 3:
                    raise ValueError("Cross product requires 3D vectors")
                result = np.cross(vector_a.flatten(), vector_b.flatten())
                flops = 6  # Cross product operations
                
            elif operation == VectorOperation.NORMALIZE:
                magnitude = np.linalg.norm(vector_a)
                if magnitude == 0:
                    result = vector_a
                else:
                    result = vector_a / magnitude
                flops = vector_a.size + 1  # Division + sqrt
                
            elif operation == VectorOperation.MAGNITUDE:
                result = np.array([np.linalg.norm(vector_a)])
                flops = vector_a.size + 1  # Sum of squares + sqrt
                
            else:
                raise ValueError(f"Unknown vector operation: {operation}")
                
            # Store result
            if result_id is None:
                result_id = f"vector_result_{self.matrix_counter}"
                self.matrix_counter += 1
                
            result_vector_id = self.load_matrix(result, result_id)
            
            # Update statistics
            compute_time = time.time() - start_time
            self.total_compute_time += compute_time
            self.operations_performed += 1
            self.flops_performed += flops
            
            print(f"Vector operation {operation.value} completed in {compute_time:.4f}s")
            print(f"Simulated {flops:,} FLOPs across {self.total_cores} cores")
            
            return result_vector_id
            
        except Exception as e:
            print(f"Error in vector operation: {e}")
            return None

    def has_model(self, model_id: str) -> bool:
        """Check if model is loaded"""
        if not model_id:
            return False
        return model_id in self.model_registry and self.storage.is_model_loaded(model_id)

    def load_model(self, model_id: str, model=None, processor=None) -> bool:
        """Load model into local storage and register it with the accelerator"""
        try:
            if not self.storage:
                raise RuntimeError("No storage available")

            # Prepare model data for storage
            model_data = model
            if isinstance(model, dict):
                model_data = model  # Use as is if it's already a dict
            elif model is not None:
                # Serialize model object
                model_data = {
                    "model_type": type(model).__name__,
                    "config": self._serialize_model_config(getattr(model, 'config', None)),
                    "loaded_at": time.time()
                }
                
            # Store in local storage
            success = self.storage.load_model(model_id, model_data=model_data)
            
            if success:
                # Update local registry
                self.model_registry[model_id] = {
                    "model_data": model_data,
                    "processor": processor,
                    "loaded_at": time.time()
                }
                
                # Update monitoring
                self.resource_monitor['loaded_models'].add(model_id)
                
                # Update storage monitoring if supported
                if hasattr(self.storage, 'resource_monitor'):
                    self.storage.resource_monitor['loaded_models'].add(model_id)
                    
                return True
            
            return False
            
        except Exception as e:
            print(f"Error loading model {model_id}: {str(e)}")
            return False

    def inference(self, model_id: str, input_tensor_id: str) -> Optional[np.ndarray]:
        """Run PyTorch model inference using virtual GPU acceleration"""
        try:
            # Load input tensor from storage
            input_data = self.storage.load_tensor(input_tensor_id)
            if input_data is None:
                print(f"Could not load input tensor {input_tensor_id}")
                return None
            
            # Convert to PyTorch tensor and move to vGPU
            from torch_vgpu import to_vgpu
            input_tensor = to_vgpu(torch.from_numpy(input_data), vram=self.vram)
            
            # Get model from registry
            if not self.has_model(model_id):
                print(f"Model {model_id} not loaded")
                return None
                
            model_info = self.model_registry[model_id]
            model = model_info.get("model")
            if not isinstance(model, torch.nn.Module):
                print(f"Invalid model type for {model_id}")
                return None
            
            # Move model to vGPU device
            model = model.to(input_tensor.device)
            model.eval()
            
            # Run inference
            with torch.no_grad():
                # Apply any preprocessing from model config
                if "preprocess" in model_info:
                    input_tensor = model_info["preprocess"](input_tensor)
                
                # Forward pass through model on vGPU
                output = model(input_tensor)
                
                # Apply any postprocessing from model config
                if "postprocess" in model_info:
                    output = model_info["postprocess"](output)
            
            # Convert output to numpy and store in VRAM
            output_np = output.cpu().numpy()
            output_id = f"{model_id}_output_{time.time()}"
            self.storage.store_tensor(output_id, output_np)
            
            # Track compute statistics
            self.total_compute_time += time.time()
            self.operations_performed += 1
            
            return output_np
                
        except Exception as e:
            print(f"Error during inference: {str(e)}")
            return None

    def get_stats(self) -> Dict[str, Any]:
        """Get AI accelerator statistics"""
        return {
            "operations_performed": self.operations_performed,
            "total_compute_time": self.total_compute_time,
            "flops_performed": self.flops_performed,
            "avg_ops_per_second": self.operations_performed / max(self.total_compute_time, 0.001),
            "tensor_cores_initialized": self.tensor_cores_initialized,
            "total_cores": self.total_cores,
            "loaded_models": list(self.resource_monitor['loaded_models']),
            "storage_status": self.storage.get_connection_status() if self.storage else None
        }

