from http_storage import LocalStorage
import numpy as np
from typing import Dict, Any, Optional, Union
import time
import threading
import logging
import hashlib

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

class VirtualVRAM:
    def __init__(self, storage=None):
        """Initialize mining-optimized virtual VRAM"""
        self.block_size = 256  # Size of SHA-256 block
        self.max_blocks = 1024  # Max concurrent blocks
        self.block_lifetime = 0.1  # 100ms block lifetime
        
        # Thread safety
        self.lock = threading.Lock()
        self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
        
        # Initialize minimal block tracking
        self.active_blocks = {}  # {block_id: (memoryview, creation_time)}
        self.allocated_bytes = 0
        self.last_cleanup = time.time()
        
        # Start cleanup thread
        self.running = True
        self.cleanup_thread.start()
        
        logging.info(f"Initialized mining-optimized VRAM: {self.block_size}B blocks, {self.max_blocks} max blocks")
        
        # Thread safety
        self.state_lock = threading.Lock()
        self.vram_state = {
            "total_size": self.block_size * self.max_blocks,
            "allocated": 0,
            "is_unlimited": False
        }
    @property
    def total_size(self) -> int:
        """Get total VRAM size in bytes"""
        return self.vram_state["total_size"]
        
    @property
    def available_size(self) -> int:
        """Get available VRAM size in bytes"""
        return self.vram_state["total_size"] - self.vram_state["allocated"] if not self.vram_state["is_unlimited"] else float('inf')
        
    def _cleanup_loop(self):
        """Background thread to clean up expired memory blocks"""
        while self.running:
            current_time = time.time()
            if current_time - self.last_cleanup >= self.block_lifetime:
                with self.lock:
                    # Remove expired blocks
                    expired = []
                    for block_id, (block_data, creation_time) in self.active_blocks.items():
                        if current_time - creation_time > self.block_lifetime:
                            expired.append(block_id)
                    
                    # Clean up expired blocks
                    for block_id in expired:
                        del self.active_blocks[block_id]
                        
                    self.last_cleanup = current_time
            
            # Sleep briefly to avoid consuming too much CPU
            time.sleep(0.01)
        
    def store_vram_state(self):
        """Store VRAM state in remote storage"""
        if self.storage is None:
            logging.error("No storage manager available for VRAM state persistence")
            return
            
        with self.state_lock:
            try:
                # Prepare state for storage
                safe_state = dict(self.vram_state)
                # Handle infinity values for JSON serialization
                if isinstance(safe_state["total_size"], float) and safe_state["total_size"] == float('inf'):
                    safe_state["total_size"] = "inf"
                
                # Add metadata
                safe_state.update({
                    "last_sync": time.time_ns(),
                    "sync_count": safe_state.get("sync_count", 0) + 1
                })
                
                # Store in remote storage
                success = self.storage.store_state(
                    component="vram",
                    state_id=self.vram_id,
                    state_data=safe_state
                )
                
                if not success:
                    raise RuntimeError("Failed to store VRAM state")
                
                return True
                
            except Exception as e:
                logging.error(f"Error storing VRAM state: {str(e)}")
                return False
          
    def allocate_block(self, size: int, block_id: Optional[str] = None) -> str:
        """Allocate a block of VRAM"""
        with self.state_lock:
            # Check available space
            if not self.vram_state["is_unlimited"] and self.vram_state["allocated"] + size > self.vram_state["total_size"]:
                raise MemoryError(f"Not enough VRAM available. Requested: {size}, Available: {self.available_size}")
            
            # Generate unique block ID if not provided
            if block_id is None:
                block_id = f"block_{self.vram_id}_{time.time_ns()}"
            
            try:
                # Initialize block metadata with enhanced tracking
                block_metadata = {
                    "size": size,
                    "vram_id": self.vram_id,
                    "allocated_at": time.time_ns(),
                    "last_accessed": time.time_ns(),
                    "last_modified": time.time_ns(),
                    "access_count": 0,
                    "write_count": 0,
                    "status": "allocated",
                    "flags": {
                        "locked": False,
                        "persistent": False,
                        "cached": False
                    }
                }
                
                # Store initial empty block in storage with proper size
                empty_data = np.zeros(size, dtype=np.uint8)  # Allocate full size to reserve space
                success = self.storage.store_tensor(block_id, empty_data, block_metadata)
                
                if not success:
                    raise RuntimeError(f"Failed to initialize block {block_id}")
                
                # Update VRAM state
                self.vram_state["blocks"][block_id] = block_metadata
                self.vram_state["allocated"] += size
                
                # Store updated state
                if self.store_vram_state():
                    return block_id
                else:
                    # Rollback on state storage failure
                    del self.vram_state["blocks"][block_id]
                    self.vram_state["allocated"] -= size
                    raise RuntimeError("Failed to store VRAM state")
                    
            except Exception as e:
                logging.error(f"Error allocating block: {str(e)}")
                raise
        
    def free_block(self, block_id: str):
        """Free a block of VRAM"""
        with self.state_lock:
            if block_id not in self.vram_state["blocks"]:
                logging.warning(f"Attempted to free non-existent block {block_id}")
                return False
                
            try:
                # Update block metadata to mark as freed
                freed_metadata = {
                    **self.vram_state["blocks"][block_id],
                    "status": "freed",
                    "freed_at": time.time_ns()
                }
                
                # Store final state of block before deletion
                self.storage.store_tensor(
                    block_id,
                    np.array([]),  # Empty array as placeholder
                    freed_metadata
                )
                
                # Update VRAM state
                self.vram_state["allocated"] -= self.vram_state["blocks"][block_id]["size"]
                del self.vram_state["blocks"][block_id]
                
                # Update memory mappings
                for addr, mapped_id in list(self.vram_state["memory_map"].items()):
                    if mapped_id == block_id:
                        del self.vram_state["memory_map"][addr]
                
                return self.store_vram_state()
                
            except Exception as e:
                logging.error(f"Error freeing block {block_id}: {str(e)}")
                return False
            
    def write_block(self, block_id: str, data: np.ndarray) -> bool:
        """Write data to a VRAM block"""
        with self.state_lock:
            if block_id not in self.vram_state["blocks"]:
                raise ValueError(f"Block {block_id} not allocated")
            
            # Check block flags
            block_info = self.vram_state["blocks"][block_id]
            if block_info.get("flags", {}).get("locked", False):
                raise RuntimeError(f"Block {block_id} is locked for writing")
                
            try:
                # Set write lock
                block_info["flags"] = block_info.get("flags", {})
                block_info["flags"]["locked"] = True
                self.store_vram_state()
                
                # Validate data size
                data_size = data.nbytes
                block_size = block_info["size"]
                if data_size > block_size:
                    raise ValueError(f"Data size ({data_size}) exceeds block size ({block_size})")
                
                # Update metadata with enhanced tracking
                write_metadata = {
                    **block_info,
                    "last_accessed": time.time_ns(),
                    "last_written": time.time_ns(),
                    "data_size": data_size,
                    "shape": data.shape,
                    "dtype": str(data.dtype),
                    "write_count": block_info.get("write_count", 0) + 1,
                    "access_count": block_info.get("access_count", 0) + 1,
                    "status": "written",
                    "flags": {**block_info["flags"], "locked": False}
                }
                
                # Store data with updated metadata
                success = self.storage.store_tensor(block_id, data, write_metadata)
                if not success:
                    raise RuntimeError("Failed to store tensor data")
                
                # Update block metadata in VRAM state
                self.vram_state["blocks"][block_id].update(write_metadata)
                return self.store_vram_state()
                
            except Exception as e:
                logging.error(f"Error writing to block {block_id}: {str(e)}")
                return False
        
    def read_block(self, block_id: str) -> Optional[np.ndarray]:
        """Read data from a VRAM block"""
        with self.state_lock:
            if block_id not in self.vram_state["blocks"]:
                raise ValueError(f"Block {block_id} not allocated")
            
            try:
                # Read from storage
                result = self.storage.load_tensor(block_id)
                if result is None:
                    logging.error(f"Failed to read block {block_id}")
                    return None
                    
                data, metadata = result
                
                # Update access time
                self.vram_state["blocks"][block_id]["last_accessed"] = time.time_ns()
                self.store_vram_state()
                
                return data
                
            except Exception as e:
                logging.error(f"Error reading block {block_id}: {str(e)}")
                return None
        
    def map_address(self, virtual_addr: str, block_id: str) -> bool:
        """Map virtual address to VRAM block"""
        with self.state_lock:
            if block_id not in self.vram_state["blocks"]:
                raise ValueError(f"Cannot map to non-existent block {block_id}")
            
            try:
                self.vram_state["memory_map"][virtual_addr] = block_id
                return self.store_vram_state()
            except Exception as e:
                logging.error(f"Error mapping address {virtual_addr}: {str(e)}")
                return False
        
    def get_block_from_address(self, virtual_addr: str) -> Optional[str]:
        """Get block ID from virtual address"""
        return self.vram_state["memory_map"].get(virtual_addr)
        
    def get_block_info(self, block_id: str) -> Optional[Dict[str, Any]]:
        """Get detailed information about a VRAM block"""
        if block_id not in self.vram_state["blocks"]:
            return None
            
        try:
            # Get block metadata from storage
            result = self.storage.load_tensor(block_id)
            if result is None:
                return self.vram_state["blocks"][block_id]
                
            _, metadata = result
            return {**self.vram_state["blocks"][block_id], **metadata}
            
        except Exception as e:
            logging.error(f"Error getting block info for {block_id}: {str(e)}")
            return self.vram_state["blocks"][block_id]
            
    def get_memory_usage(self) -> Dict[str, Any]:
        """Get current VRAM memory usage statistics"""
        with self.state_lock:
            active_blocks = len(self.vram_state["blocks"])
            mapped_addresses = len(self.vram_state["memory_map"])
            
            return {
                "vram_id": self.vram_id,
                "total_size": self.vram_state["total_size"],
                "allocated": self.vram_state["allocated"],
                "available": self.available_size,
                "active_blocks": active_blocks,
                "mapped_addresses": mapped_addresses,
                "utilization": (self.vram_state["allocated"] / self.vram_state["total_size"] * 100) 
                              if not self.vram_state["is_unlimited"] else 0,
                "timestamp": time.time_ns(),
                "is_unlimited": self.vram_state["is_unlimited"]
            }
            
    def cleanup_unused_blocks(self, older_than: Optional[float] = None) -> int:
        """Clean up unused VRAM blocks"""
        with self.state_lock:
            now = time.time_ns()
            cleaned = 0
            
            for block_id in list(self.vram_state["blocks"].keys()):
                block = self.vram_state["blocks"][block_id]
                last_access = block.get("last_accessed", 0)
                
                # Check if block is older than specified time
                if older_than and (now - last_access) / 1e9 < older_than:
                    continue
                    
                # Check if block is mapped
                if block_id in self.vram_state["memory_map"].values():
                    continue
                    
                # Free the block
                if self.free_block(block_id):
                    cleaned += 1
                    
            return cleaned
        return self.vram_state["memory_map"].get(virtual_addr)
        
    @property
    def size_gb(self) -> float:
        """Get VRAM size in GB"""
        return self.total_size / (1024 * 1024 * 1024)
        
    def get_stats(self) -> Dict[str, Any]:
        """Get VRAM statistics"""
        return {
            "total_gb": self.size_gb,
            "used_gb": self.vram_state["allocated"] / (1024 * 1024 * 1024),
            "free_gb": (self.vram_state["total_size"] - self.vram_state["allocated"]) / (1024 * 1024 * 1024),
            "num_blocks": len(self.vram_state["blocks"]),
            "mappings": len(self.vram_state["memory_map"])
        }
        
    def __str__(self) -> str:
        stats = self.get_stats()
        return f"VirtualVRAM(total={stats['total_gb']:.1f}GB, used={stats['used_gb']:.1f}GB, blocks={stats['num_blocks']})"
        
    def __repr__(self) -> str:
        return self.__str__()
