Codey-Bryant / app.py
pasxalisag's picture
Upload app.py
619f302 verified
"""
Codey Bryant 3.0 β€” SOTA RAG for Hugging Face Spaces
Maintains EXACT same architecture: HyDE + Query Rewriting + Multi-Query + Answer-Space Retrieval
"""
import os
import sys
import logging
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional, Iterator
from functools import lru_cache
from threading import Thread
import warnings
# Configure logging for Hugging Face Spaces
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler('/data/app.log')
]
)
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore")
# Import core dependencies
import numpy as np
import torch
from datasets import load_dataset, Dataset
from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi
from sklearn.cluster import MiniBatchKMeans
import spacy
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
GenerationConfig,
TextIteratorStreamer,
BitsAndBytesConfig,
)
import gradio as gr
import pickle
import json
# Try to import FAISS
try:
import faiss
FAISS_AVAILABLE = True
except ImportError:
FAISS_AVAILABLE = False
logger.warning("FAISS not available, using numpy fallback")
# Environment setup for Hugging Face Spaces
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# Use persistent storage for Hugging Face Spaces
ARTIFACT_DIR = os.environ.get("ARTIFACT_DIR", "/data/artifacts")
os.makedirs(ARTIFACT_DIR, exist_ok=True)
# Paths for artifacts
LLM_ARTIFACT_PATH = os.path.join(ARTIFACT_DIR, "llm_model")
EMBED_ARTIFACT_PATH = os.path.join(ARTIFACT_DIR, "embed_model")
BM25_ARTIFACT_PATH = os.path.join(ARTIFACT_DIR, "bm25.pkl")
CORPUS_DATA_PATH = os.path.join(ARTIFACT_DIR, "corpus_data.json")
CORPUS_EMBED_PATH = os.path.join(ARTIFACT_DIR, "corpus_embeddings.npy")
ANSWER_EMBED_PATH = os.path.join(ARTIFACT_DIR, "answer_embeddings.npy")
FAISS_INDEX_PATH = os.path.join(ARTIFACT_DIR, "faiss_index.bin")
# Device configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
logger.info("Using CPU")
# Model configuration (EXACT SAME AS BEFORE)
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
MAX_CORPUS_SIZE = 600
# ========================
# 1) Dataset & Retrieval (EXACT SAME)
# ========================
def load_opc_datasets() -> Dict[str, Dataset]:
"""Load coding datasets - same function"""
try:
logger.info("Loading OPC datasets...")
ds_instruct = load_dataset("OpenCoder-LLM/opc-sft-stage2", "educational_instruct", split="train")
ds_evol = load_dataset("OpenCoder-LLM/opc-sft-stage2", "evol_instruct", split="train")
return {"educational_instruct": ds_instruct, "evol_instruct": ds_evol}
except Exception as e:
logger.warning(f"OPC failed ({e}), falling back to python_code_instructions...")
ds = load_dataset("iamtarun/python_code_instructions_18k_alpaca", split="train")
return {"python_code": ds}
def convo_to_io(example: Dict) -> Tuple[str, str]:
"""Convert conversation to input/output - same function"""
if "messages" in example:
msgs = example["messages"]
elif "conversations" in example:
msgs = example["conversations"]
else:
instr = example.get("instruction") or example.get("prompt") or ""
inp = example.get("input") or ""
out = example.get("output") or example.get("response") or ""
return (instr + "\n" + inp).strip(), out
user_text, assistant_text = "", ""
for i, m in enumerate(msgs):
role = (m.get("role") or m.get("from") or "").lower()
content = m.get("content") or m.get("value") or ""
if role in ("user", "human") and not user_text:
user_text = content
if role in ("assistant", "gpt") and user_text:
assistant_text = content
break
return user_text.strip(), assistant_text.strip()
@dataclass
class RetrievalSystem:
"""Retrieval system dataclass - same structure"""
embed_model: SentenceTransformer
bm25: BM25Okapi
corpus_texts: List[str]
corpus_answers: List[str]
corpus_embeddings: np.ndarray
answer_embeddings: np.ndarray
corpus_meta: List[Dict]
nlp: spacy.language.Language
faiss_index: Optional[any] = None
def build_retrieval_system(ds_map: Dict[str, Dataset]) -> RetrievalSystem:
"""Build retrieval system - EXACT SAME IMPLEMENTATION"""
# Try to load from artifacts first
required = [EMBED_ARTIFACT_PATH, BM25_ARTIFACT_PATH, CORPUS_DATA_PATH, CORPUS_EMBED_PATH, ANSWER_EMBED_PATH]
if FAISS_AVAILABLE:
required.append(FAISS_INDEX_PATH)
if all(os.path.exists(p) for p in required):
logger.info("Loading retrieval system from artifacts...")
embed_model = SentenceTransformer(EMBED_ARTIFACT_PATH, device=str(DEVICE))
with open(BM25_ARTIFACT_PATH, "rb") as f:
bm25 = pickle.load(f)
with open(CORPUS_DATA_PATH, "r", encoding="utf-8") as f:
data = json.load(f)
corpus_embeddings = np.load(CORPUS_EMBED_PATH)
answer_embeddings = np.load(ANSWER_EMBED_PATH)
faiss_index = faiss.read_index(FAISS_INDEX_PATH) if FAISS_AVAILABLE and os.path.exists(FAISS_INDEX_PATH) else None
nlp = spacy.load("en_core_web_sm", disable=["parser", "ner"])
return RetrievalSystem(
embed_model=embed_model, bm25=bm25,
corpus_texts=data["texts"], corpus_answers=data["answers"],
corpus_embeddings=corpus_embeddings, answer_embeddings=answer_embeddings,
corpus_meta=data["meta"], nlp=nlp, faiss_index=faiss_index
)
# Build from scratch (same implementation)
logger.info("Building retrieval system with answer-space support...")
all_questions, all_answers, all_metas = [], [], []
for name, ds in ds_map.items():
for ex in ds.select(range(min(len(ds), 1500))):
q, a = convo_to_io(ex)
if q and a and 50 < len(a) < 2000:
all_questions.append(q)
all_answers.append(a)
all_metas.append({"intent": name, "answer": a})
embed_model = SentenceTransformer(EMBED_MODEL, device=str(DEVICE))
question_embeddings = embed_model.encode(all_questions, batch_size=64, show_progress_bar=True, normalize_embeddings=True)
answer_embeddings = embed_model.encode(all_answers, batch_size=64, show_progress_bar=True, normalize_embeddings=True)
# Clustering to reduce size (same)
if len(all_questions) > MAX_CORPUS_SIZE:
kmeans = MiniBatchKMeans(n_clusters=MAX_CORPUS_SIZE, random_state=42, batch_size=1000)
labels = kmeans.fit_predict(answer_embeddings)
selected = []
for i in range(MAX_CORPUS_SIZE):
mask = labels == i
if mask.any():
idx = np.where(mask)[0]
dists = np.linalg.norm(answer_embeddings[idx] - kmeans.cluster_centers_[i], axis=1)
selected.append(idx[np.argmin(dists)])
idxs = selected
else:
idxs = list(range(len(all_questions)))
texts = [all_questions[i] for i in idxs]
answers = [all_answers[i] for i in idxs]
metas = [all_metas[i] for i in idxs]
q_embs = question_embeddings[idxs]
a_embs = answer_embeddings[idxs]
tokenized = [t.lower().split() for t in texts]
bm25 = BM25Okapi(tokenized)
faiss_index = None
if FAISS_AVAILABLE:
faiss_index = faiss.IndexFlatIP(a_embs.shape[1])
faiss_index.add(a_embs.astype('float32'))
# Save everything
embed_model.save(EMBED_ARTIFACT_PATH)
with open(BM25_ARTIFACT_PATH, "wb") as f:
pickle.dump(bm25, f)
with open(CORPUS_DATA_PATH, "w", encoding="utf-8") as f:
json.dump({"texts": texts, "answers": answers, "meta": metas}, f)
np.save(CORPUS_EMBED_PATH, q_embs)
np.save(ANSWER_EMBED_PATH, a_embs)
if faiss_index:
faiss.write_index(faiss_index, FAISS_INDEX_PATH)
nlp = spacy.load("en_core_web_sm", disable=["parser", "ner"])
return RetrievalSystem(
embed_model=embed_model, bm25=bm25, corpus_texts=texts, corpus_answers=answers,
corpus_embeddings=q_embs, answer_embeddings=a_embs, corpus_meta=metas,
nlp=nlp, faiss_index=faiss_index
)
# ========================
# 2) Generative Core (EXACT SAME)
# ========================
@dataclass
class GenerativeCore:
"""Generative core dataclass - same structure"""
model: AutoModelForCausalLM
tokenizer: AutoTokenizer
generation_config: GenerationConfig
def build_generative_core():
"""Build generative core - EXACT SAME IMPLEMENTATION"""
# Always download fresh from HuggingFace for reliability
print("Downloading TinyLlama with 4-bit quantization...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.chat_template = (
"{% for message in messages %}"
"{{'<|'+message['role']+'|>\\n'+message['content']+'</s>\\n'}}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"<|assistant|>\n"
"{% endif %}"
)
quantization_config = None
if torch.cuda.is_available():
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float32,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=quantization_config,
device_map="auto" if torch.cuda.is_available() else None,
low_cpu_mem_usage=True
)
model.eval()
gen_cfg = GenerationConfig(
max_new_tokens=300,
temperature=0.7,
top_p=0.9,
do_sample=True,
repetition_penalty=1.15,
pad_token_id=tokenizer.pad_token_id
)
# Save for future use (optional)
if not os.path.exists(LLM_ARTIFACT_PATH):
os.makedirs(LLM_ARTIFACT_PATH, exist_ok=True)
tokenizer.save_pretrained(LLM_ARTIFACT_PATH)
gen_cfg.save_pretrained(LLM_ARTIFACT_PATH)
return GenerativeCore(model, tokenizer, gen_cfg)
# 3) SOTA Enhanced Retrieval (EXACT SAME)
class HybridCodeAssistant:
"""Main assistant class - EXACT SAME IMPLEMENTATION"""
def __init__(self):
self.retrieval = build_retrieval_system(load_opc_datasets())
self.generator = build_generative_core()
logger.info("Codey Bryant 3.0 ready with HyDE + Query Rewriting + Multi-Query + Answer-Space Retrieval!")
def generate_hyde(self, query: str) -> str:
"""Generate HyDE - same implementation"""
prompt = f"""Write a concise, direct Python code example or explanation that answers this question.
Only output the answer, no extra text.
Question: {query}
Answer:"""
inputs = self.generator.tokenizer(prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
out = self.generator.model.generate(**inputs, max_new_tokens=128, temperature=0.3, do_sample=True)
return self.generator.tokenizer.decode(out[0], skip_special_tokens=True).split("Answer:")[-1].strip()
def rewrite_query(self, query: str) -> str:
"""Rewrite query - same implementation"""
prompt = f"""Rewrite this vague or casual programming question into a clear, specific one for better code retrieval.
Original: {query}
Improved:"""
inputs = self.generator.tokenizer(prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
out = self.generator.model.generate(**inputs, max_new_tokens=64, temperature=0.1)
return self.generator.tokenizer.decode(out[0], skip_special_tokens=True).split("Improved:")[-1].strip()
def retrieve_enhanced(self, query: str, k: int = 3) -> List[Tuple[str, Dict, float]]:
"""Enhanced retrieval - EXACT SAME IMPLEMENTATION"""
# Use list of tuples instead of set to avoid hashability issues with dicts
results = []
def add_results(q_text: str, weight: float = 1.0):
try:
# Determine embedding space (answer for HyDE/long texts, question otherwise)
use_answer_space = "HyDE" in q_text or len(q_text.split()) > 20
target_embs = self.retrieval.answer_embeddings if use_answer_space else self.retrieval.corpus_embeddings
# Encode query
q_emb = self.retrieval.embed_model.encode(q_text, normalize_embeddings=True)
if self.retrieval.faiss_index is not None and use_answer_space:
# FAISS on answer space
query_vec = q_emb.astype('float32').reshape(1, -1)
scores_top, indices_top = self.retrieval.faiss_index.search(query_vec, min(k * 3, len(self.retrieval.corpus_texts)))
scores = scores_top[0]
idxs = indices_top[0]
else:
# Numpy fallback or question space
scores = np.dot(target_embs, q_emb)
idxs = np.argsort(-scores)[:k*3]
# Add BM25 if not answer space
if not use_answer_space:
tokenized_query = q_text.lower().split()
bm25_scores = self.retrieval.bm25.get_scores(tokenized_query)
if bm25_scores.max() > 0:
bm25_scores = (bm25_scores - bm25_scores.min()) / (bm25_scores.max() - bm25_scores.min())
else:
bm25_scores = np.zeros_like(bm25_scores)
scores = 0.3 * bm25_scores + 0.7 * scores # Hybrid
# Collect candidates (avoid duplicates by checking text)
seen_texts = set()
for score, idx in zip(scores, idxs):
if score > 0.15 and idx < len(self.retrieval.corpus_texts):
text = self.retrieval.corpus_texts[idx]
if text not in seen_texts:
seen_texts.add(text)
results.append((text, self.retrieval.corpus_meta[idx], float(score * weight)))
except Exception as e:
logger.error(f"add_results failed for '{q_text}': {e}")
# 1. Original query
add_results(query, weight=1.0)
# 2. Rewritten query
try:
rw = self.rewrite_query(query)
if len(rw) > 8 and rw != query:
add_results(rw, weight=1.2)
except Exception as e:
logger.warning(f"Rewrite failed: {e}")
# 3. HyDE (strong weight in answer space!)
try:
hyde = self.generate_hyde(query)
if len(hyde) > 20:
add_results(hyde, weight=1.5) # Note: No " HyDE" suffix needed now
except Exception as e:
logger.warning(f"HyDE failed: {e}")
# 4. Multi-query variants (lighter weight)
variants = [
f"Python code for: {query}",
f"Fix error: {query}",
f"Explain in Python: {query}",
f"Best way to {query} in Python",
]
for v in variants:
add_results(v, weight=0.8)
# Rerank by similarity to original (no set needed)
if not results:
return []
q_emb = self.retrieval.embed_model.encode(query, normalize_embeddings=True)
final = []
for text, meta, score in results:
text_emb = self.retrieval.embed_model.encode(text, normalize_embeddings=True)
sim = float(np.dot(q_emb, text_emb))
final.append((text, meta, score + 0.3 * sim))
final.sort(key=lambda x: x[2], reverse=True)
return final[:k]
def answer_stream(self, text: str) -> Iterator[str]:
"""Stream answer with proper message formatting"""
retrieved = self.retrieve_enhanced(text, k=3)
context = ""
if retrieved and retrieved[0][2] > 0.3:
q, meta, _ = retrieved[0]
ans = meta["answer"][:200]
context = f"Reference example:\nQ: {q}\nA: {ans}\n\n"
# Create properly formatted messages
system_content = "You are a concise, accurate Python coding assistant. " + context.strip()
# Format messages for TinyLlama chat template
messages = [
{"role": "user", "content": text}
]
# Add system message if context exists
if context:
messages.insert(0, {"role": "system", "content": system_content})
# Debug: Print messages format
logger.debug(f"Messages format: {messages}")
try:
# Apply chat template
prompt = self.generator.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
logger.debug(f"Generated prompt length: {len(prompt)}")
except Exception as e:
logger.error(f"Error applying chat template: {e}")
# Fallback: Use simple formatting
if context:
prompt = f"<|system|>\n{system_content}</s>\n<|user|>\n{text}</s>\n<|assistant|>\n"
else:
prompt = f"<|user|>\n{text}</s>\n<|assistant|>\n"
inputs = self.generator.tokenizer(prompt, return_tensors="pt").to(DEVICE)
streamer = TextIteratorStreamer(
self.generator.tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
generation_kwargs = dict(
**inputs,
streamer=streamer,
generation_config=self.generator.generation_config,
max_new_tokens=300
)
thread = Thread(target=self.generator.model.generate, kwargs=generation_kwargs)
thread.start()
for token in streamer:
yield token
thread.join()
ASSISTANT: Optional[HybridCodeAssistant] = None
def initialize_assistant():
"""Initialize assistant with progress tracking"""
global ASSISTANT
if ASSISTANT is None:
yield "Initializing Codey Bryant 3.0..."
yield "Loading retrieval system..."
ASSISTANT = HybridCodeAssistant()
yield "Codey Bryant 3.0 Ready!"
yield "SOTA RAG Features: HyDE + Query Rewriting + Multi-Query + Answer-Space Retrieval"
yield "Ask coding questions like: 'it's not working', 'help with error', 'make it faster'"
else:
yield "Assistant already initialized!"
def chat(message: str, history: list):
"""Chat function with error handling"""
if ASSISTANT is None:
yield "Please click 'Initialize Assistant' first!"
return
# Append user message
history.append([message, ""])
yield history
# Stream response
try:
response = ""
for token in ASSISTANT.answer_stream(message):
response += token
history[-1][1] = response
yield history
except Exception as e:
logger.error(f"Chat error: {e}")
history[-1][1] = f"Error: {str(e)}"
yield history
# 5) Main Entry Point - SIMPLE WORKING UI
if __name__ == "__main__":
# Configure for Hugging Face Spaces
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
server_port = int(os.environ.get("GRADIO_SERVER_PORT", 7860))
# SIMPLE, WORKING UI
with gr.Blocks(title="Codey Bryant 3.0") as demo:
gr.Markdown("""
# πŸ€– Codey Bryant 3.0
## **SOTA RAG Coding Assistant**
**Advanced Features:** HyDE + Query Rewriting + Multi-Query + Answer-Space Retrieval
""")
# Status display
status_output = gr.Textbox(
label="Status",
value="Click 'Initialize Assistant' to start",
interactive=False
)
# Initialize button
init_btn = gr.Button("πŸš€ Initialize Assistant", variant="primary")
# Chat interface
chatbot = gr.Chatbot(label="Chat", height=500)
with gr.Row():
msg = gr.Textbox(
placeholder="Ask Python coding questions...",
label="Your Question",
lines=2,
scale=4
)
submit_btn = gr.Button("Send", variant="secondary", scale=1)
clear_btn = gr.Button("Clear Chat")
# Event handlers
def on_init():
"""Handle initialization and update status"""
status_text = ""
for status in initialize_assistant():
status_text = status
yield status
# Enable the chat interface after initialization
yield status_text
init_btn.click(
fn=on_init,
outputs=status_output
)
def process_message(message, chat_history):
"""Process a new message"""
if not message.strip():
return "", chat_history
# Add user message
chat_history.append([message, ""])
return "", chat_history
def generate_response(message, chat_history):
"""Generate response from assistant"""
if not message.strip():
yield chat_history
return
try:
# Get streaming response
for updated_history in chat(message, chat_history):
yield updated_history
except Exception as e:
chat_history[-1][1] = f"Error: {str(e)}"
yield chat_history
# Connect submit button
submit_btn.click(
fn=process_message,
inputs=[msg, chatbot],
outputs=[msg, chatbot]
).then(
fn=generate_response,
inputs=[msg, chatbot],
outputs=chatbot
)
# Connect Enter key
msg.submit(
fn=process_message,
inputs=[msg, chatbot],
outputs=[msg, chatbot]
).then(
fn=generate_response,
inputs=[msg, chatbot],
outputs=chatbot
)
# Clear chat
clear_btn.click(lambda: [], None, chatbot)
# Launch the app
logger.info(f"Starting Codey Bryant 3.0 on {server_name}:{server_port}")
logger.info("SOTA RAG Architecture: HyDE + Query Rewriting + Multi-Query + Answer-Space Retrieval")
demo.launch(
server_name=server_name,
server_port=server_port,
share=False,
debug=False
)