import numpy as np
import json
import threading
from typing import Dict, Any, Optional
from http_storage import LocalStorage

class MemoryManager:
    def __init__(self, hal):
        self.hal = hal
        self.storage = LocalStorage()  # Use LocalStorage for persistence
        self.lock = threading.Lock()
        
        # Cache for frequently accessed data
        self.allocated_blocks = {}
        self.memory_pools = {}
        self.stream_buffers = {}
        self.virtual_to_physical_map = {}
        self.next_virtual_address = 0
        self.tensor_cache = {}
        
    def create_memory_pool(self, pool_id, size_bytes):
        """Create a new VGPU memory pool for tensor operations"""
        with self.lock:
            # Check cache first
            if pool_id in self.memory_pools:
                raise ValueError(f"Memory pool {pool_id} already exists")
            
            # Create pool in persistent storage
            pool_data = {
                'size': size_bytes,
                'used': 0,
                'blocks': []
            }
            
            # Store in LocalStorage
            self.storage.store_state('gpu_memory_pools', pool_id, {
                'pool_id': pool_id,
                'size_bytes': size_bytes,
                'used_bytes': 0,
                'pool_type': 'tensor',
                'gpu_id': self.hal.gpu_id
            })
            
            # Update cache
            self.memory_pools[pool_id] = pool_data
        
    def _get_next_physical_address(self, size_bytes, pool_id=None):
        """Allocate from specific memory pool if pool_id provided"""
        with self.lock:
            if pool_id is not None:
                pool = self.memory_pools.get(pool_id)
                if pool and pool['used'] + size_bytes <= pool['size']:
                    addr = sum(b['size'] for b in pool['blocks'])
                    pool['used'] += size_bytes
                    block = {'addr': addr, 'size': size_bytes}
                    pool['blocks'].append(block)
                    
                    # Store block allocation in LocalStorage
                    self.storage.store_state('gpu_memory_blocks', f"{pool_id}_{addr}", {
                        'virtual_address': str(self.next_virtual_address),
                        'physical_address': str(addr),
                        'size_bytes': size_bytes,
                        'pool_id': pool_id,
                        'allocation_type': 'tensor'
                    })
                return addr
            raise MemoryError(f"Not enough memory in pool {pool_id}")
            
        # Default allocation for non-pool memory
        addr = sum(block['size'] for block in self.allocated_blocks.values())
        return addr

    def allocate(self, size_bytes, chip_id=0, pool_id=None, stream_id=None):
        if not self.hal.initialized:
            raise RuntimeError("HAL not initialized. Cannot allocate memory.")

        physical_address = self._get_next_physical_address(size_bytes, pool_id)
        
        # Handle stream-specific allocation
        if stream_id is not None:
            if stream_id not in self.stream_buffers:
                self.stream_buffers[stream_id] = []
            self.stream_buffers[stream_id].append(physical_address)
            
    def allocate_tensor(self, shape, dtype, pool_id=None, stream_id=None):
        """Allocate memory for tensor operations"""
        size_bytes = np.prod(shape) * np.dtype(dtype).itemsize
        addr = self.allocate(size_bytes, pool_id=pool_id, stream_id=stream_id)
        
        tensor_info = {
            'shape': shape,
            'dtype': dtype,
            'address': addr,
            'pool_id': pool_id,
            'stream_id': stream_id
        }
        
        self.tensor_cache[addr] = tensor_info
        return addr
        
    def get_tensor_info(self, address):
        """Get tensor metadata for VGPU operations"""
        return self.tensor_cache.get(address)
        virtual_address = self.next_virtual_address
        self.next_virtual_address += size_bytes # For simplicity, 1:1 virtual to physical for now

        self.allocated_blocks[virtual_address] = {
            'physical_address': physical_address,
            'size': size_bytes,
            'chip_id': chip_id,
            'data': [0] * size_bytes  # Simulate memory content
        }
        self.virtual_to_physical_map[virtual_address] = physical_address

        print(f"Allocated {size_bytes} bytes: Virtual Address {virtual_address} -> Physical Address {physical_address} on Chip {chip_id}.")
        return virtual_address

    def free(self, virtual_address):
        if virtual_address in self.allocated_blocks:
            block_info = self.allocated_blocks[virtual_address]
            size = block_info['size']
            chip_id = block_info['chip_id']
            physical_address = block_info['physical_address']

            del self.allocated_blocks[virtual_address]
            del self.virtual_to_physical_map[virtual_address]
            # In a real system, physical memory would be marked as free for reuse
            print(f"Freed {size} bytes: Virtual Address {virtual_address} (Physical {physical_address}) on Chip {chip_id}.")
        else:
            print(f"Warning: Attempted to free unallocated memory at virtual address {virtual_address}.")

    def write_data(self, virtual_address, data, chip_id=0):
        if virtual_address not in self.allocated_blocks or self.allocated_blocks[virtual_address]['chip_id'] != chip_id:
            raise ValueError(f"Memory at virtual address {virtual_address} on Chip {chip_id} not allocated or invalid.")
        block_info = self.allocated_blocks[virtual_address]
        physical_address = block_info['physical_address']
        allocated_size = block_info['size']
        if len(data) > allocated_size:
            raise ValueError(f"Data size ({len(data)}) exceeds allocated size ({allocated_size}) at virtual address {virtual_address}.")
        # Try to use v2 core MMU if available
        try:
            v2_core = self.hal.get_v2_core(chip_id)
            for i, byte_val in enumerate(data):
                v2_core.mmu.write(physical_address + i, [byte_val], v2_core.clk)
            print(f"[v2] Wrote {len(data)} bytes to v2 MMU at physical {physical_address} on Chip {chip_id}.")
        except Exception:
            # fallback to legacy global memory
            for i, byte_val in enumerate(data):
                self.hal.write_global_memory(chip_id, physical_address + i, byte_val)
            print(f"Wrote {len(data)} bytes to virtual address {virtual_address} (physical {physical_address}) on Chip {chip_id}.")
        block_info['data'][:len(data)] = data # Update simulated content

    def read_data(self, virtual_address, size_bytes, chip_id=0):
        if virtual_address not in self.allocated_blocks or self.allocated_blocks[virtual_address]['chip_id'] != chip_id:
            raise ValueError(f"Memory at virtual address {virtual_address} on Chip {chip_id} not allocated or invalid.")
        block_info = self.allocated_blocks[virtual_address]
        physical_address = block_info['physical_address']
        allocated_size = block_info['size']
        if size_bytes > allocated_size:
            raise ValueError(f"Read size ({size_bytes}) exceeds allocated size ({allocated_size}) at virtual address {virtual_address}.")
        # Try to use v2 core MMU if available
        read_data = []
        try:
            v2_core = self.hal.get_v2_core(chip_id)
            for i in range(size_bytes):
                val = v2_core.mmu.read(physical_address + i)
                read_data.append(val[0])
            print(f"[v2] Read {size_bytes} bytes from v2 MMU at physical {physical_address} on Chip {chip_id}.")
        except Exception:
            for i in range(size_bytes):
                read_data.append(self.hal.read_global_memory(chip_id, physical_address + i))
            print(f"Read {size_bytes} bytes from virtual address {virtual_address} (physical {physical_address}) on Chip {chip_id}.")
        block_info['data'][:size_bytes] = read_data # Update simulated content
        return read_data


    def get_physical_address(self, virtual_address):
        if virtual_address not in self.virtual_to_physical_map:
            raise ValueError(f"Virtual address {virtual_address} not mapped.")
        return self.virtual_to_physical_map[virtual_address]


