| import os |
| import gradio as gr |
| import requests |
| import re |
| import time |
| import pandas as pd |
| from typing import Dict, Tuple, List, Optional |
|
|
| |
| API_URL = "http://localhost:5685/punctuate" |
|
|
|
|
| punc_dict = { |
| '!': 'EXCLAMATION', |
| '?': 'QUESTION', |
| ',': 'COMMA', |
| ';': 'SEMICOLON', |
| ':': 'COLON', |
| '-': 'HYPHEN', |
| '।': 'DARI', |
| } |
|
|
| allowed_punctuations = set(punc_dict.keys()) |
|
|
| def clean_and_normalize_text(text, remove_punctuations=False): |
| """Clean and normalize Bangla text with correct spacing""" |
| if remove_punctuations: |
| |
| cleaned_text = re.sub(f"[{re.escape(''.join(allowed_punctuations))}]", "", text) |
| |
| cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip() |
| return cleaned_text |
| else: |
| |
| chunks = re.split(f"([{re.escape(''.join(allowed_punctuations))}])", text) |
| filtered_chunks = [] |
| |
| for chunk in chunks: |
| if chunk in allowed_punctuations: |
| filtered_chunks.append(chunk) |
| else: |
| |
| clean_chunk = re.sub(rf"[^\u0980-\u09FF\u09E6-\u09EF\s]", "", chunk) |
| clean_chunk = re.sub(r'\s+', ' ', clean_chunk) |
| clean_chunk = clean_chunk.strip() |
| if clean_chunk: |
| filtered_chunks.append(' ' + clean_chunk) |
| |
| |
| result = ''.join(filtered_chunks) |
| result = re.sub(r'\s+', ' ', result).strip() |
| return result |
|
|
| def restore_punctuation(text): |
| """Call the punctuation restoration API""" |
| try: |
| payload = {"text": text} |
| start_time = time.time() |
| response = requests.post(API_URL, json=payload) |
| end_time = time.time() |
| |
| api_time = end_time - start_time |
| |
| if response.status_code == 200: |
| restored_text = response.json().get("restored_text") |
| return restored_text, api_time |
| else: |
| return f"API Error: {response.status_code} - {response.text}", api_time |
| except Exception as e: |
| return f"Connection Error: {str(e)}", 0.0 |
|
|
| def dummy_restore_punctuation(text): |
| """Dummy API call for demonstration when real API is not available""" |
| time.sleep(0.5) |
| |
| |
| words = text.split() |
| if len(words) > 5: |
| words[2] = words[2] + ',' |
| words[-1] = words[-1] + '?' |
| elif len(words) > 2: |
| words[-1] = words[-1] + '!' |
| |
| return ' '.join(words), 0.5 |
|
|
| def tokenize_with_punctuation(text): |
| """Tokenize text keeping punctuation separate using chunk-based approach""" |
| tokens = [] |
| chunks = re.split(f"([{re.escape(''.join(allowed_punctuations))}])", text) |
| |
| for chunk in chunks: |
| if not chunk.strip(): |
| continue |
| |
| if chunk in allowed_punctuations: |
| |
| tokens.append(chunk) |
| else: |
| |
| words = chunk.strip().split() |
| for word in words: |
| if word.strip(): |
| tokens.append(word.strip()) |
| |
| return tokens |
|
|
| def compare_texts(ground_truth, predicted): |
| """Compare ground truth and predicted text token by token with proper alignment""" |
| gt_tokens = tokenize_with_punctuation(ground_truth) |
| pred_tokens = tokenize_with_punctuation(predicted) |
| |
| comparison_result = [] |
| correct_puncs = {} |
| wrong_puncs = {} |
| gt_punc_counts = {} |
| |
| |
| for token in gt_tokens: |
| if token in allowed_punctuations: |
| punc_name = punc_dict[token] |
| gt_punc_counts[punc_name] = gt_punc_counts.get(punc_name, 0) + 1 |
| |
| |
| gt_words = [token for token in gt_tokens if token not in allowed_punctuations] |
| pred_words = [token for token in pred_tokens if token not in allowed_punctuations] |
| |
| |
| gt_punct_map = {} |
| pred_punct_map = {} |
| |
| |
| word_idx = -1 |
| for i, token in enumerate(gt_tokens): |
| if token not in allowed_punctuations: |
| word_idx += 1 |
| else: |
| if word_idx not in gt_punct_map: |
| gt_punct_map[word_idx] = [] |
| gt_punct_map[word_idx].append(token) |
| |
| |
| word_idx = -1 |
| for i, token in enumerate(pred_tokens): |
| if token not in allowed_punctuations: |
| word_idx += 1 |
| else: |
| if word_idx not in pred_punct_map: |
| pred_punct_map[word_idx] = [] |
| pred_punct_map[word_idx].append(token) |
| |
| |
| max_words = max(len(gt_words), len(pred_words)) |
| |
| for i in range(max_words): |
| |
| if i < len(gt_words) and i < len(pred_words): |
| if gt_words[i] == pred_words[i]: |
| comparison_result.append((gt_words[i], "correct", "black")) |
| else: |
| comparison_result.append((f"{gt_words[i]}→{pred_words[i]}", "word_diff", "orange")) |
| elif i < len(gt_words): |
| comparison_result.append((f"{gt_words[i]}→''", "missing_word", "red")) |
| elif i < len(pred_words): |
| comparison_result.append((f"''→{pred_words[i]}", "extra_word", "red")) |
| |
| |
| gt_puncs = gt_punct_map.get(i, []) |
| pred_puncs = pred_punct_map.get(i, []) |
| |
| |
| max_puncs = max(len(gt_puncs), len(pred_puncs)) |
| |
| for j in range(max_puncs): |
| if j < len(gt_puncs) and j < len(pred_puncs): |
| gt_punc = gt_puncs[j] |
| pred_punc = pred_puncs[j] |
| |
| if gt_punc == pred_punc: |
| punc_name = punc_dict[gt_punc] |
| correct_puncs[punc_name] = correct_puncs.get(punc_name, 0) + 1 |
| comparison_result.append((gt_punc, "correct", "green")) |
| else: |
| |
| punc_name = punc_dict[gt_punc] |
| wrong_puncs[punc_name] = wrong_puncs.get(punc_name, 0) + 1 |
| comparison_result.append((f"{gt_punc}→{pred_punc}", "wrong_punct", "red")) |
| |
| elif j < len(gt_puncs): |
| |
| gt_punc = gt_puncs[j] |
| punc_name = punc_dict[gt_punc] |
| wrong_puncs[punc_name] = wrong_puncs.get(punc_name, 0) + 1 |
| comparison_result.append((f"{gt_punc}→''", "missing_punct", "red")) |
| |
| elif j < len(pred_puncs): |
| |
| pred_punc = pred_puncs[j] |
| comparison_result.append((f"''→{pred_punc}", "extra_punct", "red")) |
| |
| return comparison_result, correct_puncs, wrong_puncs, gt_punc_counts |
|
|
| def create_evaluation_table(correct_puncs, wrong_puncs, gt_punc_counts): |
| """Create evaluation table""" |
| table_data = [] |
| |
| for punc_name in gt_punc_counts.keys(): |
| correct_count = correct_puncs.get(punc_name, 0) |
| wrong_count = wrong_puncs.get(punc_name, 0) |
| total_count = gt_punc_counts[punc_name] |
| |
| table_data.append([ |
| punc_name, |
| correct_count, |
| wrong_count, |
| total_count |
| ]) |
| |
| df = pd.DataFrame(table_data, columns=[ |
| "Punctuation Name", |
| "Correctly Classified", |
| "Wrongly Classified", |
| "Count in Ground Truth" |
| ]) |
| |
| return df |
|
|
| def format_comparison_html(comparison_result): |
| """Format comparison result as HTML with improved display""" |
| html = "<div style='font-family: monospace; font-size: 16px; line-height: 1.8; padding: 20px; border: 1px solid #ddd; border-radius: 5px;'>" |
| |
| for token, status, color in comparison_result: |
| if status == "correct" and color == "green": |
| |
| html += f"<span style='background-color: #d4edda; color: #155724; padding: 2px 4px; margin: 1px; border-radius: 3px; font-weight: bold;'>{token}</span>" |
| elif color == "red": |
| |
| if "→''" in token: |
| |
| missing_item = token.split("→")[0] |
| html += f"<span style='background-color: #f8d7da; color: #721c24; padding: 2px 4px; margin: 1px; border-radius: 3px; font-weight: bold;'>{missing_item}→∅</span>" |
| elif "''→" in token: |
| |
| extra_item = token.split("→")[1] |
| html += f"<span style='background-color: #f8d7da; color: #721c24; padding: 2px 4px; margin: 1px; border-radius: 3px; font-weight: bold;'>∅→{extra_item}</span>" |
| else: |
| |
| html += f"<span style='background-color: #f8d7da; color: #721c24; padding: 2px 4px; margin: 1px; border-radius: 3px; font-weight: bold;'>{token}</span>" |
| elif color == "orange": |
| |
| html += f"<span style='background-color: #fff3cd; color: #856404; padding: 2px 4px; margin: 1px; border-radius: 3px;'>{token}</span>" |
| else: |
| |
| html += f"<span style='padding: 2px 4px; margin: 1px;'>{token}</span>" |
| |
| |
| html += " " |
| |
| html += "</div>" |
| |
| |
| html += """ |
| <div style='margin-top: 15px; padding: 10px; background-color: #f8f9fa; border-radius: 5px; font-size: 14px;'> |
| <strong>Legend:</strong><br> |
| <span style='background-color: #d4edda; color: #155724; padding: 1px 3px; border-radius: 2px; margin: 2px;'>✓</span> Correct punctuation |
| <span style='background-color: #f8d7da; color: #721c24; padding: 1px 3px; border-radius: 2px; margin: 2px;'>✗</span> Wrong/Missing/Extra punctuation |
| <span style='background-color: #fff3cd; color: #856404; padding: 1px 3px; border-radius: 2px; margin: 2px;'>~</span> Word difference |
| <span style='padding: 1px 3px; margin: 2px;'>◦</span> Correct word<br> |
| <strong>∅</strong> = Empty/Missing |
| </div> |
| """ |
| |
| return html |
|
|
| def process_punctuation_restoration(input_text, ground_truth=""): |
| """Main processing function""" |
| if not input_text.strip(): |
| return "Please enter input text", "", "", None, "" |
| |
| |
| try: |
| |
| predicted_text, api_time = restore_punctuation(input_text) |
| if "Error" in str(predicted_text): |
| |
| |
| predicted_text, api_time = f"Error : {input_text}", 999999 |
| except: |
| |
| |
| predicted_text, api_time = f"Error : {input_text}", 999999 |
| |
| time_info = f"API call completed in {api_time:.3f} seconds" |
| |
| predicted_text = predicted_text[0] if isinstance(predicted_text, list) else predicted_text |
| |
| print(f"input_text: {input_text}", flush=True) |
| print(f"predicted_text: {predicted_text}", flush=True) |
| if not ground_truth.strip(): |
| return predicted_text, "", time_info, None, "" |
| |
| |
| ground_truth_normalized = clean_and_normalize_text(ground_truth) |
| |
| |
| comparison_result, correct_puncs, wrong_puncs, gt_punc_counts = compare_texts( |
| ground_truth_normalized, predicted_text |
| ) |
| |
| |
| comparison_html = format_comparison_html(comparison_result) |
| |
| |
| eval_table = create_evaluation_table(correct_puncs, wrong_puncs, gt_punc_counts) |
| |
| return predicted_text, comparison_html, time_info, eval_table, f"Normalized Ground Truth: {ground_truth_normalized}" |
|
|
| |
| def create_interface(): |
| with gr.Blocks(title="Punctuation Restoration Evaluator", theme=gr.themes.Soft()) as app: |
| gr.Markdown("# 🔤 Punctuation Restoration Evaluator") |
| gr.Markdown("Enter text to restore punctuation. Optionally provide ground truth for evaluation.") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| input_text = gr.Textbox( |
| label="Input Text (without punctuation)", |
| placeholder="পুরুষের সংখ্যা মোট জনসংখ্যার ৫২ এবং নারীর সংখ্যা ৪৮ শহরের সাক্ষরতার হার কত", |
| lines=4 |
| ) |
| |
| ground_truth = gr.Textbox( |
| label="Ground Truth (optional)", |
| placeholder="পুরুষের সংখ্যা মোট জনসংখ্যার ৫২, এবং নারীর সংখ্যা ৪৮। শহরের সাক্ষরতার হার কত?", |
| lines=4 |
| ) |
| |
| submit_btn = gr.Button("🚀 Restore Punctuation", variant="primary") |
| |
| with gr.Column(scale=2): |
| api_time = gr.Textbox(label="⏱️ API Response Time", interactive=False) |
| |
| predicted_output = gr.Textbox( |
| label="📝 Predicted Output", |
| lines=3, |
| interactive=False |
| ) |
| |
| normalized_gt = gr.Textbox( |
| label="📋 Normalized Ground Truth", |
| lines=2, |
| interactive=False |
| ) |
| |
| comparison_output = gr.HTML( |
| label="🔍 Token-wise Comparison", |
| value="<p>Comparison will appear here after processing with ground truth.</p>" |
| ) |
| |
| evaluation_table = gr.Dataframe( |
| label="📊 Punctuation Evaluation Metrics", |
| headers=["Punctuation Name", "Correctly Classified", "Wrongly Classified", "Count in Ground Truth"], |
| interactive=False |
| ) |
| |
| |
| gr.Markdown(""" |
| ### 🎨 Color Legend: |
| - 🟢 **Green**: Correctly predicted punctuation |
| - 🔴 **Red**: Incorrectly predicted, missing, or extra punctuation/word |
| - 🟡 **Orange**: Word-level differences |
| - ⚫ **Black**: Correct words/tokens |
| - **∅**: Empty/Missing (instead of showing word→word or punct→word) |
| """) |
| |
| submit_btn.click( |
| fn=process_punctuation_restoration, |
| inputs=[input_text, ground_truth], |
| outputs=[predicted_output, comparison_output, api_time, evaluation_table, normalized_gt] |
| ) |
| |
| |
| gr.Markdown("### 📚 Example") |
| gr.Examples( |
| examples=[ |
| [ |
| "পুরুষের সংখ্যা মোট জনসংখ্যার ৫২ এবং নারীর সংখ্যা ৪৮ শহরের সাক্ষরতার হার কত", |
| "পুরুষের সংখ্যা মোট জনসংখ্যার ৫২, এবং নারীর সংখ্যা ৪৮। শহরের সাক্ষরতার হার কত?" |
| ], |
| [ |
| "ক্রিকেট বিশ্বের কাছে নিজের আগামীবার তা ভালোভাবেই পৌঁছে দিলেন পাকিস্তানের পেসার আমের জামান", |
| "" |
| ] |
| ], |
| inputs=[input_text, ground_truth] |
| ) |
| |
| return app |
|
|
| if __name__ == "__main__": |
| app = create_interface() |
| app.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False, |
| debug=True |
| ) |