RichardHu commited on
Commit
f2f8da5
Β·
verified Β·
1 Parent(s): 9570ac3

Update retriever.py

Browse files
Files changed (1) hide show
  1. retriever.py +56 -44
retriever.py CHANGED
@@ -1,53 +1,65 @@
1
- from smolagents import Tool
2
- from langchain_community.retrievers import BM25Retriever
3
- from langchain.docstore.document import Document
4
- import datasets
5
-
6
-
7
- class GuestInfoRetrieverTool(Tool):
8
- name = "guest_info_retriever"
9
- description = "Retrieves detailed information about gala guests based on their name or relation."
10
- inputs = {
11
- "query": {
12
- "type": "string",
13
- "description": "The name or relation of the guest you want information about."
14
- }
15
- }
16
- output_type = "string"
17
-
18
- def __init__(self, docs):
19
- self.is_initialized = False
20
- self.retriever = BM25Retriever.from_documents(docs)
21
 
22
 
23
- def forward(self, query: str):
24
- results = self.retriever.get_relevant_documents(query)
25
- if results:
26
- return "\n\n".join([doc.page_content for doc in results[:3]])
27
- else:
28
- return "No matching guest information found."
 
29
 
 
 
 
30
 
31
- def load_guest_dataset():
32
- # Load the dataset
33
- guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
 
 
 
 
 
 
 
 
 
 
34
 
35
- # Convert dataset entries into Document objects
36
- docs = [
37
- Document(
38
- page_content="\n".join([
39
- f"Name: {guest['name']}",
40
- f"Relation: {guest['relation']}",
41
- f"Description: {guest['description']}",
42
- f"Email: {guest['email']}"
43
- ]),
44
- metadata={"name": guest["name"]}
45
- )
46
- for guest in guest_dataset
47
- ]
48
 
49
- # Return the tool
50
- return GuestInfoRetrieverTool(docs)
51
 
 
 
52
 
 
 
 
 
 
 
 
 
 
53
 
 
1
+ # from smolagents import Tool
2
+ # from langchain_community.retrievers import BM25Retriever
3
+ # from langchain.docstore.document import Document
4
+ # import datasets
5
+
6
+
7
+ # class GuestInfoRetrieverTool(Tool):
8
+ # name = "guest_info_retriever"
9
+ # description = "Retrieves detailed information about gala guests based on their name or relation."
10
+ # inputs = {
11
+ # "query": {
12
+ # "type": "string",
13
+ # "description": "The name or relation of the guest you want information about."
14
+ # }
15
+ # }
16
+ # output_type = "string"
17
+
18
+ # def __init__(self, docs):
19
+ # self.is_initialized = False
20
+ # self.retriever = BM25Retriever.from_documents(docs)
21
 
22
 
23
+ # def forward(self, query: str):
24
+ # results = self.retriever.get_relevant_documents(query)
25
+ # if results:
26
+ # return "\n\n".join([doc.page_content for doc in results[:3]])
27
+ # else:
28
+ # return "No matching guest information found."
29
+
30
 
31
+ # def load_guest_dataset():
32
+ # # Load the dataset
33
+ # guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
34
 
35
+ # # Convert dataset entries into Document objects
36
+ # docs = [
37
+ # Document(
38
+ # page_content="\n".join([
39
+ # f"Name: {guest['name']}",
40
+ # f"Relation: {guest['relation']}",
41
+ # f"Description: {guest['description']}",
42
+ # f"Email: {guest['email']}"
43
+ # ]),
44
+ # metadata={"name": guest["name"]}
45
+ # )
46
+ # for guest in guest_dataset
47
+ # ]
48
 
49
+ # # Return the tool
50
+ # return GuestInfoRetrieverTool(docs)
 
 
 
 
 
 
 
 
 
 
 
51
 
 
 
52
 
53
+ from langchain_community.vectorstores import Chroma
54
+ from langchain_openai import OpenAIEmbeddings
55
 
56
+ def get_retriever():
57
+ """εˆ›ε»ΊεΉΆθΏ”ε›žζ£€η΄’ε™¨"""
58
+ embeddings = OpenAIEmbeddings()
59
+ vectorstore = Chroma(
60
+ embedding_function=embeddings,
61
+ persist_directory="./chroma_db",
62
+ collection_name="rag_docs"
63
+ )
64
+ return vectorstore.as_retriever(search_kwargs={"k": 5})
65