"""
Test script for Florence-2 Batch Image Captioning System
Processes images from sample_task folder with detailed progress tracking
"""

import os
import time
import logging
import asyncio
from pathlib import Path
from florence_batch_captioner import FlorenceBatchCaptioner
from rich.progress import Progress, TextColumn, BarColumn, TimeRemainingColumn
from rich.console import Console
from rich.table import Table
from rich import print as rprint
import json

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('florence_captioning.log'),
        logging.StreamHandler()
    ]
)

async def process_image_batch(captioner, image_paths, batch_size=16):
    """Process a batch of images with progress tracking"""
    console = Console()
    
    total_images = len(image_paths)
    
    results = {}
    processing_stats = {
        'total_time': 0,
        'images_processed': 0,
        'avg_time_per_image': 0,
        'batches_processed': 0
    }

    with Progress(
        TextColumn("[progress.description]{task.description}"),
        BarColumn(),
        TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
        TimeRemainingColumn(),
        console=console
    ) as progress:
        
        overall_task = progress.add_task("[cyan]Overall Progress", total=total_images)
        
        try:
            # Process all images with built-in batching
            batch_results = await captioner.generate_captions(image_paths)
            
            # Store results
            results = dict(zip(image_paths, batch_results))
            
            # Update progress
            progress.update(overall_task, advance=len(image_paths))
            
            # Show summary
            print("\n=== Generated Captions ===")
            table = Table(show_header=True, header_style="bold")
            table.add_column("Image")
            table.add_column("Caption")
            table.add_column("Confidence")
            
            for img_path, result in results.items():
                caption = result['caption']
                confidence = result.get('confidence', {}).get('overall', 'N/A')
                table.add_row(
                    os.path.basename(img_path),
                    caption[:100] + "..." if len(caption) > 100 else caption,
                    f"{confidence:.2f}" if isinstance(confidence, float) else str(confidence)
                )
            
            console.print(table)
            
            # Get performance stats
            stats = captioner.get_performance_stats()
            processing_stats.update({
                'total_time': stats['total_time'],
                'images_processed': stats['total_images'],
                'avg_time_per_image': 1.0 / stats['average_throughput'] if stats['average_throughput'] > 0 else 0,
                'batches_processed': len(image_paths) // batch_size + (1 if len(image_paths) % batch_size else 0)
            })
            
            return results, processing_stats
            
        except Exception as e:
            console.print(f"[red]Error during processing: {str(e)}[/red]")
            raise
            
    # Calculate final stats
    processing_stats['avg_time_per_image'] = processing_stats['total_time'] / processing_stats['images_processed']
    
    return results, processing_stats

async def main():
    try:
        # Initialize captioner with optimized settings
        captioner = FlorenceBatchCaptioner(
            num_cores=8,          # Using 8 tensor cores
            chunks_per_core=250,  # Optimal chunk distribution
            batch_size=16         # Process 16 images at once
        )
        
        # Get all PNG images from sample_task folder
        sample_dir = Path("sample_task")
        image_paths = sorted(str(p) for p in sample_dir.glob("*.png"))
        
        if not image_paths:
            print("No images found in sample_task folder!")
            return
            
        print(f"\n{'='*80}")
        print(f"Starting Florence-2 Image Captioning Test")
        print(f"Found {len(image_paths)} images to process")
        print(f"{'='*80}\n")
        
        # Load model weights
        await captioner.weights
        
        # Process all images
        results, stats = await process_image_batch(captioner, image_paths)
        
        # Save results
        with open('captioning_results.json', 'w') as f:
            json.dump({
                'captions': results,
                'stats': stats,
                'performance': captioner.perf_stats
            }, f, indent=2)
            
        # Show final summary
        print(f"\n{'='*80}")
        print(f"Processing Complete! Summary:")
        print(f"{'='*80}")
        print(f"Total images processed: {stats['images_processed']}")
        print(f"Total processing time: {stats['total_time']:.2f} seconds")
        print(f"Average time per image: {stats['avg_time_per_image']:.2f} seconds")
        print(f"Number of batches: {stats['batches_processed']}")
        print(f"\nResults saved to: captioning_results.json")
        print(f"Detailed logs available in: florence_captioning.log")
        print(f"{'='*80}\n")
    except Exception as e:
        print(f"Error in main: {str(e)}")
        raise
    
    # Print final summary
    print(f"\n{'='*80}")
    print("Processing Complete! Summary:")
    print(f"{'='*80}")
    print(f"Total images processed: {stats['images_processed']}")
    print(f"Total processing time: {stats['total_time']:.2f} seconds")
    print(f"Average time per image: {stats['avg_time_per_image']:.2f} seconds")
    print(f"Number of batches: {stats['batches_processed']}")
    print(f"\nResults saved to: captioning_results.json")
    print(f"Detailed logs available in: florence_captioning.log")
    print(f"{'='*80}\n")

if __name__ == "__main__":
    try:
        # Run the async main function
        asyncio.run(main())
    except Exception as e:
        print(f"Failed to run main: {str(e)}")
        raise
