Spaces:
Running
Running
| """Unified multimodal processing service for text, audio, and image inputs.""" | |
| from functools import lru_cache | |
| from typing import Any | |
| import structlog | |
| from gradio.data_classes import FileData | |
| from src.services.audio_processing import AudioService, get_audio_service | |
| from src.services.image_ocr import ImageOCRService, get_image_ocr_service | |
| from src.utils.config import settings | |
| logger = structlog.get_logger(__name__) | |
| class MultimodalService: | |
| """Unified multimodal processing service.""" | |
| def __init__( | |
| self, | |
| audio_service: AudioService | None = None, | |
| ocr_service: ImageOCRService | None = None, | |
| ) -> None: | |
| """Initialize multimodal service. | |
| Args: | |
| audio_service: Audio service instance (default: get_audio_service()) | |
| ocr_service: Image OCR service instance (default: get_image_ocr_service()) | |
| """ | |
| self.audio = audio_service or get_audio_service() | |
| self.ocr = ocr_service or get_image_ocr_service() | |
| async def process_multimodal_input( | |
| self, | |
| text: str, | |
| files: list[FileData] | None = None, | |
| audio_input: tuple[int, Any] | None = None, | |
| hf_token: str | None = None, | |
| prepend_multimodal: bool = True, | |
| ) -> str: | |
| """Process multimodal input (text + images + audio) and return combined text. | |
| Args: | |
| text: Text input string | |
| files: List of uploaded files (images, audio, etc.) | |
| audio_input: Audio input tuple (sample_rate, audio_array) | |
| hf_token: HuggingFace token for authenticated Gradio Spaces | |
| prepend_multimodal: If True, prepend audio/image text to original text; otherwise append | |
| Returns: | |
| Combined text from all inputs | |
| """ | |
| multimodal_parts: list[str] = [] | |
| text_parts: list[str] = [] | |
| # Process audio input first | |
| if audio_input is not None and settings.enable_audio_input: | |
| try: | |
| transcribed = await self.audio.process_audio_input(audio_input, hf_token=hf_token) | |
| if transcribed: | |
| multimodal_parts.append(transcribed) | |
| except Exception as e: | |
| logger.warning("audio_processing_failed", error=str(e)) | |
| # Process uploaded files (images and audio files) | |
| if files and settings.enable_image_input: | |
| for file_data in files: | |
| file_path = file_data.path if isinstance(file_data, FileData) else str(file_data) | |
| # Check if it's an image | |
| if self._is_image_file(file_path): | |
| try: | |
| extracted_text = await self.ocr.extract_text(file_path, hf_token=hf_token) | |
| if extracted_text: | |
| multimodal_parts.append(extracted_text) | |
| except Exception as e: | |
| logger.warning("image_ocr_failed", file_path=file_path, error=str(e)) | |
| # Check if it's an audio file | |
| elif self._is_audio_file(file_path): | |
| try: | |
| # For audio files, we'd need to load and transcribe | |
| # For now, log a warning | |
| logger.warning("audio_file_upload_not_supported", file_path=file_path) | |
| except Exception as e: | |
| logger.warning( | |
| "audio_file_processing_failed", file_path=file_path, error=str(e) | |
| ) | |
| # Add original text if present | |
| if text and text.strip(): | |
| text_parts.append(text.strip()) | |
| # Combine parts based on prepend_multimodal flag | |
| if prepend_multimodal: | |
| # Prepend: multimodal content first, then original text | |
| combined_parts = multimodal_parts + text_parts | |
| else: | |
| # Append: original text first, then multimodal content | |
| combined_parts = text_parts + multimodal_parts | |
| # Combine all text parts | |
| combined_text = "\n\n".join(combined_parts) if combined_parts else "" | |
| logger.info( | |
| "multimodal_input_processed", | |
| text_length=len(combined_text), | |
| num_files=len(files) if files else 0, | |
| has_audio=audio_input is not None, | |
| ) | |
| return combined_text | |
| def _is_image_file(self, file_path: str) -> bool: | |
| """Check if file is an image. | |
| Args: | |
| file_path: Path to file | |
| Returns: | |
| True if file is an image | |
| """ | |
| image_extensions = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".tiff", ".tif"} | |
| return any(file_path.lower().endswith(ext) for ext in image_extensions) | |
| def _is_audio_file(self, file_path: str) -> bool: | |
| """Check if file is an audio file. | |
| Args: | |
| file_path: Path to file | |
| Returns: | |
| True if file is an audio file | |
| """ | |
| audio_extensions = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".aac", ".wma"} | |
| return any(file_path.lower().endswith(ext) for ext in audio_extensions) | |
| def get_multimodal_service() -> MultimodalService: | |
| """Get or create singleton multimodal service instance. | |
| Returns: | |
| MultimodalService instance | |
| """ | |
| return MultimodalService() | |