from typing import List, Optional
from dataclasses import dataclass
import time
import json
from queue import Queue
from threading import Lock
import duckdb
from huggingface_hub import HfApi, HfFileSystem
from config import get_hf_token_cached

# Initialize token from .env



@dataclass
class Event:
    """Represents a CUDA-like event for synchronization"""
    event_id: str
    timestamp: float
    completed: bool = False
    state_json: Optional[dict] = None

class Stream:
    """Represents a CUDA-like stream for concurrent execution"""
    DB_URL = "hf://datasets/Fred808/helium/storage.json"
    
    def __init__(self, stream_id: int, db_url: Optional[str] = None):
        self.stream_id = stream_id
        self.events: List[Event] = []
        self.operation_queue: Queue = Queue()
        self.lock = Lock()
        self.is_active = True
        
        # Initialize database connection
        self.db_url = db_url or self.DB_URL
        self.max_retries = 3
        self._connect_with_retries()
        self._setup_database()

    def _connect_with_retries(self):
        """Establish database connection with retry logic"""
        for attempt in range(self.max_retries):
            try:
                self.conn = self._init_db_connection()
                return
            except Exception as e:
                if attempt == self.max_retries - 1:
                    raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}")
                time.sleep(1)

    def _init_db_connection(self) -> duckdb.DuckDBPyConnection:
        """Initialize database connection with HuggingFace configuration"""
        # First create an in-memory connection to configure settings
        temp_conn = duckdb.connect(":memory:")
        
        # Configure HuggingFace access - must be done before connecting to URL
        temp_conn.execute("INSTALL httpfs;")
        temp_conn.execute("LOAD httpfs;")
        temp_conn.execute("SET s3_endpoint='hf.co';")
        temp_conn.execute("SET s3_use_ssl=true;")
        temp_conn.execute("SET s3_url_style='path';")
        
        # Now create the real connection with the configured settings
        conn = duckdb.connect(self.db_url, config={'http_keep_alive': 'true'})
        conn.execute("INSTALL httpfs;")
        conn.execute("LOAD httpfs;")
        conn.execute("SET s3_endpoint='hf.co';")
        conn.execute("SET s3_use_ssl=true;")
        conn.execute("SET s3_url_style='path';")
        
        # Configure HuggingFace authentication if token exists
        if hasattr(self, 'HF_TOKEN'):
            conn.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';")
            conn.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';")
        
        # Close temporary connection
        temp_conn.close()
        
        return conn

    def _setup_database(self):
        """Initialize database tables"""
        # Events table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS stream_events (
                event_id VARCHAR PRIMARY KEY,
                stream_id BIGINT,
                timestamp DOUBLE,
                completed BOOLEAN DEFAULT false,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                completed_at TIMESTAMP,
                state_json JSON
            )
        """)
        
        # Operations table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS stream_operations (
                operation_id VARCHAR PRIMARY KEY,
                stream_id BIGINT,
                operation_type VARCHAR,
                args JSON,
                kwargs JSON,
                status VARCHAR DEFAULT 'pending',
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                started_at TIMESTAMP,
                completed_at TIMESTAMP,
                error_message VARCHAR
            )
        """)
        
    def record_event(self) -> Event:
        """Record an event in the stream"""
        with self.lock:
            event_id = f"event_{self.stream_id}_{time.time_ns()}"
            event = Event(event_id=event_id, timestamp=time.time())
            
            # Record event in database
            self.conn.execute("""
                INSERT INTO stream_events (
                    event_id, stream_id, timestamp, state_json
                ) VALUES (?, ?, ?, ?)
            """, [event_id, self.stream_id, event.timestamp, {"status": "created"}])
            
            self.events.append(event)
            return event

    def wait_event(self, event: Event):
        """Wait for a specific event to complete"""
        while True:
            # Check database for completion
            result = self.conn.execute("""
                SELECT completed, state_json
                FROM stream_events
                WHERE event_id = ?
            """, [event.event_id]).fetchall()
            
            if result and result[0][0]:
                event.completed = True
                event.state_json = result[0][1]
                break
            
            if event.completed:
                break
                
            time.sleep(0.001)  # Small sleep to prevent busy waiting

    def synchronize(self):
        """Synchronize the stream, waiting for all operations to complete"""
        with self.lock:
            for event in self.events:
                self.wait_event(event)
                
            # Clear completed events
            self.conn.execute("""
                DELETE FROM stream_events
                WHERE stream_id = ? AND completed = true
            """, [self.stream_id])
            
            self.events.clear()

    def add_operation(self, operation: callable, *args, **kwargs):
        """Add an operation to the stream's queue"""
        with self.lock:
            self.operation_queue.put((operation, args, kwargs))

    def execute_next(self) -> bool:
        """Execute the next operation in the queue"""
        try:
            with self.lock:
                if self.operation_queue.empty():
                    return False
                
                operation, args, kwargs = self.operation_queue.get()
                event = self.record_event()
                
                try:
                    operation(*args, **kwargs)
                finally:
                    event.completed = True
                    
                return True
        except Exception as e:
            print(f"Error in stream {self.stream_id}: {str(e)}")
            return False

class StreamManager:
    """Manages multiple CUDA-like streams"""
    def __init__(self):
        self.streams: List[Stream] = []
        self.default_stream = self.create_stream()

    def create_stream(self) -> Stream:
        """Create a new stream"""
        stream_id = len(self.streams)
        stream = Stream(stream_id)
        self.streams.append(stream)
        return stream

    def get_stream(self, stream_id: int) -> Optional[Stream]:
        """Get a stream by its ID"""
        if 0 <= stream_id < len(self.streams):
            return self.streams[stream_id]
        return None

    def synchronize_all(self):
        """Synchronize all streams"""
        for stream in self.streams:
            stream.synchronize()

    def synchronize_stream(self, stream_id: int):
        """Synchronize a specific stream"""
        stream = self.get_stream(stream_id)
        if stream:
            stream.synchronize()

    def execute_streams(self):
        """Execute operations in all streams"""
        while True:
            executed = False
            for stream in self.streams:
                if stream.execute_next():
                    executed = True
            if not executed:
                break
