from enum import Enum
import re
from typing import Dict, List, Optional, Union
import hashlib
import duckdb
import json
from config import get_db_url, get_hf_token_cached
import logging

class ShaderType(Enum):
    VERTEX = "vertex"
    FRAGMENT = "fragment"
    COMPUTE = "compute"

class ShaderError(Exception):
    pass

class ShaderCompilerDB:
    DB_URL = "hf://datasets/Fred808/helium/storage.json"

    def __init__(self, db_url: Optional[str] = None):
        """Initialize shader compiler database"""
        self.db_url = db_url or self.DB_URL
        self.max_retries = 3
        self._connect_with_retries()

    def _connect_with_retries(self):
        """Establish database connection with retry logic"""
        for attempt in range(self.max_retries):
            try:
                self._init_db_connection()
                self._setup_database()
                return
            except Exception as e:
                if attempt == self.max_retries - 1:
                    raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}")
                logging.warning(f"Database connection attempt {attempt + 1} failed, retrying...")

    def _init_db_connection(self):
        """Initialize database connection with HuggingFace configuration"""
        # Convert HF URL to S3 path and connect directly
        _, _, owner, dataset, db_file = self.db_url.split('/', 4)
        db_path = f"s3://datasets-cached/{owner}/{dataset}/{db_file}"
        
        # Connect directly to remote database
        self.conn = duckdb.connect(db_path)
        self.conn.execute("""
            INSTALL httpfs;
            LOAD httpfs;
            SET s3_region='us-east-1';
            SET s3_endpoint='s3.us-east-1.amazonaws.com';
            SET s3_url_style='path';
            SET s3_access_key_id='none';
            SET s3_secret_access_key=?;
        """, [self.HF_TOKEN])

    def ensure_connection(self):
        """Ensure database connection is active and valid"""
        try:
            self.conn.execute("SELECT 1")
        except:
            logging.warning("Database connection lost, attempting to reconnect...")
            self._connect_with_retries()

    def _setup_database(self):
        """Set up shader compiler tables"""
        self.ensure_connection()
        # Shader storage
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS shaders (
                id VARCHAR PRIMARY KEY,
                type VARCHAR,
                source TEXT,
                variables JSON,
                instructions JSON
            )
        """)
        
        # Program storage (linked shaders)
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS programs (
                id VARCHAR PRIMARY KEY,
                vertex_shader_id VARCHAR,
                fragment_shader_id VARCHAR,
                uniforms JSON,
                attributes JSON,
                varyings JSON,
                linked BOOLEAN
            )
        """)
        
        # Register allocation table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS registers (
                shader_id VARCHAR,
                var_name VARCHAR,
                register_name VARCHAR,
                PRIMARY KEY (shader_id, var_name)
            )
        """)
        
        # Instructions table with optimization metadata
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS instructions (
                id INTEGER PRIMARY KEY,
                shader_id VARCHAR,
                opcode VARCHAR,
                args JSON,
                result VARCHAR,
                is_dead BOOLEAN DEFAULT FALSE,
                depends_on JSON
            )
        """)
        
        self.conn.commit()

class Instruction:
    def __init__(self, opcode: str, args: List[str], result: Optional[str] = None):
        self.opcode = opcode
        self.args = args
        self.result = result
        self.id = None  # Will be set when stored in DB

    def __str__(self):
        result_str = f"{self.result} = " if self.result else ""
        return f"{result_str}{self.opcode} {', '.join(self.args)}"
        
    def to_dict(self):
        return {
            "opcode": self.opcode,
            "args": self.args,
            "result": self.result
        }
        
    @classmethod
    def from_dict(cls, data):
        return cls(data["opcode"], data["args"], data["result"])

class Variable:
    def __init__(self, name: str, var_type: str, is_input: bool = False, is_output: bool = False):
        self.name = name
        self.type = var_type
        self.is_input = is_input
        self.is_output = is_output
        self.register = None  # Will be assigned during register allocation
        
    def to_dict(self):
        return {
            "name": self.name,
            "type": self.type,
            "is_input": self.is_input,
            "is_output": self.is_output,
            "register": self.register
        }
        
    @classmethod
    def from_dict(cls, data):
        var = cls(data["name"], data["type"], data["is_input"], data["is_output"])
        var.register = data["register"]
        return var

class ShaderCompiler:
    def __init__(self, db_url: Optional[str] = None):
        # Initialize database
        self.db = ShaderCompilerDB(db_url)
        self.temp_counter = 0
        
        # Initialize operation mappings
        self.vector_ops = {
            '+': 'add', '-': 'sub', '*': 'mul', '/': 'div',
            'dot': 'dot', 'cross': 'cross', 'normalize': 'normalize'
        }
        
        # Built-in functions and their implementations
        self.built_ins = {
            'texture2D': self._compile_texture2D,
            'normalize': self._compile_normalize,
            'dot': self._compile_dot,
            'mix': self._compile_mix,
            'clamp': self._compile_clamp
        }
        
        # Current compilation context
        self.current_shader_id = None

    def compile_shader(self, shader_source: str, shader_type: Union[str, ShaderType]) -> dict:
        """Compile a shader from source code into virtual GPU instructions."""
        if isinstance(shader_type, str):
            shader_type = ShaderType(shader_type)
        
        try:
            # Generate shader ID and set as current
            self.current_shader_id = self._generate_shader_id(shader_source)
            self.temp_counter = 0
            
            # Parse input/output variables
            variables = self._parse_interface_variables(shader_source, shader_type)
            
            # Parse and compile the main function
            instructions = self._parse_main_function(shader_source)
            
            # Store initial instructions in DB
            self._store_instructions(instructions)
            
            # Perform optimizations
            self._optimize_instructions()
            
            # Allocate registers
            self._allocate_registers(variables)
            
            # Store shader in database
            self.db.conn.execute("""
                INSERT INTO shaders (id, type, source, variables, instructions)
                VALUES (?, ?, ?, ?, ?)
            """, (
                self.current_shader_id,
                shader_type.value,
                shader_source,
                json.dumps({name: var.to_dict() for name, var in variables.items()}),
                json.dumps([instr.to_dict() for instr in instructions])
            ))
            
            # Fetch the complete compiled program
            result = self.db.conn.execute("""
                SELECT type, source, variables, instructions
                FROM shaders WHERE id = ?
            """, [self.current_shader_id]).fetchone()
            
            compiled_program = {
                "id": self.current_shader_id,
                "type": result[0],
                "source": result[1],
                "variables": json.loads(result[2]),
                "instructions": json.loads(result[3])
            }
            
            self.db.conn.commit()
            return compiled_program
            
        except Exception as e:
            raise ShaderError(f"Compilation failed: {str(e)}")

    def _parse_interface_variables(self, source: str, shader_type: ShaderType):
        """Parse input and output variable declarations."""
        # Match input/output variable declarations
        input_pattern = r'in\s+(\w+)\s+(\w+)\s*;'
        output_pattern = r'out\s+(\w+)\s+(\w+)\s*;'
        
        for match in re.finditer(input_pattern, source):
            var_type, var_name = match.groups()
            self.variables[var_name] = Variable(var_name, var_type, is_input=True)
            
        for match in re.finditer(output_pattern, source):
            var_type, var_name = match.groups()
            self.variables[var_name] = Variable(var_name, var_type, is_output=True)
            
        # Add built-in variables based on shader type
        if shader_type == ShaderType.VERTEX:
            self.variables['gl_Position'] = Variable('gl_Position', 'vec4', is_output=True)
        elif shader_type == ShaderType.FRAGMENT:
            self.variables['gl_FragColor'] = Variable('gl_FragColor', 'vec4', is_output=True)

    def _parse_main_function(self, source: str):
        """Parse and compile the main function body."""
        # Extract main function body
        main_pattern = r'void\s+main\s*\(\s*\)\s*{([^}]*)}'
        main_match = re.search(main_pattern, source)
        if not main_match:
            raise ShaderError("No main function found")
            
        main_body = main_match.group(1)
        
        # Split into statements
        statements = [s.strip() for s in main_body.split(';') if s.strip()]
        
        # Compile each statement
        for stmt in statements:
            self._compile_statement(stmt)

    def _compile_statement(self, statement: str):
        """Compile a single statement into instructions."""
        # Handle assignments
        if '=' in statement:
            target, expr = [s.strip() for s in statement.split('=')]
            result = self._compile_expression(expr)
            self.instructions.append(Instruction('mov', [result], target))
            return
            
        # Handle function calls
        if '(' in statement:
            self._compile_expression(statement)
            return
            
        raise ShaderError(f"Unsupported statement: {statement}")

    def _compile_expression(self, expr: str) -> str:
        """Compile an expression and return the register/variable containing the result."""
        # Handle parentheses first
        if '(' in expr:
            # Handle function calls
            if any(builtin in expr for builtin in self.built_ins):
                for builtin, compiler in self.built_ins.items():
                    if builtin in expr:
                        return compiler(expr)
                        
            # Handle parenthesized expressions
            inner = self._extract_parenthesized(expr)
            result = self._compile_expression(inner)
            return result
            
        # Handle basic arithmetic
        for op in self.vector_ops:
            if op in expr:
                left, right = [s.strip() for s in expr.split(op)]
                left_reg = self._compile_expression(left)
                right_reg = self._compile_expression(right)
                result = self._new_temp()
                self.instructions.append(Instruction(
                    self.vector_ops[op],
                    [left_reg, right_reg],
                    result
                ))
                return result
                
        # Must be a variable or literal
        return expr

    def _compile_texture2D(self, expr: str) -> str:
        """Compile a texture2D builtin function call."""
        args = self._extract_args(expr)
        if len(args) != 2:
            raise ShaderError("texture2D requires 2 arguments")
            
        sampler = self._compile_expression(args[0])
        coords = self._compile_expression(args[1])
        result = self._new_temp()
        
        self.instructions.append(Instruction(
            'texture2D',
            [sampler, coords],
            result
        ))
        return result

    def _optimize_instructions(self):
        """Perform basic optimizations on the instruction stream."""
        # Constant folding
        self._fold_constants()
        
        # Dead code elimination
        self._eliminate_dead_code()
        
        # Common subexpression elimination
        self._eliminate_common_subexpressions()
        
        # Update optimized instructions in DB
        self.db.conn.commit()

    def _allocate_registers(self, variables: Dict[str, Variable]):
        """Allocate hardware registers to variables."""
        used_registers = set()
        
        # Get existing register allocations
        result = self.db.conn.execute("""
            SELECT var_name, register_name FROM registers 
            WHERE shader_id = ?
        """, [self.current_shader_id]).fetchall()
        
        existing_registers = {r[0]: r[1] for r in result}
        used_registers.update(existing_registers.values())
        
        # Allocate input/output variables first
        for var in variables.values():
            if var.is_input or var.is_output:
                if var.name not in existing_registers:
                    reg = self._find_free_register(used_registers)
                    var.register = reg
                    used_registers.add(reg)
                    # Store in DB
                    self.db.ensure_connection()
                    self.db.conn.execute("""
                        INSERT INTO registers (shader_id, var_name, register_name)
                        VALUES (?, ?, ?)
                    """, (self.current_shader_id, var.name, reg))
                else:
                    var.register = existing_registers[var.name]
        
        # Allocate temporaries
        self.db.ensure_connection()
        result = self.db.conn.execute("""
            SELECT DISTINCT result FROM instructions 
            WHERE shader_id = ? AND result IS NOT NULL
        """, [self.current_shader_id]).fetchall()
        
        for (temp_var,) in result:
            if temp_var not in existing_registers:
                reg = self._find_free_register(used_registers)
                # Store in DB
                self.db.ensure_connection()
                self.db.conn.execute("""
                    INSERT INTO registers (shader_id, var_name, register_name)
                    VALUES (?, ?, ?)
                """, (self.current_shader_id, temp_var, reg))
                used_registers.add(reg)
                
        self.db.conn.commit()

    def link_program(self, vertex_shader: dict, fragment_shader: dict) -> dict:
        """Link vertex and fragment shaders into a complete program."""
        # Verify shader types
        if vertex_shader['type'] != 'vertex' or fragment_shader['type'] != 'fragment':
            raise ShaderError("Invalid shader types for linking")
            
        # Check interface compatibility
        self._verify_interface_compatibility(vertex_shader, fragment_shader)
        
        # Generate program ID
        program_id = self._generate_program_id(vertex_shader, fragment_shader)
        
        # Create linked program in database
        self.db.ensure_connection()
        self.db.conn.execute("""
            INSERT INTO programs (
                id, vertex_shader_id, fragment_shader_id,
                uniforms, attributes, varyings, linked
            ) VALUES (?, ?, ?, ?, ?, ?, ?)
        """, (
            program_id,
            vertex_shader['id'],
            fragment_shader['id'],
            json.dumps(self._collect_uniforms(vertex_shader, fragment_shader)),
            json.dumps(self._collect_attributes(vertex_shader)),
            json.dumps(self._collect_varyings(vertex_shader, fragment_shader)),
            True
        ))
        
        # Fetch complete program
        result = self.db.conn.execute("""
            SELECT p.*, 
                vs.source as vertex_source, 
                vs.instructions as vertex_instructions,
                fs.source as fragment_source,
                fs.instructions as fragment_instructions
            FROM programs p
            JOIN shaders vs ON p.vertex_shader_id = vs.id
            JOIN shaders fs ON p.fragment_shader_id = fs.id
            WHERE p.id = ?
        """, [program_id]).fetchone()
        
        program = {
            "id": result[0],
            "vertex_shader": {
                "id": result[1],
                "source": result[7],
                "instructions": json.loads(result[8])
            },
            "fragment_shader": {
                "id": result[2],
                "source": result[9],
                "instructions": json.loads(result[10])
            },
            "uniforms": json.loads(result[3]),
            "attributes": json.loads(result[4]),
            "varyings": json.loads(result[5]),
            "linked": result[6]
        }
        
        self.db.conn.commit()
        return program

    def validate_program(self, program: dict) -> bool:
        """Validate a linked program."""
        try:
            # Check required components
            if not all(k in program for k in ['vertex_shader', 'fragment_shader', 'linked']):
                return False
                
            # Verify shader validity
            for shader in [program['vertex_shader'], program['fragment_shader']]:
                if not all(k in shader for k in ['type', 'instructions', 'variables']):
                    return False
                    
            # Check interface compatibility
            vertex_outputs = {
                name for name, var in program['vertex_shader']['variables'].items()
                if var['is_output']
            }
            fragment_inputs = {
                name for name, var in program['fragment_shader']['variables'].items()
                if var['is_input']
            }
            
            if not fragment_inputs.issubset(vertex_outputs):
                return False
                
            return True
            
        except Exception:
            return False

    def _new_temp(self) -> str:
        """Generate a new temporary variable name."""
        self.temp_counter += 1
        return f"temp_{self.temp_counter}"

    def _extract_parenthesized(self, expr: str) -> str:
        """Extract content between outermost parentheses."""
        start = expr.index('(')
        count = 1
        for i, c in enumerate(expr[start + 1:], start + 1):
            if c == '(':
                count += 1
            elif c == ')':
                count -= 1
                if count == 0:
                    return expr[start + 1:i]
        raise ShaderError("Mismatched parentheses")

    def _extract_args(self, expr: str) -> List[str]:
        """Extract function arguments."""
        args_str = self._extract_parenthesized(expr)
        return [arg.strip() for arg in args_str.split(',')]

    def _generate_shader_id(self, source: str) -> str:
        """Generate a unique shader ID."""
        return f"shader_{hashlib.md5(source.encode()).hexdigest()[:8]}"

    def _generate_program_id(self, vertex_shader: dict, fragment_shader: dict) -> str:
        """Generate a unique program ID."""
        combined = vertex_shader['id'] + fragment_shader['id']
        return f"program_{hashlib.md5(combined.encode()).hexdigest()[:8]}"

    def _compile_normalize(self, expr: str) -> str:
        args = self._extract_args(expr)
        if len(args) != 1:
            raise ShaderError("normalize requires 1 argument")
        vec = self._compile_expression(args[0])
        result = self._new_temp()
        self.instructions.append(Instruction('normalize', [vec], result))
        return result

    def _compile_dot(self, expr: str) -> str:
        args = self._extract_args(expr)
        if len(args) != 2:
            raise ShaderError("dot requires 2 arguments")
        vec1 = self._compile_expression(args[0])
        vec2 = self._compile_expression(args[1])
        result = self._new_temp()
        self.instructions.append(Instruction('dot', [vec1, vec2], result))
        return result

    def _compile_mix(self, expr: str) -> str:
        args = self._extract_args(expr)
        if len(args) != 3:
            raise ShaderError("mix requires 3 arguments")
        x = self._compile_expression(args[0])
        y = self._compile_expression(args[1])
        a = self._compile_expression(args[2])
        result = self._new_temp()
        self.instructions.append(Instruction('mix', [x, y, a], result))
        return result

    def _compile_clamp(self, expr: str) -> str:
        args = self._extract_args(expr)
        if len(args) != 3:
            raise ShaderError("clamp requires 3 arguments")
        x = self._compile_expression(args[0])
        min_val = self._compile_expression(args[1])
        max_val = self._compile_expression(args[2])
        result = self._new_temp()
        self.instructions.append(Instruction('clamp', [x, min_val, max_val], result))
        return result

    def _fold_constants(self):
        """Perform constant folding optimization."""
        # Get constant expressions
        self.db.conn.execute("""
            UPDATE instructions 
            SET is_dead = TRUE
            WHERE shader_id = ? AND opcode IN ('add', 'sub', 'mul', 'div')
            AND args[0] LIKE '%[0-9]+%' AND args[1] LIKE '%[0-9]+%'
        """, [self.current_shader_id])

    def _eliminate_dead_code(self):
        """Perform dead code elimination."""
        # Get output variables
        outputs = self.db.conn.execute("""
            SELECT var_name FROM registers r
            JOIN shaders s ON r.shader_id = s.id
            JOIN json_each(s.variables) v ON v.key = r.var_name
            WHERE s.id = ? AND json_extract(v.value, '$.is_output') = true
        """, [self.current_shader_id]).fetchall()
        
        # Mark used variables recursively
        used_vars = set(r[0] for r in outputs)
        while True:
            new_used = self.db.conn.execute("""
                SELECT DISTINCT a.value::VARCHAR
                FROM instructions i,
                json_array_elements_text(i.args) a
                WHERE i.shader_id = ?
                AND i.result IN ?
                AND NOT i.is_dead
                AND a.value NOT IN ?
            """, [self.current_shader_id, tuple(used_vars), tuple(used_vars)]).fetchall()
            
            if not new_used:
                break
            used_vars.update(r[0] for r in new_used)
        
        # Mark unused instructions as dead
        self.db.conn.execute("""
            UPDATE instructions
            SET is_dead = TRUE
            WHERE shader_id = ?
            AND (result IS NULL OR result NOT IN ?)
        """, [self.current_shader_id, tuple(used_vars)])

    def _eliminate_common_subexpressions(self):
        """Perform common subexpression elimination."""
        # Find duplicate expressions
        self.db.conn.execute("""
            WITH expr_groups AS (
                SELECT opcode, args, MIN(id) as first_id,
                array_agg(id) as duplicate_ids
                FROM instructions
                WHERE shader_id = ? AND NOT is_dead
                GROUP BY opcode, args
                HAVING COUNT(*) > 1
            )
            UPDATE instructions i
            SET is_dead = TRUE
            WHERE id IN (
                SELECT unnest(duplicate_ids[2:])
                FROM expr_groups
            )
            AND shader_id = ?
        """, [self.current_shader_id, self.current_shader_id])
        
        # Add move instructions for the duplicates
        duplicates = self.db.conn.execute("""
            WITH expr_groups AS (
                SELECT opcode, args, result,
                MIN(id) as first_id,
                array_agg(id) as duplicate_ids,
                array_agg(result) as results
                FROM instructions
                WHERE shader_id = ?
                GROUP BY opcode, args
                HAVING COUNT(*) > 1
            )
            SELECT first_id, results
            FROM expr_groups
        """, [self.current_shader_id]).fetchall()
        
        for first_id, results in duplicates:
            results = json.loads(results)
            original_result = results[0]
            for result in results[1:]:
                self.db.conn.execute("""
                    INSERT INTO instructions (shader_id, opcode, args, result)
                    VALUES (?, 'mov', ?, ?)
                """, [self.current_shader_id, json.dumps([original_result]), result])

    def _find_free_register(self, used_registers: set) -> str:
        """Find an unused register."""
        i = 0
        while f"r{i}" in used_registers:
            i += 1
        return f"r{i}"

    def _verify_interface_compatibility(self, vertex_shader: dict, fragment_shader: dict):
        """Verify that shader interfaces are compatible."""
        vertex_outputs = {
            name: var for name, var in vertex_shader['variables'].items()
            if var['is_output']
        }
        fragment_inputs = {
            name: var for name, var in fragment_shader['variables'].items()
            if var['is_input']
        }
        
        # Check that all fragment inputs have matching vertex outputs
        for name, var in fragment_inputs.items():
            if name not in vertex_outputs:
                raise ShaderError(f"Fragment shader input '{name}' has no matching vertex output")
            if vertex_outputs[name]['type'] != var['type']:
                raise ShaderError(f"Type mismatch for varying '{name}'")

    def _collect_uniforms(self, vertex_shader: dict, fragment_shader: dict) -> dict:
        """Collect all uniform variables from both shaders."""
        uniforms = {}
        for shader in [vertex_shader, fragment_shader]:
            for name, var in shader['variables'].items():
                if 'uniform' in var.get('qualifiers', []):
                    uniforms[name] = var
        return uniforms

    def _collect_attributes(self, vertex_shader: dict) -> dict:
        """Collect vertex attributes."""
        return {
            name: var for name, var in vertex_shader['variables'].items()
            if var['is_input']
        }

    def _collect_varyings(self, vertex_shader: dict, fragment_shader: dict) -> dict:
        """Collect varying variables (vertex outputs / fragment inputs)."""
        return {
            name: var for name, var in vertex_shader['variables'].items()
            if var['is_output'] and name in fragment_shader['variables']
        }

