from typing import Dict, List, Union, Optional, Any, Tuple
import numpy as np
import threading
from queue import Queue
import sys
from pathlib import Path

# Add parent directory to path for imports
sys.path.append(str(Path(__file__).parent.parent))

from virtual_vram import VirtualVRAM
from tensor_core import TensorCore, TensorCoreArray
from http_storage import LocalStorage

class VirtualGPUDevice:
    """Adapter for Virtual GPU integration with Helium"""
    
    def __init__(self, device_id: int = 0, memory_size: Optional[int] = None):
        """Initialize virtual GPU device
        
        Args:
            device_id: Virtual GPU device ID
            memory_size: VRAM size in GB (None for unlimited)
        """
        self.device_id = device_id
        
        # Initialize virtual VRAM with unlimited memory
        self.vram = VirtualVRAM(size_gb=memory_size)  # None = unlimited
        
        # Initialize tensor cores with unlimited memory
        self.tensor_cores = TensorCoreArray(
            num_tensor_cores=8000,  # Like an A100
            memory_size=None,  # Unlimited memory
            device_id=device_id
        )
        
        # Command queue for async execution
        self._command_queue: Queue = Queue()
        self._worker_thread = threading.Thread(target=self._process_commands, daemon=True)
        self._worker_thread.start()
        
        # Cache of allocated tensors
        self._tensor_cache: Dict[str, Any] = {}
        
    def _process_commands(self):
        """Process commands from queue"""
        while True:
            cmd = self._command_queue.get()
            if cmd is None:
                break
                
            op, args, kwargs = cmd
            if hasattr(self.tensor_cores, op):
                getattr(self.tensor_cores, op)(*args, **kwargs)
                
    def allocate(self, shape: Tuple[int, ...], dtype=np.float32) -> str:
        """Allocate memory on virtual GPU
        
        Returns:
            Tensor ID in virtual GPU memory
        """
        size = np.prod(shape) * np.dtype(dtype).itemsize
        tensor_id = self.vram.allocate(size)
        self._tensor_cache[tensor_id] = {
            'tensor_id': tensor_id,
            'shape': shape,
            'dtype': dtype
        }
        return tensor_id
        
    def to_gpu(self, data: np.ndarray) -> str:
        """Copy numpy array to virtual GPU memory"""
        tensor_id = self.allocate(data.shape, data.dtype)
        self.vram.store_tensor(tensor_id, data)
        return tensor_id
        
    def from_gpu(self, tensor_id: str) -> np.ndarray:
        """Copy data from virtual GPU to CPU"""
        info = self._tensor_cache[tensor_id]
        data = self.vram.load_tensor(info['tensor_id'])
        return np.asarray(data, dtype=info['dtype']).reshape(info['shape'])
        
    def matmul(self, a: Union[str, "HeliumTensor"], b: Union[str, "HeliumTensor"]) -> str:
        """Matrix multiplication on virtual GPU"""
        a_id = a if isinstance(a, str) else self.to_gpu(a.numpy())
        b_id = b if isinstance(b, str) else self.to_gpu(b.numpy())
        
        a_info = self._tensor_cache[a_id]
        b_info = self._tensor_cache[b_id]
        
        out_shape = (a_info['shape'][0], b_info['shape'][1])
        out_id = self.allocate(out_shape, a_info['dtype'])
        
        # Load tensors
        a_data = self.vram.load_tensor(a_info['tensor_id'])
        b_data = self.vram.load_tensor(b_info['tensor_id'])
        
        # Queue computation
        self._command_queue.put((
            'matmul',
            (a_data, b_data),
            {'out_id': self._tensor_cache[out_id]['tensor_id']}
        ))
        
        return out_id
        
    def add(self, a: Union[str, "HeliumTensor"], b: Union[str, "HeliumTensor"]) -> str:
        """Element-wise addition on virtual GPU"""
        a_id = a if isinstance(a, str) else self.to_gpu(a.numpy())
        b_id = b if isinstance(b, str) else self.to_gpu(b.numpy())
        
        a_info = self._tensor_cache[a_id]
        b_info = self._tensor_cache[b_id]
        out_id = self.allocate(a_info['shape'], a_info['dtype'])
        
        # Load tensors
        a_data = self.vram.load_tensor(a_info['tensor_id'])
        b_data = self.vram.load_tensor(b_info['tensor_id'])
        
        # Queue computation
        self._command_queue.put((
            'add',
            (a_data, b_data),
            {'out_id': self._tensor_cache[out_id]['tensor_id']}
        ))
        
        return out_id
        
    def mul(self, a: Union[str, "HeliumTensor"], b: Union[str, "HeliumTensor"]) -> str:
        """Element-wise multiplication on virtual GPU"""
        a_id = a if isinstance(a, str) else self.to_gpu(a.numpy())
        b_id = b if isinstance(b, str) else self.to_gpu(b.numpy())
        
        a_info = self._tensor_cache[a_id]
        b_info = self._tensor_cache[b_id]
        out_id = self.allocate(a_info['shape'], a_info['dtype'])
        
        # Load tensors
        a_data = self.vram.load_tensor(a_info['tensor_id'])
        b_data = self.vram.load_tensor(b_info['tensor_id'])
        
        # Queue computation
        self._command_queue.put((
            'multiply',
            (a_data, b_data),
            {'out_id': self._tensor_cache[out_id]['tensor_id']}
        ))
        
        return out_id
        
    def mul_scalar(self, a: Union[str, "HeliumTensor"], scalar: float) -> str:
        """Scalar multiplication on virtual GPU"""
        a_id = a if isinstance(a, str) else self.to_gpu(a.numpy())
        a_info = self._tensor_cache[a_id]
        out_id = self.allocate(a_info['shape'], a_info['dtype'])
        
        # Load tensor
        a_data = self.vram.load_tensor(a_info['tensor_id'])
        
        # Queue computation
        self._command_queue.put((
            'scalar_multiply',
            (a_data, scalar),
            {'out_id': self._tensor_cache[out_id]['tensor_id']}
        ))
        
        return out_id
        
    def transpose(self, a: Union[str, "HeliumTensor"], axes: Optional[Tuple[int, ...]] = None) -> str:
        """Transpose tensor on virtual GPU"""
        a_id = a if isinstance(a, str) else self.to_gpu(a.numpy())
        a_info = self._tensor_cache[a_id]
        
        if axes is None:
            axes = tuple(range(len(a_info['shape'])-1, -1, -1))
            
        new_shape = tuple(a_info['shape'][i] for i in axes)
        out_id = self.allocate(new_shape, a_info['dtype'])
        
        # Load tensor
        a_data = self.vram.load_tensor(a_info['tensor_id'])
        
        # Queue computation
        self._command_queue.put((
            'transpose',
            (a_data,),
            {
                'axes': axes,
                'out_id': self._tensor_cache[out_id]['tensor_id']
            }
        ))
        
        return out_id
        
    def reshape(self, a: Union[str, "HeliumTensor"], new_shape: Tuple[int, ...]) -> str:
        """Reshape tensor on virtual GPU"""
        a_id = a if isinstance(a, str) else self.to_gpu(a.numpy())
        a_info = self._tensor_cache[a_id]
        
        # Verify shapes are compatible
        if np.prod(new_shape) != np.prod(a_info['shape']):
            raise ValueError("New shape must have same total size as old shape")
            
        out_id = self.allocate(new_shape, a_info['dtype'])
        
        # Load tensor
        a_data = self.vram.load_tensor(a_info['tensor_id'])
        
        # Queue computation
        self._command_queue.put((
            'reshape',
            (a_data,),
            {
                'new_shape': new_shape,
                'out_id': self._tensor_cache[out_id]['tensor_id']
            }
        ))
        
        return out_id
        
    def softmax(self, a: Union[str, "HeliumTensor"], axis: int = -1) -> str:
        """Softmax on virtual GPU"""
        a_id = a if isinstance(a, str) else self.to_gpu(a.numpy())
        a_info = self._tensor_cache[a_id]
        out_id = self.allocate(a_info['shape'], a_info['dtype'])
        
        # Load tensor
        a_data = self.vram.load_tensor(a_info['tensor_id'])
        
        # Queue computation
        self._command_queue.put((
            'softmax',
            (a_data,),
            {
                'axis': axis,
                'out_id': self._tensor_cache[out_id]['tensor_id']
            }
        ))
        
        return out_id
        
    def get_tensor(self, tensor_id: str) -> np.ndarray:
        """Get tensor data"""
        if tensor_id not in self._tensor_cache:
            raise KeyError(f"Tensor {tensor_id} not found")
        return self.from_gpu(tensor_id)
        
    def tensor_exists(self, tensor_id: str) -> bool:
        """Check if tensor exists in virtual GPU memory"""
        return tensor_id in self._tensor_cache
        
    def delete_tensor(self, tensor_id: str):
        """Free tensor memory"""
        if tensor_id in self._tensor_cache:
            self.vram.free(self._tensor_cache[tensor_id]['tensor_id'])
            del self._tensor_cache[tensor_id]
            
    def __del__(self):
        """Cleanup allocated memory and stop worker"""
        # Stop command processing thread
        self._command_queue.put(None)
        if hasattr(self, '_worker_thread'):
            self._worker_thread.join()
            
        # Free allocated tensors
        for tensor_id in list(self._tensor_cache.keys()):
            self.delete_tensor(tensor_id)
