import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from PIL import Image
import gc

class ParallelBatchProcessor:
    def __init__(self, model, batch_size=32, num_workers=4):
        self.model = DataParallel(model) if torch.cuda.device_count() > 1 else model
        self.batch_size = batch_size
        self.num_workers = num_workers
        
    def preprocess_image(self, image_path):
        """Preprocess single image in parallel"""
        try:
            image = Image.open(image_path).convert('RGB')
            # Add your preprocessing steps here
            return torch.tensor(np.array(image))
        except Exception as e:
            print(f"Error preprocessing {image_path}: {e}")
            return None
            
    def process_batch(self, image_paths):
        """Process a batch of images efficiently"""
        # Parallel preprocessing
        with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
            preprocessed = list(executor.map(self.preprocess_image, image_paths))
        
        # Filter out failed preprocessings
        preprocessed = [p for p in preprocessed if p is not None]
        
        if not preprocessed:
            return []
            
        # Stack tensors
        batch = torch.stack(preprocessed).cuda()
        
        try:
            # Model inference
            with torch.no_grad():
                outputs = self.model(batch)
                
            # Clean up GPU memory
            del batch
            torch.cuda.empty_cache()
            gc.collect()
            
            return outputs
            
        except Exception as e:
            print(f"Error in batch inference: {e}")
            return []
            
    def process_all(self, all_image_paths):
        """Process all images in optimized batches"""
        results = []
        for i in range(0, len(all_image_paths), self.batch_size):
            batch_paths = all_image_paths[i:i + self.batch_size]
            batch_results = self.process_batch(batch_paths)
            results.extend(batch_results)
            
        return results