"""
Example usage of WeightManager for model inference
"""
import os
import numpy as np
from PIL import Image
import torch
from transformers import AutoTokenizer
from weight_manager import WeightManager
from helium.main import register_device
import helium as hl
from virtual_gpu_driver.src.driver_api import VirtualGPUDriver

def setup_virtual_gpu():
    """Initialize and return a virtual GPU device"""
    # Initialize virtual GPU
    driver = VirtualGPUDriver()
    device_id = "vgpu0"
    register_device(device_id, driver)
    
    return device_id

def run_vision_model_example():
    """Example of running a Vision Transformer model"""
    manager = WeightManager()
    device_id = setup_virtual_gpu()
    
    # Load a vision transformer model
    model_name = "google/vit-base-patch16-224"
    print(f"Loading {model_name}...")
    manager.load_model(model_name)
    
    # Prepare sample image input (224x224 RGB image)
    input_data = np.random.randn(1, 3 * 224 * 224)  # Flattened RGB image
    
    # Run inference
    print("Running vision model inference...")
    output = manager.run_inference(
        model_name=model_name,
        input_data=input_data,
        device_id=device_id,
        model_type="vision"
    )
    
    print(f"Output shape: {output.shape}")
    return output

def run_text_model_example():
    """Example of running a BERT model for text classification"""
    manager = WeightManager()
    device_id = setup_virtual_gpu()
    
    # Load BERT model
    model_name = "bert-base-uncased"
    print(f"Loading {model_name}...")
    manager.load_model(model_name)
    
    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Prepare sample text input
    text = "This is an example sentence for inference."
    tokens = tokenizer(
        text,
        padding='max_length',
        max_length=128,
        truncation=True,
        return_tensors='np'
    )
    
    # Run inference
    print("Running text model inference...")
    output = manager.run_inference(
        model_name=model_name,
        input_data=tokens['input_ids'],
        device_id=device_id,
        model_type="encoder"  # BERT is an encoder-only model
    )
    
    print(f"Output shape: {output.shape}")
    return output

def run_language_model_example():
    """Example of running GPT-2 for text generation"""
    manager = WeightManager()
    device_id = setup_virtual_gpu()
    
    # Load GPT-2 model
    model_name = "gpt2"
    print(f"Loading {model_name}...")
    manager.load_model(model_name)
    
    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Prepare prompt
    prompt = "Once upon a time"
    tokens = tokenizer(
        prompt,
        return_tensors='np'
    )
    
    # Run inference
    print("Running language model inference...")
    output = manager.run_inference(
        model_name=model_name,
        input_data=tokens['input_ids'],
        device_id=device_id,
        model_type="decoder"  # GPT-2 is a decoder-only model
    )
    
    print(f"Output shape: {output.shape}")
    return output

def main():
    """Run examples for different model types"""
    print("=== Weight Manager Examples ===")
    
    # List any existing models in database
    manager = WeightManager()
    stored_models = manager.list_models()
    if stored_models:
        print("\nStored models:")
        for model in stored_models:
            print(f"- {model['name']} ({model['total_size_bytes'] / 1e9:.2f} GB)")
    
    # Run examples
    print("\n1. Vision Transformer Example")
    vision_output = run_vision_model_example()
    
    print("\n2. BERT Text Classification Example")
    text_output = run_text_model_example()
    
    print("\n3. GPT-2 Language Model Example")
    lm_output = run_language_model_example()
    
    print("\nAll examples completed successfully!")

if __name__ == "__main__":
    main()
