from streaming_multiprocessor import StreamingMultiprocessor
from typing import Dict, Any, List, Optional, Tuple
import time
from enum import Enum, auto
import logging
import numpy as np

class ChipType(Enum):
    """GPU Chip Types"""
    TENSOR = auto()      # Tensor-optimized mining chip
    COMPUTE = auto()     # General compute chip
    STREAM = auto()      # Stream processing optimized
    NETWORK = auto()     # Network/interconnect optimized

class OpticalInterconnect:
    def __init__(self, bandwidth_tbps=800, latency_ns=1):
        self.bandwidth_tbps = bandwidth_tbps
        self.latency_ns = latency_ns
        
    def transfer_time(self, data_size_bytes: int) -> float:
        """Calculate data transfer time in seconds"""
        bandwidth_bytes_per_s = self.bandwidth_tbps * 1e12
        return self.latency_ns * 1e-9 + (data_size_bytes / bandwidth_bytes_per_s)

class GPUChip:
    def __init__(self, chip_id: int, chip_type: ChipType = ChipType.TENSOR):
        self.chip_id = chip_id
        self.chip_type = chip_type
        
        # Operating parameters
        self.switch_freq = 9.80e14  # Hz (electron speed)
        self.block_size = 256  # SHA-256 block size
        
        # Minimal SM reference (no storage)
        self.primary_sm = StreamingMultiprocessor(0)
        
        # Operation tracking
        self.active = False
        self.ops_count = 0
        
    def store_chip_state(self):
        """Store chip state in local storage"""
        self.storage.store_state(f"chip_{self.chip_id}", "state", self.chip_state)
        
    def connect_chip(self, other_chip: 'GPUChip', interconnect: OpticalInterconnect) -> None:
        """Connect to another GPU chip via optical interconnect"""
        if not hasattr(self, 'connected_chips'):
            self.connected_chips = []
            self.chip_state['connected_chips'] = {}
            
        self.connected_chips.append((other_chip, interconnect))
        self.chip_state['connected_chips'][other_chip.chip_id] = {
            'bandwidth_tbps': interconnect.bandwidth_tbps,
            'latency_ns': interconnect.latency_ns,
            'active': True
        }
        self.store_chip_state()
        
    def transfer_data(self, target_chip: 'GPUChip', data_size: int) -> float:
        """Transfer data to another chip, returns transfer time in seconds"""
        for chip, interconnect in self.connected_chips:
            if chip.chip_id == target_chip.chip_id:
                transfer_time = interconnect.transfer_time(data_size)
                self.chip_state['pcie_state']['active_transfers'][str(time.time())] = {
                    'target_chip': target_chip.chip_id,
                    'size': data_size,
                    'estimated_time': transfer_time
                }
                self.store_chip_state()
                return transfer_time
        raise ValueError(f"No connection found to chip {target_chip.chip_id}")
        
    def allocate_memory(self, size: int, virtual_addr: Optional[str] = None) -> str:
        """Allocate memory through VRAM"""
        block_id = self.vram.allocate_block(size)
        if virtual_addr:
            self.vram.map_address(virtual_addr, block_id)
        
        # Update memory controller state
        self.chip_state["memory_controller"]["active_requests"][block_id] = {
            "type": "allocation",
            "size": size,
            "timestamp": time.time_ns()
        }
        self.store_chip_state()
        
        return block_id
        
    def transfer_to_device(self, data: bytes, virtual_addr: Optional[str] = None) -> str:
        """Transfer data to device through PCIe"""
        # Simulate PCIe transfer
        transfer_id = f"transfer_{time.time_ns()}"
        self.chip_state["pcie_state"]["active_transfers"][transfer_id] = {
            "direction": "to_device",
            "size": len(data),
            "timestamp": time.time_ns()
        }
        self.store_chip_state()
        
        # Allocate and store in VRAM
        block_id = self.allocate_memory(len(data), virtual_addr)
        self.storage.store_tensor(block_id, data)
        
        # Update transfer state
        self.chip_state["pcie_state"]["active_transfers"][transfer_id]["completed"] = True
        self.store_chip_state()
        
        return block_id
        
    def schedule_compute(self, sm_index: int, warp_state: Dict[str, Any]) -> str:
        """Schedule computation on an SM"""
        if 0 <= sm_index < len(self.sms):
            warp_id = f"warp_{time.time_ns()}"
            self.sms[sm_index].schedule_warp(warp_id, warp_state)
            
            # Update power state
            self.chip_state["power_state"]["sm_power"][sm_index] += 10  # Simulate power increase
            self.chip_state["power_state"]["total_watts"] = sum(self.chip_state["power_state"]["sm_power"])
            self.store_chip_state()
            
            return warp_id
        raise ValueError(f"Invalid SM index: {sm_index}")
        
    def get_stats(self) -> Dict[str, Any]:
        """Get comprehensive chip statistics"""
        stats = {
            "chip_id": self.chip_id,
            "vram": self.vram.get_stats(),
            "sms": [sm.get_stats() for sm in self.sms],
            "pcie": {
                "active_transfers": len(self.chip_state["pcie_state"]["active_transfers"]),
                "bandwidth_usage": self.chip_state["pcie_state"]["bandwidth_usage"]
            },
            "power": {
                "total_watts": self.chip_state["power_state"]["total_watts"],
                "vram_watts": self.chip_state["power_state"]["vram_power"]
            },
            "memory_controller": {
                "active_requests": len(self.chip_state["memory_controller"]["active_requests"]),
                "bandwidth_usage": self.chip_state["memory_controller"]["bandwidth_usage"]
            }
        }
        return stats
