import numpy as np
from typing import Optional, Tuple, Union, List

from .virtual_gpu_device import VirtualGPUDevice
from .device_utils import get_gpu_device, get_default_device

class HeliumTensor:
    """Base tensor class for Helium framework with virtual GPU support"""
    def __init__(
        self,
        data: Union[np.ndarray, "HeliumTensor", str],
        device: Optional[Union[str, VirtualGPUDevice]] = None,
        requires_grad: bool = False
    ):
        self.device = device if device is not None else get_default_device()
        if isinstance(self.device, str):
            self.device = get_gpu_device(self.device)
            
        self.requires_grad = requires_grad
        self.grad: Optional[Union[np.ndarray, str]] = None
        self._backward_fn = lambda: None
        
        if isinstance(data, HeliumTensor):
            if data.device == self.device:
                self.data = data.data
            else:
                # Move data to new device
                if isinstance(data.data, str):
                    # Data is already on a GPU
                    cpu_data = data.device.from_gpu(data.data)
                    self.data = self.device.to_gpu(cpu_data)
                else:
                    self.data = self.device.to_gpu(data.data)
        elif isinstance(data, str):
            # Data is a tensor ID from virtual GPU
            if hasattr(self.device, 'tensor_exists') and self.device.tensor_exists(data):
                self.data = data
            else:
                raise ValueError(f"Tensor {data} not found on device {self.device}")
        else:
            # Data is a numpy array
            self.data = self.device.to_gpu(np.asarray(data))
        
    @property
    def shape(self) -> Tuple[int, ...]:
        if isinstance(self.data, str):
            return self.device.get_tensor(self.data).shape
        return self.data.shape
        
    @property
    def dtype(self):
        if isinstance(self.data, str):
            return self.device.get_tensor(self.data).dtype
        return self.data.dtype
        
    def numpy(self) -> np.ndarray:
        """Get tensor data as numpy array"""
        if isinstance(self.data, str):
            return self.device.from_gpu(self.data)
        return self.data
        
    def __repr__(self) -> str:
        shape = self.shape
        dtype = self.dtype
        device_str = f"gpu{self.device.device_id}" if hasattr(self.device, 'device_id') else str(self.device)
        return f"HeliumTensor(shape={shape}, dtype={dtype}, device={device_str})"
        
    def to(self, device) -> "HeliumTensor":
        """Move tensor to specified device"""
        if isinstance(device, str):
            device = get_gpu_device(device)
            
        if device == self.device:
            return self
            
        # Move data to new device
        if isinstance(self.data, str):
            # Data is already on a GPU
            cpu_data = self.device.from_gpu(self.data)
            new_data = device.to_gpu(cpu_data)
        else:
            new_data = device.to_gpu(self.data)
            
        return HeliumTensor(new_data, device, requires_grad=self.requires_grad)
        
    def detach(self) -> "HeliumTensor":
        """Create a new tensor detached from current compute graph"""
        return HeliumTensor(self.data, self.device, requires_grad=False)
        
    def backward(self, gradient: Optional[Union[np.ndarray, "HeliumTensor"]] = None):
        """Compute gradients through backward pass"""
        if not self.requires_grad:
            return
            
        if gradient is None:
            if isinstance(self.data, str):
                ones = np.ones(self.shape)
                gradient = self.device.to_gpu(ones)
            else:
                gradient = np.ones_like(self.data)
                
        elif isinstance(gradient, HeliumTensor):
            if gradient.device != self.device:
                gradient = gradient.to(self.device)
            gradient = gradient.data
            
        self.grad = gradient
        self._backward_fn()
        
    def zero_grad(self):
        """Zero out gradients"""
        if self.grad is not None:
            if isinstance(self.grad, str):
                self.device.delete_tensor(self.grad)
                self.grad = None
            else:
                self.grad.fill(0)
            return
            
        if gradient is None:
            gradient = np.ones_like(self.data)
            
        self.grad = gradient
        self._backward_fn()
        
    def zero_grad(self):
        """Zero out gradients"""
        if self.grad is not None:
            self.grad.fill(0)
