"""
High-level model interface for helium, optimized for VGPU execution
"""
from typing import Optional, Union, List, Dict, Any, Tuple
import numpy as np

from .device import Device, DeviceType
from ..virtual_gpu_driver.src.hal.hal import HardwareAbstractionLayer
from ..virtual_gpu_driver.src.memory_pool import MemoryPool

class Pipeline:
    """Neural network pipeline optimized for VGPU execution"""
    
    def __init__(self, 
                 encoder: Optional[Dict[str, Any]] = None,
                 decoder: Optional[Dict[str, Any]] = None,
                 device_id: str = "vgpu0"):
        self.device_id = device_id
        self.hal = HardwareAbstractionLayer()
        self.stream_manager = self.hal.create_stream_manager()
        self.default_stream = self.stream_manager.create_stream()
        
        # Initialize memory pools
        self.global_memory = MemoryPool(8 * 1024 * 1024 * 1024)  # 8GB global memory
        self.shared_memory = MemoryPool(48 * 1024)  # 48KB shared memory per SM
        
        # Initialize model components
        self.encoder_config = encoder
        self.decoder_config = decoder
        self.loaded_weights: Dict[str, str] = {}  # Maps weight names to tensor names
        self.current_inputs: Dict[str, str] = {}  # Maps input names to tensor names
        
    def load_weights(self, weights_path: str) -> None:
        """Load model weights into VGPU memory"""
        import safetensors.numpy
        
        # Load weights from disk
        weights = safetensors.numpy.load_file(weights_path)
        
        # Transfer weights to VGPU memory
        for name, weight in weights.items():
            tensor_name = f"weight_{name}"
            self.hal.create_tensor(tensor_name, weight)
            self.loaded_weights[name] = tensor_name
            
        print(f"Loaded {len(weights)} weights to VGPU {self.device_id}")
        
    def to_vgpu(self, data: Union[np.ndarray, List, Tuple], name: str) -> str:
        """Transfer data to VGPU memory"""
        if not isinstance(data, np.ndarray):
            data = np.array(data)
            
        tensor_name = f"input_{name}"
        self.hal.create_tensor(tensor_name, data)
        return tensor_name
        
    def prepare_inputs(self, 
                      inputs: Dict[str, np.ndarray],
                      batch_size: Optional[int] = None) -> None:
        """Prepare inputs for model execution"""
        # Clear previous inputs
        self.current_inputs.clear()
        
        # Transfer inputs to VGPU memory
        for name, data in inputs.items():
            if batch_size is not None and data.shape[0] != batch_size:
                # Reshape for batching if needed
                data = data.reshape(batch_size, -1, *data.shape[1:])
                    
            self.current_inputs[name] = self.to_vgpu(data, name)
            
    def run_encoder(self, 
                   encoder_inputs: Dict[str, np.ndarray],
                   batch_size: Optional[int] = None,
                   use_cache: bool = True) -> Dict[str, np.ndarray]:
        """Run encoder forward pass"""
        # Prepare inputs
        self.prepare_inputs(encoder_inputs, batch_size)
        
        # Create computation stream
        stream = self.stream_manager.create_stream()
        
        # Execute encoder layers in sequence
        layer_outputs = {}
        hidden_states = self.current_inputs["input_ids"]  # Start with embeddings
        
        for i in range(self.encoder_config["num_layers"]):
            # Get layer weights
            layer_weights = {
                name: self.loaded_weights[f"encoder.layer.{i}.{name}"]
                for name in ["attention.weight", "ffn.weight", "ln.weight"]
            }
            
            # Layer norm
            norm1 = self.hal.layer_norm(
                hidden_states,
                layer_weights["ln.weight"],
                stream=stream
            )
            
            # Self attention
            qkv = self.hal.matmul(
                norm1,
                layer_weights["attention.weight"],
                stream=stream
            )
            q, k, v = self.hal.split(qkv, 3, axis=-1)
            
            # Scaled dot-product attention
            attn_output = self.hal.scaled_dot_product_attention(
                q, k, v,
                mask=self.current_inputs.get("attention_mask"),
                stream=stream
            )
            
            # First residual
            hidden_states = self.hal.add(hidden_states, attn_output)
            
            # Layer norm 2
            norm2 = self.hal.layer_norm(
                hidden_states,
                layer_weights["ln.weight"],
                stream=stream
            )
            
            # FFN
            ffn_output = self.hal.matmul(
                norm2,
                layer_weights["ffn.weight"],
                stream=stream
            )
            ffn_output = self.hal.gelu(ffn_output)
            
            # Second residual
            hidden_states = self.hal.add(hidden_states, ffn_output)
            
            # Cache layer output
            if use_cache:
                layer_outputs[f"layer_{i}"] = hidden_states
            
        # Get final encoder output
        encoder_output = hidden_states
        
        # Synchronize stream
        stream.synchronize()
        
        # Return outputs
        outputs = {
            "encoder_output": self.hal.get_tensor(encoder_output)
        }
        if use_cache:
            outputs["layer_outputs"] = {
                name: self.hal.get_tensor(tensor_name)
                for name, tensor_name in layer_outputs.items()
            }
            
        return outputs
        
    def run_decoder(self,
                   decoder_inputs: Dict[str, np.ndarray],
                   encoder_outputs: Dict[str, np.ndarray],
                   batch_size: Optional[int] = None,
                   use_cache: bool = True) -> Dict[str, np.ndarray]:
        """Run decoder forward pass"""
        # Prepare decoder inputs
        self.prepare_inputs(decoder_inputs, batch_size)
        
        # Transfer encoder outputs to VGPU if needed
        encoder_tensors = {}
        for name, data in encoder_outputs.items():
            tensor_name = f"encoder_{name}"
            self.hal.create_tensor(tensor_name, data)
            encoder_tensors[name] = tensor_name
            
        # Create computation stream
        stream = self.stream_manager.create_stream()
        
        # Execute decoder layers
        layer_outputs = {}
        hidden_states = self.current_inputs["input_ids"]  # Start with embeddings
        
        for i in range(self.decoder_config["num_layers"]):
            # Get layer weights
            layer_weights = {
                name: self.loaded_weights[f"decoder.layer.{i}.{name}"]
                for name in ["self_attention.weight", "cross_attention.weight", 
                           "ffn.weight", "ln.weight"]
            }
            
            # Layer norm 1
            norm1 = self.hal.layer_norm(
                hidden_states,
                layer_weights["ln.weight"],
                stream=stream
            )
            
            # Self attention
            self_qkv = self.hal.matmul(
                norm1,
                layer_weights["self_attention.weight"],
                stream=stream
            )
            q, k, v = self.hal.split(self_qkv, 3, axis=-1)
            
            self_attn_output = self.hal.scaled_dot_product_attention(
                q, k, v,
                mask=self.current_inputs.get("attention_mask"),
                stream=stream
            )
            
            # First residual
            hidden_states = self.hal.add(hidden_states, self_attn_output)
            
            # Layer norm 2
            norm2 = self.hal.layer_norm(
                hidden_states,
                layer_weights["ln.weight"],
                stream=stream
            )
            
            # Cross attention
            cross_q = self.hal.matmul(
                norm2,
                layer_weights["cross_attention.weight"],
                stream=stream
            )
            cross_k = cross_v = encoder_tensors["encoder_output"]
            
            cross_attn_output = self.hal.scaled_dot_product_attention(
                cross_q, cross_k, cross_v,
                mask=self.current_inputs.get("cross_attention_mask"),
                stream=stream
            )
            
            # Second residual
            hidden_states = self.hal.add(hidden_states, cross_attn_output)
            
            # Layer norm 3
            norm3 = self.hal.layer_norm(
                hidden_states,
                layer_weights["ln.weight"],
                stream=stream
            )
            
            # FFN
            ffn_output = self.hal.matmul(
                norm3,
                layer_weights["ffn.weight"],
                stream=stream
            )
            ffn_output = self.hal.gelu(ffn_output)
            
            # Third residual
            hidden_states = self.hal.add(hidden_states, ffn_output)
            
            # Cache layer output
            if use_cache:
                layer_outputs[f"layer_{i}"] = hidden_states
                
        # Get final decoder output
        decoder_output = hidden_states
        
        # Synchronize stream
        stream.synchronize()
        
        # Return outputs
        outputs = {
            "decoder_output": self.hal.get_tensor(decoder_output)
        }
        if use_cache:
            outputs["layer_outputs"] = {
                name: self.hal.get_tensor(tensor_name)
                for name, tensor_name in layer_outputs.items()
            }
            
        return outputs
        
    def generate(self,
                input_ids: np.ndarray,
                max_length: int = 100,
                num_beams: int = 1,
                top_k: Optional[int] = None,
                top_p: Optional[float] = None,
                temperature: float = 1.0) -> np.ndarray:
        """Generate output sequence"""
        batch_size = input_ids.shape[0]
        
        # Initial encoder pass
        encoder_outputs = self.run_encoder(
            {"input_ids": input_ids},
            batch_size=batch_size,
            use_cache=True
        )
        
        # Initialize generation
        current_ids = input_ids
        for step in range(max_length):
            # Run decoder
            decoder_outputs = self.run_decoder(
                {"input_ids": current_ids},
                encoder_outputs,
                batch_size=batch_size * num_beams,
                use_cache=True
            )
            
            # Get logits from final layer
            logits = decoder_outputs["decoder_output"]
            
            # Apply sampling
            if top_k is not None or top_p is not None:
                next_token_logits = self._sample_logits(
                    logits[:, -1:, :],  # Only sample from last position
                    top_k=top_k,
                    top_p=top_p,
                    temperature=temperature
                )
            else:
                next_token_logits = logits[:, -1:, :]
                
            # Get next token ids
            next_ids = np.argmax(next_token_logits, axis=-1)
            
            # Append to sequence
            current_ids = np.concatenate([current_ids, next_ids], axis=1)
            
        return current_ids
        
    def _sample_logits(self,
                      logits: np.ndarray,
                      top_k: Optional[int] = None,
                      top_p: Optional[float] = None,
                      temperature: float = 1.0) -> np.ndarray:
        """Apply sampling to logits"""
        # Temperature scaling
        if temperature != 1.0:
            logits = logits / temperature
            
        # Top-k sampling
        if top_k is not None:
            # Get top k values
            values, indices = self.hal.topk(logits, k=top_k)
            min_values = values[..., -1, None].repeat(logits.shape[-1], -1)
            logits = np.where(logits < min_values, float('-inf'), logits)
            
        # Top-p (nucleus) sampling
        if top_p is not None:
            sorted_logits = np.sort(logits, axis=-1)[..., ::-1]  # Descending
            cumsum_probs = np.cumsum(np.exp(sorted_logits) / np.sum(np.exp(sorted_logits), axis=-1, keepdims=True), axis=-1)
            
            # Remove tokens with cumulative probability above the threshold
            sorted_indices_to_remove = cumsum_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].copy()
            sorted_indices_to_remove[..., 0] = 0
            
            # scatter sorted tensors to original indexing
            indices_to_remove = np.zeros_like(sorted_indices_to_remove)
            indices = np.argsort(logits, axis=-1)[..., ::-1]  # Descending
            np.put_along_axis(indices_to_remove, indices, sorted_indices_to_remove, axis=-1)
            
            logits[indices_to_remove] = float('-inf')
            
        return logits
        self.device = Device.parse(device)
        
        # Move models to device
        self.encoder.to(self.device)
        if self.decoder:
            self.decoder.to(self.device)
            
        self._setup_pipeline()
        
    def _setup_pipeline(self):
        """Setup efficient pipelining between encoder and decoder"""
        driver = get_driver()
        
        # Configure memory hierarchy
        if self.device.type != DeviceType.CPU:
            driver.configure_memory_hierarchy(
                device_type=self.device.type.value,
                device_index=self.device.index,
                use_pinned_memory=True
            )
            
        # Setup tensor storage for intermediate results
        self.storage_manager = driver.storage_manager
        
    def encode(self, 
               input_ids: Union[dtensor, np.ndarray],
               attention_mask: Optional[Union[dtensor, np.ndarray]] = None,
               return_dict: bool = True) -> Dict[str, dtensor]:
        """
        Run encoder forward pass with efficient memory management
        """
        # Convert inputs to device tensors
        if not isinstance(input_ids, dtensor):
            input_ids = dtensor(input_ids, device=self.device)
        if attention_mask is not None and not isinstance(attention_mask, dtensor):
            attention_mask = dtensor(attention_mask, device=self.device)
            
        # Run encoder
        encoder_outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=return_dict
        )
        
        # Pipeline encoder outputs to shared storage for decoder
        if self.decoder is not None:
            self._pipeline_encoder_outputs(encoder_outputs)
            
        return encoder_outputs
        
    def decode(self,
               input_ids: Union[dtensor, np.ndarray],
               encoder_outputs: Optional[Dict[str, dtensor]] = None,
               attention_mask: Optional[Union[dtensor, np.ndarray]] = None,
               encoder_attention_mask: Optional[Union[dtensor, np.ndarray]] = None,
               past_key_values: Optional[List[Tuple[dtensor, dtensor]]] = None,
               return_dict: bool = True) -> Dict[str, dtensor]:
        """
        Run decoder forward pass with cached key/values and efficient memory usage
        """
        if self.decoder is None:
            raise RuntimeError("No decoder configured in pipeline")
            
        # Convert inputs to device tensors
        if not isinstance(input_ids, dtensor):
            input_ids = dtensor(input_ids, device=self.device)
        if attention_mask is not None and not isinstance(attention_mask, dtensor):
            attention_mask = dtensor(attention_mask, device=self.device)
        if encoder_attention_mask is not None and not isinstance(encoder_attention_mask, dtensor):
            encoder_attention_mask = dtensor(encoder_attention_mask, device=self.device)
            
        # Get encoder outputs from pipeline if not provided
        if encoder_outputs is None:
            encoder_outputs = self._get_pipelined_encoder_outputs()
            
        # Run decoder
        return self.decoder(
            input_ids=input_ids,
            encoder_outputs=encoder_outputs,
            attention_mask=attention_mask,
            encoder_attention_mask=encoder_attention_mask,
            past_key_values=past_key_values,
            return_dict=return_dict
        )
        
    def generate(self,
                input_ids: Union[dtensor, np.ndarray],
                max_length: int = 20,
                num_beams: int = 1,
                temperature: float = 1.0,
                top_k: Optional[int] = None,
                top_p: Optional[float] = None,
                repetition_penalty: float = 1.0,
                pad_token_id: Optional[int] = None,
                eos_token_id: Optional[int] = None,
                **kwargs) -> dtensor:
        """
        Generate sequence using encoder-decoder with beam search and sampling
        """
        # Initial encoder pass
        encoder_outputs = self.encode(input_ids)
        
        # Setup generation
        batch_size = input_ids.shape[0]
        generated = dtensor(
            input_ids.numpy()[:, -1:],
            device=self.device
        )
        past_key_values = None
        
        # Generate tokens
        for _ in range(max_length):
            # Get logits from decoder
            outputs = self.decode(
                input_ids=generated,
                encoder_outputs=encoder_outputs,
                past_key_values=past_key_values
            )
            logits = outputs["logits"]
            past_key_values = outputs.get("past_key_values")
            
            # Apply temperature
            if temperature != 1.0:
                logits = logits / temperature
                
            # Apply top-k filtering
            if top_k is not None:
                indices_to_remove = logits < dtensor.topk(logits, top_k)[0][..., -1, None]
                logits[indices_to_remove] = float('-inf')
                
            # Apply top-p (nucleus) filtering
            if top_p is not None:
                sorted_logits = dtensor.sort(logits, descending=True)[0]
                cumulative_probs = dtensor.cumsum(
                    dtensor.softmax(sorted_logits, dim=-1),
                    dim=-1
                )
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1]
                sorted_indices_to_remove[..., 0] = 0
                indices_to_remove = sorted_indices_to_remove.scatter(
                    dim=-1,
                    index=dtensor.sort(logits, descending=True)[1],
                    src=sorted_indices_to_remove
                )
                logits[indices_to_remove] = float('-inf')
                
            # Sample from logits
            probs = dtensor.softmax(logits[:, -1, :], dim=-1)
            next_tokens = dtensor.multinomial(probs, num_samples=1)
            
            # Append to generated
            generated = dtensor.cat([generated, next_tokens], dim=-1)
            
            # Check if generation should stop
            if eos_token_id is not None:
                eos_token_id = dtensor(eos_token_id, device=self.device)
                if (generated == eos_token_id).any():
                    break
                    
        return generated
        
    def _pipeline_encoder_outputs(self, encoder_outputs: Dict[str, dtensor]):
        """Store encoder outputs in efficient shared storage"""
        for key, tensor in encoder_outputs.items():
            handle = self.storage_manager.allocate_shared_tensor(
                tensor.numpy(),
                name=f"encoder_output_{key}",
                device_type=self.device.type.value,
                device_index=self.device.index
            )
            self._cached_encoder_outputs[key] = handle
            
    def _get_pipelined_encoder_outputs(self) -> Dict[str, dtensor]:
        """Retrieve encoder outputs from shared storage"""
        outputs = {}
        for key, handle in self._cached_encoder_outputs.items():
            tensor = self.storage_manager.get_shared_tensor(handle)
            outputs[key] = dtensor(tensor, device=self.device)
        return outputs
        
    def save_pretrained(self, path: str):
        """Save pipeline models and configuration"""
        import os
        import json
        
        os.makedirs(path, exist_ok=True)
        
        # Save encoder
        encoder_path = os.path.join(path, "encoder")
        os.makedirs(encoder_path, exist_ok=True)
        self.encoder.save_pretrained(encoder_path)
        
        # Save decoder if present
        if self.decoder is not None:
            decoder_path = os.path.join(path, "decoder") 
            os.makedirs(decoder_path, exist_ok=True)
            self.decoder.save_pretrained(decoder_path)
            
        # Save config
        config = {
            "device": str(self.device),
            "has_decoder": self.decoder is not None
        }
        with open(os.path.join(path, "config.json"), "w") as f:
            json.dump(config, f)
            
    @classmethod
    def from_pretrained(cls,
                       path: str,
                       device: Optional[Union[str, Device]] = None,
                       **kwargs) -> 'Pipeline':
        """Load pipeline from pretrained models"""
        import os
        import json
        
        # Load config
        with open(os.path.join(path, "config.json")) as f:
            config = json.load(f)
            
        # Load encoder
        from ..models import AutoModel
        encoder = AutoModel.from_pretrained(
            os.path.join(path, "encoder"),
            **kwargs
        )
        
        # Load decoder if present
        decoder = None
        if config["has_decoder"]:
            decoder = AutoModel.from_pretrained(
                os.path.join(path, "decoder"),
                **kwargs
            )
            
        # Create pipeline
        return cls(
            encoder=encoder,
            decoder=decoder,
            device=device or config["device"]
        )
