"""
Hugging Face model weight manager for Helium inference with database storage
"""
import os
import json
from typing import Dict, List, Optional, Union, Any
import numpy as np
from pathlib import Path
import duckdb
from transformers import AutoConfig, AutoTokenizer
import helium as hl

from helium.run_transformer import (
    run_vision_transformer_inference,
    run_bert_inference,
    run_gpt2_inference
)

from .git_model_utils import git_clone_model, download_lfs_files, validate_model_files
from helium.utils import parse_hf_config, create_causal_mask
from helium.encoder import TransformerEncoder
from helium.decoder_model import TransformerDecoder
from helium.utils import (
    map_hf_weights_to_blocks,
    map_hf_weights_to_blocks_bert,
    map_hf_weights_to_blocks_t5
)
import torch  # For loading weights

class WeightManager:
    """Manages model weights using Hugging Face Hub and DuckDB storage"""
    
    def __init__(self, cache_dir: Optional[str] = None, db_path: Optional[str] = None):
        """Initialize weight manager with optional cache directory and database path
        
        Args:
            cache_dir: Directory for caching downloaded models
            db_path: Path to DuckDB database file for storing model data
        """
        self.cache_dir = cache_dir or os.path.join(os.path.expanduser("~"), ".cache", "helium", "models")
        os.makedirs(self.cache_dir, exist_ok=True)
        
        # Initialize database
        self.db_path = db_path or os.path.join(self.cache_dir, "models.db")
        self.conn = duckdb.connect(self.db_path)
        self._init_db()
        
        self.loaded_models = {}
        
    def _init_db(self):
        """Initialize database tables"""
        # Model registry table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS model_registry (
                model_name VARCHAR PRIMARY KEY,
                model_path VARCHAR,
                config JSON,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            )
        """)
        
        # Model weights table
        self.conn.execute("""
            CREATE TABLE IF NOT EXISTS model_weights (
                model_name VARCHAR,
                weight_name VARCHAR,
                weight_data BLOB,
                shape VARCHAR,  -- JSON array of dimensions
                dtype VARCHAR,
                PRIMARY KEY (model_name, weight_name),
                FOREIGN KEY (model_name) REFERENCES model_registry(model_name)
            )
        """)
        
    def load_model(self, model_name: str, force_reload: bool = False) -> Dict[str, Any]:
        """Download and load complete model weights from Hugging Face Hub and store in database
        
        Args:
            model_name: The Hugging Face model name (e.g. 'bert-base-uncased')
            force_reload: If True, redownload and update even if model exists in DB
            
        Returns:
            Dict containing model weights and config
        """
        # Check if model exists in database
        if not force_reload:
            existing = self.conn.execute(
                "SELECT model_path, config FROM model_registry WHERE model_name = ?",
                [model_name]
            ).fetchone()
            
            if existing:
                model_path, config = existing
                config = json.loads(config)
                
                # Load weights from database
                weights_result = self.conn.execute("""
                    SELECT weight_name, weight_data, shape, dtype
                    FROM model_weights
                    WHERE model_name = ?
                """, [model_name]).fetchall()
                
                weights_np = {}
                for weight_name, data, shape, dtype in weights_result:
                    shape = json.loads(shape)
                    weights_np[weight_name] = np.frombuffer(data, dtype=dtype).reshape(shape)
                
                self.loaded_models[model_name] = {
                    'weights': weights_np,
                    'config': config,
                    'path': model_path
                }
                
                return self.loaded_models[model_name]
        
        # Download model using Git if not in database or force_reload
        print(f"Downloading model {model_name} using Git...")
        try:
            model_path = git_clone_model(
                repo_id=model_name,
                cache_dir=self.cache_dir
            )
            
            # Download LFS files
            print("Downloading LFS files...")
            download_lfs_files(model_path)
            
            # Validate model files
            if not validate_model_files(model_path):
                raise RuntimeError("Model files validation failed")
                
            # Load model config from the local files
            config = AutoConfig.from_pretrained(model_path)
        except Exception as e:
            raise RuntimeError(f"Failed to download and setup model {model_name}: {str(e)}")
        
        # Load complete weights into memory
        state_dict = {}
        for filename in os.listdir(model_path):
            if filename.endswith('.bin') or filename.endswith('.safetensors'):
                weights = torch.load(os.path.join(model_path, filename))
                state_dict.update(weights)
        
        # Convert weights to numpy and store in database
        weights_np = {}
        
        # Store model in registry
        self.conn.execute("""
            INSERT OR REPLACE INTO model_registry (model_name, model_path, config)
            VALUES (?, ?, ?)
        """, [model_name, model_path, json.dumps(config.to_dict())])
        
        # Store weights
        for key, tensor in state_dict.items():
            np_array = tensor.cpu().numpy()
            weights_np[key] = np_array
            
            self.conn.execute("""
                INSERT OR REPLACE INTO model_weights 
                (model_name, weight_name, weight_data, shape, dtype)
                VALUES (?, ?, ?, ?, ?)
            """, [
                model_name,
                key,
                np_array.tobytes(),
                json.dumps(np_array.shape),
                str(np_array.dtype)
            ])
        
        # Store in memory cache
        self.loaded_models[model_name] = {
            'weights': weights_np,
            'config': config,
            'path': model_path
        }
        
        return self.loaded_models[model_name]
    
    def prepare_helium_weights(
        self,
        weights: Dict[str, np.ndarray],
        config: Any,
        model_type: str
    ) -> Dict[str, Any]:
        """Prepare weights in Helium's expected format
        
        Args:
            weights: Raw weights dictionary
            config: Model configuration
            model_type: Type of model ('vision', 'encoder', 'decoder', 'encoder-decoder')
            
        Returns:
            Dictionary with weights mapped to Helium's expected format
        """
        if model_type == 'vision':
            # Map vision transformer weights
            return {
                'vit.patch_embed.proj.weight': weights.get('patch_embed.projection.weight'),
                'vit.patch_embed.proj.bias': weights.get('patch_embed.projection.bias'),
                'vit.pos_embed': weights.get('position_embedding'),
                'vit.cls_token': weights.get('cls_token'),
                'vit.head.weight': weights.get('classifier.weight'),
                'vit.head.bias': weights.get('classifier.bias'),
                **{f'vit.encoder.layer.{i}.' + k: v 
                   for i in range(config.num_hidden_layers)
                   for k, v in weights.items() 
                   if k.startswith(f'encoder.layer.{i}.')}
            }
        elif model_type == 'encoder':
            # Map BERT-style weights
            return {
                'bert.embeddings.word_embeddings.weight': weights.get('embeddings.word_embeddings.weight'),
                'bert.embeddings.position_embeddings.weight': weights.get('embeddings.position_embeddings.weight'),
                'bert.embeddings.token_type_embeddings.weight': weights.get('embeddings.token_type_embeddings.weight'),
                'bert.embeddings.LayerNorm.weight': weights.get('embeddings.LayerNorm.weight'),
                'bert.embeddings.LayerNorm.bias': weights.get('embeddings.LayerNorm.bias'),
                **{f'bert.encoder.layer.{i}.' + k: v 
                   for i in range(config.num_hidden_layers)
                   for k, v in weights.items() 
                   if k.startswith(f'encoder.layer.{i}.')}
            }
        elif model_type == 'decoder':
            # Map GPT-style weights
            return {
                'transformer.wte.weight': weights.get('wte.weight'),
                'transformer.wpe.weight': weights.get('wpe.weight'),
                'lm_head.weight': weights.get('lm_head.weight'),
                **{f'transformer.h.{i}.' + k: v 
                   for i in range(config.num_hidden_layers)
                   for k, v in weights.items() 
                   if k.startswith(f'h.{i}.')}
            }
        else:  # encoder-decoder
            # Map T5/BART-style weights
            return {
                'encoder': {
                    'embeddings.weight': weights.get('encoder.embed_tokens.weight'),
                    **{f'encoder.block.{i}.' + k: v 
                       for i in range(config.encoder_layers)
                       for k, v in weights.items() 
                       if k.startswith(f'encoder.block.{i}.')}
                },
                'decoder': {
                    'embeddings.weight': weights.get('decoder.embed_tokens.weight'),
                    **{f'decoder.block.{i}.' + k: v 
                       for i in range(config.decoder_layers)
                       for k, v in weights.items() 
                       if k.startswith(f'decoder.block.{i}.')}
                }
            }
    
    def run_inference(
        self,
        model_name: str,
        input_data: np.ndarray,
        device_id: Optional[str] = None,
        model_type: Optional[str] = None
    ) -> np.ndarray:
        """Run inference using Helium framework
        
        Args:
            model_name: The model name to use for inference
            input_data: Input tensor as numpy array
            device_id: Optional virtual GPU device ID
            model_type: Optional model type override ('encoder', 'decoder', 'encoder-decoder', 'vision')
            
        Returns:
            Model output as numpy array
        """
        if model_name not in self.loaded_models:
            self.load_model(model_name)
            
        model_data = self.loaded_models[model_name]
        config = model_data['config']
        weights = model_data['weights']
        
        # Set up Helium device if specified
        if device_id:
            hl.set_default_device(device_id)
            
        # Determine model type if not specified
        if not model_type:
            if hasattr(config, 'is_vision_model') and config.is_vision_model:
                model_type = 'vision'
            elif hasattr(config, 'is_decoder_only') and config.is_decoder_only:
                model_type = 'decoder'
            elif hasattr(config, 'is_encoder_decoder') and config.is_encoder_decoder:
                model_type = 'encoder-decoder'
            else:
                model_type = 'encoder'  # Default to encoder
                
        # Parse config for Helium
        parsed_config = parse_hf_config(config)
        
        # Map weights to Helium's format
        mapped_weights = self.prepare_helium_weights(weights, config, model_type)
        
        # Get driver from Helium device registry
        driver = hl.get_device(device_id) if device_id else None
        
        if model_type == 'vision':
            # Vision Transformer
            output = run_vision_transformer_inference(
                hf_weights=mapped_weights,
                config=parsed_config,
                input_data=input_data,
                driver=driver
            )
        elif model_type == 'encoder':
            # Encoder-only models (BERT, RoBERTa, etc)
            encoder = TransformerEncoder(
                vocab_size=config.vocab_size,
                hidden_dim=config.hidden_size,
                num_layers=config.num_hidden_layers,
                num_heads=config.num_attention_heads,
                max_seq_len=config.max_position_embeddings,
                embedding_weights=mapped_weights['bert.embeddings.word_embeddings.weight'],
                block_weights_list=[{k: v for k, v in mapped_weights.items() if f'bert.encoder.layer.{i}.' in k}
                                  for i in range(config.num_hidden_layers)],
                driver=driver
            )
            output = encoder.forward(input_data)
            
        elif model_type == 'decoder':
            # Decoder-only models (GPT, LLaMA, etc)
            decoder = TransformerDecoder(
                vocab_size=config.vocab_size,
                hidden_dim=config.hidden_size,
                num_layers=config.num_hidden_layers,
                num_heads=config.num_attention_heads,
                max_seq_len=config.max_position_embeddings,
                embedding_weights=mapped_weights['transformer.wte.weight'],
                block_weights_list=[{k: v for k, v in mapped_weights.items() if f'transformer.h.{i}.' in k}
                                  for i in range(config.num_hidden_layers)],
                driver=driver
            )
            causal_mask = create_causal_mask(input_data.shape[1])
            output = decoder.forward(input_data, attention_mask=causal_mask)
            
        else:  # encoder-decoder
            # Encoder-Decoder models (T5, BART, etc)
            encoder = TransformerEncoder(
                vocab_size=config.vocab_size,
                hidden_dim=config.hidden_size,
                num_layers=config.encoder_layers,
                num_heads=config.num_attention_heads,
                max_seq_len=config.max_position_embeddings,
                embedding_weights=mapped_weights['encoder']['embeddings.weight'],
                block_weights_list=[{k: v for k, v in mapped_weights['encoder'].items() if f'encoder.block.{i}.' in k}
                                  for i in range(config.encoder_layers)],
                driver=driver
            )
            
            decoder = TransformerDecoder(
                vocab_size=config.vocab_size,
                hidden_dim=config.hidden_size,
                num_layers=config.decoder_layers,
                num_heads=config.num_attention_heads,
                max_seq_len=config.max_position_embeddings,
                embedding_weights=mapped_weights['decoder']['embeddings.weight'],
                block_weights_list=[{k: v for k, v in mapped_weights['decoder'].items() if f'decoder.block.{i}.' in k}
                                  for i in range(config.decoder_layers)],
                driver=driver
            )
            
            # Encode input
            encoder_output = encoder.forward(input_data)
            
            # Initialize decoder input (usually start token)
            decoder_input = np.zeros((input_data.shape[0], 1), dtype=np.int64)
            
            # Decode
            output = decoder.forward(
                decoder_input,
                encoder_hidden_states=encoder_output
            )
        
        return output
        
    def clear_model(self, model_name: str):
        """Remove a model from both database and memory cache
        
        Args:
            model_name: Name of the model to remove
        """
        # Remove from database
        self.conn.execute("DELETE FROM model_weights WHERE model_name = ?", [model_name])
        self.conn.execute("DELETE FROM model_registry WHERE model_name = ?", [model_name])
        
        # Remove from memory cache
        if model_name in self.loaded_models:
            del self.loaded_models[model_name]
            
    def list_models(self) -> List[Dict[str, Any]]:
        """List all models stored in the database
        
        Returns:
            List of dicts containing model info
        """
        results = self.conn.execute("""
            SELECT 
                r.model_name,
                r.model_path,
                r.config,
                r.created_at,
                COUNT(w.weight_name) as num_weights,
                SUM(LENGTH(w.weight_data)) as total_size
            FROM model_registry r
            LEFT JOIN model_weights w ON r.model_name = w.model_name
            GROUP BY r.model_name, r.model_path, r.config, r.created_at
        """).fetchall()
        
        return [{
            'name': row[0],
            'path': row[1],
            'config': json.loads(row[2]),
            'created_at': row[3],
            'num_weights': row[4],
            'total_size_bytes': row[5]
        } for row in results]
      