"""
Test script for Florence image captioning using optimized tensor processing with async support
"""

import logging
import time
import asyncio
import numpy as np
from pathlib import Path
from florence_batch_captioner import FlorenceBatchCaptioner
from tensor_core_v2 import ParallelTensorProcessor, TensorCore

# Configure detailed logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('florence_test.log'),
        logging.StreamHandler()
    ]
)

def format_number(num):
    """Format large numbers with commas and scientific notation if needed"""
    if num >= 1e6:
        return f"{num:,.2e}"
    return f"{num:,}"

def log_chunk_stats(stage, chunks_info):
    """Log human-readable chunk processing information"""
    logging.info(f"\n=== {stage} Processing Stats ===")
    logging.info(f"Total chunks: {len(chunks_info):,}")
    
    for core_id, core_chunks in chunks_info.items():
        logging.info(f"\nCore {core_id}:")
        logging.info(f"  Total chunks: {len(core_chunks):,}")
        logging.info(f"  Units active: {len(set(c['unit_id'] for c in core_chunks)):,}")
        
        # Calculate ops per second for this core
        total_ops = sum(c['ops'] for c in core_chunks)
        total_time = sum(c['processing_time'] for c in core_chunks)
        if total_time > 0:
            ops_per_sec = total_ops / total_time
            logging.info(f"  Operations: {format_number(total_ops)}")
            logging.info(f"  Processing time: {total_time:.9f}s")
            logging.info(f"  Speed: {format_number(ops_per_sec)} ops/s")

async def run_test():
    """Run comprehensive test of the Florence captioning system"""
    logging.info("\n=== Starting Florence Captioning System Test ===\n")
    
    # Initialize captioner with optimized settings
    captioner = FlorenceBatchCaptioner(
        num_cores=8,
        batch_size=512,  # Using optimized batch size
        chunks_per_core=1000  # Using optimized chunks per core
    )
    
    # Initialize captioner components
    await captioner.initialize()
    
    # Register performance callback
    def performance_callback(stats):
        stage = stats['stage']
        logging.info(f"\nStage: {stage}")
        logging.info(f"Time taken: {stats['time']:.9f}s")
        if 'memory' in stats and stats['memory']:
            logging.info(f"Memory used: {stats['memory']/1024/1024:.2f}MB")
        if 'batch_size' in stats and stats['batch_size']:
            logging.info(f"Batch size: {stats['batch_size']}")
            
    captioner.register_callback(performance_callback)
    
    # Test with all images from sample_task folder
    test_images = list(Path("sample_task").glob("*.png"))
    
    try:
        # Process images
        logging.info(f"Processing {len(test_images)} images...")
        start_time = time.time()
        
        captions = await captioner.generate_captions(test_images)
        
        # Log results
        total_time = time.time() - start_time
        logging.info("\n=== Results ===")
        
        for img_path, caption in zip(test_images, captions):
            logging.info(f"\nImage: {img_path}")
            logging.info(f"Caption: {caption['caption']}")
            if 'confidence' in caption:
                logging.info(f"Confidence: {caption['confidence']['overall']:.4f}")
            
        # Get detailed performance stats
        stats = captioner.get_performance_stats()
        logging.info("\n=== Performance Summary ===")
        logging.info(f"Total processing time: {total_time:.4f}s")
        logging.info(f"Average throughput: {len(test_images)/total_time:.2f} images/s")
        
        # Log stage timings
        logging.info("\nStage Timings:")
        for stage, avg_time in stats['stage_averages'].items():
            logging.info(f"  {stage}: {avg_time:.6f}s")
            
        if 'memory' in stats:
            logging.info("\nMemory Usage:")
            logging.info(f"  Peak: {stats['memory']['peak']/1024/1024:.2f}MB")
            logging.info(f"  Average: {stats['memory']['average']/1024/1024:.2f}MB")
            
        logging.info("\n=== Test Completed Successfully ===")
        
    except Exception as e:
        logging.error(f"Test failed: {str(e)}")
        raise

if __name__ == "__main__":
    asyncio.run(run_test())