Spaces:
Running
Running
| import os | |
| import streamlit as st | |
| import numpy as np | |
| import time | |
| from sentence_transformers import SentenceTransformer | |
| import datetime | |
| import feedparser | |
| from huggingface_hub import hf_hub_download | |
| import faiss, pickle | |
| import aiohttp | |
| import asyncio | |
| import sqlite3 | |
| # ------------------- | |
| # Load prebuilt index | |
| # ------------------- | |
| def init_cache_db(): | |
| conn = sqlite3.connect("query_cache.db") | |
| c = conn.cursor() | |
| c.execute(""" | |
| CREATE TABLE IF NOT EXISTS cache ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| query TEXT UNIQUE, | |
| answer TEXT, | |
| embedding BLOB, | |
| frequency INTEGER DEFAULT 1 | |
| ) | |
| """) | |
| conn.commit() | |
| return conn | |
| cache_conn = init_cache_db() | |
| def store_in_cache(query, answer, embedding): | |
| c = cache_conn.cursor() | |
| c.execute(""" | |
| INSERT OR REPLACE INTO cache (query, answer, embedding, frequency) | |
| VALUES (?, ?, ?, COALESCE( | |
| (SELECT frequency FROM cache WHERE query=?), 0 | |
| ) + 1) | |
| """, | |
| (query, answer, embedding.tobytes(), query) | |
| ) | |
| cache_conn.commit() | |
| def search_cache(query, embed_model, threshold=0.85): | |
| q_emb = embed_model.encode([query], convert_to_numpy=True)[0] | |
| c = cache_conn.cursor() | |
| c.execute("SELECT query, answer, embedding, frequency FROM cache") | |
| rows = c.fetchall() | |
| best_sim = -1 | |
| best_row = None | |
| for qry, ans, emb_blob, freq in rows: | |
| emb = np.frombuffer(emb_blob, dtype=np.float32).reshape(-1) | |
| sim = np.dot(q_emb, emb) / (np.linalg.norm(q_emb) * np.linalg.norm(emb)) | |
| if sim > threshold and sim > best_sim: | |
| best_sim = sim | |
| best_row = (qry, ans, freq) | |
| if best_row: | |
| return best_row[1] # return only answer | |
| return None | |
| # ------------------- | |
| # Load FAISS index + metadata | |
| # ------------------- | |
| def load_index(): | |
| faiss_path = hf_hub_download( | |
| repo_id="krishnasimha/health-chatbot-data", | |
| filename="health_index.faiss", | |
| repo_type="dataset" | |
| ) | |
| pkl_path = hf_hub_download( | |
| repo_id="krishnasimha/health-chatbot-data", | |
| filename="health_metadata.pkl", | |
| repo_type="dataset" | |
| ) | |
| index = faiss.read_index(faiss_path) | |
| with open(pkl_path, "rb") as f: | |
| metadata = pickle.load(f) | |
| embed_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| return index, metadata, embed_model | |
| index, metadata, embed_model = load_index() | |
| # ------------------- | |
| # FAISS Benchmark | |
| # ------------------- | |
| def benchmark_faiss(n_queries=100, k=3): | |
| queries = ["What is diabetes?", "How to prevent malaria?", "Symptoms of dengue?"] | |
| query_embs = embed_model.encode(queries, convert_to_numpy=True) | |
| times = [] | |
| for _ in range(n_queries): | |
| q = query_embs[np.random.randint(0, len(query_embs))].reshape(1, -1) | |
| start = time.time() | |
| D, I = index.search(q, k) | |
| times.append(time.time() - start) | |
| avg_time = np.mean(times) * 1000 | |
| st.sidebar.write(f"β‘ FAISS Benchmark: {avg_time:.2f} ms/query over {n_queries} queries") | |
| # ------------------- | |
| # Chat session management | |
| # ------------------- | |
| if "chats" not in st.session_state: | |
| st.session_state.chats = {} | |
| if "current_chat" not in st.session_state: | |
| st.session_state.current_chat = "New Chat 1" | |
| st.session_state.chats["New Chat 1"] = [ | |
| {"role": "system", "content": "You are a helpful public health awareness chatbot."} | |
| ] | |
| st.sidebar.header("Chat Manager") | |
| if st.sidebar.button("β New Chat"): | |
| chat_count = len(st.session_state.chats) + 1 | |
| new_chat_name = f"New Chat {chat_count}" | |
| st.session_state.chats[new_chat_name] = [ | |
| {"role": "system", "content": "You are a helpful public health awareness chatbot."} | |
| ] | |
| st.session_state.current_chat = new_chat_name | |
| benchmark_faiss() | |
| # ------------------- | |
| # Most Asked Questions | |
| # ------------------- | |
| def get_top_cached_queries(limit=5): | |
| c = cache_conn.cursor() | |
| c.execute(""" | |
| SELECT query, frequency FROM cache | |
| ORDER BY frequency DESC | |
| LIMIT ? | |
| """, (limit,)) | |
| return c.fetchall() | |
| st.sidebar.subheader("π₯ Most Asked Questions") | |
| top_qs = get_top_cached_queries() | |
| for q, freq in top_qs: | |
| st.sidebar.write(f"**{q}** β used {freq} times") | |
| # ------------------- | |
| # Chat selector | |
| # ------------------- | |
| chat_list = list(st.session_state.chats.keys()) | |
| selected_chat = st.sidebar.selectbox( | |
| "Your chats:", chat_list, index=chat_list.index(st.session_state.current_chat), key="chat_select" | |
| ) | |
| st.session_state.current_chat = selected_chat | |
| new_name = st.sidebar.text_input("Rename Chat:", st.session_state.current_chat) | |
| if new_name and new_name != st.session_state.current_chat: | |
| if new_name not in st.session_state.chats: | |
| st.session_state.chats[new_name] = st.session_state.chats.pop(st.session_state.current_chat) | |
| st.session_state.current_chat = new_name | |
| # ------------------- | |
| # RSS News Fetcher (async) | |
| # ------------------- | |
| RSS_URL = "https://news.google.com/rss/search?q=health+disease+awareness&hl=en-IN&gl=IN&ceid=IN:en" | |
| async def fetch_rss_url(url): | |
| async with aiohttp.ClientSession() as session: | |
| async with session.get(url) as resp: | |
| return await resp.text() | |
| def fetch_news(): | |
| raw_xml = asyncio.run(fetch_rss_url(RSS_URL)) | |
| feed = feedparser.parse(raw_xml) | |
| articles = [] | |
| for entry in feed.entries[:5]: | |
| articles.append({ | |
| "title": entry.title, | |
| "link": entry.link, | |
| "published": entry.published | |
| }) | |
| return articles | |
| def update_news_hourly(): | |
| now = datetime.datetime.now() | |
| if "last_news_update" not in st.session_state or (now - st.session_state.last_news_update).seconds > 3600: | |
| st.session_state.last_news_update = now | |
| st.session_state.news_articles = fetch_news() | |
| # ------------------- | |
| # Async Together API | |
| # ------------------- | |
| async def async_together_chat(messages): | |
| url = "https://api.together.xyz/v1/chat/completions" | |
| headers = { | |
| "Authorization": f"Bearer {os.environ['TOGETHER_API_KEY']}", | |
| "Content-Type": "application/json", | |
| } | |
| payload = { | |
| "model": "deepseek-ai/DeepSeek-V3", | |
| "messages": messages, | |
| } | |
| async with aiohttp.ClientSession() as session: | |
| async with session.post(url, headers=headers, json=payload) as resp: | |
| result = await resp.json() | |
| return result["choices"][0]["message"]["content"] | |
| # ------------------- | |
| # Query function | |
| # ------------------- | |
| def retrieve_answer(query, k=3): | |
| # 1οΈβ£ Try fetch from cache | |
| cached_answer = search_cache(query, embed_model) | |
| if cached_answer: | |
| st.sidebar.success("β‘ Retrieved from cache") | |
| return cached_answer, [] # no FAISS sources | |
| # 2οΈβ£ If no cache β normal FAISS pipeline | |
| query_emb = embed_model.encode([query], convert_to_numpy=True) | |
| D, I = index.search(query_emb, k) | |
| retrieved = [metadata["texts"][i] for i in I[0]] | |
| sources = [metadata["sources"][i] for i in I[0]] | |
| context = "\n".join(retrieved) | |
| user_message = { | |
| "role": "user", | |
| "content": f"Answer based on the context below:\n\n{context}\n\nQuestion: {query}" | |
| } | |
| st.session_state.chats[st.session_state.current_chat].append(user_message) | |
| answer = asyncio.run(async_together_chat(st.session_state.chats[st.session_state.current_chat])) | |
| # 3οΈβ£ Save the new query + embedding + answer into cache | |
| store_in_cache(query, answer, query_emb[0]) | |
| st.session_state.chats[st.session_state.current_chat].append({"role": "assistant", "content": answer}) | |
| return answer, sources | |
| # ------------------- | |
| # Background news task | |
| # ------------------- | |
| async def background_news_updater(): | |
| while True: | |
| st.session_state.news_articles = fetch_news() | |
| await asyncio.sleep(3600) # refresh every hour | |
| if "news_task" not in st.session_state: | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| st.session_state.news_task = loop.create_task(background_news_updater()) | |
| # ------------------- | |
| # Streamlit UI | |
| # ------------------- | |
| st.title(st.session_state.current_chat) | |
| update_news_hourly() | |
| st.subheader("π° Latest Health Updates") | |
| if "news_articles" in st.session_state: | |
| for art in st.session_state.news_articles: | |
| st.markdown(f"**{art['title']}** \n[Read more]({art['link']}) \n*Published: {art['published']}*") | |
| st.write("---") | |
| user_query = st.text_input("Ask me about health, prevention, or awareness:") | |
| if user_query: | |
| with st.spinner("Searching knowledge base..."): | |
| answer, sources = retrieve_answer(user_query) | |
| st.write("### π‘ Answer") | |
| st.write(answer) | |
| st.write("### π Sources") | |
| for src in sources: | |
| st.write(f"- {src}") | |
| for msg in st.session_state.chats[st.session_state.current_chat]: | |
| if msg["role"] == "user": | |
| st.write(f"π§ **You:** {msg['content']}") | |
| elif msg["role"] == "assistant": | |
| st.write(f"π€ **Bot:** {msg['content']}") | |