from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Any
import time
import numpy as np
from datetime import datetime
from .remote_storage import RemoteStorageManager
from .ftl import AdvancedFTL

@dataclass
class QoSParameters:
    """Quality of Service parameters"""
    priority: int          # 0-7, higher is more important
    bandwidth_min: float   # Minimum guaranteed bandwidth in GB/s
    latency_max: float    # Maximum acceptable latency in microseconds
    bandwidth_weight: float # Weight for bandwidth allocation

@dataclass
class DMARequest:
    """DMA transfer request details"""
    source_addr: int
    dest_addr: int
    size: int
    priority: int
    is_async: bool
    callback: Optional[callable] = None

class PCIeInterface:
    PCIE_VERSIONS = {
        '4.0': {'bandwidth': 16.0, 'encoding': 128/130, 'base_latency': 0.5},
        '5.0': {'bandwidth': 32.0, 'encoding': 128/130, 'base_latency': 0.4},
        '6.0': {'bandwidth': 64.0, 'encoding': 242/256, 'base_latency': 0.3}
    }

    def __init__(self, version='6.0', lanes=16, max_gbps=None):
        self.version = version
        self.lanes = lanes
        self.spec = self.PCIE_VERSIONS[version]
        self.max_gbps = max_gbps or self.spec['bandwidth'] * lanes * self.spec['encoding']
        
        # Initialize storage components
        self.storage = RemoteStorageManager()
        self.total_vram = 16 * 1024 * 1024 * 1024  # 16GB default
        self.page_size = 4096  # 4KB pages
        self.block_size = 256 * self.page_size  # 1MB blocks
        
        # Initialize FTL
        total_blocks = self.total_vram // self.block_size
        pages_per_block = self.block_size // self.page_size
        self.ftl = AdvancedFTL(total_blocks=total_blocks, pages_per_block=pages_per_block)
        
        # Initialize interface state in remote storage
        self._init_interface_state()
        
        # Lane bonding and management
        self.active_lanes = lanes
        self.lane_groups: List[int] = self._initialize_lane_groups()
        self.lane_errors = [0] * lanes
        
        # QoS and bandwidth management
        self.active_transfers: Dict[int, DMARequest] = {}
        self.qos_profiles: Dict[int, QoSParameters] = {}
        self.bandwidth_allocations: Dict[int, float] = {}
        
        # DMA engine
        self.dma_queue: List[DMARequest] = []
        self.dma_active = False
        self.dma_batch_size = 1024 * 1024  # 1MB batches

    def _init_interface_state(self):
        """Initialize interface state in remote storage"""
        interface_state = {
            'version': self.version,
            'lanes': self.lanes,
            'max_gbps': self.max_gbps,
            'active_lanes': self.active_lanes,
            'lane_groups': self.lane_groups,
            'lane_errors': self.lane_errors,
            'qos_profiles': {},
            'bandwidth_allocations': {},
            'timestamp': datetime.now().isoformat()
        }
        
        # Store initial state
        self.storage.store_interface_state(interface_state)

    def _initialize_lane_groups(self) -> List[int]:
        """Initialize lane groups for bonding"""
        groups = []
        lanes_per_group = 4
        for i in range(0, self.lanes, lanes_per_group):
            groups.append(lanes_per_group)
        return groups

    def add_qos_profile(self, profile_id: int, params: QoSParameters):
        """Add or update QoS profile"""
        self.qos_profiles[profile_id] = params
        self._rebalance_bandwidth()

    def _rebalance_bandwidth(self):
        """Rebalance bandwidth allocations based on QoS profiles and log to remote DB"""
        total_weight = sum(p.bandwidth_weight for p in self.qos_profiles.values())
        available_bandwidth = self.max_gbps
        
        for profile_id, params in self.qos_profiles.items():
            # Ensure minimum bandwidth
            self.bandwidth_allocations[profile_id] = params.bandwidth_min
            available_bandwidth -= params.bandwidth_min
        
        # Distribute remaining bandwidth by weight
        if available_bandwidth > 0 and total_weight > 0:
            for profile_id, params in self.qos_profiles.items():
                extra = (params.bandwidth_weight / total_weight) * available_bandwidth
                self.bandwidth_allocations[profile_id] += extra
                
                # Log QoS metrics to remote storage
                qos_data = {
                    'timestamp': datetime.now().isoformat(),
                    'profile_id': profile_id,
                    'bandwidth_allocated': self.bandwidth_allocations[profile_id],
                    'bandwidth_used': 0.0,  # Will be updated as bandwidth is used
                    'latency_measured': 0.0,  # Will be updated as transfers occur
                    'latency_target': params.latency_max
                }
                self.storage.store_qos_metrics(qos_data)

    def _log_transfer(self, size_bytes: int, direction: str, qos_profile_id: Optional[int], 
                      transfer_time: float, bandwidth: float):
        """Log transfer details to remote storage"""
        transfer_data = {
            'timestamp': datetime.now().isoformat(),
            'size_bytes': size_bytes,
            'direction': direction,
            'qos_profile_id': qos_profile_id,
            'transfer_time': transfer_time,
            'lanes_active': self.active_lanes,
            'bandwidth_achieved': bandwidth
        }
        self.storage.store_transfer(transfer_data)

    def transfer_time(self, size_bytes: int, qos_profile_id: Optional[int] = None) -> float:
        """Calculate transfer time with QoS consideration"""
        # Get effective bandwidth based on QoS
        effective_bandwidth = self.max_gbps
        if qos_profile_id is not None and qos_profile_id in self.bandwidth_allocations:
            effective_bandwidth = self.bandwidth_allocations[qos_profile_id]
            
        # Calculate transfer time
        gb = size_bytes / 1e9
        transfer_time = gb / effective_bandwidth
        
        # Add encoding overhead
        transfer_time /= self.spec['encoding']
        
        # Add base latency
        total_time = transfer_time + self.spec['base_latency']
        
        # Log to remote DB
        self._log_transfer(size_bytes, 'calculate', qos_profile_id, total_time, effective_bandwidth)
        
        return total_time

    def initiate_dma_transfer(self, request: DMARequest) -> bool:
        """Initialize DMA transfer with QoS awareness"""
        self.dma_queue.append(request)
        if not self.dma_active:
            self._process_dma_queue()
        return True

    def _process_dma_queue(self):
        """Process DMA queue with QoS prioritization"""
        if not self.dma_queue:
            self.dma_active = False
            return
            
        self.dma_active = True
        # Sort by priority
        self.dma_queue.sort(key=lambda x: x.priority, reverse=True)
        
        while self.dma_queue:
            request = self.dma_queue[0]
            # Process in batches for better efficiency
            remaining = request.size
            while remaining > 0:
                batch_size = min(remaining, self.dma_batch_size)
                self._execute_dma_batch(request, batch_size)
                remaining -= batch_size
                
            if request.callback:
                request.callback()
            self.dma_queue.pop(0)
            
    def _execute_dma_batch(self, request: DMARequest, batch_size: int):
        """Execute a single DMA batch transfer with remote logging"""
        start_time = time.time()
        
        # Validate addresses using FTL
        source_phys = self.ftl.get_phys(request.source_addr // self.page_size)
        dest_phys = self.ftl.get_phys(request.dest_addr // self.page_size)
        
        if source_phys is None or dest_phys is None:
            raise RuntimeError("Invalid memory address in DMA transfer")
        
        transfer_time = self.transfer_time(batch_size)
        
        # Simulate DMA transfer
        time.sleep(transfer_time)
        
        # Log DMA operation to remote storage
        dma_data = {
            'timestamp': datetime.now().isoformat(),
            'source_addr': request.source_addr,
            'dest_addr': request.dest_addr,
            'size_bytes': batch_size,
            'priority': request.priority,
            'completion_time': time.time() - start_time,
            'status': 'completed'
        }
        self.storage.store_dma_operation(dma_data)

    def allocate_vram(self, size: int, qos: Optional[QoSParameters] = None) -> Optional[int]:
        """
        Allocate VRAM with optional QoS parameters
        Args:
            size: Size in bytes to allocate
            qos: Quality of Service parameters
        Returns:
            Virtual address or None if allocation fails
        """
        try:
            # Round up to nearest page size
            pages_needed = (size + self.page_size - 1) // self.page_size
            
            # Get a free block from FTL
            block_id = self.ftl.get_free_block()
            if block_id is None:
                # Try garbage collection
                self._run_garbage_collection()
                block_id = self.ftl.get_free_block()
                if block_id is None:
                    raise RuntimeError("Out of VRAM")
            
            # Calculate virtual address
            virt_addr = block_id * self.block_size
            
            # Map pages in FTL
            for i in range(pages_needed):
                lba = (virt_addr // self.page_size) + i
                phys = (block_id * self.ftl.pages_per_block) + i
                # Mark as hot if high priority QoS
                is_hot = qos and qos.priority >= 6
                self.ftl.map(lba, phys, is_hot)
            
            return virt_addr
            
        except Exception as e:
            self.storage.log_error("VRAM allocation failed", str(e))
            return None
    
    def free_vram(self, virt_addr: int, size: int) -> bool:
        """
        Free allocated VRAM
        Args:
            virt_addr: Virtual address to free
            size: Size in bytes to free
        Returns:
            True if successful
        """
        try:
            # Calculate pages to free
            start_page = virt_addr // self.page_size
            pages_to_free = (size + self.page_size - 1) // self.page_size
            
            # Invalidate pages in FTL
            for i in range(pages_to_free):
                lba = start_page + i
                phys = self.ftl.get_phys(lba)
                if phys is not None:
                    block_id = phys // self.ftl.pages_per_block
                    self.ftl.garbage_collect(block_id)
            
            return True
            
        except Exception as e:
            self.storage.log_error("VRAM free failed", str(e))
            return False
    
    def _run_garbage_collection(self) -> None:
        """Run garbage collection on VRAM blocks"""
        stats = self.ftl.get_stats()
        if stats.get('free_blocks', 0) > stats.get('total_blocks', 0) * 0.1:
            return  # Still enough free blocks
            
        # Find and collect blocks with most invalid pages
        for block in range(stats.get('total_blocks', 0)):
            self.ftl.garbage_collect(block)

    def get_vram_stats(self) -> Dict[str, Any]:
        """Get VRAM statistics"""
        ftl_stats = self.ftl.get_stats()
        stats = {
            "total_vram": self.total_vram,
            "page_size": self.page_size,
            "block_size": self.block_size,
            "used_blocks": ftl_stats.get('total_blocks', 0) - ftl_stats.get('free_blocks', 0),
            "free_blocks": ftl_stats.get('free_blocks', 0),
            "wear_leveling": ftl_stats.get('avg_erase_count', 0),
            "cache_hit_ratio": (
                ftl_stats.get('cache_hits', 0) / 
                max(ftl_stats.get('cache_hits', 0) + ftl_stats.get('cache_misses', 0), 1)
            ) * 100
        }
        
        # Add PCIe stats
        stats.update({
            "pcie_bandwidth": self.max_gbps,
            "active_lanes": self.active_lanes,
            "lane_errors": sum(self.lane_errors)
        })
        
        return stats

    def optimize_lanes(self) -> None:
        """Optimize lane configuration based on errors and performance"""
        error_threshold = 10
        for i, errors in enumerate(self.lane_errors):
            if errors > error_threshold:
                self._disable_lane(i)
                self._rebalance_lanes()

    def _disable_lane(self, lane_idx: int) -> None:
        """Disable a problematic lane"""
        group_idx = lane_idx // 4
        if group_idx < len(self.lane_groups):
            self.lane_groups[group_idx] -= 1
            self.active_lanes -= 1
            self._update_max_bandwidth()

    def _update_max_bandwidth(self) -> None:
        """Update maximum bandwidth based on active lanes"""
        lane_bandwidth = self.PCIE_VERSIONS[self.version]['bandwidth']
        self.max_gbps = lane_bandwidth * self.active_lanes * self.spec['encoding']
        self._rebalance_bandwidth()
