"""Unified multimodal processing service for text, audio, and image inputs.""" from functools import lru_cache from typing import Any import structlog from gradio.data_classes import FileData from src.services.audio_processing import AudioService, get_audio_service from src.services.image_ocr import ImageOCRService, get_image_ocr_service from src.utils.config import settings logger = structlog.get_logger(__name__) class MultimodalService: """Unified multimodal processing service.""" def __init__( self, audio_service: AudioService | None = None, ocr_service: ImageOCRService | None = None, ) -> None: """Initialize multimodal service. Args: audio_service: Audio service instance (default: get_audio_service()) ocr_service: Image OCR service instance (default: get_image_ocr_service()) """ self.audio = audio_service or get_audio_service() self.ocr = ocr_service or get_image_ocr_service() async def process_multimodal_input( self, text: str, files: list[FileData] | None = None, audio_input: tuple[int, Any] | None = None, hf_token: str | None = None, prepend_multimodal: bool = True, ) -> str: """Process multimodal input (text + images + audio) and return combined text. Args: text: Text input string files: List of uploaded files (images, audio, etc.) audio_input: Audio input tuple (sample_rate, audio_array) hf_token: HuggingFace token for authenticated Gradio Spaces prepend_multimodal: If True, prepend audio/image text to original text; otherwise append Returns: Combined text from all inputs """ multimodal_parts: list[str] = [] text_parts: list[str] = [] # Process audio input first if audio_input is not None and settings.enable_audio_input: try: transcribed = await self.audio.process_audio_input(audio_input, hf_token=hf_token) if transcribed: multimodal_parts.append(transcribed) except Exception as e: logger.warning("audio_processing_failed", error=str(e)) # Process uploaded files (images and audio files) if files and settings.enable_image_input: for file_data in files: file_path = file_data.path if isinstance(file_data, FileData) else str(file_data) # Check if it's an image if self._is_image_file(file_path): try: extracted_text = await self.ocr.extract_text(file_path, hf_token=hf_token) if extracted_text: multimodal_parts.append(extracted_text) except Exception as e: logger.warning("image_ocr_failed", file_path=file_path, error=str(e)) # Check if it's an audio file elif self._is_audio_file(file_path): try: # For audio files, we'd need to load and transcribe # For now, log a warning logger.warning("audio_file_upload_not_supported", file_path=file_path) except Exception as e: logger.warning( "audio_file_processing_failed", file_path=file_path, error=str(e) ) # Add original text if present if text and text.strip(): text_parts.append(text.strip()) # Combine parts based on prepend_multimodal flag if prepend_multimodal: # Prepend: multimodal content first, then original text combined_parts = multimodal_parts + text_parts else: # Append: original text first, then multimodal content combined_parts = text_parts + multimodal_parts # Combine all text parts combined_text = "\n\n".join(combined_parts) if combined_parts else "" logger.info( "multimodal_input_processed", text_length=len(combined_text), num_files=len(files) if files else 0, has_audio=audio_input is not None, ) return combined_text def _is_image_file(self, file_path: str) -> bool: """Check if file is an image. Args: file_path: Path to file Returns: True if file is an image """ image_extensions = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".tiff", ".tif"} return any(file_path.lower().endswith(ext) for ext in image_extensions) def _is_audio_file(self, file_path: str) -> bool: """Check if file is an audio file. Args: file_path: Path to file Returns: True if file is an audio file """ audio_extensions = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".aac", ".wma"} return any(file_path.lower().endswith(ext) for ext in audio_extensions) @lru_cache(maxsize=1) def get_multimodal_service() -> MultimodalService: """Get or create singleton multimodal service instance. Returns: MultimodalService instance """ return MultimodalService()