cassandrasestier commited on
Commit
d969648
ยท
verified ยท
1 Parent(s): 3e7d5d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -89
app.py CHANGED
@@ -1,62 +1,44 @@
1
  # ================================
2
  # ๐Ÿชž MoodMirror+ โ€” Conversational Emotional Self-Care
3
- # Advice + Inspirational quotes + Emotion-based color + SQLite DB
4
- # GoEmotions model + loads GoEmotions dataset ("simplified" config)
5
  # ================================
6
  import os
7
  import re
8
  import random
9
  import sqlite3
 
10
  from datetime import datetime
11
 
12
  import gradio as gr
13
- import torch
14
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
15
  from datasets import load_dataset
16
-
17
- # --- Storage paths (robust across local dev vs. HF Spaces) ---
 
 
 
 
 
 
18
  def _pick_data_dir():
19
- # Prefer /data if it exists AND is writable (Spaces with persistent storage).
20
  if os.path.isdir("/data") and os.access("/data", os.W_OK):
21
  return "/data"
22
- # Otherwise, fall back to the repo working directory.
23
  return os.getcwd()
24
 
25
  DATA_DIR = os.getenv("MM_DATA_DIR", _pick_data_dir())
26
  os.makedirs(DATA_DIR, exist_ok=True)
27
  DB_PATH = os.path.join(DATA_DIR, "moodmirror.db")
 
 
 
28
  print(f"[MM] Using data dir: {DATA_DIR}")
29
- print(f"[MM] SQLite path: {DB_PATH}")
 
30
 
31
- # --- Load GoEmotions dataset ("simplified") ---
32
- # This pulls from: google-research-datasets/go_emotions
33
- # The "simplified" config uses train/validation/test splits and label indices.
34
- try:
35
- ds = load_dataset("google-research-datasets/go_emotions", "simplified")
36
- LABEL_NAMES = ds["train"].features["labels"].feature.names # e.g. ['admiration', ..., 'neutral']
37
- print("[MM] GoEmotions dataset loaded.")
38
- except Exception as e:
39
- ds = None
40
- LABEL_NAMES = None
41
- print(f"[WARN] Could not load GoEmotions dataset: {e}")
42
-
43
- # --- GoEmotions model (multi-label: 27 emotions + neutral) ---
44
- MODEL_ID = "SamLowe/roberta-base-go_emotions"
45
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
46
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
47
- pipe = TextClassificationPipeline(
48
- model=model,
49
- tokenizer=tokenizer,
50
- return_all_scores=True, # list of dicts for every label
51
- function_to_apply="sigmoid", # multi-label probabilities per label
52
- device=0 if torch.cuda.is_available() else -1,
53
- )
54
-
55
- # --- Regex detection ---
56
  CRISIS_RE = re.compile(r"\b(self[- ]?harm|suicid|kill myself|end my life|overdose|cutting|i don.?t want to live|can.?t go on)\b", re.I)
57
  CLOSING_RE = re.compile(r"\b(thanks?|thank you|that'?s all|bye|goodbye|see you|take care|ok bye|no thanks?)\b", re.I)
58
 
59
- # --- Crisis resources ---
60
  CRISIS_NUMBERS = {
61
  "United States": "Call or text **988** (24/7 Suicide & Crisis Lifeline). If in immediate danger, call **911**.",
62
  "Canada": "Call or text **988** (Suicide Crisis Helpline, 24/7). If in immediate danger, call **911**.",
@@ -66,7 +48,6 @@ CRISIS_NUMBERS = {
66
  "Other / Not listed": "Call your local emergency number (**112/911**) or search โ€œsuicide crisis hotlineโ€ + your country.",
67
  }
68
 
69
- # --- Psychology-informed suggestions ---
70
  SUGGESTIONS = {
71
  "sadness": "Be gentle with yourself. Rest, cry, or connect โ€” emotions pass when theyโ€™re acknowledged.",
72
  "fear": "Ground yourself: 5 things you see, 4 you feel, 3 you hear, 2 you smell, 1 you taste.",
@@ -81,7 +62,6 @@ SUGGESTIONS = {
81
  "neutral": "Take a mindful moment: breathe deeply and release any hidden tension in your shoulders.",
82
  }
83
 
84
- # --- Inspirational quotes (short & emotionally tuned) ---
85
  QUOTES = {
86
  "sadness": [
87
  "โ€œEven the darkest night will end and the sun will rise.โ€ โ€“ Victor Hugo",
@@ -133,7 +113,7 @@ COLOR_MAP = {
133
  "neutral": "#F5F5F5",
134
  }
135
 
136
- # --- Map GoEmotions โ†’ app categories (27 emotions + neutral) ---
137
  GOEMO_TO_APP = {
138
  "admiration": "gratitude",
139
  "amusement": "joy",
@@ -165,11 +145,10 @@ GOEMO_TO_APP = {
165
  "neutral": "neutral",
166
  }
167
 
168
- THRESHOLD = 0.35 # tune to be more/less sensitive
169
 
170
- # --- SQLite setup ---
171
  def get_conn():
172
- # timeout helps if multiple requests hit the DB at once
173
  return sqlite3.connect(DB_PATH, check_same_thread=False, timeout=10)
174
 
175
  def init_db():
@@ -189,8 +168,7 @@ def init_db():
189
  conn.commit()
190
  finally:
191
  try:
192
- if conn is not None:
193
- conn.close()
194
  except Exception:
195
  pass
196
 
@@ -206,34 +184,130 @@ def log_session(country, msg, emotion):
206
  conn.commit()
207
  finally:
208
  try:
209
- if conn is not None:
210
- conn.close()
211
  except Exception:
212
  pass
213
 
214
- # --- Emotion detection (multi-label via model) ---
215
- def detect_emotions(text: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  """
217
- Returns:
218
- - chosen: list of (label, score) above threshold, sorted desc
219
- - main_app: top mapped category for UI/tips/colors
220
  """
 
 
 
 
221
  try:
222
- preds = pipe(text)[0] # list of {'label': 'joy', 'score': 0.82} for all labels
223
- chosen = [p for p in preds if p["score"] >= THRESHOLD]
224
- chosen.sort(key=lambda x: x["score"], reverse=True)
225
-
226
- # Map to app categories and pick the strongest mapped bucket
227
- bucket = {}
228
- for p in chosen:
229
- app_label = GOEMO_TO_APP.get(p["label"].lower(), "neutral")
230
- bucket[app_label] = max(bucket.get(app_label, 0.0), float(p["score"]))
231
- main_app = max(bucket, key=bucket.get) if bucket else "neutral"
232
- return chosen, main_app
233
- except Exception:
234
- return [], "neutral"
235
 
236
- # --- Chat logic ---
 
 
 
 
 
 
 
 
 
 
 
 
237
  def crisis_block(country):
238
  msg = CRISIS_NUMBERS.get(country, CRISIS_NUMBERS["Other / Not listed"])
239
  return (
@@ -243,19 +317,16 @@ def crisis_block(country):
243
  )
244
 
245
  def chat_step(message, history, country, save_session):
246
- # Crisis check
247
  if CRISIS_RE.search(message):
248
  return crisis_block(country), "#FFD6E7"
249
 
250
  if CLOSING_RE.search(message):
251
  return ("You're very welcome ๐Ÿ’› Take care of yourself. Small steps matter. ๐ŸŒฟ", "#FFFFFF")
252
 
253
- # Focus on the most recent ~100 words (simple heuristic)
254
  recent = " ".join(message.split()[-100:])
255
  detected, main = detect_emotions(recent)
256
  color = COLOR_MAP.get(main, "#FFFFFF")
257
 
258
- # Save anonymized session
259
  if save_session:
260
  log_session(country, message, main)
261
 
@@ -274,20 +345,14 @@ def chat_step(message, history, country, save_session):
274
  if not history:
275
  reply += "\n\n*Can you tell me a bit more about whatโ€™s behind that feeling?*"
276
 
 
 
 
 
 
277
  return reply, color
278
 
279
- # --- Helper: sample dataset rows for UI preview ---
280
- def sample_goemotions(n=5, split="train", seed=42):
281
- if ds is None:
282
- return [{"text": "Dataset not loaded", "labels": []}]
283
- rows = ds[split].shuffle(seed=seed).select(range(min(n, len(ds[split]))))
284
- out = []
285
- names = LABEL_NAMES or []
286
- for text, labs in zip(rows["text"], rows["labels"]):
287
- out.append({"text": text, "labels": [names[i] for i in labs]})
288
- return out
289
-
290
- # --- Gradio interface ---
291
  init_db()
292
 
293
  custom_css = """
@@ -296,11 +361,11 @@ custom_css = """
296
  @keyframes blink { 50% {opacity: 0.4;} }
297
  """
298
 
299
- with gr.Blocks(css=custom_css, title="๐Ÿชž MoodMirror+ (GoEmotions Edition)") as demo:
300
  style_injector = gr.HTML("")
301
  gr.Markdown(
302
  "### ๐Ÿชž MoodMirror+ โ€” Emotional Support & Inspiration ๐ŸŒธ\n"
303
- "Share how you feel โ€” Iโ€™ll respond with care, science-based advice, or inspiring thoughts.\n\n"
304
  "_Not medical advice. If you feel unsafe, please reach out for help immediately._"
305
  )
306
 
@@ -308,13 +373,13 @@ with gr.Blocks(css=custom_css, title="๐Ÿชž MoodMirror+ (GoEmotions Edition)") as
308
  country = gr.Dropdown(choices=list(CRISIS_NUMBERS.keys()), value="Other / Not listed", label="Country")
309
  save_ok = gr.Checkbox(value=False, label="Save anonymized session (no personal data)")
310
 
311
- chat = gr.Chatbot(height=350)
312
  msg = gr.Textbox(placeholder="Type how you feel...", label="Your message")
313
  send = gr.Button("Send")
314
  typing = gr.Markdown("", elem_classes="typing")
315
 
316
- # Dataset preview UI
317
- with gr.Accordion("๐Ÿ”Ž Preview GoEmotions samples (from the linked dataset)", open=False):
318
  with gr.Row():
319
  n_examples = gr.Slider(1, 10, value=5, step=1, label="Number of examples")
320
  split = gr.Dropdown(["train", "validation", "test"], value="train", label="Split")
@@ -322,9 +387,13 @@ with gr.Blocks(css=custom_css, title="๐Ÿชž MoodMirror+ (GoEmotions Edition)") as
322
  table = gr.Dataframe(headers=["text", "labels"], row_count=5, wrap=True)
323
 
324
  def refresh_samples(n, split_name):
325
- rows = sample_goemotions(int(n), split=split_name)
326
- # Convert to a list of [text, "label1, label2, ..."] rows for display
327
- return [[r["text"], ", ".join(r["labels"])] for r in rows]
 
 
 
 
328
 
329
  refresh.click(refresh_samples, inputs=[n_examples, split], outputs=[table])
330
 
@@ -337,8 +406,10 @@ with gr.Blocks(css=custom_css, title="๐Ÿชž MoodMirror+ (GoEmotions Edition)") as
337
  style_tag = f"<style>:root,body,.gradio-container{{background:{color}!important;}}</style>"
338
  yield chat_hist + [[user_msg, reply]], "", style_tag, ""
339
 
340
- send.click(respond, inputs=[msg, chat, country, save_ok], outputs=[chat, typing, style_injector, msg], queue=True)
341
- msg.submit(respond, inputs=[msg, chat, country, save_ok], outputs=[chat, typing, style_injector, msg], queue=True)
 
 
342
 
343
  if __name__ == "__main__":
344
  demo.queue()
 
1
  # ================================
2
  # ๐Ÿชž MoodMirror+ โ€” Conversational Emotional Self-Care
3
+ # Uses ONLY the GoEmotions dataset (no pretrained model)
4
+ # Trains TF-IDF + OneVsRest Logistic Regression on first run, caches to /data
5
  # ================================
6
  import os
7
  import re
8
  import random
9
  import sqlite3
10
+ import joblib
11
  from datetime import datetime
12
 
13
  import gradio as gr
 
 
14
  from datasets import load_dataset
15
+ from sklearn.feature_extraction.text import TfidfVectorizer
16
+ from sklearn.preprocessing import MultiLabelBinarizer
17
+ from sklearn.linear_model import LogisticRegression
18
+ from sklearn.multiclass import OneVsRestClassifier
19
+ from sklearn.pipeline import Pipeline
20
+ from sklearn.metrics import f1_score
21
+
22
+ # ---------------- Storage paths (robust local vs. HF Spaces) ----------------
23
  def _pick_data_dir():
 
24
  if os.path.isdir("/data") and os.access("/data", os.W_OK):
25
  return "/data"
 
26
  return os.getcwd()
27
 
28
  DATA_DIR = os.getenv("MM_DATA_DIR", _pick_data_dir())
29
  os.makedirs(DATA_DIR, exist_ok=True)
30
  DB_PATH = os.path.join(DATA_DIR, "moodmirror.db")
31
+ MODEL_PATH = os.path.join(DATA_DIR, "goemo_sklearn.joblib") # pipeline + mlb
32
+ MODEL_VERSION = "v1-tfidf-lr-ovr" # bump if you change training
33
+
34
  print(f"[MM] Using data dir: {DATA_DIR}")
35
+ print(f"[MM] SQLite path: {DB_PATH}")
36
+ print(f"[MM] Model path: {MODEL_PATH}")
37
 
38
+ # ---------------- Crisis & regex ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  CRISIS_RE = re.compile(r"\b(self[- ]?harm|suicid|kill myself|end my life|overdose|cutting|i don.?t want to live|can.?t go on)\b", re.I)
40
  CLOSING_RE = re.compile(r"\b(thanks?|thank you|that'?s all|bye|goodbye|see you|take care|ok bye|no thanks?)\b", re.I)
41
 
 
42
  CRISIS_NUMBERS = {
43
  "United States": "Call or text **988** (24/7 Suicide & Crisis Lifeline). If in immediate danger, call **911**.",
44
  "Canada": "Call or text **988** (Suicide Crisis Helpline, 24/7). If in immediate danger, call **911**.",
 
48
  "Other / Not listed": "Call your local emergency number (**112/911**) or search โ€œsuicide crisis hotlineโ€ + your country.",
49
  }
50
 
 
51
  SUGGESTIONS = {
52
  "sadness": "Be gentle with yourself. Rest, cry, or connect โ€” emotions pass when theyโ€™re acknowledged.",
53
  "fear": "Ground yourself: 5 things you see, 4 you feel, 3 you hear, 2 you smell, 1 you taste.",
 
62
  "neutral": "Take a mindful moment: breathe deeply and release any hidden tension in your shoulders.",
63
  }
64
 
 
65
  QUOTES = {
66
  "sadness": [
67
  "โ€œEven the darkest night will end and the sun will rise.โ€ โ€“ Victor Hugo",
 
113
  "neutral": "#F5F5F5",
114
  }
115
 
116
+ # Map GoEmotions label -> your UI buckets
117
  GOEMO_TO_APP = {
118
  "admiration": "gratitude",
119
  "amusement": "joy",
 
145
  "neutral": "neutral",
146
  }
147
 
148
+ THRESHOLD = 0.30 # probability threshold for selecting labels
149
 
150
+ # ---------------- SQLite helpers ----------------
151
  def get_conn():
 
152
  return sqlite3.connect(DB_PATH, check_same_thread=False, timeout=10)
153
 
154
  def init_db():
 
168
  conn.commit()
169
  finally:
170
  try:
171
+ if conn: conn.close()
 
172
  except Exception:
173
  pass
174
 
 
184
  conn.commit()
185
  finally:
186
  try:
187
+ if conn: conn.close()
 
188
  except Exception:
189
  pass
190
 
191
+ # ---------------- Train / Load model from DATASET ONLY ----------------
192
+ def load_goemotions_dataset():
193
+ # "simplified" gives 'text' and 'labels' as list[int] indices
194
+ ds = load_dataset("google-research-datasets/go_emotions", "simplified")
195
+ label_names = ds["train"].features["labels"].feature.names
196
+ return ds, label_names
197
+
198
+ def _prepare_xy(split):
199
+ # Each example has text and labels (list of ints)
200
+ X = split["text"]
201
+ y = split["labels"] # list[list[int]]
202
+ return X, y
203
+
204
+ def train_or_load_model():
205
+ # Try cache first
206
+ if os.path.isfile(MODEL_PATH):
207
+ print("[MM] Loading cached classifier...")
208
+ bundle = joblib.load(MODEL_PATH)
209
+ if bundle.get("version") == MODEL_VERSION:
210
+ return bundle["pipeline"], bundle["mlb"], bundle["label_names"]
211
+ else:
212
+ print("[MM] Cached model version mismatch; retraining...")
213
+
214
+ print("[MM] Loading GoEmotions dataset...")
215
+ ds, label_names = load_goemotions_dataset()
216
+
217
+ print("[MM] Preparing data...")
218
+ X_train, y_train_idx = _prepare_xy(ds["train"])
219
+ X_val, y_val_idx = _prepare_xy(ds["validation"])
220
+
221
+ # MultiLabelBinarizer to convert list[int] -> multi-hot
222
+ mlb = MultiLabelBinarizer(classes=list(range(len(label_names))))
223
+ Y_train = mlb.fit_transform(y_train_idx)
224
+ Y_val = mlb.transform(y_val_idx)
225
+
226
+ # Build pipeline
227
+ # - TfidfVectorizer with simple English settings
228
+ # - LogisticRegression (saga) in One-vs-Rest for multi-label probabilities
229
+ clf = Pipeline(steps=[
230
+ ("tfidf", TfidfVectorizer(
231
+ lowercase=True,
232
+ ngram_range=(1,2),
233
+ min_df=2,
234
+ max_df=0.9,
235
+ strip_accents="unicode",
236
+ )),
237
+ ("ovr", OneVsRestClassifier(
238
+ LogisticRegression(
239
+ solver="saga",
240
+ max_iter=1000,
241
+ n_jobs=-1,
242
+ class_weight="balanced",
243
+ ),
244
+ n_jobs=-1
245
+ ))
246
+ ])
247
+
248
+ print("[MM] Training classifier (this happens once; cached afterward)...")
249
+ clf.fit(X_train, Y_train)
250
+
251
+ # Quick validation metric (macro F1 over labels present in val)
252
+ Y_val_pred = clf.predict(X_val)
253
+ macro_f1 = f1_score(Y_val, Y_val_pred, average="macro", zero_division=0)
254
+ print(f"[MM] Validation macro F1: {macro_f1:.3f}")
255
+
256
+ # Cache model
257
+ joblib.dump({
258
+ "version": MODEL_VERSION,
259
+ "pipeline": clf,
260
+ "mlb": mlb,
261
+ "label_names": label_names
262
+ }, MODEL_PATH)
263
+ print(f"[MM] Saved classifier to {MODEL_PATH}")
264
+
265
+ return clf, mlb, label_names
266
+
267
+ # Train/load at startup
268
+ try:
269
+ CLASSIFIER, MLB, LABEL_NAMES = train_or_load_model()
270
+ except Exception as e:
271
+ print(f"[WARN] Failed to train/load classifier: {e}")
272
+ CLASSIFIER, MLB, LABEL_NAMES = None, None, None
273
+
274
+ # ---------------- Inference using ONLY the trained classifier ----------------
275
+ def classify_text(text: str):
276
  """
277
+ Returns list of (label_name, prob) for labels above THRESHOLD, sorted desc.
 
 
278
  """
279
+ if not CLASSIFIER or not MLB or not LABEL_NAMES:
280
+ return []
281
+
282
+ # predict_proba returns array shape (1, n_labels)
283
  try:
284
+ proba = CLASSIFIER.predict_proba([text])[0]
285
+ except AttributeError:
286
+ # If estimator doesn't support predict_proba (shouldn't happen with LR),
287
+ # fall back to decision_function -> sigmoid
288
+ import numpy as np
289
+ from scipy.special import expit
290
+ scores = CLASSIFIER.decision_function([text])[0]
291
+ proba = expit(scores)
292
+
293
+ idxs = [i for i, p in enumerate(proba) if p >= THRESHOLD]
294
+ # Sort by probability desc
295
+ idxs.sort(key=lambda i: proba[i], reverse=True)
296
+ return [(LABEL_NAMES[i], float(proba[i])) for i in idxs]
297
 
298
+ def detect_emotions(text: str):
299
+ chosen = classify_text(text)
300
+ if not chosen:
301
+ return [], "neutral"
302
+ # Map to app buckets and take the strongest
303
+ bucket = {}
304
+ for label, p in chosen:
305
+ app = GOEMO_TO_APP.get(label.lower(), "neutral")
306
+ bucket[app] = max(bucket.get(app, 0.0), p)
307
+ main = max(bucket, key=bucket.get) if bucket else "neutral"
308
+ return chosen, main
309
+
310
+ # ---------------- Chat logic ----------------
311
  def crisis_block(country):
312
  msg = CRISIS_NUMBERS.get(country, CRISIS_NUMBERS["Other / Not listed"])
313
  return (
 
317
  )
318
 
319
  def chat_step(message, history, country, save_session):
 
320
  if CRISIS_RE.search(message):
321
  return crisis_block(country), "#FFD6E7"
322
 
323
  if CLOSING_RE.search(message):
324
  return ("You're very welcome ๐Ÿ’› Take care of yourself. Small steps matter. ๐ŸŒฟ", "#FFFFFF")
325
 
 
326
  recent = " ".join(message.split()[-100:])
327
  detected, main = detect_emotions(recent)
328
  color = COLOR_MAP.get(main, "#FFFFFF")
329
 
 
330
  if save_session:
331
  log_session(country, message, main)
332
 
 
345
  if not history:
346
  reply += "\n\n*Can you tell me a bit more about whatโ€™s behind that feeling?*"
347
 
348
+ # (Optional) append detected emotions summary
349
+ if detected:
350
+ summary = ", ".join([f"{lbl} ({p:.2f})" for lbl, p in detected[:3]])
351
+ reply += f"\n\nDetected: {summary}"
352
+
353
  return reply, color
354
 
355
+ # ---------------- Gradio UI ----------------
 
 
 
 
 
 
 
 
 
 
 
356
  init_db()
357
 
358
  custom_css = """
 
361
  @keyframes blink { 50% {opacity: 0.4;} }
362
  """
363
 
364
+ with gr.Blocks(css=custom_css, title="๐Ÿชž MoodMirror+ (Dataset-only Edition)") as demo:
365
  style_injector = gr.HTML("")
366
  gr.Markdown(
367
  "### ๐Ÿชž MoodMirror+ โ€” Emotional Support & Inspiration ๐ŸŒธ\n"
368
+ "Powered only by the **GoEmotions dataset** (trained locally on startup).\n\n"
369
  "_Not medical advice. If you feel unsafe, please reach out for help immediately._"
370
  )
371
 
 
373
  country = gr.Dropdown(choices=list(CRISIS_NUMBERS.keys()), value="Other / Not listed", label="Country")
374
  save_ok = gr.Checkbox(value=False, label="Save anonymized session (no personal data)")
375
 
376
+ chat = gr.Chatbot(height=360)
377
  msg = gr.Textbox(placeholder="Type how you feel...", label="Your message")
378
  send = gr.Button("Send")
379
  typing = gr.Markdown("", elem_classes="typing")
380
 
381
+ # Optional: dataset sample preview (for transparency)
382
+ with gr.Accordion("๐Ÿ”Ž Preview GoEmotions samples", open=False):
383
  with gr.Row():
384
  n_examples = gr.Slider(1, 10, value=5, step=1, label="Number of examples")
385
  split = gr.Dropdown(["train", "validation", "test"], value="train", label="Split")
 
387
  table = gr.Dataframe(headers=["text", "labels"], row_count=5, wrap=True)
388
 
389
  def refresh_samples(n, split_name):
390
+ try:
391
+ ds = load_dataset("google-research-datasets/go_emotions", "simplified")
392
+ names = ds["train"].features["labels"].feature.names
393
+ rows = ds[split_name].shuffle(seed=42).select(range(min(int(n), len(ds[split_name]))))
394
+ return [[t, ", ".join([names[i] for i in labs])] for t, labs in zip(rows["text"], rows["labels"])]
395
+ except Exception as e:
396
+ return [[f"Dataset load error: {e}", ""]]
397
 
398
  refresh.click(refresh_samples, inputs=[n_examples, split], outputs=[table])
399
 
 
406
  style_tag = f"<style>:root,body,.gradio-container{{background:{color}!important;}}</style>"
407
  yield chat_hist + [[user_msg, reply]], "", style_tag, ""
408
 
409
+ send.click(respond, inputs=[msg, chat, country, save_ok],
410
+ outputs=[chat, typing, style_injector, msg], queue=True)
411
+ msg.submit(respond, inputs=[msg, chat, country, save_ok],
412
+ outputs=[chat, typing, style_injector, msg], queue=True)
413
 
414
  if __name__ == "__main__":
415
  demo.queue()