from typing import Dict, Any, Optional, List
import time
import numpy as np
from electron_speed import max_switch_freq, drift_velocity, transit_time

# Electron physics constants
ELECTRON_FREQ = max_switch_freq  # ~9.80e14 Hz
ELECTRON_VELOCITY = drift_velocity  # m/s
GATE_DELAY = transit_time  # seconds
CYCLE_TIME = 1.0 / ELECTRON_FREQ

class StreamingMultiprocessor:
    """Pure electron-speed SM implementation"""
    def __init__(self, sm_id: int):
        # Core parameters
        self.sm_id = sm_id
        self.switch_freq = ELECTRON_FREQ
        self.gate_delay = GATE_DELAY
        self.drift_speed = ELECTRON_VELOCITY
        self.cycle_time = CYCLE_TIME
        
        # Operation tracking at electron speed
        self.op_cycles = {}
        self.op_timing = {}
        
    def process_block(self, block_data: bytes) -> dict:
        """Process data at electron switching frequency"""
        start_time = time.time()
        
        # Direct electron computation
        cycles = len(block_data) * 8  # Bit-level processing
        compute_time = cycles * self.gate_delay
        result = int.from_bytes(block_data, 'big')
        
        # Track electron-speed metrics
        compute_state = {
            'result': result,
            'cycles': cycles,
            'frequency': self.switch_freq,
            'timing': time.time() - start_time,
            'throughput': len(block_data) / compute_time
        }
        
        return compute_state
        
    def acquire_compute(self, op_id: str, op_info: Dict[str, Any]) -> dict:
        """Request electron-speed computation"""
        start_time = time.time()
        
        # Track operation at electron speed
        self.op_cycles[op_id] = 0
        self.op_timing[op_id] = {
            'start': start_time,
            'frequency': self.switch_freq,
            'gate_delay': self.gate_delay
        }
        
        compute_state = {
            'op_id': op_id,
            'sm_id': self.sm_id,
            'cycles': 0,
            'status': 'initialized',
            'timing': time.time() - start_time
        }
        
        return compute_state
    def release_compute(self, op_id: str) -> dict:
        """Complete electron-speed computation"""
        end_time = time.time()
        
        # Track final electron metrics
        if op_id in self.op_timing:
            start_time = self.op_timing[op_id]['start']
            cycles = self.op_cycles.get(op_id, 0)
            
            compute_metrics = {
                'op_id': op_id,
                'sm_id': self.sm_id,
                'cycles': cycles, 
                'frequency': self.switch_freq,
                'total_time': end_time - start_time,
                'gate_delays': cycles * self.gate_delay,
                'status': 'completed'
            }
            
            # Cleanup operation tracking
            del self.op_cycles[op_id]
            del self.op_timing[op_id]
            
            return compute_metrics
            
        return {
            'op_id': op_id,
            'sm_id': self.sm_id,
            'status': 'not_found'
        }
        self.matrix_ops_lock = threading.Lock()
        
        # Initialize scheduler
        self.matrix_scheduler = MatrixOpScheduler()
        
        # Generate unique SM key
        self.sm_key = hashlib.md5(f"{self.chip_id}_{self.sm_id}_{time.time_ns()}".encode()).hexdigest()
        
        # Initialize SM state
        self.sm_state = {
            "sm_id": sm_id,
            "chip_id": chip_id,
            "tensor_cores": {
                "count": 8,  # Number of tensor cores per SM
                "status": ["idle"] * 8
            },
            "tensor_operations": {},
            "memory_state": {
                "allocated": 0,
                "available": 1024 * 1024 * 1024  # 1GB per SM
            },
            "shared_memory": {},  # For matrix operations
            "storage_state": {     # Add storage tracking
                "last_sync": time.time_ns(),
                "cache_hits": 0,
                "cache_misses": 0
            },
            "warp_scheduler_state": {  # Add scheduler state
                "active_warps": [],
                "completed_warps": [],
                "blocked_warps": {},
                "warp_priorities": {},
                "warp_dependencies": {}
            }
        }
        self.store_sm_state()

    def store_sm_state(self):
        """Store SM state in local storage"""
        with self.state_lock:
            self.storage.cursor.execute("""
                INSERT OR REPLACE INTO sm_states 
                (sm_id, chip_id, state_json) 
                VALUES (?, ?, ?)
            """, (str(self.sm_id), str(self.chip_id), json.dumps(self.sm_state)))
            self.storage.conn.commit()

    def tensor_core_matmul(self, A: np.ndarray, B: np.ndarray, tensor_core_id: int = 0) -> Optional[np.ndarray]:
        """Execute matrix multiplication on tensor core"""
        op_id = f"tensor_op_{time.time_ns()}"
        
        with self.matrix_ops_lock:
            # Check tensor core availability
            if tensor_core_id >= self.sm_state["tensor_cores"]["count"]:
                logging.error(f"Invalid tensor core ID: {tensor_core_id}")
                return None
                
            try:
                # Update operation state
                self.sm_state["tensor_operations"][op_id] = {
                    "type": "matmul",
                    "tensor_core_id": tensor_core_id,
                    "status": "running",
                    "start_time": time.time()
                }
                self.store_sm_state()
                
                # Execute matrix multiplication
                result = np.matmul(A, B)
                
                # Update operation status
                self.sm_state["tensor_operations"][op_id]["status"] = "completed"
                self.sm_state["tensor_operations"][op_id]["end_time"] = time.time()
                self.store_sm_state()
                
                return result
                
            except Exception as e:
                logging.error(f"Tensor core matmul failed: {str(e)}")
                if op_id in self.sm_state["tensor_operations"]:
                    self.sm_state["tensor_operations"][op_id]["status"] = "failed"
                    self.sm_state["tensor_operations"][op_id]["error"] = str(e)
                    self.store_sm_state()
                return None
                
    def read_matrix_from_shared_memory(self, addr: int, n: int, m: int) -> np.ndarray:
        """Read a matrix from shared memory"""
        matrix = np.zeros((n, m))
        for i in range(n):
            for j in range(m):
                key = f"{addr + i * m + j}"
                matrix[i, j] = self.sm_state["shared_memory"].get(key, 0.0)
        return matrix
        
    def write_matrix_to_shared_memory(self, addr: int, matrix: np.ndarray) -> None:
        """Write a matrix to shared memory"""
        n, m = matrix.shape
        for i in range(n):
            for j in range(m):
                key = f"{addr + i * m + j}"
                self.sm_state["shared_memory"][key] = float(matrix[i, j])
        self.store_sm_state()
                
    def tensor_core_matmul_from_memory(self, addr_A: int, shape_A: tuple, 
                                     addr_B: int, shape_B: tuple,
                                     addr_C: int, tensor_core_id: int = 0) -> bool:
        """Execute matrix multiplication using data from shared memory"""
        try:
            # Read input matrices
            A = self.read_matrix_from_shared_memory(addr_A, *shape_A)
            B = self.read_matrix_from_shared_memory(addr_B, *shape_B)
            
            # Perform multiplication
            C = self.tensor_core_matmul(A, B, tensor_core_id=tensor_core_id)
            if C is None:
                return False
                
            # Write result
            self.write_matrix_to_shared_memory(addr_C, C)
            return True
            
        except Exception as e:
            logging.error(f"Tensor core matmul from memory failed: {str(e)}")
            return False
            
    def tensor_core_matmul(self, A: np.ndarray, B: np.ndarray, tensor_core_id: int = 0) -> Optional[np.ndarray]:
        """Execute matrix multiplication on tensor core"""
        op_id = f"tensor_op_{time.time_ns()}"
        
        with self.matrix_ops_lock:
            # Check tensor core availability
            if tensor_core_id >= self.sm_state["tensor_cores"]["count"]:
                logging.error(f"Invalid tensor core ID: {tensor_core_id}")
                return None
                
            try:
                # Update operation state
                self.sm_state["tensor_operations"][op_id] = {
                    "type": "matmul",
                    "tensor_core_id": tensor_core_id,
                    "status": "running",
                    "start_time": time.time()
                }
                self.store_sm_state()
                
                # Execute matrix multiplication
                result = np.matmul(A, B)
                
                # Update operation status
                self.sm_state["tensor_operations"][op_id]["status"] = "completed"
                self.sm_state["tensor_operations"][op_id]["end_time"] = time.time()
                self.store_sm_state()
                
                return result
                
            except Exception as e:
                logging.error(f"Tensor core matmul failed: {str(e)}")
                if op_id in self.sm_state["tensor_operations"]:
                    self.sm_state["tensor_operations"][op_id]["status"] = "failed"
                    self.sm_state["tensor_operations"][op_id]["error"] = str(e)
                    self.store_sm_state()
                return None
                
    def read_matrix_from_shared_memory(self, addr: int, n: int, m: int) -> np.ndarray:
        """Read a matrix from shared memory"""
        matrix = np.zeros((n, m))
        for i in range(n):
            for j in range(m):
                key = f"{addr + i * m + j}"
                matrix[i, j] = self.sm_state["shared_memory"].get(key, 0.0)
        return matrix
        
    def write_matrix_to_shared_memory(self, addr: int, matrix: np.ndarray) -> None:
        """Write a matrix to shared memory"""
        n, m = matrix.shape
        for i in range(n):
            for j in range(m):
                key = f"{addr + i * m + j}"
                self.sm_state["shared_memory"][key] = float(matrix[i, j])
        self.store_sm_state()
                
    def tensor_core_matmul_from_memory(self, addr_A: int, shape_A: tuple, 
                                     addr_B: int, shape_B: tuple,
                                     addr_C: int, tensor_core_id: int = 0,
                                     warp_id: Optional[str] = None) -> bool:
        """Execute matrix multiplication using data from shared memory with enhanced tracking"""
        try:
            # Schedule the operation
            op_metadata = self.matrix_op_scheduler.schedule_operation(
                op_type="matmul",
                input_shapes=[shape_A, shape_B],
                warp_id=warp_id
            )
            
            if op_metadata is None:
                logging.error("Failed to schedule matrix operation - resources unavailable")
                return False
                
            try:
                # Read input matrices
                A = self.read_matrix_from_shared_memory(addr_A, *shape_A)
                B = self.read_matrix_from_shared_memory(addr_B, *shape_B)
                
                # Acquire matrix operation lock
                if not self.matrix_op_lock.acquire_matrix_op(op_metadata.op_id, {
                    "type": "matmul",
                    "input_shapes": [shape_A, shape_B],
                    "warp_id": warp_id,
                    "tensor_core_id": tensor_core_id
                }):
                    raise RuntimeError("Failed to acquire matrix operation lock")
                    
                try:
                    # Perform multiplication with tensor core
                    C = self.tensor_core_matmul(A, B, tensor_core_id, warp_id)
                    if C is None:
                        raise RuntimeError("Matrix multiplication failed")
                        
                    # Write result
                    self.write_matrix_to_shared_memory(addr_C, C)
                    
                    # Complete operation successfully
                    self.matrix_op_scheduler.complete_operation(
                        op_metadata,
                        output_shape=C.shape,
                        success=True
                    )
                    
                    # Update operation history
                    self.tensor_op_history.append({
                        "op_id": op_metadata.op_id,
                        "type": "matmul",
                        "input_shapes": [shape_A, shape_B],
                        "output_shape": C.shape,
                        "warp_id": warp_id,
                        "tensor_core_id": tensor_core_id,
                        "start_time": op_metadata.start_time,
                        "end_time": time.time_ns(),
                        "status": "completed"
                    })
                    
                    return True
                    
                finally:
                    # Always release the matrix operation lock
                    self.matrix_op_lock.release_matrix_op(op_metadata.op_id)
                    
            except Exception as e:
                # Handle operation failure
                self.matrix_op_scheduler.complete_operation(
                    op_metadata,
                    output_shape=None,
                    success=False,
                    error=str(e)
                )
                raise
                
        except Exception as e:
            logging.error(f"Tensor core matmul from memory failed: {str(e)}")
            return False
        
    def store_sm_state(self):
        """Store SM state in storage"""
        with self.state_lock:
            try:
                # Store state using SQLite
                self.storage.cursor.execute("""
                    INSERT OR REPLACE INTO sm_states 
                    (sm_id, chip_id, state_json, sm_key, timestamp) 
                    VALUES (?, ?, ?, ?, ?)
                """, (
                    str(self.sm_id),
                    str(self.chip_id),
                    json.dumps(self.sm_state),
                    self.sm_key,
                    time.time_ns()
                ))
                self.storage.conn.commit()
                
                # Update last sync time
                self.sm_state["storage_state"]["last_sync"] = time.time_ns()
                return True
                
            except Exception as e:
                logging.error(f"Error storing SM state: {str(e)}")
                return False
            
    def allocate_shared_memory(self, size: int, block_id: str) -> str:
        """Allocate shared memory block in remote storage"""
        shared_id = f"shared_{self.chip_id}_{self.sm_id}_{block_id}_{time.time_ns()}"
        
        with self.state_lock:
            # Create memory block metadata
            memory_block = {
                "size": size,
                "block_id": block_id,
                "allocated_at": time.time_ns(),
                "sm_key": self.sm_key,
                "shared_id": shared_id
            }
            
            # Store metadata in SM state and remote storage
            self.sm_state["shared_memory"][shared_id] = memory_block
            
            try:
                # Store initial empty tensor to reserve the space
                empty_tensor = np.zeros(size, dtype=np.float32)
                self.storage.store_tensor(shared_id, empty_tensor, {
                    "sm_key": self.sm_key,
                    "block_id": block_id,
                    "allocated_at": time.time_ns(),
                    "size": size,
                    "status": "allocated"
                })
                
                # Update SM state in storage
                self.store_sm_state()
                return shared_id
                
            except Exception as e:
                # Cleanup on failure
                del self.sm_state["shared_memory"][shared_id]
                logging.error(f"Failed to allocate shared memory: {str(e)}")
                raise RuntimeError(f"Shared memory allocation failed: {str(e)}")
        
    def write_shared_memory(self, shared_id: str, data: np.ndarray):
        """Write to shared memory using remote storage"""
        with self.state_lock:
            if shared_id not in self.sm_state["shared_memory"]:
                raise ValueError(f"Shared memory block {shared_id} not allocated")
                
            try:
                # Store data with metadata
                success = self.storage.store_tensor(shared_id, data, {
                    "sm_key": self.sm_key,
                    "block_id": self.sm_state["shared_memory"][shared_id]["block_id"],
                    "last_write": time.time_ns(),
                    "shape": data.shape,
                    "dtype": str(data.dtype),
                    "status": "written"
                })
                
                if not success:
                    raise RuntimeError("Failed to store tensor data")
                
                # Update access timestamp and state
                self.sm_state["shared_memory"][shared_id]["last_accessed"] = time.time_ns()
                self.sm_state["shared_memory"][shared_id]["last_write"] = time.time_ns()
                self.store_sm_state()
                
                return True
                
            except Exception as e:
                logging.error(f"Error writing to shared memory: {str(e)}")
                return False
            
    def read_shared_memory(self, shared_id: str) -> Optional[np.ndarray]:
        """Read from shared memory using remote storage"""
        with self.state_lock:
            if shared_id not in self.sm_state["shared_memory"]:
                raise ValueError(f"Shared memory block {shared_id} not allocated")
                
            try:
                # Read from remote storage
                result = self.storage.load_tensor(shared_id)
                
                if result is not None:
                    data, metadata = result
                    # Update cache hit/miss stats
                    self.sm_state["storage_state"]["cache_hits"] += 1
                    # Update access timestamp
                    self.sm_state["shared_memory"][shared_id]["last_accessed"] = time.time_ns()
                    return data
                else:
                    self.sm_state["storage_state"]["cache_misses"] += 1
                    logging.warning(f"Cache miss for shared memory block {shared_id}")
                    return None
                    
            except Exception as e:
                logging.error(f"Error reading from shared memory: {str(e)}")
                return None
                
            finally:
                # Always update access timestamp and state
                self.sm_state["shared_memory"][shared_id]["last_accessed"] = time.time_ns()
                self.store_sm_state()
        
    def schedule_warp(self, warp_id: str, warp_state: Dict[str, Any]):
        """Schedule a warp for execution with enhanced state tracking and resource management"""
        with self.state_lock:
            # Generate unique storage key for warp
            warp_key = f"warp_{self.chip_id}_{self.sm_id}_{warp_id}_{time.time_ns()}"
            
            try:
                # Check resource availability and dependencies
                resource_state = self._check_warp_resources(warp_id, warp_state)
                if not resource_state['available']:
                    logging.warning(f"Resources not available for warp {warp_id}: {resource_state['reason']}")
                    self.sm_state["warp_scheduler_state"]["blocked_warps"][warp_id] = {
                        "reason": resource_state['reason'],
                        "blocking_resources": resource_state['blocking_resources'],
                        "timestamp": time.time_ns()
                    }
                    return False
                    
                # Check for dependencies
                dependencies = warp_state.get('dependencies', [])
                if dependencies:
                    for dep_id in dependencies:
                        if dep_id not in self.sm_state["warp_scheduler_state"]["completed_warps"]:
                            self.sm_state["warp_scheduler_state"]["warp_dependencies"][warp_id] = dependencies
                            logging.info(f"Warp {warp_id} waiting for dependencies: {dependencies}")
                            return False
                            
                # Prepare enhanced warp state with resource tracking
                enhanced_warp_state = {
                    **warp_state,
                    "warp_key": warp_key,
                    "scheduled_at": time.time_ns(),
                    "resources": resource_state['allocated_resources'],
                    "priority": warp_state.get('priority', 0),
                    "expected_duration": warp_state.get('expected_duration'),
                    "matrix_ops": [],
                    "sync_points": []
                }
                
                # Store state in remote storage with resource metadata
                success = self.storage.store_state(
                    component=f"warp_{self.chip_id}_{self.sm_id}",
                    state_id=warp_key,
                    state_data={
                        "warp_id": warp_id,
                        "warp_state": enhanced_warp_state,
                        "sm_key": self.sm_key,
                        "scheduled_at": time.time_ns(),
                        "status": "scheduled",
                        "resource_state": resource_state
                    }
                )
                
                if not success:
                    raise RuntimeError("Failed to store warp state")
                
                # Update scheduler state
                self.sm_state["warp_scheduler_state"]["active_warps"].append(warp_id)
                self.sm_state["warp_scheduler_state"]["warp_priorities"][warp_id] = enhanced_warp_state["priority"]
                
                # Update active warps with resource tracking
                self.sm_state["active_warps"][warp_id] = enhanced_warp_state
                
                # Clear any blocked state
                if warp_id in self.sm_state["warp_scheduler_state"]["blocked_warps"]:
                    del self.sm_state["warp_scheduler_state"]["blocked_warps"][warp_id]
                
                # Update SM state in storage
                self.store_sm_state()
                logging.info(f"Successfully scheduled warp {warp_id} with priority {enhanced_warp_state['priority']}")
                return True
                
            except Exception as e:
                logging.error(f"Error scheduling warp {warp_id}: {str(e)}")
                # Cleanup on failure
                if warp_id in self.sm_state["active_warps"]:
                    del self.sm_state["active_warps"][warp_id]
                if warp_id in self.sm_state["warp_scheduler_state"]["active_warps"]:
                    self.sm_state["warp_scheduler_state"]["active_warps"].remove(warp_id)
                if warp_id in self.sm_state["warp_scheduler_state"]["warp_priorities"]:
                    del self.sm_state["warp_scheduler_state"]["warp_priorities"][warp_id]
                return False
                
    def _check_warp_resources(self, warp_id: str, warp_state: Dict[str, Any]) -> Dict[str, Any]:
        """Check and allocate resources for a warp"""
        needed_resources = warp_state.get('resource_requirements', {})
        
        # Check tensor core availability
        if 'tensor_cores' in needed_resources:
            num_cores_needed = needed_resources['tensor_cores']
            available_cores = self.sm_state["tensor_cores"]["count"] - len(self.sm_state["tensor_cores"]["current_ops"])
            if available_cores < num_cores_needed:
                return {
                    'available': False,
                    'reason': 'insufficient_tensor_cores',
                    'blocking_resources': {'tensor_cores': num_cores_needed - available_cores}
                }
                
        # Check shared memory availability
        if 'shared_memory' in needed_resources:
            memory_needed = needed_resources['shared_memory']
            memory_used = sum(self.sm_state["matrix_operations"]["resource_usage"]["shared_memory_usage"].values())
            if memory_used + memory_needed > self._get_max_shared_memory():
                return {
                    'available': False,
                    'reason': 'insufficient_shared_memory',
                    'blocking_resources': {'shared_memory': memory_needed}
                }
                
        # All resources available, allocate them
        allocated_resources = {
            'tensor_cores': [],  # Will be filled when actually used
            'shared_memory': 0,  # Will be updated when memory is actually allocated
            'allocation_time': time.time_ns()
        }
        
        return {
            'available': True,
            'allocated_resources': allocated_resources,
            'allocation_id': f"alloc_{warp_id}_{time.time_ns()}"
        }

    def complete_warp(self, warp_id: str):
        """Mark a warp as completed using remote storage"""
        with self.state_lock:
            if warp_id in self.sm_state["active_warps"]:
                try:
                    # Get warp state and key
                    warp_state = self.sm_state["active_warps"][warp_id]
                    warp_key = warp_state.get("warp_key")
                    
                    if warp_key:
                        # Update warp state in storage
                        success = self.storage.store_state(
                            component=f"warp_{self.chip_id}_{self.sm_id}",
                            state_id=warp_key,
                            state_data={
                                "warp_id": warp_id,
                                "warp_state": warp_state,
                                "sm_key": self.sm_key,
                                "completed_at": time.time_ns(),
                                "status": "completed"
                            }
                        )
                        
                        if not success:
                            logging.error(f"Failed to store completed state for warp {warp_id}")
                    
                    # Update local state
                    self.sm_state["warp_scheduler_state"]["active_warps"].remove(warp_id)
                    self.sm_state["warp_scheduler_state"]["completed_warps"].append(warp_id)
                    self.sm_state["active_warps"].pop(warp_id)
                    
                    # Update SM state
                    self.store_sm_state()
                    return True
                    
                except Exception as e:
                    logging.error(f"Error completing warp {warp_id}: {str(e)}")
                    return False
            
            return False
            
    def write_register(self, warp_id: str, reg_id: str, data: np.ndarray):
        """Write to register file using remote storage"""
        reg_key = f"reg_{self.chip_id}_{self.sm_id}_{warp_id}_{reg_id}_{time.time_ns()}"
        
        try:
            # Store register data with metadata
            success = self.storage.store_tensor(reg_key, data, {
                "warp_id": warp_id,
                "reg_id": reg_id,
                "sm_key": self.sm_key,
                "chip_id": self.chip_id,
                "written_at": time.time_ns(),
                "shape": data.shape,
                "dtype": str(data.dtype)
            })
            
            if success:
                # Update register file state
                self.sm_state["register_file"][reg_key] = {
                    "warp_id": warp_id,
                    "reg_id": reg_id,
                    "last_accessed": time.time_ns(),
                    "storage_key": reg_key
                }
                self.store_sm_state()
                return True
                
            return False
            
        except Exception as e:
            logging.error(f"Error writing to register {reg_id} for warp {warp_id}: {str(e)}")
            return False
        
    def read_register(self, warp_id: str, reg_id: str) -> Optional[np.ndarray]:
        """Read from register file using remote storage"""
        # Find the latest register key for this warp/reg combination
        reg_keys = [k for k in self.sm_state["register_file"].keys() 
                   if k.startswith(f"reg_{self.chip_id}_{self.sm_id}_{warp_id}_{reg_id}")]
        
        if not reg_keys:
            return None
            
        # Get the latest register key
        latest_key = max(reg_keys, key=lambda k: self.sm_state["register_file"][k]["last_accessed"])
        
        try:
            # Read from storage
            result = self.storage.load_tensor(latest_key)
            
            if result is not None:
                data, metadata = result
                # Update access timestamp
                self.sm_state["register_file"][latest_key]["last_accessed"] = time.time_ns()
                self.store_sm_state()
                return data
                
            return None
            
        except Exception as e:
            logging.error(f"Error reading register {reg_id} for warp {warp_id}: {str(e)}")
            return None
        
    def get_stats(self) -> Dict[str, Any]:
        """Get SM statistics"""
        return {
            "sm_id": self.sm_id,
            "num_cores": self.num_cores,
            "active_warps": len(self.sm_state["active_warps"]),
            "shared_memory_blocks": len(self.sm_state["shared_memory"]),
            "register_file_entries": len(self.sm_state["register_file"]),
            "completed_warps": len(self.sm_state["warp_scheduler_state"]["completed_warps"])
        }
