"""
Tensor type system with DuckDB-based type and metadata storage
"""
from enum import Enum
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import duckdb
import json
import os
import pathlib

# Create data directory if it doesn't exist
os.makedirs("data", exist_ok=True)



class TensorTypeDB:
    _instance = None
    DB_FILE = "tensor_types.db"
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(TensorTypeDB, cls).__new__(cls)
            cls._instance.init_db()
        return cls._instance
        
    def init_db(self):
        # Use local DuckDB file
        self.conn = duckdb.connect(self.DB_FILE)
        
        self._init_tables()
        self._init_types()
        
    def ensure_connection(self):
        """Ensure database connection is active and valid"""
        try:
            self.conn.execute("SELECT 1")
        except:
            self.init_db()
            
    def _init_tables(self):
        # Data types table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS tensor_dtypes (
                id INTEGER PRIMARY KEY,
                name VARCHAR UNIQUE,
                size_bits INTEGER,
                is_floating BOOLEAN,
                is_signed BOOLEAN
            )
        """)
        
        # Tensor layout table 
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS tensor_layouts (
                id INTEGER PRIMARY KEY,
                name VARCHAR UNIQUE,
                description TEXT
            )
        """)
        
        # Device types table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS device_types (
                id INTEGER PRIMARY KEY,
                name VARCHAR UNIQUE,
                description TEXT
            )
        """)
        
    def _init_types(self):
        # Initialize dtypes
        dtypes = [
            (1, 'float32', 32, True, True),
            (2, 'float64', 64, True, True),
            (3, 'float16', 16, True, True),
            (4, 'bfloat16', 16, True, True),
            (5, 'int32', 32, False, True),
            (6, 'int64', 64, False, True),
            (7, 'int16', 16, False, True),
            (8, 'int8', 8, False, True),
            (9, 'uint8', 8, False, False),
            (10, 'bool', 1, False, False)
        ]
        
        for dtype in dtypes:
            self.conn.execute("""
                INSERT INTO tensor_dtypes 
                (id, name, size_bits, is_floating, is_signed)
                VALUES (?, ?, ?, ?, ?)
                ON CONFLICT (id) DO UPDATE SET
                    name = excluded.name,
                    size_bits = excluded.size_bits,
                    is_floating = excluded.is_floating,
                    is_signed = excluded.is_signed
            """, dtype)
            
        # Initialize layouts
        layouts = [
            (1, 'NCHW', 'Standard CNN layout (batch, channels, height, width)'),
            (2, 'NHWC', 'TensorFlow preferred layout'),
            (3, 'CHWN', 'Optimized for certain hardware'),
            (4, 'RowMajor', 'Standard matrix layout'),
            (5, 'ColumnMajor', 'Fortran-style layout')
        ]
        
        for layout in layouts:
            self.conn.execute("""
                INSERT INTO tensor_layouts
                (id, name, description)
                VALUES (?, ?, ?)
                ON CONFLICT (id) DO UPDATE SET
                    name = excluded.name,
                    description = excluded.description
            """, layout)
            
        # Initialize device types
        devices = [
            (1, 'VIRTUAL', 'Virtual GPU Device'),
            (2, 'DISTRIBUTED', 'Multi-Device')
        ]
        
        for device in devices:
            self.conn.execute("""
                INSERT INTO device_types
                (id, name, description)
                VALUES (?, ?, ?)
                ON CONFLICT (id) DO UPDATE SET
                    name = excluded.name,
                    description = excluded.description
            """, device)

class DType(Enum):
    """Tensor data types backed by database"""
    def __new__(cls, value):
        obj = object.__new__(cls)
        obj._value_ = value
        db = TensorTypeDB()
        obj._db_value = db.conn.execute(
            "SELECT * FROM tensor_dtypes WHERE id = ?", [value]
        ).fetchone()
        return obj
        
    @property
    def bits(self):
        return self._db_value[2]
        
    @property
    def is_floating(self):
        return bool(self._db_value[3])
        
    @property
    def is_signed(self):
        return bool(self._db_value[4])
    
    FLOAT32 = 1
    FLOAT64 = 2
    FLOAT16 = 3
    BFLOAT16 = 4
    INT32 = 5
    INT64 = 6
    INT16 = 7
    INT8 = 8
    UINT8 = 9
    BOOL = 10

class Layout(Enum):
    """Tensor memory layouts backed by database"""
    def __new__(cls, value):
        obj = object.__new__(cls)
        obj._value_ = value
        db = TensorTypeDB()
        obj._db_value = db.conn.execute(
            "SELECT * FROM tensor_layouts WHERE id = ?", [value]
        ).fetchone()
        return obj
        
    @property
    def description(self):
        return self._db_value[2]
    
    NCHW = 1
    NHWC = 2
    CHWN = 3
    ROW_MAJOR = 4
    COLUMN_MAJOR = 5

class Device(Enum):
    """Tensor device types backed by database"""
    def __new__(cls, value):
        obj = object.__new__(cls)
        obj._value_ = value
        db = TensorTypeDB()
        db.ensure_connection()
        obj._db_value = db.conn.execute(
            "SELECT * FROM device_types WHERE id = ?", [value]
        ).fetchone()
        return obj
        
    @property
    def description(self):
        return self._db_value[2]
    
    VIRTUAL = 1
    DISTRIBUTED = 2

@dataclass
class TensorDescriptor:
    shape: Tuple[int, ...]
    dtype: DType
    layout: Layout
    device: Device
    requires_grad: bool = False
    
    def to_dict(self) -> Dict:
        return {
            "shape": self.shape,
            "dtype": self.dtype.value,
            "layout": self.layout.value,
            "device": self.device.value,
            "requires_grad": self.requires_grad
        }
    
    @classmethod
    def from_dict(cls, d: Dict) -> 'TensorDescriptor':
        return cls(
            shape=tuple(d["shape"]),
            dtype=DType(d["dtype"]),
            layout=Layout(d["layout"]),
            device=Device(d["device"]),
            requires_grad=d.get("requires_grad", False)
        )

class TensorDB:
    DB_FILE = "tensors.db"
    
    def __init__(self):
        self.init_db()
    
    def init_db(self):
        # Use local DuckDB file
        self.conn = duckdb.connect(self.DB_FILE)
        
        self._init_tables()
    
    def ensure_connection(self):
        """Ensure database connection is active and valid"""
        try:
            self.conn.execute("SELECT 1")
        except:
            self.init_db()
    
    def _init_tables(self):
        # Tensor metadata table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS tensors (
                id VARCHAR PRIMARY KEY,
                shape VARCHAR,  -- JSON array
                dtype VARCHAR,
                layout VARCHAR,
                device VARCHAR,
                requires_grad BOOLEAN,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            )
        """)
        
        # Tensor data table (for CPU tensors)
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS tensor_data (
                tensor_id VARCHAR PRIMARY KEY,
                data BLOB,
                FOREIGN KEY (tensor_id) REFERENCES tensors(id)
            )
        """)
        
        # Tensor operations history
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS tensor_ops (
                id INTEGER PRIMARY KEY,
                tensor_id VARCHAR,
                operation VARCHAR,
                timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                FOREIGN KEY (tensor_id) REFERENCES tensors(id)
            )
        """)
        
    def _init_types(self):
        # Initialize dtypes
        dtypes = [
            (1, 'float32', 32, True, True),
            (2, 'float64', 64, True, True),
            (3, 'float16', 16, True, True),
            (4, 'bfloat16', 16, True, True),
            (5, 'int32', 32, False, True),
            (6, 'int64', 64, False, True),
            (7, 'int16', 16, False, True),
            (8, 'int8', 8, False, True),
            (9, 'uint8', 8, False, False),
            (10, 'bool', 1, False, False)
        ]
        
        for dtype in dtypes:
            self.conn.execute("""
                INSERT INTO tensor_dtypes 
                (id, name, size_bits, is_floating, is_signed)
                VALUES (?, ?, ?, ?, ?)
                ON CONFLICT (id) DO UPDATE SET
                    name = excluded.name,
                    size_bits = excluded.size_bits,
                    is_floating = excluded.is_floating,
                    is_signed = excluded.is_signed
            """, dtype)
        # Initialize layouts
        layouts = [
            (1, 'NCHW', 'Standard CNN layout (batch, channels, height, width)'),
            (2, 'NHWC', 'TensorFlow preferred layout'),
            (3, 'CHWN', 'Optimized for certain hardware'),
            (4, 'RowMajor', 'Standard matrix layout'),
            (5, 'ColumnMajor', 'Fortran-style layout')
        ]
        
        for layout in layouts:
            self.conn.execute("""
                INSERT INTO tensor_layouts
                (id, name, description)
                VALUES (?, ?, ?)
                ON CONFLICT (id) DO UPDATE SET
                    name = excluded.name,
                    description = excluded.description
            """, layout)
        # Initialize device types
        devices = [
            (1, 'VGPU', 'Virtual GPU Device'),
            (2, 'CPU', 'CPU Device'),
            (3, 'TPU', 'Tensor Processing Unit'),
            (4, 'Multi', 'Multi-Device')
        ]
        
        for device in devices:
            self.conn.execute("""
                INSERT INTO device_types
                (id, name, description)
                VALUES (?, ?, ?)
                ON CONFLICT (id) DO UPDATE SET
                    name = excluded.name,
                    description = excluded.description
            """, device)

class DType(Enum):
    """Tensor data types backed by database"""
    def __new__(cls, value):
        obj = object.__new__(cls)
        obj._value_ = value
        db = TensorTypeDB()
        obj._db_value = db.conn.execute(
            "SELECT * FROM tensor_dtypes WHERE id = ?", [value]
        ).fetchone()
        return obj
        
    @property
    def bits(self):
        return self._db_value[2]
        
    @property
    def is_floating(self):
        return bool(self._db_value[3])
        
    @property
    def is_signed(self):
        return bool(self._db_value[4])
    
    FLOAT32 = 1
    FLOAT64 = 2
    FLOAT16 = 3
    BFLOAT16 = 4
    INT32 = 5
    INT64 = 6
    INT16 = 7
    INT8 = 8
    UINT8 = 9
    BOOL = 10

class Layout(Enum):
    """Tensor memory layouts backed by database"""
    def __new__(cls, value):
        obj = object.__new__(cls)
        obj._value_ = value
        db = TensorTypeDB()
        obj._db_value = db.conn.execute(
            "SELECT * FROM tensor_layouts WHERE id = ?", [value]
        ).fetchone()
        return obj
        
    @property
    def description(self):
        return self._db_value[2]
    
    NCHW = 1
    NHWC = 2
    CHWN = 3
    ROW_MAJOR = 4
    COLUMN_MAJOR = 5

class Tensor:
    _db = TensorDB()
    
    def __init__(self, data: Union[np.ndarray, List, Tuple], 
                 dtype: Optional[DType] = None,
                 device: Optional[Device] = None,
                 layout: Optional[Layout] = None,
                 requires_grad: bool = False):
        
        if isinstance(data, (list, tuple)):
            data = np.array(data)
            
        self.data = data
        self.id = f"tensor_{id(self)}"
        
        if dtype is None:
            dtype = self._numpy_to_dtype(data.dtype)
        if device is None:
            device = Device.CPU
        if layout is None:
            layout = Layout.ROW_MAJOR
            
        self.descriptor = TensorDescriptor(
            shape=data.shape,
            dtype=dtype,
            layout=layout,
            device=device,
            requires_grad=requires_grad
        )
        
        self._register_tensor()
    
    def _register_tensor(self):
        """Register tensor in the database"""
        self._db.ensure_connection()
        self._db.conn.execute("""
            INSERT INTO tensors 
            (id, shape, dtype, layout, device, requires_grad)
            VALUES (?, ?, ?, ?, ?, ?)
        """, [
            self.id,
            json.dumps(self.descriptor.shape),
            self.descriptor.dtype.value,
            self.descriptor.layout.value,
            self.descriptor.device.value,
            self.descriptor.requires_grad
        ])
        
        if self.descriptor.device == Device.CPU:
            self._db.ensure_connection()
            self._db.conn.execute("""
                INSERT INTO tensor_data (tensor_id, data)
                VALUES (?, ?)
            """, [self.id, duckdb.blob(self.data.tobytes())])
    
    def _log_operation(self, operation: str):
        """Log tensor operation in history"""
        self._db.ensure_connection()
        self._db.conn.execute("""
            INSERT INTO tensor_ops (tensor_id, operation)
            VALUES (?, ?)
        """, [self.id, operation])
    
    @staticmethod
    def _numpy_to_dtype(np_dtype) -> DType:
        """Convert numpy dtype to our DType enum"""
        dtype_map = {
            np.float32: DType.FLOAT32,
            np.float64: DType.FLOAT64,
            np.float16: DType.FLOAT16,
            np.int32: DType.INT32,
            np.int64: DType.INT64,
            np.int16: DType.INT16,
            np.int8: DType.INT8,
            np.uint8: DType.UINT8,
            np.bool_: DType.BOOL
        }
        return dtype_map.get(np_dtype.type, DType.FLOAT32)
    
    def to(self, device: Device) -> 'Tensor':
        """Move tensor to specified device"""
        if device == self.descriptor.device:
            return self
            
        self._log_operation(f"move_to_{device.value}")
        self.descriptor = TensorDescriptor(
            shape=self.descriptor.shape,
            dtype=self.descriptor.dtype,
            layout=self.descriptor.layout,
            device=device,
            requires_grad=self.descriptor.requires_grad
        )
        
        # Update device in database
        self._db.conn.execute("""
            UPDATE tensors 
            SET device = ?
            WHERE id = ?
        """, [device.value, self.id])
        
        return self
    
    def cpu(self) -> 'Tensor':
        """Move tensor to CPU"""
        return self.to(Device.CPU)
    
    def cuda(self) -> 'Tensor':
        """Move tensor to CUDA device"""
        return self.to(Device.CUDA)
    
    def virtual(self) -> 'Tensor':
        """Move tensor to virtual device"""
        return self.to(Device.VIRTUAL)
    
    @property
    def shape(self) -> Tuple[int, ...]:
        return self.descriptor.shape
    
    @property
    def dtype(self) -> DType:
        return self.descriptor.dtype
    
    @property
    def device(self) -> Device:
        return self.descriptor.device
    
    @property
    def layout(self) -> Layout:
        return self.descriptor.layout
    
    def __repr__(self) -> str:
        return f"Tensor(shape={self.shape}, dtype={self.dtype.value}, device={self.device.value})"

class Device(Enum):
    """Device types backed by database"""
    def __new__(cls, value):
        obj = object.__new__(cls)
        obj._value_ = value
        db = TensorTypeDB()
        obj._db_value = db.conn.execute(
            "SELECT * FROM device_types WHERE id = ?", [value]
        ).fetchone()
        return obj
        
    @property
    def description(self):
        return self._db_value[2]
    
    VGPU = 1
    CPU = 2
    TPU = 3
    MULTI = 4

@dataclass
class TensorDescriptor:
    """Full tensor descriptor"""
    shape: Tuple[int, ...]
    dtype: DType
    layout: Layout = Layout.ROW_MAJOR
    device: Device = Device.VGPU
    requires_grad: bool = False
    is_leaf: bool = True
    grad_fn: Optional[str] = None
    _version: int = 1

@dataclass
class TensorStorageDescriptor:
    """Physical storage descriptor"""
    data_ptr: int  # Memory address/pointer
    size_bytes: int
    chip_id: int
    sm_assignments: Dict[int, List[int]]  # SM ID -> List of tensor slice indices
    stream_id: Optional[int] = None
    is_pinned: bool = False
    cache_mode: str = "write_back"
    _version: int = 1


