from __future__ import annotations
from typing import Dict, List, Union, Optional, Any, Tuple
import numpy as np
from dataclasses import dataclass
from enum import Enum
import warnings
import json
import hashlib
import logging
from functools import lru_cache
from pathlib import Path

# Import local dependencies 
from .broadcast import ModalityType
from .attention_utils import AttentionState
from .core.db_manager import HeliumDBManager
from .virtual_gpu_device import VirtualGPUDevice

# Initialize virtual GPU device pool
_gpu_devices: Dict[str, VirtualGPUDevice] = {}

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class ModelType(Enum):
    """Supported model architectures"""
    GPT2 = "gpt2"
    BERT = "bert"
    T5 = "t5"
    LLAMA = "llama"
    MISTRAL = "mistral"
    FALCON = "falcon"

@dataclass
class ModelConfig:
    """Universal configuration for transformer models"""
    model_type: ModelType
    num_layers: int
    num_heads: int
    hidden_dim: int
    vocab_size: int
    max_seq_len: int
    intermediate_size: Optional[int] = None
    layer_norm_epsilon: float = 1e-5
    initializer_range: float = 0.02
    use_cache: bool = True
    use_fp16: bool = False
    rotary_dim: Optional[int] = None  # For models with rotary embeddings
    vocab_padding_size: Optional[int] = None  # For vocab size optimization

class CacheManager:
    """Manages caching for model utilities"""
    def __init__(self):
        self.db = HeliumDBManager.get_instance()
        
    def _compute_key(self, data: Any, prefix: str) -> str:
        """Compute cache key for data"""
        if isinstance(data, np.ndarray):
            return f"{prefix}_{hashlib.sha256(data.tobytes()).hexdigest()}"
        return f"{prefix}_{hashlib.sha256(str(data).encode()).hexdigest()}"
        
    def get(self, key: str) -> Optional[Any]:
        """Get cached data"""
        return self.db.get_activation(key)
        
    def set(self, key: str, value: Any, metadata: Dict):
        """Cache data"""
        self.db.set_activation(key, value, metadata)

class MaskGenerator:
    """Optimized attention mask generator"""
    def __init__(self, use_cache: bool = True):
        self.cache_manager = CacheManager() if use_cache else None
        
    @lru_cache(maxsize=128)
    def create_causal_mask(self, seq_len: int, dtype: np.dtype = np.bool_, device = None) -> np.ndarray:
        """Create causal (autoregressive) attention mask.
        Uses caching for common sequence lengths.
        """
        # Check cache first
        if self.cache_manager:
            cache_key = self._compute_key((seq_len, str(dtype)), "causal_mask")
            cached_mask = self.cache_manager.get(cache_key)
            if cached_mask is not None:
                return cached_mask
        
        # Create mask        
        mask = np.tril(np.ones((seq_len, seq_len), dtype=dtype))
        mask = mask[np.newaxis, np.newaxis, :, :]
        
        # Cache if enabled
        if self.cache_manager:
            metadata = {"type": "causal_mask", "seq_len": seq_len, "dtype": str(dtype)}
            self.cache_manager.set(cache_key, mask, metadata)
        
        # Return, moving to device if needed    
        return mask if device is None else device.to_gpu(mask)

    def _compute_key(self, data: Any, prefix: str) -> str:
        """Compute cache key for data"""
        if isinstance(data, tuple):
            data = "_".join(str(x) for x in data)
        return f"{prefix}_{hashlib.sha256(str(data).encode()).hexdigest()}"

def split_heads(x: Union[str, "HeliumTensor"],
               num_heads: int,
               driver,
               modality: Optional[ModalityType] = None) -> Union[str, "HeliumTensor"]:
    """Split hidden dim into multiple heads"""
    if isinstance(x, str):
        x = driver.get_tensor(x)
        
    batch_size, seq_len, hidden_dim = x.shape
    head_dim = hidden_dim // num_heads
    
    # Reshape and transpose
    x = driver.reshape(x, (batch_size, seq_len, num_heads, head_dim))
    x = driver.transpose(x, (0, 2, 1, 3))
    
    return x

def apply_rotary_embedding(x: Union[str, HeliumTensor],
                         seq_len: int,
                         head_dim: int,
                         driver) -> Union[str, HeliumTensor]:
    """Apply rotary positional embeddings"""
    if isinstance(x, str):
        x = driver.get_tensor(x)
        
    # Generate position indices
    pos = np.arange(seq_len)
    
    # Generate frequencies
    freqs = np.exp(-np.arange(0, head_dim, 2) * np.log(10000) / head_dim)
    angles = pos[:, None] * freqs[None, :]
    
    # Generate rotation matrix elements
    cos = np.cos(angles).reshape(seq_len, -1)
    sin = np.sin(angles).reshape(seq_len, -1)
    
    # Move to device
    cos = driver.to_gpu(cos)
    sin = driver.to_gpu(sin)
    
    # Apply rotations
    x_rot = driver.matmul(x, cos) - driver.matmul(x, sin)
    x = driver.add(x, x_rot)
    
    return x


def fuse_cross_modal_attention(q: Union[str, HeliumTensor],
                             k: Union[str, HeliumTensor],
                             v: Union[str, HeliumTensor],
                             q_modality: ModalityType,
                             kv_modality: ModalityType,
                             fusion_type: str,
                             driver,
                             state: AttentionState) -> Tuple[Union[str, HeliumTensor],
                                                           Union[str, HeliumTensor],
                                                           Union[str, HeliumTensor]]:
    """Fuse cross-modal attention patterns"""
    if isinstance(driver, str):
        driver = get_gpu_device(driver)
    
    if fusion_type == "additive":
        # Simple additive fusion
        q = driver.add(q, k)
        k = q
    elif fusion_type == "multiplicative":
        # Element-wise multiplication
        q = driver.mul(q, k)
        k = q
    elif fusion_type == "gated":
        # Gated fusion with learned parameters
        gate = driver.sigmoid(driver.matmul(q, state.stored_tensors.get("gate_weight", None)))
        q = driver.add(
            driver.mul(gate, q),
            driver.mul(driver.sub(1.0, gate), k)
        )
        k = q
        
    return q, k, v
import numpy as np
from enum import Enum
from dataclasses import dataclass
import warnings
from .core.db_manager import HeliumDBManager
import json
import hashlib
from pathlib import Path
import torch  # For tensor conversion utilities
import logging
from functools import lru_cache

# Import local dependencies
from typing import TYPE_CHECKING
if TYPE_CHECKING:
    from .tensor import HeliumTensor
from .broadcast import ModalityType
from .attention_utils import AttentionState

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class ModelType(Enum):
    """Supported model architectures"""
    GPT2 = "gpt2"
    BERT = "bert"
    T5 = "t5"
    LLAMA = "llama"
    MISTRAL = "mistral"
    FALCON = "falcon"

@dataclass
class ModelConfig:
    """Universal configuration for transformer models"""
    model_type: ModelType
    num_layers: int
    num_heads: int
    hidden_dim: int
    vocab_size: int
    max_seq_len: int
    intermediate_size: Optional[int] = None
    layer_norm_epsilon: float = 1e-5
    initializer_range: float = 0.02
    use_cache: bool = True
    use_fp16: bool = False
    rotary_dim: Optional[int] = None  # For models with rotary embeddings
    vocab_padding_size: Optional[int] = None  # For vocab size optimization

class CacheManager:
    """Manages caching for model utilities"""
    def __init__(self):
        self.db = HeliumDBManager.get_instance()
        
    def _compute_key(self, data: Any, prefix: str) -> str:
        """Compute cache key for data"""
        if isinstance(data, np.ndarray):
            return f"{prefix}_{hashlib.sha256(data.tobytes()).hexdigest()}"
        return f"{prefix}_{hashlib.sha256(str(data).encode()).hexdigest()}"
        
    def get(self, key: str) -> Optional[Any]:
        """Get cached data"""
        return self.db.get_activation(key)
        
    def set(self, key: str, value: Any, metadata: Dict):
        """Cache data"""
        self.db.set_activation(key, value, metadata)

class MaskGenerator:
    """Optimized attention mask generator"""
    def __init__(self, use_cache: bool = True):
        self.cache_manager = CacheManager() if use_cache else None
        
    @lru_cache(maxsize=128)
    def create_causal_mask(self, seq_len: int, dtype: np.dtype = np.bool_, device = None) -> np.ndarray:
        """Create causal (autoregressive) attention mask.
        Uses caching for common sequence lengths.
        
        Args:
            seq_len: Sequence length
            dtype: Data type for mask
            device: Device to place mask on
            
        Returns:
            mask: Shape (1, 1, seq_len, seq_len) attention mask
        """
        if self.cache_manager:
            cache_key = self.cache_manager._compute_key((seq_len, str(dtype)), "causal_mask")
            cached_mask = self.cache_manager.get(cache_key)
            if cached_mask is not None:
                return cached_mask
                
        mask = np.tril(np.ones((seq_len, seq_len), dtype=dtype))
        mask = mask[np.newaxis, np.newaxis, :, :]
        
        if self.cache_manager:
            metadata = {
                "type": "causal_mask",
                "seq_len": seq_len,
                "dtype": str(dtype)
            }
            self.cache_manager.set(cache_key, mask, metadata)
            
        return mask if device is None else device.to_gpu(mask)

def split_heads(
    x: Union[str, "HeliumTensor"],
    num_heads: int,
    driver,
    modality: Optional["ModalityType"] = None
) -> Union[str, "HeliumTensor"]:
    """Split hidden dim into multiple heads"""
    if isinstance(x, str):
        x = driver.get_tensor(x)
        
    batch_size, seq_len, hidden_dim = x.shape
    head_dim = hidden_dim // num_heads
    
    # Reshape and transpose
    x = driver.reshape(x, (batch_size, seq_len, num_heads, head_dim))
    x = driver.transpose(x, (0, 2, 1, 3))
    
    return x

def apply_rotary_embedding(
    x: Union[str, "HeliumTensor"],
    seq_len: int,
    head_dim: int,
    driver
) -> Union[str, "HeliumTensor"]:
    """Apply rotary positional embeddings"""
    if isinstance(x, str):
        x = driver.get_tensor(x)
        
    # Generate position indices
    pos = np.arange(seq_len)
    
    # Generate frequencies
    freqs = np.exp(-np.arange(0, head_dim, 2) * np.log(10000) / head_dim)
    angles = pos[:, None] * freqs[None, :]
    
    # Generate rotation matrix elements
    cos = np.cos(angles).reshape(seq_len, -1)
    sin = np.sin(angles).reshape(seq_len, -1)
    
    # Move to device
    cos = driver.to_gpu(cos)
    sin = driver.to_gpu(sin)
    
    # Apply rotations
    x_rot = driver.matmul(x, cos) - driver.matmul(x, sin)
    x = driver.add(x, x_rot)
    
    return x

def fuse_cross_modal_attention(
    q: Union[str, "HeliumTensor"],
    k: Union[str, "HeliumTensor"], 
    v: Union[str, "HeliumTensor"],
    q_modality: "ModalityType",
    kv_modality: "ModalityType",
    fusion_type: str,
    driver,
    state: "AttentionState"
) -> Tuple[Union[str, "HeliumTensor"], Union[str, "HeliumTensor"], Union[str, "HeliumTensor"]]:
    """Fuse cross-modal attention patterns"""
    if fusion_type == "additive":
        # Simple additive fusion
        q = driver.add(q, k)
        k = q
    elif fusion_type == "multiplicative":
        # Element-wise multiplication
        q = driver.mul(q, k)
        k = q
    elif fusion_type == "gated":
        # Gated fusion with learned parameters
        gate = driver.sigmoid(driver.matmul(q, state.stored_tensors.get("gate_weight", None)))
        q = driver.add(
            driver.mul(gate, q),
            driver.mul(driver.sub(1.0, gate), k)
        )
        k = q
        
    return q, k, v

class WeightMapper:
    """Optimized weight mapping utility for different model architectures"""
    
    def __init__(self, use_cache: bool = True):
        self.cache_manager = CacheManager() if use_cache else None
        
    def map_weights(
        self,
        hf_weights: Dict[str, np.ndarray],
        config: ModelConfig
    ) -> List[Dict[str, np.ndarray]]:
        """
        Map weights based on model type
        
        Args:
            hf_weights: HuggingFace weight dictionary
            config: Model configuration
            
        Returns:
            List of block weight dictionaries
        """
        if self.cache_manager:
            cache_key = self.cache_manager._compute_key(
                (list(hf_weights.keys()), config.model_type.value),
                "weight_mapping"
            )
            cached_mapping = self.cache_manager.get(cache_key)
            if cached_mapping is not None:
                return cached_mapping
        
        mapping_funcs = {
            ModelType.GPT2: self._map_gpt2_weights,
            ModelType.BERT: self._map_bert_weights,
            ModelType.T5: self._map_t5_weights,
            ModelType.LLAMA: self._map_llama_weights,
            ModelType.MISTRAL: self._map_mistral_weights,
            ModelType.FALCON: self._map_falcon_weights
        }
        
        mapper = mapping_funcs.get(config.model_type)
        if not mapper:
            raise ValueError(f"Unsupported model type: {config.model_type}")
            
        result = mapper(hf_weights, config)
        
        if self.cache_manager:
            self.cache_manager.set(
                cache_key,
                result,
                {'model_type': config.model_type.value}
            )
            
        return result
    
    def _map_gpt2_weights(
        self,
        hf_weights: Dict[str, np.ndarray],
        config: ModelConfig,
        prefix: str = 'transformer.h.'
    ) -> List[Dict[str, np.ndarray]]:
        """Map GPT-2 weights with optimizations"""
        block_weights_list = []
        
        try:
            for i in range(config.num_layers):
                block = {}
                # Layer normalization weights
                block['ln1.weight'] = hf_weights[f'{prefix}{i}.ln_1.weight']
                block['ln1.bias']   = hf_weights[f'{prefix}{i}.ln_1.bias']
                
                # Attention weights with efficient splitting
                attn_weight = hf_weights[f'{prefix}{i}.attn.c_attn.weight']
                split_size = attn_weight.shape[0] // 3
                block['attn.q_proj.weight'] = attn_weight[:, :split_size]
                block['attn.k_proj.weight'] = attn_weight[:, split_size:2*split_size]
                block['attn.v_proj.weight'] = attn_weight[:, 2*split_size:]
                
                # Output projection
                block['attn.out_proj.weight'] = hf_weights[f'{prefix}{i}.attn.c_proj.weight']
                
                # Second layer norm
                block['ln2.weight'] = hf_weights[f'{prefix}{i}.ln_2.weight']
                block['ln2.bias']   = hf_weights[f'{prefix}{i}.ln_2.bias']
                
                # Feed-forward weights
                block['ff1.weight'] = hf_weights[f'{prefix}{i}.mlp.c_fc.weight']
                block['ff1.bias']   = hf_weights[f'{prefix}{i}.mlp.c_fc.bias']
                block['ff2.weight'] = hf_weights[f'{prefix}{i}.mlp.c_proj.weight']
                block['ff2.bias']   = hf_weights[f'{prefix}{i}.mlp.c_proj.bias']
                
                # Optional rotary embeddings for newer variants
                if config.rotary_dim:
                    if f'{prefix}{i}.attn.rotary_emb.inv_freq' in hf_weights:
                        block['attn.rotary_emb.inv_freq'] = hf_weights[f'{prefix}{i}.attn.rotary_emb.inv_freq']
                
                block_weights_list.append(block)
                
        except KeyError as e:
            logger.error(f"Failed to map GPT-2 weights: {str(e)}")
            raise ValueError(f"Missing required weight: {str(e)}")
            
        return block_weights_list

    def _map_bert_weights(
        self,
        hf_weights: Dict[str, np.ndarray],
        config: ModelConfig,
        prefix: str = 'bert.encoder.layer.'
    ) -> List[Dict[str, np.ndarray]]:
        """Map BERT weights with optimizations"""
        block_weights_list = []
        
        try:
            for i in range(config.num_layers):
                block = {}
                # Layer normalization weights
                block['ln1.weight'] = hf_weights[f'{prefix}{i}.attention.output.LayerNorm.weight']
                block['ln1.bias']   = hf_weights[f'{prefix}{i}.attention.output.LayerNorm.bias']
                
                # Attention weights
                block['attn.q_proj.weight'] = hf_weights[f'{prefix}{i}.attention.self.query.weight']
                block['attn.k_proj.weight'] = hf_weights[f'{prefix}{i}.attention.self.key.weight']
                block['attn.v_proj.weight'] = hf_weights[f'{prefix}{i}.attention.self.value.weight']
                block['attn.out_proj.weight'] = hf_weights[f'{prefix}{i}.attention.output.dense.weight']
                
                # Second layer norm
                block['ln2.weight'] = hf_weights[f'{prefix}{i}.output.LayerNorm.weight']
                block['ln2.bias']   = hf_weights[f'{prefix}{i}.output.LayerNorm.bias']
                
                # Feed-forward weights
                block['ff1.weight'] = hf_weights[f'{prefix}{i}.intermediate.dense.weight']
                block['ff1.bias']   = hf_weights[f'{prefix}{i}.intermediate.dense.bias']
                block['ff2.weight'] = hf_weights[f'{prefix}{i}.output.dense.weight']
                block['ff2.bias']   = hf_weights[f'{prefix}{i}.output.dense.bias']
                
                # Add position embeddings if available
                if i == 0 and 'bert.embeddings.position_embeddings.weight' in hf_weights:
                    block['position_embeddings'] = hf_weights['bert.embeddings.position_embeddings.weight']
                
                block_weights_list.append(block)
                
        except KeyError as e:
            logger.error(f"Failed to map BERT weights: {str(e)}")
            raise ValueError(f"Missing required weight: {str(e)}")
            
        return block_weights_list


    def _map_t5_weights(
        self,
        hf_weights: Dict[str, np.ndarray],
        config: ModelConfig,
        prefix: str = 'encoder.block.'
    ) -> List[Dict[str, np.ndarray]]:
        """Map T5 weights with optimizations"""
        block_weights_list = []
        
        try:
            for i in range(config.num_layers):
                block = {}
                # Layer normalization
                block['ln1.weight'] = hf_weights[f'{prefix}{i}.layer.0.layer_norm.weight']
                
                # Attention weights
                block['attn.q_proj.weight'] = hf_weights[f'{prefix}{i}.layer.0.SelfAttention.q.weight']
                block['attn.k_proj.weight'] = hf_weights[f'{prefix}{i}.layer.0.SelfAttention.k.weight']
                block['attn.v_proj.weight'] = hf_weights[f'{prefix}{i}.layer.0.SelfAttention.v.weight']
                block['attn.out_proj.weight'] = hf_weights[f'{prefix}{i}.layer.0.SelfAttention.o.weight']
                
                # Second layer norm
                block['ln2.weight'] = hf_weights[f'{prefix}{i}.layer.1.layer_norm.weight']
                
                # Feed-forward weights
                block['ff1.weight'] = hf_weights[f'{prefix}{i}.layer.1.DenseReluDense.wi.weight']
                block['ff2.weight'] = hf_weights[f'{prefix}{i}.layer.1.DenseReluDense.wo.weight']
                
                # Relative position bias if available
                if f'{prefix}{i}.layer.0.SelfAttention.relative_attention_bias' in hf_weights:
                    block['attn.relative_attention_bias'] = hf_weights[
                        f'{prefix}{i}.layer.0.SelfAttention.relative_attention_bias'
                    ]
                
                block_weights_list.append(block)
                
        except KeyError as e:
            logger.error(f"Failed to map T5 weights: {str(e)}")
            raise ValueError(f"Missing required weight: {str(e)}")
            
        return block_weights_list

    def _map_llama_weights(
        self,
        hf_weights: Dict[str, np.ndarray],
        config: ModelConfig,
        prefix: str = 'model.layers.'
    ) -> List[Dict[str, np.ndarray]]:
        """Map LLaMA weights with optimizations"""
        block_weights_list = []
        
        try:
            for i in range(config.num_layers):
                block = {}
                # Attention weights
                block['attn.q_proj.weight'] = hf_weights[f'{prefix}{i}.self_attn.q_proj.weight']
                block['attn.k_proj.weight'] = hf_weights[f'{prefix}{i}.self_attn.k_proj.weight']
                block['attn.v_proj.weight'] = hf_weights[f'{prefix}{i}.self_attn.v_proj.weight']
                block['attn.o_proj.weight'] = hf_weights[f'{prefix}{i}.self_attn.o_proj.weight']
                
                # Rotary embeddings
                if config.rotary_dim:
                    block['attn.rotary_emb.inv_freq'] = hf_weights.get(
                        f'{prefix}{i}.self_attn.rotary_emb.inv_freq'
                    )
                
                # RMSNorm weights
                block['input_layernorm.weight'] = hf_weights[f'{prefix}{i}.input_layernorm.weight']
                block['post_attention_layernorm.weight'] = hf_weights[f'{prefix}{i}.post_attention_layernorm.weight']
                
                # Feed-forward weights
                block['mlp.gate_proj.weight'] = hf_weights[f'{prefix}{i}.mlp.gate_proj.weight']
                block['mlp.up_proj.weight'] = hf_weights[f'{prefix}{i}.mlp.up_proj.weight']
                block['mlp.down_proj.weight'] = hf_weights[f'{prefix}{i}.mlp.down_proj.weight']
                
                block_weights_list.append(block)
                
        except KeyError as e:
            logger.error(f"Failed to map LLaMA weights: {str(e)}")
            raise ValueError(f"Missing required weight: {str(e)}")
            
        return block_weights_list

class ConfigParser:
    """Enhanced configuration parser with caching"""
    
    def __init__(self, use_cache: bool = True):
        self.cache_manager = CacheManager() if use_cache else None
        
    def parse_config(
        self,
        config: Dict[str, Any],
        model_type: Optional[ModelType] = None
    ) -> ModelConfig:
        """
        Parse model configuration with caching
        
        Args:
            config: Configuration dictionary
            model_type: Optional model type override
            
        Returns:
            ModelConfig instance
        """
        if self.cache_manager:
            cache_key = self.cache_manager._compute_key(config, "config_parsing")
            cached_config = self.cache_manager.get(cache_key)
            if cached_config is not None:
                return cached_config
        
        # Detect model type if not provided
        if not model_type:
            model_type = self._detect_model_type(config)
        
        # Parse based on model type
        parsed_config = self._parse_by_type(config, model_type)
        
        if self.cache_manager:
            self.cache_manager.set(
                cache_key,
                parsed_config,
                {'model_type': model_type.value}
            )
        
        return parsed_config
    
    def _detect_model_type(self, config: Dict[str, Any]) -> ModelType:
        """Detect model type from config"""
        if 'n_layer' in config:
            return ModelType.GPT2
        elif 'num_hidden_layers' in config:
            return ModelType.BERT
        elif 'd_model' in config:
            return ModelType.T5
        elif 'num_key_value_heads' in config:
            return ModelType.LLAMA
        elif 'sliding_window' in config:
            return ModelType.MISTRAL
        elif 'multi_query_group_num' in config:
            return ModelType.FALCON
        else:
            raise ValueError("Unable to detect model type from config")
    
    def _parse_by_type(
        self,
        config: Dict[str, Any],
        model_type: ModelType
    ) -> ModelConfig:
        """Parse config based on model type"""
        if model_type == ModelType.GPT2:
            return ModelConfig(
                model_type=model_type,
                num_layers=config['n_layer'],
                num_heads=config['n_head'],
                hidden_dim=config['n_embd'],
                vocab_size=config['vocab_size'],
                max_seq_len=config.get('n_positions', 1024),
                intermediate_size=config.get('n_inner', None),
                layer_norm_epsilon=config.get('layer_norm_epsilon', 1e-5)
            )
        elif model_type == ModelType.BERT:
            return ModelConfig(
                model_type=model_type,
                num_layers=config['num_hidden_layers'],
                num_heads=config['num_attention_heads'],
                hidden_dim=config['hidden_size'],
                vocab_size=config['vocab_size'],
                max_seq_len=config.get('max_position_embeddings', 512),
                intermediate_size=config.get('intermediate_size', None),
                layer_norm_epsilon=config.get('layer_norm_eps', 1e-12)
            )
        elif model_type == ModelType.T5:
            return ModelConfig(
                model_type=model_type,
                num_layers=config['num_layers'],
                num_heads=config['num_heads'],
                hidden_dim=config['d_model'],
                vocab_size=config['vocab_size'],
                max_seq_len=config.get('n_positions', 512),
                intermediate_size=config.get('d_ff', None),
                layer_norm_epsilon=config.get('layer_norm_epsilon', 1e-6)
            )
        elif model_type == ModelType.LLAMA:
            return ModelConfig(
                model_type=model_type,
                num_layers=config['num_hidden_layers'],
                num_heads=config['num_attention_heads'],
                hidden_dim=config['hidden_size'],
                vocab_size=config['vocab_size'],
                max_seq_len=config.get('max_position_embeddings', 2048),
                intermediate_size=config.get('intermediate_size', None),
                layer_norm_epsilon=config.get('rms_norm_eps', 1e-6),
                rotary_dim=config.get('rotary_dim', None)
            )
        else:
            raise ValueError(f"Unsupported model type: {model_type}")

# Legacy function for backward compatibility
def parse_hf_config(config: Dict[str, Any]) -> Dict[str, Any]:
    """Legacy config parsing interface"""
    warnings.warn(
        "parse_hf_config is deprecated, use ConfigParser class instead",
        DeprecationWarning
    )
    
    parser = ConfigParser(use_cache=True)
    parsed_config = parser.parse_config(config)
    
    return {
        'num_layers': parsed_config.num_layers,
        'num_heads': parsed_config.num_heads,
        'hidden_dim': parsed_config.hidden_dim,
        'vocab_size': parsed_config.vocab_size,
        'max_seq_len': parsed_config.max_seq_len
    }
   