"""
Example of running GPT-OSS-20B model using Git-based weight loading
"""
import os
import numpy as np
from typing import List, Optional
from transformers import AutoTokenizer
from weight_manager import WeightManager
from helium.main import register_device
import helium as hl
from virtual_gpu_driver.src.driver_api import VirtualGPUDriver

class GPTOSSInference:
    def __init__(self, cache_dir: Optional[str] = None):
        self.model_name = "openai/gpt-oss-20b"
        self.manager = WeightManager(cache_dir=cache_dir)
        self.device_id = self._setup_virtual_gpu()
        self.tokenizer = None
        self.max_length = 2048  # Max sequence length
        
    def _setup_virtual_gpu(self) -> str:
        """Initialize and return a virtual GPU device"""
        driver = VirtualGPUDriver()
        device_id = "vgpu0"
        register_device(device_id, driver)
        return device_id
        
    def load_model(self):
        """Load the GPT-OSS-20B model using Git-based download"""
        print(f"Loading {self.model_name} using Git-based download...")
        self.manager.load_model(self.model_name)
        
        # Initialize tokenizer from local files
        print("Initializing tokenizer...")
        model_path = self.manager.loaded_models[self.model_name]['path']
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        
    def generate(self, prompt: str, max_new_tokens: int = 100) -> str:
        """Generate text from a prompt
        
        Args:
            prompt: Input text prompt
            max_new_tokens: Maximum number of new tokens to generate
            
        Returns:
            Generated text
        """
        if not self.tokenizer:
            raise RuntimeError("Model and tokenizer must be loaded first")
            
        # Tokenize input
        tokens = self.tokenizer(
            prompt,
            return_tensors='np',
            max_length=self.max_length,
            truncation=True
        )
        
        # Run inference in batches for memory efficiency
        print("Running GPT-OSS inference...")
        output = self.manager.run_inference(
            model_name=self.model_name,
            input_data=tokens['input_ids'],
            device_id=self.device_id,
            model_type="decoder"  # GPT models are decoder-only
        )
        
        # Decode output tokens
        generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
        return generated_text

def main():
    """Example usage of GPT-OSS-20B"""
    # Initialize with custom cache directory if needed
    cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "helium", "models")
    inference = GPTOSSInference(cache_dir=cache_dir)
    
    # Load model (will use Git clone)
    inference.load_model()
    
    # Example prompts
    prompts = [
        "Write a short story about",
        "Explain how quantum computers",
        "The future of artificial intelligence"
    ]
    
    # Generate text for each prompt
    for i, prompt in enumerate(prompts, 1):
        print(f"\n=== Example {i} ===")
        print(f"Prompt: {prompt}")
        
        generated_text = inference.generate(prompt)
        print(f"\nGenerated text:\n{generated_text}")

if __name__ == "__main__":
    main()
