Spaces:
Configuration error
Configuration error
| # rag.py | |
| import os | |
| import json | |
| import pickle | |
| import logging | |
| from typing import List, Tuple, Optional | |
| import numpy as np | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| from config import VECTORSTORE_DIR, EMBEDDING_MODEL | |
| log = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| class RAGAgent: | |
| """ | |
| Loads a FAISS index + metadata from VECTORSTORE_DIR (config). | |
| Provides retrieve(query, k) -> (contexts: List[str], sources: List[dict]) | |
| """ | |
| def __init__(self, vectorstore_dir: Optional[str] = None, embedding_model: Optional[str] = None): | |
| self.vectorstore_dir = vectorstore_dir or str(VECTORSTORE_DIR) | |
| self.embedding_model_name = embedding_model or EMBEDDING_MODEL | |
| self.index: Optional[faiss.Index] = None | |
| self.metadata: Optional[List[dict]] = None | |
| self._embedder: Optional[SentenceTransformer] = None | |
| self._loaded = False | |
| def _find_index_file(self) -> str: | |
| if not os.path.isdir(self.vectorstore_dir): | |
| raise FileNotFoundError(f"Vectorstore dir not found: {self.vectorstore_dir}") | |
| for fname in os.listdir(self.vectorstore_dir): | |
| if fname.endswith(".faiss") or fname.endswith(".index") or fname.endswith(".bin") or fname.startswith("index"): | |
| return os.path.join(self.vectorstore_dir, fname) | |
| raise FileNotFoundError(f"No FAISS index file (.faiss/.index/.bin) found in {self.vectorstore_dir}") | |
| def _find_meta_file(self) -> str: | |
| for candidate in ("index.pkl", "metadata.pkl", "index_meta.pkl", "metadata.json", "index.json"): | |
| p = os.path.join(self.vectorstore_dir, candidate) | |
| if os.path.exists(p): | |
| return p | |
| for fname in os.listdir(self.vectorstore_dir): | |
| if fname.endswith(".pkl"): | |
| return os.path.join(self.vectorstore_dir, fname) | |
| raise FileNotFoundError(f"No metadata (.pkl/.json) found in {self.vectorstore_dir}") | |
| def embedder(self) -> SentenceTransformer: | |
| if self._embedder is None: | |
| log.info("Loading embedder: %s", self.embedding_model_name) | |
| self._embedder = SentenceTransformer(self.embedding_model_name) | |
| return self._embedder | |
| def load(self) -> None: | |
| """Load index and metadata into memory (idempotent).""" | |
| if self._loaded: | |
| return | |
| idx_path = self._find_index_file() | |
| meta_path = self._find_meta_file() | |
| log.info("Loading FAISS index from: %s", idx_path) | |
| try: | |
| self.index = faiss.read_index(idx_path) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to read faiss index {idx_path}: {e}") | |
| log.info("Loading metadata from: %s", meta_path) | |
| if meta_path.endswith(".json"): | |
| with open(meta_path, "r", encoding="utf-8") as f: | |
| self.metadata = json.load(f) | |
| else: | |
| with open(meta_path, "rb") as f: | |
| self.metadata = pickle.load(f) | |
| if not isinstance(self.metadata, list): | |
| if isinstance(self.metadata, dict): | |
| keys = sorted(self.metadata.keys()) | |
| try: | |
| self.metadata = [self.metadata[k] for k in keys] | |
| except Exception: | |
| self.metadata = list(self.metadata.values()) | |
| else: | |
| self.metadata = list(self.metadata) | |
| log.info("Loaded index and metadata: metadata length=%d", len(self.metadata)) | |
| self._loaded = True | |
| def retrieve(self, query: str, k: int = 3) -> Tuple[List[str], List[dict]]: | |
| """ | |
| Return two lists: | |
| - contexts: [str, ...] top-k chunk texts (may be fewer) | |
| - sources: [ {meta..., "score": float}, ... ] | |
| """ | |
| if not self._loaded: | |
| self.load() | |
| if self.index is None or self.metadata is None: | |
| return [], [] | |
| q_emb = self.embedder.encode([query], convert_to_numpy=True) | |
| # try normalize if index uses normalized vectors | |
| try: | |
| faiss.normalize_L2(q_emb) | |
| except Exception: | |
| pass | |
| q_emb = q_emb.astype("float32") | |
| # safe search call | |
| try: | |
| D, I = self.index.search(q_emb, k) | |
| except Exception as e: | |
| log.warning("FAISS search error: %s", e) | |
| return [], [] | |
| # ensure shapes | |
| if I is None or D is None: | |
| return [], [] | |
| indices = np.array(I).reshape(-1)[:k].tolist() | |
| scores = np.array(D).reshape(-1)[:k].tolist() | |
| contexts = [] | |
| sources = [] | |
| for idx, score in zip(indices, scores): | |
| if int(idx) < 0: | |
| continue | |
| # guard against idx out of metadata bounds | |
| if idx >= len(self.metadata): | |
| log.debug("Index %s >= metadata length %d — skipping", idx, len(self.metadata)) | |
| continue | |
| meta = self.metadata[int(idx)] | |
| # extract text from common keys | |
| text = None | |
| for key in ("text", "page_content", "content", "chunk_text", "source_text"): | |
| if isinstance(meta, dict) and key in meta and meta[key]: | |
| text = meta[key] | |
| break | |
| if text is None: | |
| # fallbac if metadata itself is a string or has 'text' attribute | |
| if isinstance(meta, str): | |
| text = meta | |
| elif isinstance(meta, dict) and "metadata" in meta and isinstance(meta["metadata"], dict): | |
| # sometimes nested | |
| text = meta["metadata"].get("text") or meta["metadata"].get("page_content") | |
| else: | |
| text = str(meta) | |
| contexts.append(text) | |
| sources.append({"meta": meta, "score": float(score)}) | |
| return contexts, sources | |