"""
Florence-2 Batch Image Captioning System
Using optimized electron-speed tensor processing with parallel chunk distribution
"""

import os
from pathlib import Path
import numpy as np
import logging
import time
import asyncio
from typing import List, Dict, Tuple
import safetensors.numpy
import re
from tensor_chunk_manager import ChunkManager
from parallel_array_distributor import ParallelArrayDistributor
from tensor_core_distributor import TensorCoreDistributor
import json
from transformers import BartTokenizerFast

# Set up logging
logging.basicConfig(level=logging.INFO,
                   format='%(asctime)s - %(message)s')
logger = logging.getLogger("FlorenceCaptioner")

class FlorenceBatchCaptioner:
    def __init__(self, 
                 model_path: str = "models/microsoft--Florence-2-large",
                 batch_size: int = 512,  # Keep original batch size
                 num_cores: int = 8,     # Keep original core count
                 chunks_per_core: int = 1000):  # Keep original chunks per core
        """
        Initialize Florence-2 captioning with optimized tensor compute
        
        Args:
            model_path: Path to the Florence-2 model directory
            batch_size: Number of images to process in parallel
            num_cores: Number of tensor cores to use
            chunks_per_core: Number of chunks per core for weight distribution
        """
        start_time = time.time()
        logger.info(f"Initializing with {num_cores} cores, {chunks_per_core} chunks per core")
        
        self.model_path = Path(model_path)
        self.max_batch_size = batch_size
        self.chunks_per_core = chunks_per_core
        self.num_cores = num_cores
        
        # Initialize processing components
        self.chunk_manager = ChunkManager()
        self.chunk_cache = self.chunk_manager.tensor_chunk_cache  # Add alias for compatibility
        self.parallel_distributor = ParallelArrayDistributor(num_sms=num_cores)
        self.tensor_distributor = TensorCoreDistributor()
        
        # Configure chunk manager with electron-speed settings
        self.chunk_manager = ChunkManager(
            num_cores=num_cores,
            chunks_per_core=chunks_per_core,
            persistent=True
        )
        logger.info("Initialized ChunkManager with high-speed configuration")
        
        # Create persistent core and unit assignments with speed scaling
        logger.info("Creating persistent core and unit assignments...")
        # Create persistent core and unit assignments with speed scaling - 10 units per core
        self.array_distributor = ParallelArrayDistributor(
            num_sms=num_cores, 
            cores_per_sm=10,  # Reduced units per core
            persistent=True   # Enable persistent assignments
        )
        
        # Initialize tensor processor
        self.tensor_processor = TensorCoreDistributor(
            num_sms=num_cores, 
            cores_per_sm=10  # Match the reduced unit count
        )
        
        self.memory_tracker = self._init_memory_tracker()
        self.perf_stats = {
            'total_images': 0,
            'total_time': 0,
            'stage_times': {
                'preprocessing': [],
                'feature_extraction': [],
                'caption_generation': [],
                'postprocessing': []
            },
            'memory_usage': [],
            'batch_sizes': [],
            'chunk_processing': [],
            'core_utilization': []
        }
        self.callbacks = []
        self.initialized = False
        
        # Load model configuration
        self.config = self._load_config()
        
        # Model parameters
        self.image_size = (224, 224)  # Florence-2 default
        self.max_caption_length = 128
        
        # Initialize tokenizer for decoding
        try:
            self.tokenizer = BartTokenizerFast.from_pretrained('facebook/bart-large')
            logger.info("Initialized BartTokenizer for caption decoding")
        except Exception as e:
            logger.error(f"Failed to initialize tokenizer: {e}")
            raise
        
        # Model parameters will be loaded during initialize()
        self.weights = None
        
        # Model parameters
        self.image_size = (224, 224)  # Florence-2 default
        self.max_caption_length = 128
        self.vocab_size = self.config["text_config"]["vocab_size"]
        
    def _init_memory_tracker(self):
        """Initialize memory tracking system"""
        try:
            import psutil
            return {
                'process': psutil.Process(),
                'total_memory': psutil.virtual_memory().total,
                'memory_threshold': 0.95  # Use up to 95% of available memory for max throughput
            }
        except ImportError:
            logger.warning("psutil not installed - memory optimization disabled")
            return None
            
    def _get_optimal_batch_size(self, image_size: tuple) -> int:
        """Calculate optimal batch size based on available memory"""
        if not self.memory_tracker:
            return self.max_batch_size
            
        try:
            # Get current memory usage
            current_memory = self.memory_tracker['process'].memory_info().rss
            available_memory = self.memory_tracker['total_memory'] - current_memory
            
            # Estimate memory needed per image (with safety factor)
            img_memory = np.prod(image_size) * 4 * 3  # RGB float32 + feature space + safety factor
            
            # Calculate max images that fit in memory
            max_images = int((available_memory * self.memory_tracker['memory_threshold']) / img_memory)
            
            # Return smaller of calculated size and max_batch_size
            optimal_size = min(max_images, self.max_batch_size)
            logger.debug(f"Optimal batch size: {optimal_size} (memory-based)")
            return optimal_size
            
        except Exception as e:
            logger.warning(f"Error calculating optimal batch size: {str(e)}")
            return self.max_batch_size
        
    def __enter__(self):
        """Context manager entry"""
        return self
        
    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager exit with proper cleanup"""
        try:
            # Clean up processor resources
            # Clean up any stored tensors
            if hasattr(self, 'storage'):
                if hasattr(self.storage, 'cleanup'):
                    self.storage.cleanup()
                
            logger.info("Successfully cleaned up resources")
            
        except Exception as e:
            logger.error(f"Error during cleanup: {str(e)}")
            # Don't suppress the exception if we're exiting due to an error
            if exc_type is not None:
                return False
            
        return True
        

    def _load_config(self) -> Dict:
        """Load model configuration"""
        config_path = self.model_path / "snapshots" / "00d2f1570b00c6dea5df998f5635db96840436bc" / "config.json"
        if not config_path.exists():
            raise FileNotFoundError(f"Config not found at {config_path}")
            
        import json
        with open(config_path) as f:
            return json.load(f)
            
    async def initialize(self):
        """Initialize all async components"""
        if self.initialized:
            return
            
        await self._load_model_weights()
        await self.array_distributor.initialize() if hasattr(self.array_distributor, 'initialize') else None
        await self.tensor_processor.initialize() if hasattr(self.tensor_processor, 'initialize') else None
        await self.chunk_manager.initialize() if hasattr(self.chunk_manager, 'initialize') else None
        
        self.initialized = True
        logger.info("Florence captioner initialization complete")

    async def _load_model_weights(self) -> Dict[str, np.ndarray]:
        """Load and distribute model weights using chunk management"""
        model_path = self.model_path / "snapshots" / "00d2f1570b00c6dea5df998f5635db96840436bc"
        safe_file = model_path / "model.safetensors"
        
        try:
            if not safe_file.exists():
                raise FileNotFoundError(f"Model weights not found at {safe_file}")
                
            logger.info("Loading and distributing model weights...")
            start_time = time.time()
            
            # Load weights with parallel processing
            weights = safetensors.numpy.load_file(str(safe_file))
            logger.info(f"Loaded {len(weights)} weight tensors, starting distribution")
            
            # Process weight tensors in parallel
            distributed_weights = {}
            chunk_futures = []
            
            async def process_weight(name, tensor):
                if tensor.nbytes > 1e8:  # Lower threshold to 100MB for more parallelization
                    logger.info(f"Processing tensor {name}: {tensor.nbytes/1e9:.2f}GB")
                    chunk_ids = await self.chunk_manager.chunk_tensor(tensor)
                    return name, {
                        'chunked': True,
                        'chunk_ids': chunk_ids,
                        'shape': tensor.shape,
                        'dtype': tensor.dtype
                    }
                return name, tensor
            
            # Create tasks for parallel processing
            tasks = [process_weight(name, tensor) for name, tensor in weights.items()]
            results = await asyncio.gather(*tasks)
            
            # Combine results
            distributed_weights = dict(results)
              
            
        except Exception as e:
            logger.error(f"Failed to load and distribute weights: {str(e)}")
            raise RuntimeError(f"Failed to load and distribute weights: {str(e)}")
            
    def _preprocess_images(self, image_paths: List[str], max_retries: int = 3) -> Tuple[np.ndarray, List[str]]:
        """
        Preprocess a batch of images with error handling and retries
        
        Returns:
            Tuple of (processed_images, failed_paths)
        """
        import cv2
        from PIL import Image
        
        processed = []
        failed_paths = []
        
        for path in image_paths:
            success = False
            error = None
            
            for attempt in range(max_retries):
                try:
                    # Validate file exists and is readable
                    if not os.path.isfile(path):
                        raise FileNotFoundError(f"Image file not found: {path}")
                        
                    # Try loading with PIL first to validate image
                    with Image.open(path) as pil_img:
                        if pil_img.format not in ['JPEG', 'PNG', 'WEBP']:
                            raise ValueError(f"Unsupported image format: {pil_img.format}")
                    
                    # Load and resize image with OpenCV
                    logger.info(f"Loading image: {path}")
                    img = cv2.imread(path)
                    if img is None:
                        raise ValueError(f"Failed to load image: {path}")
                    
                    logger.info(f"Image shape before resize: {img.shape}")    
                    img = cv2.resize(img, self.image_size)
                    logger.info(f"Image shape after resize: {img.shape}")
                    
                    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                    logger.info(f"Image shape after color conversion: {img.shape}")
                    
                    # Normalize
                    img = img.astype(np.float32) / 255.0
                    img = (img - 0.5) / 0.5  # Normalize to [-1, 1]
                    
                    processed.append(img)
                    success = True
                    break
                    
                except Exception as e:
                    error = str(e)
                    if attempt < max_retries - 1:
                        logger.warning(f"Retry {attempt + 1}/{max_retries} for {path}: {error}")
                        time.sleep(0.5 * (attempt + 1))  # Exponential backoff
                    else:
                        logger.error(f"Failed to process {path} after {max_retries} attempts: {error}")
                        failed_paths.append(path)
                        
        if not processed:
            raise RuntimeError("No images could be processed successfully")
            
        return np.stack(processed), failed_paths
        
    async def _process_batch(self, image_batch: np.ndarray) -> List[str]:
        """Process image batch using optimized parallel processing"""
        batch_start = time.time()
        batch_size = len(image_batch)
        
        # Adjust batch size if needed based on memory
        optimal_size = self._get_optimal_batch_size(image_batch.shape[1:3])
        if batch_size > optimal_size:
            logger.warning(f"Reducing batch size from {batch_size} to {optimal_size} due to memory constraints")
            image_batch = image_batch[:optimal_size]
            batch_size = optimal_size
            
        logger.info(f"\nProcessing batch of {batch_size} images with parallel distribution...")
        
        # Split batch into chunks and distribute to cores
        feature_start = time.time()
        
        # Pre-split batch for more efficient parallel processing
        chunk_size = max(1, batch_size // (self.num_cores * 2))  # 2 chunks per core minimum
        chunks = np.array_split(image_batch, chunk_size)
        chunk_ids = []
        
        # Process chunks in parallel
        chunk_futures = []
        for chunk in chunks:
            chunk_id = await self.chunk_manager.chunk_tensor(chunk)
            chunk_ids.extend(chunk_id)
            chunk_futures.append(self.tensor_processor.process_tensor(chunk, "vision_encode"))
            
        # Wait for all chunks to complete
        feature_chunks = await asyncio.gather(*chunk_futures)
        features = np.concatenate(feature_chunks, axis=0) if feature_chunks else None
        
        feature_time = time.time() - feature_start
        logger.info(f"Feature extraction: {feature_time:.9f}s ({batch_size/feature_time:.1f} img/s)")
        
        if features is None:
            raise RuntimeError("Failed to extract image features")
            
        # Generate text tokens using electron-speed compute
        token_start = time.time()
        tokens = await self.tensor_processor.process_tensor(features, "text_generate")
        token_time = time.time() - token_start
        logger.info(f"Token generation: {token_time:.9f}s ({batch_size/token_time:.1f} tokens/s)")
        
        if tokens is None:
            raise RuntimeError("Failed to generate tokens")
            
        # Convert logits to token ids with timing
        decode_start = time.time()
        token_ids = np.argmax(tokens, axis=-1)
        captions = await self._decode_captions(token_ids)
        decode_time = time.time() - decode_start
        
        # Track chunk processing metrics
        for chunk_id in chunk_ids:
            chunk_meta = self.chunk_manager.chunk_metadata[chunk_id]
            if chunk_meta.processing_end is not None and chunk_meta.processing_start is not None:
                self.perf_stats['chunk_processing'].append({
                    'chunk_id': chunk_id,
                    'processing_time': chunk_meta.processing_end - chunk_meta.processing_start,
                    'electron_cycles': chunk_meta.electron_cycles,
                    'quantum_ops': chunk_meta.quantum_ops
                })
        
        # Log performance metrics
        total_time = time.time() - batch_start
        logger.info(f"\nBatch Stats:")
        logger.info(f"Total time: {total_time:.3f}s")
        logger.info(f"Throughput: {batch_size/total_time:.1f} images/sec")
        logger.info(f"Per-stage times:")
        logger.info(f"- Features: {feature_time*1000:.1f}ms")
        logger.info(f"- Tokens: {token_time*1000:.1f}ms")
        logger.info(f"- Decode: {decode_time*1000:.1f}ms")
        
        return captions
        
    def _process_unit_batch(self,
                           images: np.ndarray,
                           processor_unit=None,
                           batch_key: str = None) -> List[str]:
        """Process a subset of images on a single tensor unit
        
        This is an alternative implementation using direct tensor unit processing.
        Currently not used in the main flow which uses process_parallel instead.
        """
        if processor_unit is None:
            processor_unit = self.processor.get_default_unit()
            
        if batch_key is None:
            batch_key = f"batch_{time.time()}"
            
        # Process images through vision encoder using tensor_core_v2
        vision_features = processor_unit.process_tensor(
            images,
            operation="vision_encode",
            electron_accelerated=True
        )
        
        # Store intermediate results if storage is available
        if hasattr(self, 'storage'):
            self.storage.store_tensor(
                f"{batch_key}_vision_features",
                vision_features,
                metadata={"timestamp": time.time()}
            )
        
        # Generate captions using tensor operations
        caption_tokens = processor_unit.process_tensor(
            vision_features,
            operation="text_generate",
            max_length=self.max_caption_length,
            electron_accelerated=True
        )
        
        return self._decode_captions(caption_tokens)
        
    async def _decode_captions(self, token_ids: np.ndarray, token_probs: np.ndarray = None) -> List[Dict]:
        """
        Decode token IDs to human-readable captions with confidence scores
        
        Args:
            token_ids: Array of token IDs for each caption 
            token_probs: Optional array of token probabilities
            
        Returns:
            List of dicts containing caption text and metadata
        """
        # Initialize the tokenizer if not already initialized
        if not hasattr(self, 'tokenizer'):
            try:
                # First try to load from BART tokenizer
                self.tokenizer = BartTokenizerFast.from_pretrained('facebook/bart-large')
            except:
                # If that fails, initialize empty tokenizer 
                self.tokenizer = BartTokenizerFast.from_pretrained('facebook/bart-large', use_auth_token=False)

        # Load vocabulary if not already loaded
        if not hasattr(self, 'vocab_dict'):
            vocab_path = Path(self.model_path) / "snapshots" / "00d2f1570b00c6dea5df998f5635db96840436bc" / "vocab.json"
            if vocab_path.exists():
                with open(vocab_path, encoding='utf-8') as f:
                    self.vocab_dict = json.load(f)
                # Create reverse mapping from id to token and clean tokens
                self.id_to_token = {}
                # Convert vocabulary to token map
                for token, idx in self.vocab_dict.items():
                    cleaned_token = token
                    if token.startswith('Ġ'):  # Handle space prefix tokens
                        cleaned_token = token[1:]
                    elif token.startswith('##'):  # Handle subword tokens
                        cleaned_token = token[2:]
                    elif token.startswith('▁'):  # Handle alternative space tokens 
                        cleaned_token = ' ' + token[1:]
                    elif token in ['<s>', '</s>', '<pad>', '<unk>']:
                        continue
                    self.id_to_token[idx] = cleaned_token
            else:
                logger.warning(f"Florence-2 vocabulary file not found at {vocab_path}, falling back to basic decoding")

            # Decode each sequence
        results = []
        for i, seq in enumerate(token_ids):
            # Remove padding (zeros) and get raw text
            seq = seq[seq != 0]
            # Try using the tokenizer first
            caption = self.tokenizer.decode(seq, skip_special_tokens=True)
            self.logger.debug(f"Tokenizer decode result: {caption}")
            
            if not caption:  # If tokenizer decode fails, use manual decoding
                self.logger.debug(f"Tokenizer decode failed, trying manual decoding for sequence {seq}")
                words = []
                word_probs = []
                
                for j, token_id in enumerate(seq):
                    # Get token from vocabulary if available
                    if hasattr(self, 'vocab_dict') and token_id in self.id_to_token:
                        word = self.id_to_token[token_id]
                        
                        # Skip special tokens
                        if word in ['[CLS]', '[SEP]', '[PAD]', '[MASK]', '<s>', '</s>', '<pad>', '<unk>']:
                            continue
                            
                        words.append(word)
                        if token_probs is not None:
                            word_probs.append(float(token_probs[i, j]))
                    else:
                        # Skip unknown tokens
                        continue
                
                # Join words into caption text for manual decoding case
                if words:
                    caption = " ".join(words)
                    self.logger.debug(f"Manual decode result: {caption}")
                else:
                    self.logger.debug("Manual decoding produced no words")            # Clean up the caption text
            caption = re.sub(r'\s+', ' ', caption)  # Normalize spaces
            caption = re.sub(r'\s+([.,!?])', r'\1', caption)  # Fix punctuation spacing
            caption = caption.strip()
            # Remove duplicate spaces and normalize punctuation
            caption = re.sub(r'\s*([.,!?])\s*', r'\1 ', caption)
            caption = re.sub(r'\s+', ' ', caption)
            caption = caption.strip()
            
            # Capitalize sentence beginnings
            caption = '. '.join(s.capitalize() for s in caption.split('. '))
            if caption and not any(caption.endswith(p) for p in '.!?'):
                caption += "."
            
            # Calculate confidence metrics
            result = {
                "caption": caption,
                "metadata": {
                    "length": len(words),
                    "timestamp": time.time()
                }
            }
            
            if token_probs is not None:
                # Calculate confidence scores
                avg_prob = np.mean(word_probs) if word_probs else 0
                min_prob = np.min(word_probs) if word_probs else 0
                
                result["confidence"] = {
                    "overall": float(avg_prob),
                    "min_token": float(min_prob),
                    "per_token": word_probs
                }
            
            results.append(result)
            
        return results
        
    def register_callback(self, callback):
        """Register a callback for progress updates"""
        self.callbacks.append(callback)
        
    def _update_perf_stats(self, stage: str, time_taken: float, batch_size: int = None):
        """Update performance statistics"""
        self.perf_stats['stage_times'][stage].append(time_taken)
        
        if batch_size:
            self.perf_stats['batch_sizes'].append(batch_size)
            
        if self.memory_tracker:
            mem_used = self.memory_tracker['process'].memory_info().rss
            self.perf_stats['memory_usage'].append(mem_used)
            
        # Notify callbacks
        stats = {
            'stage': stage,
            'time': time_taken,
            'batch_size': batch_size,
            'memory': mem_used if self.memory_tracker else None,
            'total_images': self.perf_stats['total_images']
        }
        
        for callback in self.callbacks:
            try:
                callback(stats)
            except Exception as e:
                logger.error(f"Error in callback: {str(e)}")
                
    def get_performance_stats(self) -> Dict:
        """Get summary of performance statistics"""
        stats = {
            'total_images': self.perf_stats['total_images'],
            'total_time': self.perf_stats['total_time'],
            'average_throughput': self.perf_stats['total_images'] / max(self.perf_stats['total_time'], 0.001),
            'stage_averages': {}
        }
        
        for stage, times in self.perf_stats['stage_times'].items():
            if times:
                stats['stage_averages'][stage] = sum(times) / len(times)
                
        if self.memory_tracker and self.perf_stats.get('memory_usage', []):
            memory_usage = self.perf_stats['memory_usage']
            stats['memory'] = {
                'peak': max(memory_usage) if memory_usage else 0,
                'average': sum(memory_usage) / len(memory_usage) if memory_usage else 0
            }
        else:
            stats['memory'] = {'peak': 0, 'average': 0}
            
        return stats
        
    async def generate_captions(self, image_paths: List[str]) -> List[Dict]:
        """Generate captions for a list of images"""
        logger.info(f"\nProcessing {len(image_paths)} images in batches of {self.max_batch_size}")
        start_time = time.time()
        all_captions = []
        
        # Process in batches
        for i in range(0, len(image_paths), self.max_batch_size):
            batch_start = time.time()
            batch_paths = image_paths[i:i + self.max_batch_size]
            logger.info(f"\nBatch {i//self.max_batch_size + 1}/{(len(image_paths) + self.max_batch_size - 1)//self.max_batch_size}:")
            
            # Preprocess batch
            # Preprocess batch and handle failed images
            image_batch, failed_paths = self._preprocess_images(batch_paths)
            
            if failed_paths:
                logger.warning(f"Failed to process {len(failed_paths)} images in this batch")
            
            if image_batch.size == 0:
                logger.warning("No images were successfully processed in this batch")
                continue
                
            try:
                # Generate captions
                batch_captions = await self._process_batch(image_batch)
                all_captions.extend(batch_captions)
            except Exception as e:
                logger.error(f"Error processing batch: {str(e)}")
                logger.error(f"Image batch shape: {image_batch.shape}")
                raise
            
            batch_time = time.time() - batch_start
            logger.info(f"Batch processing time: {batch_time:.2f}s ({len(batch_paths)/batch_time:.2f} images/s)")
            
        total_time = time.time() - start_time
        logger.info(f"\nTotal processing complete:")
        logger.info(f"Total time: {total_time:.2f}s")
        logger.info(f"Average speed: {len(image_paths)/total_time:.2f} images/s")
        logger.info(f"Images processed: {len(image_paths)}")
            
        return all_captions

def main():
    start_time = time.time()
    logger.info("=== Florence-2 Batch Captioning System ===")
    
    # Initialize captioner using context manager for proper cleanup
    try:
        with FlorenceBatchCaptioner() as captioner:
            # Get all images from sample_task directory
            frame_dir = Path("sample_task")
            image_paths = []
            
            # Look for numbered PNG files
            pattern = "*.png"
            logger.info(f"Searching for images in {frame_dir.absolute()} with pattern {pattern}")
            for img_path in sorted(frame_dir.glob(pattern)):
                logger.info(f"Found image: {img_path}")
                image_paths.append(str(img_path.absolute()))
            
            if not image_paths:
                logger.warning(f"No images found in {frame_dir}")
                return
                
            logger.info(f"Found {len(image_paths)} images to process")
            
            # Generate captions with progress tracking
            captions = captioner.generate_captions(image_paths)
            
            # Display results
            logger.info("\n=== Generated Captions ===")
            for path, result in zip(image_paths, captions):
                logger.info(f"\nImage: {path}")
                logger.info(f"Caption: {result['caption']}")
                if 'confidence' in result:
                    logger.info(f"Confidence: {result['confidence']['overall']:.2f}")
                    
            # Show performance stats
            stats = captioner.get_performance_stats()
            logger.info("\n=== Performance Summary ===")
            logger.info(f"Total time: {stats['total_time']:.2f}s")
            logger.info(f"Average throughput: {stats['average_throughput']:.2f} images/s")
            
    except Exception as e:
        logger.error(f"Error during caption generation: {str(e)}")
        raise

if __name__ == "__main__":
    main()