"""
Formats tensor outputs into human-readable results for different types of operations
"""

import numpy as np
from typing import Dict, List, Union, Optional

class TensorOutputFormatter:
    def __init__(self, vocab_size: int = 50257):
        # Default to GPT-2 vocab size if not specified
        self.vocab_size = vocab_size
        self._load_vocab()
        
    def _load_vocab(self):
        """Load vocabulary for text generation"""
        try:
            import tiktoken
            self.encoding = tiktoken.get_encoding("gpt2")
            self.has_tiktoken = True
        except ImportError:
            # Fallback to basic ASCII if tiktoken not available
            self.has_tiktoken = False
            self.encoding = {i: chr(i % 128) for i in range(self.vocab_size)}
    
    def format_output(self, result: Dict) -> Dict:
        """Format tensor output based on operation type"""
        output_type = result['metadata']['output_type']
        
        if output_type == 'probability':
            return self._format_probabilities(result)
        elif output_type == 'transformer':
            return self._format_transformer_output(result)
        elif output_type == 'attention_scores':
            return self._format_attention_scores(result)
        elif output_type == 'text_generate':
            return self._format_text_output(result)
        else:
            # Return original result with basic stats for other types
            return {
                'raw_output': result['result'],
                'shape': result['metadata']['shape'],
                'stats': {
                    'mean': float(np.mean(result['result'])),
                    'std': float(np.std(result['result'])),
                    'min': float(np.min(result['result'])),
                    'max': float(np.max(result['result']))
                }
            }
    
    def _format_probabilities(self, result: Dict) -> Dict:
        """Format probability distributions with top-k values"""
        probs = result['result']
        k = min(10, probs.shape[-1])  # Show top 10 probabilities
        
        # Get top-k indices and probabilities
        top_k_probs = np.sort(probs, axis=-1)[..., -k:]
        top_k_indices = np.argsort(probs, axis=-1)[..., -k:]
        
        formatted = {
            'distribution_stats': {
                'entropy': float(-np.sum(probs * np.log2(probs + 1e-10))),
                'top_k_confidence': float(np.sum(top_k_probs))
            },
            'top_k_results': []
        }
        
        # Format each top-k result
        for i in range(k):
            prob = float(top_k_probs[..., i])
            idx = int(top_k_indices[..., i])
            token = self.encoding.decode([idx]) if self.has_tiktoken else f"token_{idx}"
            formatted['top_k_results'].append({
                'token': token,
                'probability': prob,
                'index': idx
            })
            
        return formatted
    
    def _format_transformer_output(self, result: Dict) -> Dict:
        """Format transformer output with attention visualization"""
        output = result['result']
        metadata = result['metadata']
        
        return {
            'layer_stats': metadata.get('layer_norm_stats', {}),
            'output_summary': {
                'shape': metadata['shape'],
                'activation_stats': {
                    'mean': float(np.mean(output)),
                    'max_activation': float(np.max(output)),
                    'sparsity': float(np.mean(output == 0))
                }
            },
            'attention_patterns': self._summarize_attention(output)
        }
    
    def _format_attention_scores(self, result: Dict) -> Dict:
        """Format attention scores with visualization data"""
        scores = result['result']
        
        # Calculate attention statistics
        attention_stats = {
            'max_attention': float(np.max(scores)),
            'entropy': float(-np.sum(scores * np.log2(scores + 1e-10), axis=-1)),
            'sparsity': float(np.mean(scores < 0.01))  # % of near-zero attention
        }
        
        # Get top attended positions
        top_k = 5
        top_indices = np.argsort(scores, axis=-1)[..., -top_k:]
        top_values = np.sort(scores, axis=-1)[..., -top_k:]
        
        return {
            'stats': attention_stats,
            'top_attended_positions': [
                {
                    'position': int(idx),
                    'attention_value': float(val)
                }
                for idx, val in zip(top_indices.flatten(), top_values.flatten())
            ]
        }
    
    def _format_text_output(self, result: Dict) -> Dict:
        """Format text generation output with token probabilities"""
        logits = result['result']
        
        # Convert logits to probabilities
        probs = np.exp(logits - np.max(logits, axis=-1, keepdims=True))
        probs = probs / np.sum(probs, axis=-1, keepdims=True)
        
        # Get top-k tokens and their probabilities
        k = 5
        top_k_probs = np.sort(probs, axis=-1)[..., -k:]
        top_k_indices = np.argsort(probs, axis=-1)[..., -k:]
        
        formatted = {
            'top_tokens': [],
            'generation_stats': {
                'entropy': float(-np.sum(probs * np.log2(probs + 1e-10))),
                'max_confidence': float(np.max(probs))
            }
        }
        
        # Format each top token
        for i in range(k):
            prob = float(top_k_probs[..., i])
            idx = int(top_k_indices[..., i])
            token = self.encoding.decode([idx]) if self.has_tiktoken else f"token_{idx}"
            formatted['top_tokens'].append({
                'token': token,
                'probability': prob,
                'index': idx
            })
            
        return formatted
    
    def _summarize_attention(self, output: np.ndarray) -> Dict:
        """Summarize attention patterns in transformer output"""
        # Assume last two dimensions are attention patterns
        if len(output.shape) >= 2:
            attention_pattern = output[..., :output.shape[-2], :output.shape[-1]]
            return {
                'pattern_stats': {
                    'average_attention': float(np.mean(attention_pattern)),
                    'attention_sparsity': float(np.mean(attention_pattern < 0.01)),
                    'max_attention_value': float(np.max(attention_pattern))
                },
                'shape': list(attention_pattern.shape)
            }
        return {}