krishnasimha's picture
Update app4.py
5417d7d verified
import os
import streamlit as st
import numpy as np
import time
from sentence_transformers import SentenceTransformer
import datetime
import feedparser
from huggingface_hub import hf_hub_download
import faiss, pickle
import aiohttp
import asyncio
import sqlite3
# -------------------
# Load prebuilt index
# -------------------
def init_cache_db():
conn = sqlite3.connect("query_cache.db")
c = conn.cursor()
c.execute("""
CREATE TABLE IF NOT EXISTS cache (
id INTEGER PRIMARY KEY AUTOINCREMENT,
query TEXT UNIQUE,
answer TEXT,
embedding BLOB,
frequency INTEGER DEFAULT 1
)
""")
conn.commit()
return conn
cache_conn = init_cache_db()
def store_in_cache(query, answer, embedding):
c = cache_conn.cursor()
c.execute("""
INSERT OR REPLACE INTO cache (query, answer, embedding, frequency)
VALUES (?, ?, ?, COALESCE(
(SELECT frequency FROM cache WHERE query=?), 0
) + 1)
""",
(query, answer, embedding.tobytes(), query)
)
cache_conn.commit()
def search_cache(query, embed_model, threshold=0.85):
q_emb = embed_model.encode([query], convert_to_numpy=True)[0]
c = cache_conn.cursor()
c.execute("SELECT query, answer, embedding, frequency FROM cache")
rows = c.fetchall()
best_sim = -1
best_row = None
for qry, ans, emb_blob, freq in rows:
emb = np.frombuffer(emb_blob, dtype=np.float32).reshape(-1)
sim = np.dot(q_emb, emb) / (np.linalg.norm(q_emb) * np.linalg.norm(emb))
if sim > threshold and sim > best_sim:
best_sim = sim
best_row = (qry, ans, freq)
if best_row:
return best_row[1] # return only answer
return None
# -------------------
# Load FAISS index + metadata
# -------------------
@st.cache_resource
def load_index():
faiss_path = hf_hub_download(
repo_id="krishnasimha/health-chatbot-data",
filename="health_index.faiss",
repo_type="dataset"
)
pkl_path = hf_hub_download(
repo_id="krishnasimha/health-chatbot-data",
filename="health_metadata.pkl",
repo_type="dataset"
)
index = faiss.read_index(faiss_path)
with open(pkl_path, "rb") as f:
metadata = pickle.load(f)
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
return index, metadata, embed_model
index, metadata, embed_model = load_index()
# -------------------
# FAISS Benchmark
# -------------------
def benchmark_faiss(n_queries=100, k=3):
queries = ["What is diabetes?", "How to prevent malaria?", "Symptoms of dengue?"]
query_embs = embed_model.encode(queries, convert_to_numpy=True)
times = []
for _ in range(n_queries):
q = query_embs[np.random.randint(0, len(query_embs))].reshape(1, -1)
start = time.time()
D, I = index.search(q, k)
times.append(time.time() - start)
avg_time = np.mean(times) * 1000
st.sidebar.write(f"⚑ FAISS Benchmark: {avg_time:.2f} ms/query over {n_queries} queries")
# -------------------
# Chat session management
# -------------------
if "chats" not in st.session_state:
st.session_state.chats = {}
if "current_chat" not in st.session_state:
st.session_state.current_chat = "New Chat 1"
st.session_state.chats["New Chat 1"] = [
{"role": "system", "content": "You are a helpful public health awareness chatbot."}
]
st.sidebar.header("Chat Manager")
if st.sidebar.button("βž• New Chat"):
chat_count = len(st.session_state.chats) + 1
new_chat_name = f"New Chat {chat_count}"
st.session_state.chats[new_chat_name] = [
{"role": "system", "content": "You are a helpful public health awareness chatbot."}
]
st.session_state.current_chat = new_chat_name
benchmark_faiss()
# -------------------
# Most Asked Questions
# -------------------
def get_top_cached_queries(limit=5):
c = cache_conn.cursor()
c.execute("""
SELECT query, frequency FROM cache
ORDER BY frequency DESC
LIMIT ?
""", (limit,))
return c.fetchall()
st.sidebar.subheader("πŸ”₯ Most Asked Questions")
top_qs = get_top_cached_queries()
for q, freq in top_qs:
st.sidebar.write(f"**{q}** β€” used {freq} times")
# -------------------
# Chat selector
# -------------------
chat_list = list(st.session_state.chats.keys())
selected_chat = st.sidebar.selectbox(
"Your chats:", chat_list, index=chat_list.index(st.session_state.current_chat), key="chat_select"
)
st.session_state.current_chat = selected_chat
new_name = st.sidebar.text_input("Rename Chat:", st.session_state.current_chat)
if new_name and new_name != st.session_state.current_chat:
if new_name not in st.session_state.chats:
st.session_state.chats[new_name] = st.session_state.chats.pop(st.session_state.current_chat)
st.session_state.current_chat = new_name
# -------------------
# RSS News Fetcher (async)
# -------------------
RSS_URL = "https://news.google.com/rss/search?q=health+disease+awareness&hl=en-IN&gl=IN&ceid=IN:en"
async def fetch_rss_url(url):
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
return await resp.text()
def fetch_news():
raw_xml = asyncio.run(fetch_rss_url(RSS_URL))
feed = feedparser.parse(raw_xml)
articles = []
for entry in feed.entries[:5]:
articles.append({
"title": entry.title,
"link": entry.link,
"published": entry.published
})
return articles
def update_news_hourly():
now = datetime.datetime.now()
if "last_news_update" not in st.session_state or (now - st.session_state.last_news_update).seconds > 3600:
st.session_state.last_news_update = now
st.session_state.news_articles = fetch_news()
# -------------------
# Async Together API
# -------------------
async def async_together_chat(messages):
url = "https://api.together.xyz/v1/chat/completions"
headers = {
"Authorization": f"Bearer {os.environ['TOGETHER_API_KEY']}",
"Content-Type": "application/json",
}
payload = {
"model": "deepseek-ai/DeepSeek-V3",
"messages": messages,
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=payload) as resp:
result = await resp.json()
return result["choices"][0]["message"]["content"]
# -------------------
# Query function
# -------------------
def retrieve_answer(query, k=3):
# 1️⃣ Try fetch from cache
cached_answer = search_cache(query, embed_model)
if cached_answer:
st.sidebar.success("⚑ Retrieved from cache")
return cached_answer, [] # no FAISS sources
# 2️⃣ If no cache β†’ normal FAISS pipeline
query_emb = embed_model.encode([query], convert_to_numpy=True)
D, I = index.search(query_emb, k)
retrieved = [metadata["texts"][i] for i in I[0]]
sources = [metadata["sources"][i] for i in I[0]]
context = "\n".join(retrieved)
user_message = {
"role": "user",
"content": f"Answer based on the context below:\n\n{context}\n\nQuestion: {query}"
}
st.session_state.chats[st.session_state.current_chat].append(user_message)
answer = asyncio.run(async_together_chat(st.session_state.chats[st.session_state.current_chat]))
# 3️⃣ Save the new query + embedding + answer into cache
store_in_cache(query, answer, query_emb[0])
st.session_state.chats[st.session_state.current_chat].append({"role": "assistant", "content": answer})
return answer, sources
# -------------------
# Background news task
# -------------------
async def background_news_updater():
while True:
st.session_state.news_articles = fetch_news()
await asyncio.sleep(3600) # refresh every hour
if "news_task" not in st.session_state:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
st.session_state.news_task = loop.create_task(background_news_updater())
# -------------------
# Streamlit UI
# -------------------
st.title(st.session_state.current_chat)
update_news_hourly()
st.subheader("πŸ“° Latest Health Updates")
if "news_articles" in st.session_state:
for art in st.session_state.news_articles:
st.markdown(f"**{art['title']}** \n[Read more]({art['link']}) \n*Published: {art['published']}*")
st.write("---")
user_query = st.text_input("Ask me about health, prevention, or awareness:")
if user_query:
with st.spinner("Searching knowledge base..."):
answer, sources = retrieve_answer(user_query)
st.write("### πŸ’‘ Answer")
st.write(answer)
st.write("### πŸ“– Sources")
for src in sources:
st.write(f"- {src}")
for msg in st.session_state.chats[st.session_state.current_chat]:
if msg["role"] == "user":
st.write(f"πŸ§‘ **You:** {msg['content']}")
elif msg["role"] == "assistant":
st.write(f"πŸ€– **Bot:** {msg['content']}")