Spaces:
Build error
Build error
| from modules.chat.helpers import get_prompt | |
| from modules.chat.chat_model_loader import ChatModelLoader | |
| from modules.vectorstore.store_manager import VectorStoreManager | |
| from modules.retriever.retriever import Retriever | |
| from modules.chat.langchain.langchain_rag import ( | |
| Langchain_RAG_V2, | |
| QuestionGenerator, | |
| ) | |
| class LLMTutor: | |
| def __init__(self, config, user, logger=None): | |
| """ | |
| Initialize the LLMTutor class. | |
| Args: | |
| config (dict): Configuration dictionary. | |
| user (str): User identifier. | |
| logger (Logger, optional): Logger instance. Defaults to None. | |
| """ | |
| self.config = config | |
| self.llm = self.load_llm() | |
| self.user = user | |
| self.logger = logger | |
| self.vector_db = VectorStoreManager(config, logger=self.logger).load_database() | |
| self.qa_prompt = get_prompt(config, "qa") # Initialize qa_prompt | |
| self.rephrase_prompt = get_prompt( | |
| config, "rephrase" | |
| ) # Initialize rephrase_prompt | |
| # TODO: Removed this functionality for now, don't know if we need it | |
| # if self.config["vectorstore"]["embedd_files"]: | |
| # self.vector_db.create_database() | |
| # self.vector_db.save_database() | |
| def update_llm(self, old_config, new_config): | |
| """ | |
| Update the LLM and VectorStoreManager based on new configuration. | |
| Args: | |
| new_config (dict): New configuration dictionary. | |
| """ | |
| changes = self.get_config_changes(old_config, new_config) | |
| if "llm_params.llm_loader" in changes: | |
| self.llm = self.load_llm() # Reinitialize LLM if chat_model changes | |
| if "vectorstore.db_option" in changes: | |
| self.vector_db = VectorStoreManager( | |
| self.config, logger=self.logger | |
| ).load_database() # Reinitialize VectorStoreManager if vectorstore changes | |
| # TODO: Removed this functionality for now, don't know if we need it | |
| # if self.config["vectorstore"]["embedd_files"]: | |
| # self.vector_db.create_database() | |
| # self.vector_db.save_database() | |
| if "llm_params.llm_style" in changes: | |
| self.qa_prompt = get_prompt( | |
| self.config, "qa" | |
| ) # Update qa_prompt if ELI5 changes | |
| def get_config_changes(self, old_config, new_config): | |
| """ | |
| Get the changes between the old and new configuration. | |
| Args: | |
| old_config (dict): Old configuration dictionary. | |
| new_config (dict): New configuration dictionary. | |
| Returns: | |
| dict: Dictionary containing the changes. | |
| """ | |
| changes = {} | |
| def compare_dicts(old, new, parent_key=""): | |
| for key in new: | |
| full_key = f"{parent_key}.{key}" if parent_key else key | |
| if isinstance(new[key], dict) and isinstance(old.get(key), dict): | |
| compare_dicts(old.get(key, {}), new[key], full_key) | |
| elif old.get(key) != new[key]: | |
| changes[full_key] = (old.get(key), new[key]) | |
| # Include keys that are in old but not in new | |
| for key in old: | |
| if key not in new: | |
| full_key = f"{parent_key}.{key}" if parent_key else key | |
| changes[full_key] = (old[key], None) | |
| compare_dicts(old_config, new_config) | |
| return changes | |
| def retrieval_qa_chain( | |
| self, llm, qa_prompt, rephrase_prompt, db, memory=None, callbacks=None | |
| ): | |
| """ | |
| Create a Retrieval QA Chain. | |
| Args: | |
| llm (LLM): The language model instance. | |
| qa_prompt (str): The QA prompt string. | |
| rephrase_prompt (str): The rephrase prompt string. | |
| db (VectorStore): The vector store instance. | |
| memory (Memory, optional): Memory instance. Defaults to None. | |
| Returns: | |
| Chain: The retrieval QA chain instance. | |
| """ | |
| retriever = Retriever(self.config)._return_retriever(db) | |
| if self.config["llm_params"]["llm_arch"] == "langchain": | |
| self.qa_chain = Langchain_RAG_V2( | |
| llm=llm, | |
| memory=memory, | |
| retriever=retriever, | |
| qa_prompt=qa_prompt, | |
| rephrase_prompt=rephrase_prompt, | |
| config=self.config, | |
| callbacks=callbacks, | |
| ) | |
| self.question_generator = QuestionGenerator() | |
| else: | |
| raise ValueError( | |
| f"Invalid LLM Architecture: {self.config['llm_params']['llm_arch']}" | |
| ) | |
| return self.qa_chain | |
| def load_llm(self): | |
| """ | |
| Load the language model. | |
| Returns: | |
| LLM: The loaded language model instance. | |
| """ | |
| chat_model_loader = ChatModelLoader(self.config) | |
| llm = chat_model_loader.load_chat_model() | |
| return llm | |
| def qa_bot(self, memory=None, callbacks=None): | |
| """ | |
| Create a QA bot instance. | |
| Args: | |
| memory (Memory, optional): Memory instance. Defaults to None. | |
| qa_prompt (str, optional): QA prompt string. Defaults to None. | |
| rephrase_prompt (str, optional): Rephrase prompt string. Defaults to None. | |
| Returns: | |
| Chain: The QA bot chain instance. | |
| """ | |
| # sanity check to see if there are any documents in the database | |
| if len(self.vector_db) == 0: | |
| raise ValueError( | |
| "No documents in the database. Populate the database first." | |
| ) | |
| qa = self.retrieval_qa_chain( | |
| self.llm, | |
| self.qa_prompt, | |
| self.rephrase_prompt, | |
| self.vector_db, | |
| memory, | |
| callbacks=callbacks, | |
| ) | |
| return qa | |