Spaces:
Sleeping
Sleeping
File size: 5,882 Bytes
bfe0937 4d8544f bfe0937 4d8544f bfe0937 4d8544f bfe0937 4d8544f bfe0937 4d8544f bfe0937 4d8544f ba7e98a 4d8544f ba7e98a 4d8544f bfe0937 ba7e98a bfe0937 ba7e98a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import gradio as gr
import fitz # PyMuPDF
from transformers import pipeline
from keybert import KeyBERT
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import re
# =============== 初始化模型 ===============
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
qa_generator = pipeline("text2text-generation", model="google/flan-t5-base")
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
kw_model = KeyBERT(model=embedder)
# 全局变量(保存上传文献的分块与索引)
CHUNKS = []
INDEX = None
# =============== 工具函数 ===============
def clean_text(t):
t = t.replace("\x00", " ")
t = re.sub(r"\s+", " ", t)
return t.strip()
def pdf_to_text(file_bytes):
doc = fitz.open(stream=file_bytes, filetype="pdf")
texts = []
for page in doc:
t = page.get_text("text")
if t:
texts.append(t)
return clean_text("\n".join(texts))
def chunk_text(text, chunk_size=800, overlap=120):
if len(text) <= chunk_size:
return [text]
chunks, start = [], 0
while start < len(text):
end = min(start + chunk_size, len(text))
boundary = max(
text.rfind(". ", start, end),
text.rfind("。", start, end),
text.rfind("\n", start, end),
)
if boundary == -1 or boundary <= start + 200:
boundary = end
chunks.append(text[start:boundary].strip())
start = max(boundary - overlap, 0)
if start == boundary:
start += 1
return [c for c in chunks if len(c) > 10]
def build_faiss(chunks):
global INDEX
embs = embedder.encode(chunks, normalize_embeddings=True)
index = faiss.IndexFlatIP(embs.shape[1])
index.add(embs.astype(np.float32))
INDEX = index
def retrieve(query, top_k=5):
if INDEX is None or not CHUNKS:
return []
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
D, I = INDEX.search(q, top_k)
results = [(CHUNKS[i], float(D[0][j])) for j, i in enumerate(I[0]) if i != -1]
return results
# =============== 功能实现 ===============
def handle_upload(pdf_files):
global CHUNKS
if not pdf_files:
return "", "未上传文件"
texts = []
for f in pdf_files:
b = f.read()
txt = pdf_to_text(b)
texts.append(txt)
merged = "\n".join(texts)
CHUNKS = chunk_text(merged)
build_faiss(CHUNKS)
return merged[:50000], f"✅ 已成功建立索引:{len(CHUNKS)} 个片段"
def summarize_text(text):
if not text or len(text) < 50:
return "请输入更长的文本或先上传 PDF。"
parts = chunk_text(text, 1000, 100)
summaries = []
for p in parts:
try:
s = summarizer(p, max_length=150, min_length=40, do_sample=False)[0]["summary_text"]
summaries.append(s)
except Exception:
continue
combined = " ".join(summaries)
final = summarizer(combined, max_length=200, min_length=60, do_sample=False)[0]["summary_text"]
return final
def extract_keywords(text, top_n=10):
if not text or len(text) < 50:
return "请输入更长的文本或先上传 PDF。"
pairs = kw_model.extract_keywords(text[:10000], top_n=top_n)
return ", ".join([k for k, _ in pairs])
def answer_question(question, top_k=5):
if not CHUNKS:
return "请先上传 PDF 并建立索引。", ""
docs = retrieve(question, top_k)
context = "\n\n".join([f"[{i+1}] {c}" for i, (c, _) in enumerate(docs)])
prompt = (
"You are a helpful research assistant. Answer the question strictly based on the CONTEXT. "
"If the answer cannot be found, say 'Not found in the provided documents.'\n\n"
f"CONTEXT:\n{context}\n\nQUESTION: {question}\nANSWER:"
)
out = qa_generator(prompt, max_new_tokens=256)[0]["generated_text"]
cites = "\n".join([f"[{i+1}] 相似度={score:.3f}" for i, (_, score) in enumerate(docs)])
return out, cites
# =============== Gradio 界面 ===============
with gr.Blocks(title="Paper Reader Assistant") as demo:
gr.Markdown("""
# 📖 Paper Reader Assistant
上传 PDF,自动抽取文本,生成摘要、关键词,并支持基于内容的问答(RAG)。
""")
with gr.Row():
pdf_uploader = gr.File(label="上传 PDF(可多选)", file_count="multiple", file_types=[".pdf"])
build_info = gr.Textbox(label="状态", interactive=False)
# 修改点 1:使用 Textbox 替代 Dataframe
with gr.Row():
doc_text = gr.Textbox(label="文档全文(前 50,000 字符预览)", lines=14)
file_table = gr.Textbox(label="文件状态", lines=3)
upload_btn = gr.Button("📥 解析 PDF 并建立索引")
with gr.Tab("📝 摘要"):
sum_btn = gr.Button("生成摘要")
sum_out = gr.Textbox(label="摘要结果", lines=10)
with gr.Tab("🔑 关键词"):
kw_btn = gr.Button("提取关键词")
kw_out = gr.Textbox(label="关键词", lines=8)
with gr.Tab("❓ 问答 RAG"):
question = gr.Textbox(label="你的问题", lines=2)
qa_btn = gr.Button("回答问题")
answer_out = gr.Textbox(label="答案", lines=10)
cites_out = gr.Textbox(label="参考片段", lines=6)
# 修改点 2:绑定时删除 Dataframe 输出
upload_btn.click(handle_upload, inputs=[pdf_uploader], outputs=[doc_text, build_info])
sum_btn.click(summarize_text, inputs=[doc_text], outputs=sum_out)
kw_btn.click(extract_keywords, inputs=[doc_text], outputs=kw_out)
qa_btn.click(answer_question, inputs=[question], outputs=[answer_out, cites_out])
# 修改点 3:允许远程访问、生成共享链接
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|