Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import fitz # PyMuPDF | |
| from transformers import pipeline | |
| from keybert import KeyBERT | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import numpy as np | |
| import re | |
| # =============== 初始化模型 =============== | |
| summarizer = pipeline("summarization", model="facebook/bart-large-cnn") | |
| qa_generator = pipeline("text2text-generation", model="google/flan-t5-base") | |
| embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
| kw_model = KeyBERT(model=embedder) | |
| # 全局变量(保存上传文献的分块与索引) | |
| CHUNKS = [] | |
| INDEX = None | |
| # =============== 工具函数 =============== | |
| def clean_text(t): | |
| t = t.replace("\x00", " ") | |
| t = re.sub(r"\s+", " ", t) | |
| return t.strip() | |
| def pdf_to_text(file_bytes): | |
| doc = fitz.open(stream=file_bytes, filetype="pdf") | |
| texts = [] | |
| for page in doc: | |
| t = page.get_text("text") | |
| if t: | |
| texts.append(t) | |
| return clean_text("\n".join(texts)) | |
| def chunk_text(text, chunk_size=800, overlap=120): | |
| if len(text) <= chunk_size: | |
| return [text] | |
| chunks, start = [], 0 | |
| while start < len(text): | |
| end = min(start + chunk_size, len(text)) | |
| boundary = max( | |
| text.rfind(". ", start, end), | |
| text.rfind("。", start, end), | |
| text.rfind("\n", start, end), | |
| ) | |
| if boundary == -1 or boundary <= start + 200: | |
| boundary = end | |
| chunks.append(text[start:boundary].strip()) | |
| start = max(boundary - overlap, 0) | |
| if start == boundary: | |
| start += 1 | |
| return [c for c in chunks if len(c) > 10] | |
| def build_faiss(chunks): | |
| global INDEX | |
| embs = embedder.encode(chunks, normalize_embeddings=True) | |
| index = faiss.IndexFlatIP(embs.shape[1]) | |
| index.add(embs.astype(np.float32)) | |
| INDEX = index | |
| def retrieve(query, top_k=5): | |
| if INDEX is None or not CHUNKS: | |
| return [] | |
| q = embedder.encode([query], normalize_embeddings=True).astype(np.float32) | |
| D, I = INDEX.search(q, top_k) | |
| results = [(CHUNKS[i], float(D[0][j])) for j, i in enumerate(I[0]) if i != -1] | |
| return results | |
| # =============== 功能实现 =============== | |
| def handle_upload(pdf_files): | |
| global CHUNKS | |
| if not pdf_files: | |
| return "", "未上传文件" | |
| texts = [] | |
| for f in pdf_files: | |
| b = f.read() | |
| txt = pdf_to_text(b) | |
| texts.append(txt) | |
| merged = "\n".join(texts) | |
| CHUNKS = chunk_text(merged) | |
| build_faiss(CHUNKS) | |
| return merged[:50000], f"✅ 已成功建立索引:{len(CHUNKS)} 个片段" | |
| def summarize_text(text): | |
| if not text or len(text) < 50: | |
| return "请输入更长的文本或先上传 PDF。" | |
| parts = chunk_text(text, 1000, 100) | |
| summaries = [] | |
| for p in parts: | |
| try: | |
| s = summarizer(p, max_length=150, min_length=40, do_sample=False)[0]["summary_text"] | |
| summaries.append(s) | |
| except Exception: | |
| continue | |
| combined = " ".join(summaries) | |
| final = summarizer(combined, max_length=200, min_length=60, do_sample=False)[0]["summary_text"] | |
| return final | |
| def extract_keywords(text, top_n=10): | |
| if not text or len(text) < 50: | |
| return "请输入更长的文本或先上传 PDF。" | |
| pairs = kw_model.extract_keywords(text[:10000], top_n=top_n) | |
| return ", ".join([k for k, _ in pairs]) | |
| def answer_question(question, top_k=5): | |
| if not CHUNKS: | |
| return "请先上传 PDF 并建立索引。", "" | |
| docs = retrieve(question, top_k) | |
| context = "\n\n".join([f"[{i+1}] {c}" for i, (c, _) in enumerate(docs)]) | |
| prompt = ( | |
| "You are a helpful research assistant. Answer the question strictly based on the CONTEXT. " | |
| "If the answer cannot be found, say 'Not found in the provided documents.'\n\n" | |
| f"CONTEXT:\n{context}\n\nQUESTION: {question}\nANSWER:" | |
| ) | |
| out = qa_generator(prompt, max_new_tokens=256)[0]["generated_text"] | |
| cites = "\n".join([f"[{i+1}] 相似度={score:.3f}" for i, (_, score) in enumerate(docs)]) | |
| return out, cites | |
| # =============== Gradio 界面 =============== | |
| with gr.Blocks(title="Paper Reader Assistant") as demo: | |
| gr.Markdown(""" | |
| # 📖 Paper Reader Assistant | |
| 上传 PDF,自动抽取文本,生成摘要、关键词,并支持基于内容的问答(RAG)。 | |
| """) | |
| with gr.Row(): | |
| pdf_uploader = gr.File(label="上传 PDF(可多选)", file_count="multiple", file_types=[".pdf"]) | |
| build_info = gr.Textbox(label="状态", interactive=False) | |
| # 修改点 1:使用 Textbox 替代 Dataframe | |
| with gr.Row(): | |
| doc_text = gr.Textbox(label="文档全文(前 50,000 字符预览)", lines=14) | |
| file_table = gr.Textbox(label="文件状态", lines=3) | |
| upload_btn = gr.Button("📥 解析 PDF 并建立索引") | |
| with gr.Tab("📝 摘要"): | |
| sum_btn = gr.Button("生成摘要") | |
| sum_out = gr.Textbox(label="摘要结果", lines=10) | |
| with gr.Tab("🔑 关键词"): | |
| kw_btn = gr.Button("提取关键词") | |
| kw_out = gr.Textbox(label="关键词", lines=8) | |
| with gr.Tab("❓ 问答 RAG"): | |
| question = gr.Textbox(label="你的问题", lines=2) | |
| qa_btn = gr.Button("回答问题") | |
| answer_out = gr.Textbox(label="答案", lines=10) | |
| cites_out = gr.Textbox(label="参考片段", lines=6) | |
| # 修改点 2:绑定时删除 Dataframe 输出 | |
| upload_btn.click(handle_upload, inputs=[pdf_uploader], outputs=[doc_text, build_info]) | |
| sum_btn.click(summarize_text, inputs=[doc_text], outputs=sum_out) | |
| kw_btn.click(extract_keywords, inputs=[doc_text], outputs=kw_out) | |
| qa_btn.click(answer_question, inputs=[question], outputs=[answer_out, cites_out]) | |
| # 修改点 3:允许远程访问、生成共享链接 | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |