"""
Matrix Operations Scheduler for GPU Processing
"""
from typing import Dict, Any, List, Optional
import numpy as np
import time


class MatrixOpMetadata:
    def __init__(self, op_type: str, input_shape: tuple, output_shape: tuple):
        self.op_type = op_type
        self.input_shape = input_shape
        self.output_shape = output_shape
        self.timestamp = time.time()
        self.compute_cycles = 0
        self.memory_accesses = 0

    def estimate_compute_cycles(self) -> int:
        """Estimate number of compute cycles needed based on operation type and shapes"""
        if self.op_type == "matmul":
            m, n = self.input_shape[0], self.output_shape[1]
            k = self.input_shape[1]  # Inner dimension
            return m * n * k  # One cycle per multiply-add
        elif self.op_type in ["add", "sub", "mul", "div"]:
            elements = np.prod(self.input_shape)
            return elements  # One cycle per element
        return 0

    def estimate_memory_accesses(self) -> int:
        """Estimate number of memory accesses needed"""
        if self.op_type == "matmul":
            m, n = self.input_shape[0], self.output_shape[1]
            k = self.input_shape[1]
            # Read each input element once, write each output once
            return m*k + k*n + m*n
        elif self.op_type in ["add", "sub", "mul", "div"]:
            elements = np.prod(self.input_shape)
            return elements * 2  # Read input + write output
        return 0


class MatrixOpScheduler:
    def __init__(self):
        self.pending_ops: List[MatrixOpMetadata] = []
        self.completed_ops: List[MatrixOpMetadata] = []
        self.current_op: Optional[MatrixOpMetadata] = None
        self.stats = {
            "total_compute_cycles": 0,
            "total_memory_accesses": 0,
            "ops_completed": 0
        }

    def schedule_op(self, op: MatrixOpMetadata) -> None:
        """Add a new matrix operation to the scheduler queue"""
        self.pending_ops.append(op)
        op.compute_cycles = op.estimate_compute_cycles()
        op.memory_accesses = op.estimate_memory_accesses()

    def get_next_op(self) -> Optional[MatrixOpMetadata]:
        """Get the next operation to process"""
        if not self.pending_ops:
            return None
        return self.pending_ops.pop(0)

    def complete_current_op(self) -> None:
        """Mark the current operation as complete and update stats"""
        if self.current_op:
            self.completed_ops.append(self.current_op)
            self.stats["total_compute_cycles"] += self.current_op.compute_cycles
            self.stats["total_memory_accesses"] += self.current_op.memory_accesses
            self.stats["ops_completed"] += 1
            self.current_op = None

    def get_stats(self) -> Dict[str, Any]:
        """Get current scheduler statistics"""
        return {
            **self.stats,
            "pending_ops": len(self.pending_ops),
            "completed_ops": len(self.completed_ops)
        }
