WENior commited on
Commit
4d8544f
·
verified ·
1 Parent(s): bfe0937

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -10
app.py CHANGED
@@ -1,21 +1,164 @@
1
  import gradio as gr
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return f"Hello, {name}! 👋"
 
 
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  with gr.Blocks(title="Paper Reader Assistant") as demo:
7
  gr.Markdown("""
8
  # 📖 Paper Reader Assistant
9
- 欢迎使用论文辅助阅读网站!
10
- (这是一个示例界面。上传PDF、摘要、关键词、问答功能可在完整版中添加。)
11
  """)
12
-
13
  with gr.Row():
14
- name_input = gr.Textbox(label="请输入你的名字")
15
- greet_btn = gr.Button("打招呼")
16
- output = gr.Textbox(label="结果")
17
-
18
- greet_btn.click(fn=greet, inputs=name_input, outputs=output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  if __name__ == "__main__":
21
  demo.launch()
 
1
  import gradio as gr
2
+ import fitz # PyMuPDF
3
+ from transformers import pipeline
4
+ from keybert import KeyBERT
5
+ from sentence_transformers import SentenceTransformer
6
+ import faiss
7
+ import numpy as np
8
+ import re
9
 
10
+ # =============== 初始化模型 ===============
11
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
12
+ qa_generator = pipeline("text2text-generation", model="google/flan-t5-base")
13
+ embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
14
+ kw_model = KeyBERT(model=embedder)
15
 
16
+ # 全局变量(保存上传文献的分块与索引)
17
+ CHUNKS = []
18
+ INDEX = None
19
+
20
+
21
+ # =============== 工具函数 ===============
22
+ def clean_text(t):
23
+ t = t.replace("\x00", " ")
24
+ t = re.sub(r"\s+", " ", t)
25
+ return t.strip()
26
+
27
+
28
+ def pdf_to_text(file_bytes):
29
+ doc = fitz.open(stream=file_bytes, filetype="pdf")
30
+ texts = []
31
+ for page in doc:
32
+ t = page.get_text("text")
33
+ if t:
34
+ texts.append(t)
35
+ return clean_text("\n".join(texts))
36
+
37
+
38
+ def chunk_text(text, chunk_size=800, overlap=120):
39
+ if len(text) <= chunk_size:
40
+ return [text]
41
+ chunks, start = [], 0
42
+ while start < len(text):
43
+ end = min(start + chunk_size, len(text))
44
+ boundary = max(
45
+ text.rfind(". ", start, end),
46
+ text.rfind("。", start, end),
47
+ text.rfind("\n", start, end),
48
+ )
49
+ if boundary == -1 or boundary <= start + 200:
50
+ boundary = end
51
+ chunks.append(text[start:boundary].strip())
52
+ start = max(boundary - overlap, 0)
53
+ if start == boundary:
54
+ start += 1
55
+ return [c for c in chunks if len(c) > 10]
56
+
57
+
58
+ def build_faiss(chunks):
59
+ global INDEX
60
+ embs = embedder.encode(chunks, normalize_embeddings=True)
61
+ index = faiss.IndexFlatIP(embs.shape[1])
62
+ index.add(embs.astype(np.float32))
63
+ INDEX = index
64
+
65
+
66
+ def retrieve(query, top_k=5):
67
+ if INDEX is None or not CHUNKS:
68
+ return []
69
+ q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
70
+ D, I = INDEX.search(q, top_k)
71
+ results = [(CHUNKS[i], float(D[0][j])) for j, i in enumerate(I[0]) if i != -1]
72
+ return results
73
+
74
+
75
+ # =============== 功能实现 ===============
76
+ def handle_upload(pdf_files):
77
+ global CHUNKS
78
+ if not pdf_files:
79
+ return "", "未上传文件"
80
+ texts = []
81
+ for f in pdf_files:
82
+ b = f.read()
83
+ txt = pdf_to_text(b)
84
+ texts.append(txt)
85
+ merged = "\n".join(texts)
86
+ CHUNKS = chunk_text(merged)
87
+ build_faiss(CHUNKS)
88
+ return merged[:50000], f"✅ 已成功建立索引:{len(CHUNKS)} 个片段"
89
+
90
+
91
+ def summarize_text(text):
92
+ if not text or len(text) < 50:
93
+ return "请输入更长的文本或先上传 PDF。"
94
+ parts = chunk_text(text, 1000, 100)
95
+ summaries = []
96
+ for p in parts:
97
+ try:
98
+ s = summarizer(p, max_length=150, min_length=40, do_sample=False)[0]["summary_text"]
99
+ summaries.append(s)
100
+ except Exception:
101
+ continue
102
+ combined = " ".join(summaries)
103
+ final = summarizer(combined, max_length=200, min_length=60, do_sample=False)[0]["summary_text"]
104
+ return final
105
+
106
+
107
+ def extract_keywords(text, top_n=10):
108
+ if not text or len(text) < 50:
109
+ return "请输入更长的文本或先上传 PDF。"
110
+ pairs = kw_model.extract_keywords(text[:10000], top_n=top_n)
111
+ return ", ".join([k for k, _ in pairs])
112
+
113
+
114
+ def answer_question(question, top_k=5):
115
+ if not CHUNKS:
116
+ return "请先上传 PDF 并建立索引。", ""
117
+ docs = retrieve(question, top_k)
118
+ context = "\n\n".join([f"[{i+1}] {c}" for i, (c, _) in enumerate(docs)])
119
+ prompt = (
120
+ "You are a helpful research assistant. Answer the question strictly based on the CONTEXT. "
121
+ "If the answer cannot be found, say 'Not found in the provided documents.'\n\n"
122
+ f"CONTEXT:\n{context}\n\nQUESTION: {question}\nANSWER:"
123
+ )
124
+ out = qa_generator(prompt, max_new_tokens=256)[0]["generated_text"]
125
+ cites = "\n".join([f"[{i+1}] 相似度={score:.3f}" for i, (_, score) in enumerate(docs)])
126
+ return out, cites
127
+
128
+
129
+ # =============== Gradio 界面 ===============
130
  with gr.Blocks(title="Paper Reader Assistant") as demo:
131
  gr.Markdown("""
132
  # 📖 Paper Reader Assistant
133
+ 上传 PDF,自动抽取文本,生成摘要、关键词,并支持基于内容的问答(RAG)。
 
134
  """)
135
+
136
  with gr.Row():
137
+ pdf_uploader = gr.File(label="上传 PDF(可多选)", file_count="multiple", file_types=[".pdf"])
138
+ build_info = gr.Textbox(label="状态", interactive=False)
139
+
140
+ doc_text = gr.Textbox(label="文档内容预览(前 50,000 字符)", lines=14)
141
+
142
+ upload_btn = gr.Button("📥 解析 PDF 并建立索引")
143
+
144
+ with gr.Tab("📝 摘要"):
145
+ sum_btn = gr.Button("生成摘要")
146
+ sum_out = gr.Textbox(label="摘要结果", lines=10)
147
+
148
+ with gr.Tab("🔑 关键词"):
149
+ kw_btn = gr.Button("提取关键词")
150
+ kw_out = gr.Textbox(label="关键词", lines=8)
151
+
152
+ with gr.Tab("❓ 问答 RAG"):
153
+ question = gr.Textbox(label="你的问题", lines=2)
154
+ qa_btn = gr.Button("回答问题")
155
+ answer_out = gr.Textbox(label="答案", lines=10)
156
+ cites_out = gr.Textbox(label="参考片段", lines=6)
157
+
158
+ upload_btn.click(handle_upload, inputs=[pdf_uploader], outputs=[doc_text, build_info])
159
+ sum_btn.click(summarize_text, inputs=[doc_text], outputs=sum_out)
160
+ kw_btn.click(extract_keywords, inputs=[doc_text], outputs=kw_out)
161
+ qa_btn.click(answer_question, inputs=[question], outputs=[answer_out, cites_out])
162
 
163
  if __name__ == "__main__":
164
  demo.launch()