Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings | |
| from langchain_core.runnables.passthrough import RunnableAssign, RunnablePassthrough | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain_core.messages import get_buffer_string | |
| from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| embedder = NVIDIAEmbeddings(model="nvolveqa_40k", model_type=None) | |
| db = FAISS.load_local("data_first_faiss_index", embedder, allow_dangerous_deserialization=True) | |
| # docs = new_db.similarity_search(query) | |
| nvidia_api_key = os.environ.get("NVIDIA_API_KEY", "") | |
| from operator import itemgetter | |
| # available models names | |
| # mixtral_8x7b | |
| # llama2_13b | |
| llm = ChatNVIDIA(model="mixtral_8x7b") | StrOutputParser() | |
| initial_msg = ( | |
| "Hello! I am a chatbot to help with any questions about Data First Company." | |
| f"\nHow can I help you?" | |
| ) | |
| context_prompt = ChatPromptTemplate.from_messages([ | |
| ('system', | |
| "You are a chatbot, and you are helping customer with their inquries about Data First Company." | |
| "Answer the question using only the context provided. Do not include based on the context or based on the documents provided in your answer." | |
| "Please help them with their question about the company. Remember that your job is to represent Data First company that create data solutions." | |
| "Do not hallucinate any details, and make sure the knowledge base is not redundant." | |
| "Please say you do not know if you do not know or you cannot find the information needed." | |
| "\n\nQuestion: {question}\n\nContext: {context}"), | |
| ('user', "{question}" | |
| )]) | |
| chain = ( | |
| { | |
| 'context': db.as_retriever(search_type="similarity"), | |
| 'question': (lambda x:x) | |
| } | |
| | context_prompt | |
| # | RPrint() | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| conv_chain = ( | |
| context_prompt | |
| # | RPrint() | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| def chat_gen(message, history, return_buffer=True): | |
| buffer = "" | |
| doc_retriever = db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.2}) | |
| retrieved_docs = doc_retriever.invoke(message) | |
| print(len(retrieved_docs)) | |
| print(retrieved_docs) | |
| if len(retrieved_docs) > 0: | |
| state = { | |
| 'question': message, | |
| 'context': retrieved_docs | |
| } | |
| for token in conv_chain.stream(state): | |
| buffer += token | |
| yield buffer | |
| else: | |
| passage = "I am sorry. I do not have relevant information to answer the question. Please try another question." | |
| buffer += passage | |
| yield buffer if return_buffer else passage | |
| chatbot = gr.Chatbot(value = [[None, initial_msg]]) | |
| iface = gr.ChatInterface(chat_gen, chatbot=chatbot).queue() | |
| iface.launch() |