"""
Advanced tensor storage operations extending the base storage system.
"""

import numpy as np
from typing import Dict, Any, Optional, Union, List, Tuple
import threading
from http_storage import LocalStorage
import logging
import json

class TensorOps:
    """Tensor operations for storage system"""
    
    @staticmethod
    def serialize_tensor(tensor: np.ndarray) -> Tuple[bytes, Dict[str, Any]]:
        """Serialize tensor to bytes with metadata"""
        metadata = {
            'shape': tensor.shape,
            'dtype': str(tensor.dtype),
            'strides': tensor.strides
        }
        return tensor.tobytes(), metadata
        
    @staticmethod
    def deserialize_tensor(data: bytes, metadata: Dict[str, Any]) -> np.ndarray:
        """Deserialize tensor from bytes and metadata"""
        tensor = np.frombuffer(data, dtype=np.dtype(metadata['dtype']))
        return tensor.reshape(metadata['shape'])

class TensorStorage(LocalStorage):
    """
    Enhanced storage implementation with tensor operations.
    Extends LocalStorage with advanced tensor manipulation capabilities.
    """
    
    def __init__(self, db_url: str = None):
        super().__init__(db_url)
        self._ops = TensorOps()
        
    def store_tensor(self, tensor_id: str, tensor: np.ndarray, metadata: Optional[Dict] = None) -> bool:
        """Store a tensor with its metadata"""
        try:
            # Serialize tensor
            data, tensor_meta = self._ops.serialize_tensor(tensor)
            
            # Merge with additional metadata
            if metadata:
                tensor_meta.update(metadata)
            
            # Store in database
            success = self._store_in_db('tensors', tensor_id, data, tensor_meta)
            
            if success:
                # Update stats
                with self._lock:
                    self.stats['tensor_count'] += 1
                    self.stats['total_size'] += len(data)
                    
            return success
            
        except Exception as e:
            logging.error(f"Error storing tensor {tensor_id}: {str(e)}")
            return False
            
    def get_tensor(self, tensor_id: str) -> Optional[np.ndarray]:
        """Retrieve a tensor by ID"""
        try:
            # Get from database
            result = self.conn.execute("""
                SELECT data, metadata
                FROM tensors
                WHERE id = ?
            """, [tensor_id]).fetchone()
            
            if not result:
                return None
                
            data, metadata = result
            metadata = json.loads(metadata) if isinstance(metadata, str) else metadata
            
            # Deserialize tensor
            return self._ops.deserialize_tensor(data, metadata)
            
        except Exception as e:
            logging.error(f"Error retrieving tensor {tensor_id}: {str(e)}")
            return None
            
    def matmul(self, tensor_id_a: str, tensor_id_b: str, result_id: str) -> Optional[str]:
        """Perform matrix multiplication between two stored tensors"""
        tensor_a = self.get_tensor(tensor_id_a)
        tensor_b = self.get_tensor(tensor_id_b)
        
        if tensor_a is None or tensor_b is None:
            return None
            
        try:
            result = np.matmul(tensor_a, tensor_b)
            if self.store_tensor(result_id, result):
                return result_id
            return None
            
        except Exception as e:
            logging.error(f"Error in matrix multiplication: {str(e)}")
            return None
            
    def add(self, tensor_id_a: str, tensor_id_b: str, result_id: str) -> Optional[str]:
        """Add two stored tensors"""
        tensor_a = self.get_tensor(tensor_id_a)
        tensor_b = self.get_tensor(tensor_id_b)
        
        if tensor_a is None or tensor_b is None:
            return None
            
        try:
            result = tensor_a + tensor_b
            if self.store_tensor(result_id, result):
                return result_id
            return None
            
        except Exception as e:
            logging.error(f"Error in tensor addition: {str(e)}")
            return None
            
    def multiply(self, tensor_id_a: str, tensor_id_b: str, result_id: str) -> Optional[str]:
        """Element-wise multiply two stored tensors"""
        tensor_a = self.get_tensor(tensor_id_a)
        tensor_b = self.get_tensor(tensor_id_b)
        
        if tensor_a is None or tensor_b is None:
            return None
            
        try:
            result = tensor_a * tensor_b
            if self.store_tensor(result_id, result):
                return result_id
            return None
            
        except Exception as e:
            logging.error(f"Error in tensor multiplication: {str(e)}")
            return None
            
    def transpose(self, tensor_id: str, result_id: str) -> Optional[str]:
        """Transpose a stored tensor"""
        tensor = self.get_tensor(tensor_id)
        
        if tensor is None:
            return None
            
        try:
            result = np.transpose(tensor)
            if self.store_tensor(result_id, result):
                return result_id
            return None
            
        except Exception as e:
            logging.error(f"Error in tensor transpose: {str(e)}")
            return None
            
    def reshape(self, tensor_id: str, new_shape: Tuple[int, ...], result_id: str) -> Optional[str]:
        """Reshape a stored tensor"""
        tensor = self.get_tensor(tensor_id)
        
        if tensor is None:
            return None
            
        try:
            result = tensor.reshape(new_shape)
            if self.store_tensor(result_id, result):
                return result_id
            return None
            
        except Exception as e:
            logging.error(f"Error in tensor reshape: {str(e)}")
            return None
            
    def split(self, tensor_id: str, indices_or_sections: Union[int, List[int]], axis: int = 0) -> List[str]:
        """Split a stored tensor into multiple tensors"""
        tensor = self.get_tensor(tensor_id)
        
        if tensor is None:
            return []
            
        try:
            # Split the tensor
            split_results = np.split(tensor, indices_or_sections, axis=axis)
            
            # Store each split result
            result_ids = []
            for i, split_tensor in enumerate(split_results):
                result_id = f"{tensor_id}_split_{i}"
                if self.store_tensor(result_id, split_tensor):
                    result_ids.append(result_id)
                    
            return result_ids
            
        except Exception as e:
            logging.error(f"Error in tensor split: {str(e)}")
            return []
            
    def concatenate(self, tensor_ids: List[str], result_id: str, axis: int = 0) -> Optional[str]:
        """Concatenate multiple stored tensors"""
        tensors = [self.get_tensor(tid) for tid in tensor_ids]
        
        if None in tensors:
            return None
            
        try:
            result = np.concatenate(tensors, axis=axis)
            if self.store_tensor(result_id, result):
                return result_id
            return None
            
        except Exception as e:
            logging.error(f"Error in tensor concatenation: {str(e)}")
            return None
