import gradio as gr import fitz # PyMuPDF from transformers import pipeline from keybert import KeyBERT from sentence_transformers import SentenceTransformer import faiss import numpy as np import re # =============== 初始化模型 =============== summarizer = pipeline("summarization", model="facebook/bart-large-cnn") qa_generator = pipeline("text2text-generation", model="google/flan-t5-base") embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") kw_model = KeyBERT(model=embedder) # 全局变量(保存上传文献的分块与索引) CHUNKS = [] INDEX = None # =============== 工具函数 =============== def clean_text(t): t = t.replace("\x00", " ") t = re.sub(r"\s+", " ", t) return t.strip() def pdf_to_text(file_bytes): doc = fitz.open(stream=file_bytes, filetype="pdf") texts = [] for page in doc: t = page.get_text("text") if t: texts.append(t) return clean_text("\n".join(texts)) def chunk_text(text, chunk_size=800, overlap=120): if len(text) <= chunk_size: return [text] chunks, start = [], 0 while start < len(text): end = min(start + chunk_size, len(text)) boundary = max( text.rfind(". ", start, end), text.rfind("。", start, end), text.rfind("\n", start, end), ) if boundary == -1 or boundary <= start + 200: boundary = end chunks.append(text[start:boundary].strip()) start = max(boundary - overlap, 0) if start == boundary: start += 1 return [c for c in chunks if len(c) > 10] def build_faiss(chunks): global INDEX embs = embedder.encode(chunks, normalize_embeddings=True) index = faiss.IndexFlatIP(embs.shape[1]) index.add(embs.astype(np.float32)) INDEX = index def retrieve(query, top_k=5): if INDEX is None or not CHUNKS: return [] q = embedder.encode([query], normalize_embeddings=True).astype(np.float32) D, I = INDEX.search(q, top_k) results = [(CHUNKS[i], float(D[0][j])) for j, i in enumerate(I[0]) if i != -1] return results # =============== 功能实现 =============== def handle_upload(pdf_files): global CHUNKS if not pdf_files: return "", "未上传文件" texts = [] for f in pdf_files: b = f.read() txt = pdf_to_text(b) texts.append(txt) merged = "\n".join(texts) CHUNKS = chunk_text(merged) build_faiss(CHUNKS) return merged[:50000], f"✅ 已成功建立索引:{len(CHUNKS)} 个片段" def summarize_text(text): if not text or len(text) < 50: return "请输入更长的文本或先上传 PDF。" parts = chunk_text(text, 1000, 100) summaries = [] for p in parts: try: s = summarizer(p, max_length=150, min_length=40, do_sample=False)[0]["summary_text"] summaries.append(s) except Exception: continue combined = " ".join(summaries) final = summarizer(combined, max_length=200, min_length=60, do_sample=False)[0]["summary_text"] return final def extract_keywords(text, top_n=10): if not text or len(text) < 50: return "请输入更长的文本或先上传 PDF。" pairs = kw_model.extract_keywords(text[:10000], top_n=top_n) return ", ".join([k for k, _ in pairs]) def answer_question(question, top_k=5): if not CHUNKS: return "请先上传 PDF 并建立索引。", "" docs = retrieve(question, top_k) context = "\n\n".join([f"[{i+1}] {c}" for i, (c, _) in enumerate(docs)]) prompt = ( "You are a helpful research assistant. Answer the question strictly based on the CONTEXT. " "If the answer cannot be found, say 'Not found in the provided documents.'\n\n" f"CONTEXT:\n{context}\n\nQUESTION: {question}\nANSWER:" ) out = qa_generator(prompt, max_new_tokens=256)[0]["generated_text"] cites = "\n".join([f"[{i+1}] 相似度={score:.3f}" for i, (_, score) in enumerate(docs)]) return out, cites # =============== Gradio 界面 =============== with gr.Blocks(title="Paper Reader Assistant") as demo: gr.Markdown(""" # 📖 Paper Reader Assistant 上传 PDF,自动抽取文本,生成摘要、关键词,并支持基于内容的问答(RAG)。 """) with gr.Row(): pdf_uploader = gr.File(label="上传 PDF(可多选)", file_count="multiple", file_types=[".pdf"]) build_info = gr.Textbox(label="状态", interactive=False) # 修改点 1:使用 Textbox 替代 Dataframe with gr.Row(): doc_text = gr.Textbox(label="文档全文(前 50,000 字符预览)", lines=14) file_table = gr.Textbox(label="文件状态", lines=3) upload_btn = gr.Button("📥 解析 PDF 并建立索引") with gr.Tab("📝 摘要"): sum_btn = gr.Button("生成摘要") sum_out = gr.Textbox(label="摘要结果", lines=10) with gr.Tab("🔑 关键词"): kw_btn = gr.Button("提取关键词") kw_out = gr.Textbox(label="关键词", lines=8) with gr.Tab("❓ 问答 RAG"): question = gr.Textbox(label="你的问题", lines=2) qa_btn = gr.Button("回答问题") answer_out = gr.Textbox(label="答案", lines=10) cites_out = gr.Textbox(label="参考片段", lines=6) # 修改点 2:绑定时删除 Dataframe 输出 upload_btn.click(handle_upload, inputs=[pdf_uploader], outputs=[doc_text, build_info]) sum_btn.click(summarize_text, inputs=[doc_text], outputs=sum_out) kw_btn.click(extract_keywords, inputs=[doc_text], outputs=kw_out) qa_btn.click(answer_question, inputs=[question], outputs=[answer_out, cites_out]) # 修改点 3:允许远程访问、生成共享链接 if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, share=True)