|
|
|
|
|
""" |
|
|
Model-based Processing Pipeline for News Dashboard |
|
|
Handles summarization and translation using Hugging Face transformers |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import torch |
|
|
from typing import List, Dict, Any, Optional |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForSeq2SeqLM, |
|
|
pipeline, |
|
|
BartForConditionalGeneration, |
|
|
BartTokenizer |
|
|
) |
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ModelProcessor: |
|
|
""" |
|
|
Model-based processing for summarization and translation |
|
|
""" |
|
|
|
|
|
def __init__(self, device: str = "auto"): |
|
|
""" |
|
|
Initialize the model processor |
|
|
|
|
|
Args: |
|
|
device: Device to run models on ("auto", "cpu", "cuda") |
|
|
""" |
|
|
self.device = self._get_device(device) |
|
|
self.summarization_model = None |
|
|
self.summarization_tokenizer = None |
|
|
self.translation_model = None |
|
|
self.translation_tokenizer = None |
|
|
self.models_loaded = False |
|
|
|
|
|
logger.info(f"ModelProcessor initialized on device: {self.device}") |
|
|
|
|
|
def _get_device(self, device: str) -> str: |
|
|
""" |
|
|
Determine the best device to use |
|
|
|
|
|
Args: |
|
|
device: Requested device |
|
|
|
|
|
Returns: |
|
|
Device string |
|
|
""" |
|
|
if device == "auto": |
|
|
if torch.cuda.is_available(): |
|
|
return "cuda" |
|
|
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): |
|
|
return "mps" |
|
|
else: |
|
|
return "cpu" |
|
|
return device |
|
|
|
|
|
def load_models(self) -> bool: |
|
|
""" |
|
|
Load all required models |
|
|
|
|
|
Returns: |
|
|
True if all models loaded successfully, False otherwise |
|
|
""" |
|
|
try: |
|
|
logger.info("Loading summarization model...") |
|
|
self._load_summarization_model() |
|
|
|
|
|
logger.info("Loading translation model...") |
|
|
self._load_translation_model() |
|
|
|
|
|
self.models_loaded = True |
|
|
logger.info("All models loaded successfully!") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading models: {str(e)}") |
|
|
return False |
|
|
|
|
|
def _load_summarization_model(self): |
|
|
""" |
|
|
Load the summarization model and tokenizer |
|
|
""" |
|
|
try: |
|
|
|
|
|
model_name = "sshleifer/distilbart-cnn-12-6" |
|
|
|
|
|
self.summarization_tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
self.summarization_model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
self.summarization_model.to(self.device) |
|
|
self.summarization_model.eval() |
|
|
|
|
|
logger.info(f"Summarization model loaded: {model_name}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading summarization model: {str(e)}") |
|
|
raise |
|
|
|
|
|
def _load_translation_model(self): |
|
|
""" |
|
|
Load the translation model and tokenizer |
|
|
""" |
|
|
try: |
|
|
|
|
|
model_name = "Helsinki-NLP/opus-mt-synthetic-en-so" |
|
|
|
|
|
self.translation_tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
self.translation_model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
self.translation_model.to(self.device) |
|
|
self.translation_model.eval() |
|
|
|
|
|
logger.info(f"Translation model loaded: {model_name}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading translation model: {str(e)}") |
|
|
raise |
|
|
|
|
|
def process_content(self, content: str, max_length: int = 150, min_length: int = 30) -> Dict[str, Any]: |
|
|
""" |
|
|
Process content through summarization and translation |
|
|
|
|
|
Args: |
|
|
content: Text content to process |
|
|
max_length: Maximum length for summary |
|
|
min_length: Minimum length for summary |
|
|
|
|
|
Returns: |
|
|
Dictionary containing processed results |
|
|
""" |
|
|
if not self.models_loaded: |
|
|
logger.error("Models not loaded. Call load_models() first.") |
|
|
return {} |
|
|
|
|
|
if not content or len(content.strip()) < 50: |
|
|
logger.warning("Content too short for processing") |
|
|
return { |
|
|
'summary': '', |
|
|
'summary_somali': '', |
|
|
'translation': '', |
|
|
'bullet_points': [], |
|
|
'bullet_points_somali': [], |
|
|
'processing_success': False, |
|
|
'error': 'Content too short' |
|
|
} |
|
|
|
|
|
try: |
|
|
|
|
|
summary = self._summarize_content(content, max_length, min_length) |
|
|
|
|
|
|
|
|
bullet_points = self._create_bullet_points(summary) |
|
|
|
|
|
|
|
|
summary_somali = self._translate_to_somali(summary) |
|
|
content_somali = self._translate_to_somali(content) |
|
|
bullet_points_somali = [self._translate_to_somali(point) for point in bullet_points] |
|
|
|
|
|
return { |
|
|
'summary': summary, |
|
|
'summary_somali': summary_somali, |
|
|
'translation': content_somali, |
|
|
'bullet_points': bullet_points, |
|
|
'bullet_points_somali': bullet_points_somali, |
|
|
'processing_success': True, |
|
|
'error': None |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing content: {str(e)}") |
|
|
return { |
|
|
'summary': '', |
|
|
'summary_somali': '', |
|
|
'translation': '', |
|
|
'bullet_points': [], |
|
|
'bullet_points_somali': [], |
|
|
'processing_success': False, |
|
|
'error': str(e) |
|
|
} |
|
|
|
|
|
def _summarize_content(self, content: str, max_length: int, min_length: int) -> str: |
|
|
""" |
|
|
Summarize content using the loaded model |
|
|
|
|
|
Args: |
|
|
content: Text to summarize |
|
|
max_length: Maximum summary length |
|
|
min_length: Minimum summary length |
|
|
|
|
|
Returns: |
|
|
Summarized text |
|
|
""" |
|
|
try: |
|
|
|
|
|
inputs = self.summarization_tokenizer( |
|
|
content, |
|
|
max_length=1024, |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
summary_ids = self.summarization_model.generate( |
|
|
inputs.input_ids, |
|
|
max_length=max_length, |
|
|
min_length=min_length, |
|
|
length_penalty=2.0, |
|
|
num_beams=4, |
|
|
early_stopping=True |
|
|
) |
|
|
|
|
|
|
|
|
summary = self.summarization_tokenizer.decode( |
|
|
summary_ids[0], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
return summary.strip() |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in summarization: {str(e)}") |
|
|
return "" |
|
|
|
|
|
def _translate_to_somali(self, text: str) -> str: |
|
|
""" |
|
|
Translate text to Somali using the loaded model |
|
|
|
|
|
Args: |
|
|
text: Text to translate |
|
|
|
|
|
Returns: |
|
|
Translated text |
|
|
""" |
|
|
if not text or len(text.strip()) < 5: |
|
|
return "" |
|
|
|
|
|
try: |
|
|
|
|
|
inputs = self.translation_tokenizer( |
|
|
text, |
|
|
max_length=512, |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
translated_ids = self.translation_model.generate( |
|
|
inputs.input_ids, |
|
|
max_length=512, |
|
|
num_beams=4, |
|
|
early_stopping=True |
|
|
) |
|
|
|
|
|
|
|
|
translation = self.translation_tokenizer.decode( |
|
|
translated_ids[0], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
return translation.strip() |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in translation: {str(e)}") |
|
|
return text |
|
|
|
|
|
def _create_bullet_points(self, summary: str) -> List[str]: |
|
|
""" |
|
|
Convert summary into bullet points |
|
|
|
|
|
Args: |
|
|
summary: Summarized text |
|
|
|
|
|
Returns: |
|
|
List of bullet points |
|
|
""" |
|
|
if not summary: |
|
|
return [] |
|
|
|
|
|
|
|
|
sentences = [s.strip() for s in summary.split('.') if s.strip()] |
|
|
|
|
|
|
|
|
bullet_points = [] |
|
|
for i, sentence in enumerate(sentences[:5]): |
|
|
if sentence: |
|
|
|
|
|
sentence = sentence.strip() |
|
|
if not sentence.endswith('.'): |
|
|
sentence += '.' |
|
|
bullet_points.append(sentence) |
|
|
|
|
|
return bullet_points |
|
|
|
|
|
def process_batch(self, data_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Process a batch of data items |
|
|
|
|
|
Args: |
|
|
data_list: List of data items to process |
|
|
|
|
|
Returns: |
|
|
List of processed data items |
|
|
""" |
|
|
if not self.models_loaded: |
|
|
logger.error("Models not loaded. Call load_models() first.") |
|
|
return data_list |
|
|
|
|
|
processed_data = [] |
|
|
|
|
|
for i, item in enumerate(data_list): |
|
|
logger.info(f"Processing item {i+1}/{len(data_list)}") |
|
|
|
|
|
|
|
|
content = item.get('content', {}) |
|
|
if isinstance(content, dict): |
|
|
text_content = content.get('cleaned_text', '') |
|
|
else: |
|
|
text_content = str(content) |
|
|
|
|
|
|
|
|
model_results = self.process_content(text_content) |
|
|
|
|
|
|
|
|
item['model_processing'] = model_results |
|
|
|
|
|
|
|
|
if isinstance(content, dict): |
|
|
content['model_summary'] = model_results['summary'] |
|
|
content['model_summary_somali'] = model_results['summary_somali'] |
|
|
content['model_translation'] = model_results['translation'] |
|
|
content['bullet_points'] = model_results['bullet_points'] |
|
|
content['bullet_points_somali'] = model_results['bullet_points_somali'] |
|
|
|
|
|
processed_data.append(item) |
|
|
|
|
|
logger.info(f"Batch processing completed: {len(processed_data)} items processed") |
|
|
return processed_data |
|
|
|
|
|
def get_model_info(self) -> Dict[str, Any]: |
|
|
""" |
|
|
Get information about loaded models |
|
|
|
|
|
Returns: |
|
|
Dictionary with model information |
|
|
""" |
|
|
return { |
|
|
'models_loaded': self.models_loaded, |
|
|
'device': self.device, |
|
|
'summarization_model': 'distilbart-cnn-12-6' if self.summarization_model else None, |
|
|
'translation_model': 'Helsinki-NLP/opus-mt-synthetic-en-so' if self.translation_model else None, |
|
|
'cuda_available': torch.cuda.is_available(), |
|
|
'mps_available': hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() |
|
|
} |
|
|
|
|
|
|
|
|
def process_with_models(data_list: List[Dict[str, Any]], device: str = "auto") -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Convenience function to process data with models |
|
|
|
|
|
Args: |
|
|
data_list: List of data items to process |
|
|
device: Device to run models on |
|
|
|
|
|
Returns: |
|
|
List of processed data items |
|
|
""" |
|
|
processor = ModelProcessor(device=device) |
|
|
|
|
|
if not processor.load_models(): |
|
|
logger.error("Failed to load models") |
|
|
return data_list |
|
|
|
|
|
return processor.process_batch(data_list) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
sample_data = [ |
|
|
{ |
|
|
'id': 'test1', |
|
|
'content': { |
|
|
'cleaned_text': 'This is a sample article about water management in Somalia. The article discusses the challenges of water scarcity and the need for sustainable water management practices. It also covers the role of international organizations in supporting water infrastructure development.' |
|
|
}, |
|
|
'source_metadata': { |
|
|
'title': 'Water Management in Somalia', |
|
|
'url': 'https://example.com' |
|
|
} |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
processed = process_with_models(sample_data) |
|
|
|
|
|
|
|
|
for item in processed: |
|
|
print(f"Original: (text length: {len(item['content']['cleaned_text'])} chars)") |
|
|
print(f"Summary: {item['model_processing']['summary']}") |
|
|
print(f"Bullet Points: {item['model_processing']['bullet_points']}") |
|
|
print(f"Somali Translation: {item['model_processing']['summary_somali']}") |
|
|
print("-" * 50) |
|
|
|