File size: 5,882 Bytes
bfe0937
4d8544f
 
 
 
 
 
 
bfe0937
4d8544f
 
 
 
 
bfe0937
4d8544f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfe0937
 
 
4d8544f
bfe0937
4d8544f
bfe0937
4d8544f
 
 
ba7e98a
 
 
 
4d8544f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba7e98a
4d8544f
 
 
 
bfe0937
ba7e98a
bfe0937
ba7e98a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import gradio as gr
import fitz  # PyMuPDF
from transformers import pipeline
from keybert import KeyBERT
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import re

# =============== 初始化模型 ===============
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
qa_generator = pipeline("text2text-generation", model="google/flan-t5-base")
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
kw_model = KeyBERT(model=embedder)

# 全局变量(保存上传文献的分块与索引)
CHUNKS = []
INDEX = None


# =============== 工具函数 ===============
def clean_text(t):
    t = t.replace("\x00", " ")
    t = re.sub(r"\s+", " ", t)
    return t.strip()


def pdf_to_text(file_bytes):
    doc = fitz.open(stream=file_bytes, filetype="pdf")
    texts = []
    for page in doc:
        t = page.get_text("text")
        if t:
            texts.append(t)
    return clean_text("\n".join(texts))


def chunk_text(text, chunk_size=800, overlap=120):
    if len(text) <= chunk_size:
        return [text]
    chunks, start = [], 0
    while start < len(text):
        end = min(start + chunk_size, len(text))
        boundary = max(
            text.rfind(". ", start, end),
            text.rfind("。", start, end),
            text.rfind("\n", start, end),
        )
        if boundary == -1 or boundary <= start + 200:
            boundary = end
        chunks.append(text[start:boundary].strip())
        start = max(boundary - overlap, 0)
        if start == boundary:
            start += 1
    return [c for c in chunks if len(c) > 10]


def build_faiss(chunks):
    global INDEX
    embs = embedder.encode(chunks, normalize_embeddings=True)
    index = faiss.IndexFlatIP(embs.shape[1])
    index.add(embs.astype(np.float32))
    INDEX = index


def retrieve(query, top_k=5):
    if INDEX is None or not CHUNKS:
        return []
    q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
    D, I = INDEX.search(q, top_k)
    results = [(CHUNKS[i], float(D[0][j])) for j, i in enumerate(I[0]) if i != -1]
    return results


# =============== 功能实现 ===============
def handle_upload(pdf_files):
    global CHUNKS
    if not pdf_files:
        return "", "未上传文件"
    texts = []
    for f in pdf_files:
        b = f.read()
        txt = pdf_to_text(b)
        texts.append(txt)
    merged = "\n".join(texts)
    CHUNKS = chunk_text(merged)
    build_faiss(CHUNKS)
    return merged[:50000], f"✅ 已成功建立索引:{len(CHUNKS)} 个片段"


def summarize_text(text):
    if not text or len(text) < 50:
        return "请输入更长的文本或先上传 PDF。"
    parts = chunk_text(text, 1000, 100)
    summaries = []
    for p in parts:
        try:
            s = summarizer(p, max_length=150, min_length=40, do_sample=False)[0]["summary_text"]
            summaries.append(s)
        except Exception:
            continue
    combined = " ".join(summaries)
    final = summarizer(combined, max_length=200, min_length=60, do_sample=False)[0]["summary_text"]
    return final


def extract_keywords(text, top_n=10):
    if not text or len(text) < 50:
        return "请输入更长的文本或先上传 PDF。"
    pairs = kw_model.extract_keywords(text[:10000], top_n=top_n)
    return ", ".join([k for k, _ in pairs])


def answer_question(question, top_k=5):
    if not CHUNKS:
        return "请先上传 PDF 并建立索引。", ""
    docs = retrieve(question, top_k)
    context = "\n\n".join([f"[{i+1}] {c}" for i, (c, _) in enumerate(docs)])
    prompt = (
        "You are a helpful research assistant. Answer the question strictly based on the CONTEXT. "
        "If the answer cannot be found, say 'Not found in the provided documents.'\n\n"
        f"CONTEXT:\n{context}\n\nQUESTION: {question}\nANSWER:"
    )
    out = qa_generator(prompt, max_new_tokens=256)[0]["generated_text"]
    cites = "\n".join([f"[{i+1}] 相似度={score:.3f}" for i, (_, score) in enumerate(docs)])
    return out, cites


# =============== Gradio 界面 ===============
with gr.Blocks(title="Paper Reader Assistant") as demo:
    gr.Markdown("""
    # 📖 Paper Reader Assistant
    上传 PDF,自动抽取文本,生成摘要、关键词,并支持基于内容的问答(RAG)。
    """)

    with gr.Row():
        pdf_uploader = gr.File(label="上传 PDF(可多选)", file_count="multiple", file_types=[".pdf"])
        build_info = gr.Textbox(label="状态", interactive=False)

    # 修改点 1:使用 Textbox 替代 Dataframe
    with gr.Row():
        doc_text = gr.Textbox(label="文档全文(前 50,000 字符预览)", lines=14)
        file_table = gr.Textbox(label="文件状态", lines=3)

    upload_btn = gr.Button("📥 解析 PDF 并建立索引")

    with gr.Tab("📝 摘要"):
        sum_btn = gr.Button("生成摘要")
        sum_out = gr.Textbox(label="摘要结果", lines=10)

    with gr.Tab("🔑 关键词"):
        kw_btn = gr.Button("提取关键词")
        kw_out = gr.Textbox(label="关键词", lines=8)

    with gr.Tab("❓ 问答 RAG"):
        question = gr.Textbox(label="你的问题", lines=2)
        qa_btn = gr.Button("回答问题")
        answer_out = gr.Textbox(label="答案", lines=10)
        cites_out = gr.Textbox(label="参考片段", lines=6)

    # 修改点 2:绑定时删除 Dataframe 输出
    upload_btn.click(handle_upload, inputs=[pdf_uploader], outputs=[doc_text, build_info])
    sum_btn.click(summarize_text, inputs=[doc_text], outputs=sum_out)
    kw_btn.click(extract_keywords, inputs=[doc_text], outputs=kw_out)
    qa_btn.click(answer_question, inputs=[question], outputs=[answer_out, cites_out])

# 修改点 3:允许远程访问、生成共享链接
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, share=True)