"""
Prefetching logic for weight streaming
"""
from typing import List, Dict, Set, Optional
import threading
import queue
import time

class PrefetchManager:
    """Manages prefetching of weights for upcoming layers"""
    
    def __init__(self, weight_manager, max_prefetch: int = 3):
        self.weight_manager = weight_manager
        self.max_prefetch = max_prefetch
        self.prefetch_queue = queue.Queue()
        self.active_prefetch: Set[str] = set()
        self.prefetch_thread = threading.Thread(
            target=self._prefetch_worker,
            daemon=True
        )
        self.prefetch_thread.start()
        
    def request_prefetch(self, 
                        model_id: str,
                        layer_names: List[str],
                        priority: int = 0):
        """Request prefetch for layers"""
        for layer in layer_names[:self.max_prefetch]:
            if layer not in self.active_prefetch:
                self.prefetch_queue.put((priority, model_id, layer))
                self.active_prefetch.add(layer)
                
    def _prefetch_worker(self):
        """Worker thread for prefetching"""
        while True:
            try:
                priority, model_id, layer = self.prefetch_queue.get()
                # Trigger prefetch in weight manager
                self.weight_manager.prefetch_layers(model_id, [layer])
                self.active_prefetch.remove(layer)
                time.sleep(0.001)  # Small sleep to prevent CPU hogging
            except queue.Empty:
                time.sleep(0.01)
            except Exception as e:
                print(f"Prefetch error: {e}")
                
    def clear_prefetch(self):
        """Clear prefetch queue"""
        while not self.prefetch_queue.empty():
            try:
                self.prefetch_queue.get_nowait()
            except queue.Empty:
                break
        self.active_prefetch.clear()
