"""
Helium Probability and Tokenization Module
Handles probability calculations, argmax operations, and tokenization
"""
from typing import List, Dict, Union, Optional, Tuple
import array
import math
from pathlib import Path
import json

class HeliumTokenizer:
    """
    Tokenizer for text encoding/decoding with vocabulary management
    """
    def __init__(self, vocab_path: Optional[str] = None):
        self.vocab = {}
        self.inv_vocab = {}
        self.special_tokens = {
            "[PAD]": 0,
            "[UNK]": 1,
            "[CLS]": 2,
            "[SEP]": 3,
            "[MASK]": 4
        }
        if vocab_path:
            self.load_vocabulary(vocab_path)
        
    def load_vocabulary(self, vocab_path: str):
        """Load vocabulary from file"""
        with open(vocab_path, 'r', encoding='utf-8') as f:
            vocab_data = json.load(f)
            
        # Initialize with special tokens
        self.vocab = self.special_tokens.copy()
        current_id = len(self.special_tokens)
        
        # Add remaining tokens
        for token in vocab_data:
            if token not in self.vocab:
                self.vocab[token] = current_id
                current_id += 1
                
        # Create inverse vocabulary
        self.inv_vocab = {v: k for k, v in self.vocab.items()}
        
    def encode(self, text: str, max_length: Optional[int] = None,
              padding: bool = True, truncation: bool = True) -> List[int]:
        """
        Encode text to token IDs
        
        Args:
            text: Input text to encode
            max_length: Maximum sequence length
            padding: Whether to pad sequence
            truncation: Whether to truncate sequence
            
        Returns:
            List of token IDs
        """
        # Basic tokenization (can be enhanced with BPE/WordPiece)
        tokens = text.split()
        token_ids = []
        
        # Add CLS token
        token_ids.append(self.special_tokens["[CLS]"])
        
        # Convert tokens to IDs
        for token in tokens:
            if token in self.vocab:
                token_ids.append(self.vocab[token])
            else:
                # Handle unknown tokens
                subwords = self._break_into_subwords(token)
                for subword in subwords:
                    token_ids.append(self.vocab.get(subword, self.special_tokens["[UNK]"]))
                    
        # Add SEP token
        token_ids.append(self.special_tokens["[SEP]"])
        
        # Handle length constraints
        if max_length is not None:
            if truncation and len(token_ids) > max_length:
                token_ids = token_ids[:max_length-1] + [self.special_tokens["[SEP]"]]
            elif padding and len(token_ids) < max_length:
                token_ids.extend([self.special_tokens["[PAD]"]] * (max_length - len(token_ids)))
                
        return token_ids
        
    def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str:
        """
        Decode token IDs back to text
        
        Args:
            token_ids: List of token IDs
            skip_special_tokens: Whether to skip special tokens in output
            
        Returns:
            Decoded text
        """
        tokens = []
        for token_id in token_ids:
            if token_id in self.inv_vocab:
                token = self.inv_vocab[token_id]
                if skip_special_tokens and token in self.special_tokens:
                    continue
                tokens.append(token)
                
        return " ".join(tokens)
        
    def _break_into_subwords(self, token: str) -> List[str]:
        """Break unknown tokens into subwords"""
        subwords = []
        while token:
            # Try to find the longest matching subword
            i = len(token)
            while i > 0:
                subword = token[:i]
                if subword in self.vocab:
                    subwords.append(subword)
                    token = token[i:]
                    break
                i -= 1
            if i == 0:
                # No matching subword found, treat character as UNK
                subwords.append(token[0])
                token = token[1:]
        return subwords

class ProbabilityCalculator:
    """
    Handles probability calculations and argmax operations
    """
    def __init__(self):
        self.last_logits = None
        self.last_probs = None
        
    def compute_probabilities(self, logits: array.array,
                            temperature: float = 1.0) -> array.array:
        """
        Convert logits to probabilities using softmax
        
        Args:
            logits: Raw model outputs
            temperature: Softmax temperature (higher = more random)
            
        Returns:
            Array of probabilities
        """
        # Apply temperature
        scaled_logits = array.array('f')
        for logit in logits:
            scaled_logits.append(logit / temperature)
            
        # Find max for numerical stability
        max_logit = max(scaled_logits)
        
        # Compute exp and sum
        exps = array.array('f')
        exp_sum = 0.0
        for logit in scaled_logits:
            exp_val = math.exp(logit - max_logit)
            exps.append(exp_val)
            exp_sum += exp_val
            
        # Compute probabilities
        probs = array.array('f')
        for exp_val in exps:
            probs.append(exp_val / exp_sum)
            
        self.last_logits = logits
        self.last_probs = probs
        return probs
        
    def argmax(self, values: array.array, stream_id: Optional[int] = None) -> int:
        """
        Hardware-accelerated argmax using VGPU
        
        Args:
            values: Array of values
            stream_id: Optional stream for async execution
            
        Returns:
            Index of maximum value
        """
        from .argmax import VGPUArgMax
        from ..hardware.hal import HardwareAbstractionLayer
        
        # Initialize hardware components
        hal = HardwareAbstractionLayer()
        vgpu_argmax = VGPUArgMax(hal)
        
        # Allocate tensor on VGPU
        tensor_addr = hal.allocate_memory(len(values) * 4)  # 4 bytes per float
        hal.write_memory(tensor_addr, values)
        
        # Execute argmax on VGPU
        if stream_id is not None:
            # Asynchronous execution
            return vgpu_argmax.stream_argmax(
                tensor_addr,
                shape=(len(values),),
                stream_id=stream_id
            )
        else:
            # Synchronous execution
            return vgpu_argmax.argmax(
                tensor_addr,
                shape=(len(values),)
            )
        
    def top_k_indices(self, values: array.array, k: int) -> List[int]:
        """
        Find indices of top k values
        
        Args:
            values: Array of values
            k: Number of top values to return
            
        Returns:
            List of indices
        """
        # Create index-value pairs
        pairs = [(i, v) for i, v in enumerate(values)]
        
        # Sort by value in descending order
        pairs.sort(key=lambda x: x[1], reverse=True)
        
        # Return top k indices
        return [idx for idx, _ in pairs[:k]]
        
    def sample_from_probs(self, probs: array.array,
                         temperature: float = 1.0,
                         top_k: Optional[int] = None,
                         top_p: Optional[float] = None) -> int:
        """
        Sample token index from probability distribution
        
        Args:
            probs: Probability distribution
            temperature: Sampling temperature
            top_k: If set, only sample from top k tokens
            top_p: If set, only sample from top tokens with cumulative prob < p
            
        Returns:
            Sampled token index
        """
        if top_k is not None:
            # Only consider top k tokens
            top_indices = self.top_k_indices(probs, top_k)
            filtered_probs = array.array('f')
            for i in range(len(probs)):
                filtered_probs.append(probs[i] if i in top_indices else 0.0)
            probs = filtered_probs
            
        if top_p is not None:
            # Only consider top tokens with cumulative prob < p
            sorted_indices = self.top_k_indices(probs, len(probs))
            cumsum = 0.0
            cutoff_idx = len(probs)
            
            for i, idx in enumerate(sorted_indices):
                cumsum += probs[idx]
                if cumsum > top_p:
                    cutoff_idx = i
                    break
                    
            filtered_probs = array.array('f')
            for i in range(len(probs)):
                filtered_probs.append(probs[i] if i in sorted_indices[:cutoff_idx] else 0.0)
            probs = filtered_probs
            
        # Renormalize probabilities
        prob_sum = sum(probs)
        if prob_sum > 0:
            probs = array.array('f', [p / prob_sum for p in probs])
            
        # Sample from distribution
        rand_val = random.random()
        cumsum = 0.0
        
        for i, prob in enumerate(probs):
            cumsum += prob
            if rand_val <= cumsum:
                return i
                
        return len(probs) - 1  # Fallback to last token
