WENior's picture
Update app.py
ba7e98a verified
import gradio as gr
import fitz # PyMuPDF
from transformers import pipeline
from keybert import KeyBERT
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import re
# =============== 初始化模型 ===============
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
qa_generator = pipeline("text2text-generation", model="google/flan-t5-base")
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
kw_model = KeyBERT(model=embedder)
# 全局变量(保存上传文献的分块与索引)
CHUNKS = []
INDEX = None
# =============== 工具函数 ===============
def clean_text(t):
t = t.replace("\x00", " ")
t = re.sub(r"\s+", " ", t)
return t.strip()
def pdf_to_text(file_bytes):
doc = fitz.open(stream=file_bytes, filetype="pdf")
texts = []
for page in doc:
t = page.get_text("text")
if t:
texts.append(t)
return clean_text("\n".join(texts))
def chunk_text(text, chunk_size=800, overlap=120):
if len(text) <= chunk_size:
return [text]
chunks, start = [], 0
while start < len(text):
end = min(start + chunk_size, len(text))
boundary = max(
text.rfind(". ", start, end),
text.rfind("。", start, end),
text.rfind("\n", start, end),
)
if boundary == -1 or boundary <= start + 200:
boundary = end
chunks.append(text[start:boundary].strip())
start = max(boundary - overlap, 0)
if start == boundary:
start += 1
return [c for c in chunks if len(c) > 10]
def build_faiss(chunks):
global INDEX
embs = embedder.encode(chunks, normalize_embeddings=True)
index = faiss.IndexFlatIP(embs.shape[1])
index.add(embs.astype(np.float32))
INDEX = index
def retrieve(query, top_k=5):
if INDEX is None or not CHUNKS:
return []
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
D, I = INDEX.search(q, top_k)
results = [(CHUNKS[i], float(D[0][j])) for j, i in enumerate(I[0]) if i != -1]
return results
# =============== 功能实现 ===============
def handle_upload(pdf_files):
global CHUNKS
if not pdf_files:
return "", "未上传文件"
texts = []
for f in pdf_files:
b = f.read()
txt = pdf_to_text(b)
texts.append(txt)
merged = "\n".join(texts)
CHUNKS = chunk_text(merged)
build_faiss(CHUNKS)
return merged[:50000], f"✅ 已成功建立索引:{len(CHUNKS)} 个片段"
def summarize_text(text):
if not text or len(text) < 50:
return "请输入更长的文本或先上传 PDF。"
parts = chunk_text(text, 1000, 100)
summaries = []
for p in parts:
try:
s = summarizer(p, max_length=150, min_length=40, do_sample=False)[0]["summary_text"]
summaries.append(s)
except Exception:
continue
combined = " ".join(summaries)
final = summarizer(combined, max_length=200, min_length=60, do_sample=False)[0]["summary_text"]
return final
def extract_keywords(text, top_n=10):
if not text or len(text) < 50:
return "请输入更长的文本或先上传 PDF。"
pairs = kw_model.extract_keywords(text[:10000], top_n=top_n)
return ", ".join([k for k, _ in pairs])
def answer_question(question, top_k=5):
if not CHUNKS:
return "请先上传 PDF 并建立索引。", ""
docs = retrieve(question, top_k)
context = "\n\n".join([f"[{i+1}] {c}" for i, (c, _) in enumerate(docs)])
prompt = (
"You are a helpful research assistant. Answer the question strictly based on the CONTEXT. "
"If the answer cannot be found, say 'Not found in the provided documents.'\n\n"
f"CONTEXT:\n{context}\n\nQUESTION: {question}\nANSWER:"
)
out = qa_generator(prompt, max_new_tokens=256)[0]["generated_text"]
cites = "\n".join([f"[{i+1}] 相似度={score:.3f}" for i, (_, score) in enumerate(docs)])
return out, cites
# =============== Gradio 界面 ===============
with gr.Blocks(title="Paper Reader Assistant") as demo:
gr.Markdown("""
# 📖 Paper Reader Assistant
上传 PDF,自动抽取文本,生成摘要、关键词,并支持基于内容的问答(RAG)。
""")
with gr.Row():
pdf_uploader = gr.File(label="上传 PDF(可多选)", file_count="multiple", file_types=[".pdf"])
build_info = gr.Textbox(label="状态", interactive=False)
# 修改点 1:使用 Textbox 替代 Dataframe
with gr.Row():
doc_text = gr.Textbox(label="文档全文(前 50,000 字符预览)", lines=14)
file_table = gr.Textbox(label="文件状态", lines=3)
upload_btn = gr.Button("📥 解析 PDF 并建立索引")
with gr.Tab("📝 摘要"):
sum_btn = gr.Button("生成摘要")
sum_out = gr.Textbox(label="摘要结果", lines=10)
with gr.Tab("🔑 关键词"):
kw_btn = gr.Button("提取关键词")
kw_out = gr.Textbox(label="关键词", lines=8)
with gr.Tab("❓ 问答 RAG"):
question = gr.Textbox(label="你的问题", lines=2)
qa_btn = gr.Button("回答问题")
answer_out = gr.Textbox(label="答案", lines=10)
cites_out = gr.Textbox(label="参考片段", lines=6)
# 修改点 2:绑定时删除 Dataframe 输出
upload_btn.click(handle_upload, inputs=[pdf_uploader], outputs=[doc_text, build_info])
sum_btn.click(summarize_text, inputs=[doc_text], outputs=sum_out)
kw_btn.click(extract_keywords, inputs=[doc_text], outputs=kw_out)
qa_btn.click(answer_question, inputs=[question], outputs=[answer_out, cites_out])
# 修改点 3:允许远程访问、生成共享链接
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)