Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import logging | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| # Configure professional logging | |
| logging.basicConfig(format='%(asctime)s | %(levelname)s | %(message)s', level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class MagicSupportClassifier: | |
| """ | |
| Encapsulates the customer support intent classification model. | |
| Engineered for dynamic label resolution and rapid inference. | |
| """ | |
| def __init__(self, model_id: str = "learn-abc/magicSupport-intent-classifier"): | |
| self.model_id = model_id | |
| self.max_length = 128 | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self._load_model() | |
| def _load_model(self): | |
| logger.info(f"Initializing model {self.model_id} on {self.device}...") | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) | |
| self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| # Extract number of classes dynamically | |
| self.num_classes = len(self.model.config.id2label) | |
| logger.info(f"Model loaded successfully with {self.num_classes} intent classes.") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| raise | |
| def _get_iconography(self, label: str) -> str: | |
| """ | |
| Dynamically assigns UI icons based on intent keywords. | |
| Future-proofs the application against retrained label sets. | |
| """ | |
| label_lower = label.lower() | |
| if "order" in label_lower or "delivery" in label_lower or "track" in label_lower: | |
| return "π¦" | |
| if "refund" in label_lower or "payment" in label_lower or "invoice" in label_lower or "fee" in label_lower: | |
| return "π³" | |
| if "account" in label_lower or "password" in label_lower or "register" in label_lower or "profile" in label_lower: | |
| return "π€" | |
| if "cancel" in label_lower or "delete" in label_lower or "problem" in label_lower or "issue" in label_lower: | |
| return "β οΈ" | |
| if "contact" in label_lower or "service" in label_lower or "support" in label_lower: | |
| return "π§" | |
| return "πΉ" | |
| def _format_label(self, label: str) -> str: | |
| """Cleans up raw dataset labels for professional UI presentation.""" | |
| return label.replace("_", " ").title() | |
| def predict(self, text: str, top_k: int = 5): | |
| if not text or not text.strip(): | |
| return "<div style='color: #ef4444; padding: 10px;'>β οΈ <b>Input Required:</b> Please enter a customer query.</div>", None | |
| try: | |
| inputs = self.tokenizer( | |
| text.strip(), | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=self.max_length, | |
| padding=True | |
| ).to(self.device) | |
| logits = self.model(**inputs).logits | |
| probs = F.softmax(logits, dim=-1).squeeze() | |
| if probs.dim() == 0: | |
| probs = probs.unsqueeze(0) | |
| # Cap top_k to the maximum number of available classes | |
| actual_top_k = min(top_k, self.num_classes) | |
| top_indices = torch.topk(probs, k=actual_top_k).indices.tolist() | |
| top_probs = torch.topk(probs, k=actual_top_k).values.tolist() | |
| id2label = self.model.config.id2label | |
| # Primary Prediction Formatting | |
| top_intent_raw = id2label[top_indices[0]] | |
| emoji = self._get_iconography(top_intent_raw) | |
| clean_label = self._format_label(top_intent_raw) | |
| confidence = top_probs[0] * 100 | |
| result_html = f""" | |
| <h2 style='margin-bottom: 5px; display: flex; align-items: center; gap: 8px;'>{emoji} {clean_label}</h2> | |
| <p style='margin-top: 0; font-size: 16px;'><b>Confidence:</b> {confidence:.1f}%</p> | |
| <hr style='border-top: 1px solid var(--border-color-primary); margin: 20px 0;'/> | |
| <h3 style='margin-bottom: 15px;'>π Top {actual_top_k} Predictions</h3> | |
| """ | |
| # HTML Progress Bars | |
| for idx, prob in zip(top_indices, top_probs): | |
| intent_raw = id2label[idx] | |
| e = self._get_iconography(intent_raw) | |
| l = self._format_label(intent_raw) | |
| pct = prob * 100 | |
| bar_html = f""" | |
| <div style="margin-bottom: 16px;"> | |
| <div style="display: flex; justify-content: space-between; margin-bottom: 4px;"> | |
| <strong>{e} {l}</strong> | |
| <span style="font-family:monospace;">{pct:.1f}%</span> | |
| </div> | |
| <div style="background-color: var(--background-fill-secondary); border: 1px solid var(--border-color-primary); border-radius: 6px; width: 100%; height: 10px;"> | |
| <div style="background-color: #8b5cf6; width: {pct}%; height: 100%; border-radius: 5px;"></div> | |
| </div> | |
| </div> | |
| """ | |
| result_html += bar_html | |
| # Format data for the full distribution chart | |
| chart_data = { | |
| self._format_label(id2label[i]): float(probs[i].item()) | |
| for i in range(len(probs)) | |
| } | |
| return result_html, chart_data | |
| except Exception as e: | |
| logger.error(f"Inference error: {e}") | |
| return f"<div style='color: #ef4444;'>β <b>System Error:</b> Inference failed. Check application logs.</div>", None | |
| # Initialize application backend | |
| app_backend = MagicSupportClassifier() | |
| # High-value test scenarios based on Bitext taxonomy | |
| EXAMPLES = [ | |
| ["I need to cancel my order immediately, it was placed by mistake.", 5], | |
| ["Where can I find the invoice for my last purchase?", 3], | |
| ["The item arrived damaged and I want a full refund.", 5], | |
| ["How do I change the shipping address on my account?", 3], | |
| ["I forgot my password and cannot log in.", 3], | |
| ["Are there any hidden fees if I cancel my subscription now?", 5], | |
| ] | |
| # Build Gradio Interface | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(primary_hue="violet", secondary_hue="slate"), | |
| title="MagicSupport Intent Classifier R&D Dashboard", | |
| css=""" | |
| .header-box { text-align: center; padding: 25px; background: var(--background-fill-secondary); border-radius: 10px; border: 1px solid var(--border-color-primary); margin-bottom: 20px;} | |
| .header-box h1 { color: var(--body-text-color); margin-bottom: 5px; } | |
| .header-box p { color: var(--body-text-color-subdued); font-size: 16px; margin-top: 0; } | |
| .badge { display: inline-block; padding: 4px 12px; border-radius: 12px; font-size: 13px; font-weight: 600; margin: 4px; } | |
| .domain-badge { background: #ede9fe; color: #5b21b6; border: 1px solid #ddd6fe;} | |
| .metric-badge { background: #f1f5f9; color: #334155; border: 1px solid #cbd5e1;} | |
| footer { display: none !important; } | |
| """ | |
| ) as demo: | |
| gr.HTML(""" | |
| <div class="header-box"> | |
| <h1>π§ MagicSupport Intent Classifier</h1> | |
| <p> | |
| High-precision semantic routing for automated customer support pipelines. | |
| </p> | |
| <div style="margin-top:12px;"> | |
| <span class="badge domain-badge">E-commerce & Retail</span> | |
| <span class="badge domain-badge">Account Management</span> | |
| <span class="badge domain-badge">Billing & Refunds</span> | |
| <span class="badge metric-badge">Based on Bitext Taxonomy</span> | |
| </div> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| text_input = gr.Textbox( | |
| label="Input Customer Query", | |
| placeholder="Type a customer message here (e.g., 'Where is my package?')...", | |
| lines=3, | |
| ) | |
| with gr.Row(): | |
| top_k_slider = gr.Slider( | |
| minimum=1, maximum=15, value=5, step=1, | |
| label="Display Top-K Predictions" | |
| ) | |
| with gr.Row(): | |
| predict_btn = gr.Button("π Execute Prediction", variant="primary") | |
| clear_btn = gr.Button("ποΈ Clear Interface", variant="secondary") | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=[text_input, top_k_slider], | |
| label="Actionable Test Scenarios", | |
| examples_per_page=6, | |
| ) | |
| with gr.Column(scale=5): | |
| result_output = gr.HTML(label="Inference Results") | |
| with gr.Row(): | |
| chart_output = gr.Label( | |
| label="Full Semantic Distribution Map", | |
| num_top_classes=app_backend.num_classes # Dynamically set based on model config | |
| ) | |
| with gr.Accordion("βοΈ Technical Architecture & Model Details", open=False): | |
| gr.Markdown(""" | |
| ### Core Specifications | |
| * **Target Model:** `learn-abc/magicSupport-intent-classifier` | |
| * **Objective:** Multi-class text sequence classification for customer support routing. | |
| * **Dataset Lineage:** Trained on the comprehensive `bitext/Bitext-customer-support-llm-chatbot-training-dataset`. | |
| ### Pipeline Features | |
| * **Dynamic Label Resolution:** The UI heuristic engine automatically maps raw dataset labels (e.g., `change_shipping_address`) into clean, professional UI elements (e.g., Change Shipping Address) and assigns contextual iconography. | |
| * **Optimized Inference:** Utilizes PyTorch `inference_mode` for reduced memory footprint and accelerated compute during forward passes. | |
| """) | |
| # Event Wiring | |
| predict_btn.click( | |
| fn=app_backend.predict, | |
| inputs=[text_input, top_k_slider], | |
| outputs=[result_output, chart_output], | |
| ) | |
| text_input.submit( | |
| fn=app_backend.predict, | |
| inputs=[text_input, top_k_slider], | |
| outputs=[result_output, chart_output], | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", 5, "", None), | |
| outputs=[text_input, top_k_slider, result_output, chart_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |