Ubuntu
add requirements
3d8d4db
import gradio as gr
from transformers import AutoModel, AutoTokenizer
import torch
import os
import tempfile
from pathlib import Path
import re
from PIL import Image
import fitz # PyMuPDF
import io
import sys
import threading
import time
from queue import Queue
# os.environ["CUDA_VISIBLE_DEVICES"] = '0'
# 初始化模型
print("Loading model...")
model_name = 'deepseek-ai/DeepSeek-OCR'
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(
model_name,
_attn_implementation='flash_attention_2',
trust_remote_code=True,
use_safetensors=True
)
model = model.eval() # .cuda().to(torch.bfloat16)
print("Model loaded successfully!")
class StreamCapture:
"""捕获并流式输出标准输出"""
def __init__(self):
self.queue = Queue()
self.captured_text = []
self.original_stdout = None
self.is_capturing = False
def write(self, text):
"""捕获write调用"""
if text and text.strip():
self.captured_text.append(text)
self.queue.put(text)
# 同时输出到终端
if self.original_stdout:
self.original_stdout.write(text)
def flush(self):
"""flush方法"""
if self.original_stdout:
self.original_stdout.flush()
def start_capture(self):
"""开始捕获"""
self.original_stdout = sys.stdout
sys.stdout = self
self.is_capturing = True
self.captured_text = []
def stop_capture(self):
"""停止捕获"""
if self.is_capturing:
sys.stdout = self.original_stdout
self.is_capturing = False
return ''.join(self.captured_text)
def pdf_to_images(pdf_path, dpi=144):
"""将PDF转换为图片列表"""
images = []
pdf_document = fitz.open(pdf_path)
zoom = dpi / 72.0
matrix = fitz.Matrix(zoom, zoom)
for page_num in range(pdf_document.page_count):
page = pdf_document[page_num]
pixmap = page.get_pixmap(matrix=matrix, alpha=False)
Image.MAX_IMAGE_PIXELS = None
img_data = pixmap.tobytes("png")
img = Image.open(io.BytesIO(img_data))
images.append(img)
pdf_document.close()
return images
def extract_image_refs(text):
"""提取图片引用标签"""
pattern = r'(<\|ref\|>image<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
matches = re.findall(pattern, text, re.DOTALL)
return matches
def extract_coordinates(det_text):
"""从det标签中提取坐标"""
try:
coords_list = eval(det_text)
return coords_list
except:
return None
def crop_images_from_text(original_image, ocr_text, output_dir, page_idx=0):
"""根据OCR文本中的坐标裁切图片"""
if original_image is None or ocr_text is None:
print("crop_images_from_text: 输入为空")
return []
print(f"crop_images_from_text: 开始处理页面 {page_idx}")
print(f"输出目录: {output_dir}")
# 提取图片引用
image_refs = extract_image_refs(ocr_text)
print(f"找到 {len(image_refs)} 个图片引用")
if not image_refs:
return []
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
# 加载原始图片
if isinstance(original_image, str):
img = Image.open(original_image)
print(f"从文件加载图片: {original_image}")
else:
img = original_image
print("使用PIL Image对象")
image_width, image_height = img.size
print(f"图片尺寸: {image_width} x {image_height}")
cropped_images = []
for idx, (full_match, coords_str) in enumerate(image_refs):
print(f"\n处理图片引用 {idx+1}: {coords_str}")
coords_list = extract_coordinates(coords_str)
if coords_list is None:
print(" 坐标解析失败")
continue
print(f" 解析出 {len(coords_list)} 个坐标框")
# 处理每个坐标框
for coord_idx, coord in enumerate(coords_list):
try:
x1, y1, x2, y2 = coord
print(f" 框 {coord_idx+1}: 原始坐标 ({x1}, {y1}) -> ({x2}, {y2})")
# 将归一化坐标转换为实际像素坐标
px1 = int(x1 / 999 * image_width)
py1 = int(y1 / 999 * image_height)
px2 = int(x2 / 999 * image_width)
py2 = int(y2 / 999 * image_height)
print(f" 框 {coord_idx+1}: 像素坐标 ({px1}, {py1}) -> ({px2}, {py2})")
# 裁切图片
cropped = img.crop((px1, py1, px2, py2))
print(f" 裁切后尺寸: {cropped.size}")
# 保存图片
img_filename = f"page{page_idx}_img{len(cropped_images)}.jpg"
img_path = os.path.join(output_dir, img_filename)
cropped.save(img_path, quality=95)
print(f" 保存到: {img_path}")
cropped_images.append(img_path)
except Exception as e:
print(f" 裁切图片出错: {e}")
import traceback
traceback.print_exc()
continue
print(f"\n总共裁切了 {len(cropped_images)} 张图片")
return cropped_images
def clean_ocr_output(text, image_refs_replacement=None):
"""清理OCR输出文本"""
if text is None:
return ""
# 移除或格式化调试信息
# 匹配开头的调试块:=====================\nBASE: ...\nPATCHES: ...\n=====================
text = re.sub(r'={5,}\s*BASE:\s+.*?\s+PATCHES:\s+.*?\s+={5,}\s*', '', text, flags=re.DOTALL)
# 匹配结尾的调试块:==================================================\nimage size: ...\n...==================================================
text = re.sub(r'={10,}\s*image size:.*?={10,}\s*', '', text, flags=re.DOTALL)
# 移除grounding标签
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
matches = re.findall(pattern, text, re.DOTALL)
# 如果提供了图片引用替换,则替换image标签
img_idx = 0
for match in matches:
if '<|ref|>image<|/ref|>' in match[0]:
if image_refs_replacement:
# 用实际的图片链接替换
text = text.replace(match[0], image_refs_replacement.format(img_idx), 1)
img_idx += 1
else:
text = text.replace(match[0], '')
else:
text = text.replace(match[0], '')
# 清理特殊字符
text = text.replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:')
text = text.replace('\n\n\n\n', '\n\n').replace('\n\n\n', '\n\n')
# 移除结束标记
if '<|end▁of▁sentence|>' in text:
text = text.replace('<|end▁of▁sentence|>', '')
return text.strip()
def create_markdown_with_images(ocr_text, image_paths):
"""创建包含图片链接的Markdown文本"""
if ocr_text is None:
return ""
result_text = ocr_text
# 移除或格式化调试信息
# 匹配开头的调试块:=====================\nBASE: ...\nPATCHES: ...\n=====================
result_text = re.sub(r'={5,}\s*BASE:\s+.*?\s+PATCHES:\s+.*?\s+={5,}\s*', '', result_text, flags=re.DOTALL)
# 匹配结尾的调试块:==================================================\nimage size: ...\n...==================================================
result_text = re.sub(r'={10,}\s*image size:.*?={10,}\s*', '', result_text, flags=re.DOTALL)
# 移除grounding标签,但保留图片位置
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
matches = re.findall(pattern, result_text, re.DOTALL)
img_idx = 0
for match in matches:
if '<|ref|>image<|/ref|>' in match[0]:
# 用Markdown图片语法替换
if image_paths and img_idx < len(image_paths):
img_path = image_paths[img_idx]
# Gradio需要使用HTML img标签来显示本地图片
# 使用base64编码嵌入图片
try:
import base64
with open(img_path, 'rb') as f:
img_data = f.read()
img_base64 = base64.b64encode(img_data).decode('utf-8')
# 检测图片格式
img_ext = os.path.splitext(img_path)[1].lower()
if img_ext == '.png':
mime_type = 'image/png'
elif img_ext in ['.jpg', '.jpeg']:
mime_type = 'image/jpeg'
else:
mime_type = 'image/jpeg'
# 使用base64嵌入的img标签
img_markdown = f'\n\n<img src="data:{mime_type};base64,{img_base64}" alt="提取的图片{img_idx+1}" style="max-width: 100%; height: auto; border: 1px solid #ddd; border-radius: 4px; padding: 5px;">\n\n'
print(f"图片{img_idx+1}已转换为base64嵌入 (大小: {len(img_base64)} bytes)")
except Exception as e:
print(f"转换图片{img_idx+1}为base64失败: {e}")
img_markdown = f'\n\n[图片{img_idx+1}: {os.path.basename(img_path)}]\n\n'
result_text = result_text.replace(match[0], img_markdown, 1)
img_idx += 1
else:
result_text = result_text.replace(match[0], '\n[图片]\n', 1)
else:
# 移除其他标注
result_text = result_text.replace(match[0], '')
# 清理特殊字符
result_text = result_text.replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:')
result_text = result_text.replace('\n\n\n\n', '\n\n').replace('\n\n\n', '\n\n')
# 移除结束标记
if '<|end▁of▁sentence|>' in result_text:
result_text = result_text.replace('<|end▁of▁sentence|>', '')
return result_text.strip()
def create_pdf_markdown_with_images(raw_results_list, all_image_paths):
"""为PDF创建包含图片的Markdown"""
if not raw_results_list:
return ""
# 合并所有原始结果
combined_text = "\n".join(raw_results_list)
# 使用create_markdown_with_images处理
return create_markdown_with_images(combined_text, all_image_paths)
def process_image_ocr_stream(image, prompt_type, base_size, image_size, crop_mode):
"""处理单张图片的OCR - 生成器版本,支持流式输出"""
if image is None:
yield "请上传图片", "", "请上传图片", None
return
# 设置prompt
if prompt_type == "Free OCR":
prompt = "<image>\nFree OCR. "
else: # Convert to Markdown
prompt = "<image>\n<|grounding|>Convert the document to markdown. "
# 创建持久化的输出目录
script_dir = os.path.dirname(os.path.abspath(__file__))
output_base_dir = os.path.join(script_dir, "gradio_outputs")
session_dir = os.path.join(output_base_dir, f"session_{int(time.time() * 1000)}")
os.makedirs(session_dir, exist_ok=True)
try:
tmpdir = session_dir
# 保存上传的图片
temp_image_path = os.path.join(tmpdir, "input_image.png")
if isinstance(image, str):
# 如果是文件路径
temp_image_path = image
else:
# 如果是PIL Image或numpy array
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
image.save(temp_image_path)
yield "正在处理图片...", "", "正在处理图片...", None
# 创建流式捕获对象
stream_capture = StreamCapture()
# 在新线程中运行模型推理
result_container = {'result': None, 'error': None}
def run_inference():
try:
stream_capture.start_capture()
result = model.infer(
tokenizer,
prompt=prompt,
image_file=temp_image_path,
output_path=tmpdir,
base_size=int(base_size),
image_size=int(image_size),
crop_mode=crop_mode,
save_results=False,
test_compress=True
)
result_container['result'] = result
except Exception as e:
result_container['error'] = e
finally:
stream_capture.stop_capture()
# 启动推理线程
inference_thread = threading.Thread(target=run_inference)
inference_thread.start()
# 流式输出
accumulated_text = ""
last_update = time.time()
while inference_thread.is_alive() or not stream_capture.queue.empty():
try:
# 尝试从队列获取新文本
text = stream_capture.queue.get(timeout=0.1)
accumulated_text += text
# 清理后的文本
cleaned_text = clean_ocr_output(accumulated_text)
# 每0.1秒更新一次界面(流式阶段暂不显示图片)
current_time = time.time()
if current_time - last_update >= 0.1:
yield cleaned_text, accumulated_text, cleaned_text + "\n\n*识别中...*", None
last_update = current_time
except:
# 队列为空,继续等待
time.sleep(0.05)
# 等待线程结束
inference_thread.join()
# 检查是否有错误
if result_container['error']:
import traceback
error_msg = f"OCR处理出错: {str(result_container['error'])}\n\n{traceback.format_exc()}"
yield error_msg, "", error_msg, None
return
# 获取最终结果
final_captured = stream_capture.stop_capture()
# 使用返回值或捕获的输出
if result_container['result'] is not None and result_container['result'] != "":
final_result = result_container['result']
else:
final_result = final_captured
if not final_result:
yield "OCR处理完成,但模型未返回结果。", "", "OCR处理完成,但模型未返回结果。", None
return
# 先输出识别完成的提示,图片正在处理中
cleaned_temp = clean_ocr_output(final_result)
yield cleaned_temp, final_result, cleaned_temp + "\n\n*正在提取图片...*", None
# 裁切图片
images_dir = os.path.join(tmpdir, "extracted_images")
cropped_image_paths = crop_images_from_text(temp_image_path, final_result, images_dir, page_idx=0)
print(f"提取的图片数量: {len(cropped_image_paths)}")
print(f"图片路径: {cropped_image_paths}")
# 清理输出(纯文本)
cleaned_result = clean_ocr_output(final_result)
# 为Markdown创建带图片链接的版本(图片已经裁切完成)
markdown_result = create_markdown_with_images(final_result, cropped_image_paths)
print(f"Markdown结果前500字符:\n{markdown_result[:500]}")
# 创建图片画廊
gallery_images = cropped_image_paths if cropped_image_paths else None
# 最终输出,包含图片
yield cleaned_result, final_result, markdown_result, gallery_images
except Exception as e:
import traceback
error_msg = f"OCR处理出错: {str(e)}\n\n详细错误:\n{traceback.format_exc()}"
yield error_msg, "", error_msg, None
def process_pdf_ocr_stream(pdf_file, prompt_type, base_size, image_size, crop_mode):
"""处理PDF文件的OCR - 生成器版本,支持流式输出"""
if pdf_file is None:
yield "请上传PDF文件", "", "请上传PDF文件", None
return
# 设置prompt
if prompt_type == "Free OCR":
prompt = "<image>\nFree OCR. "
else: # Convert to Markdown
prompt = "<image>\n<|grounding|>Convert the document to markdown. "
# 创建持久化的输出目录
script_dir = os.path.dirname(os.path.abspath(__file__))
output_base_dir = os.path.join(script_dir, "gradio_outputs")
session_dir = os.path.join(output_base_dir, f"session_{int(time.time() * 1000)}")
os.makedirs(session_dir, exist_ok=True)
try:
tmpdir = session_dir
# 将PDF转换为图片
yield "正在加载PDF...", "", "正在加载PDF...", None
images = pdf_to_images(pdf_file.name)
yield f"PDF加载完成,共 {len(images)} 页,开始识别...", "", f"PDF加载完成,共 {len(images)} 页", None
all_results = []
all_raw_results = []
all_cleaned_results = []
all_cropped_images = [] # 存储所有裁切的图片
images_base_dir = os.path.join(tmpdir, "extracted_images")
for idx, img in enumerate(images):
temp_image_path = os.path.join(tmpdir, f"page_{idx}.png")
img.save(temp_image_path)
page_header = f"\n{'='*50}\n第 {idx+1}/{len(images)} 页\n{'='*50}\n"
# 更新进度
progress_msg = "\n".join(all_cleaned_results) + page_header + "正在识别..."
yield progress_msg, "", progress_msg, all_cropped_images if all_cropped_images else None
# 创建流式捕获对象
stream_capture = StreamCapture()
# 在新线程中运行模型推理
result_container = {'result': None, 'error': None}
def run_inference():
try:
stream_capture.start_capture()
result = model.infer(
tokenizer,
prompt=prompt,
image_file=temp_image_path,
output_path=tmpdir,
base_size=int(base_size),
image_size=int(image_size),
crop_mode=crop_mode,
save_results=False,
test_compress=True
)
result_container['result'] = result
except Exception as e:
result_container['error'] = e
finally:
stream_capture.stop_capture()
# 启动推理线程
inference_thread = threading.Thread(target=run_inference)
inference_thread.start()
# 流式输出当前页
page_text = ""
last_update = time.time()
while inference_thread.is_alive() or not stream_capture.queue.empty():
try:
text = stream_capture.queue.get(timeout=0.1)
page_text += text
# 清理后的文本
cleaned_page = clean_ocr_output(page_text)
# 更新界面
current_time = time.time()
if current_time - last_update >= 0.1:
current_display = "\n".join(all_cleaned_results) + page_header + cleaned_page
yield current_display, "", current_display, all_cropped_images if all_cropped_images else None
last_update = current_time
except:
time.sleep(0.05)
# 等待线程结束
inference_thread.join()
# 获取最终结果
final_captured = stream_capture.stop_capture()
if result_container['result'] is not None and result_container['result'] != "":
final_result = result_container['result']
else:
final_result = final_captured
if not final_result:
final_result = f"[第 {idx+1} 页处理失败]"
# 裁切当前页的图片
page_cropped_images = crop_images_from_text(temp_image_path, final_result, images_base_dir, page_idx=idx)
all_cropped_images.extend(page_cropped_images)
print(f"第 {idx+1} 页提取图片数: {len(page_cropped_images)}")
# 清理输出
cleaned_result = clean_ocr_output(final_result)
all_cleaned_results.append(page_header + cleaned_result)
all_raw_results.append(f"{page_header}(原始输出)\n{final_result}")
final_clean = "\n".join(all_cleaned_results)
final_raw = "\n".join(all_raw_results)
print(f"总共提取图片数: {len(all_cropped_images)}")
print(f"图片路径: {all_cropped_images}")
# 创建包含图片的Markdown版本
# 对于PDF,我们需要重新构建带图片的Markdown
final_markdown = create_pdf_markdown_with_images(all_raw_results, all_cropped_images)
yield final_clean, final_raw, final_markdown, all_cropped_images if all_cropped_images else None
except Exception as e:
import traceback
error_msg = f"PDF处理出错: {str(e)}\n{traceback.format_exc()}"
yield error_msg, "", error_msg, None
def process_file_stream(file, prompt_type, base_size, image_size, crop_mode):
"""统一处理图片或PDF文件 - 生成器版本"""
if file is None:
yield "请上传文件", "", "请上传文件", None
return
# 判断文件类型
file_path = file.name if hasattr(file, 'name') else file
file_ext = Path(file_path).suffix.lower()
if file_ext == '.pdf':
yield from process_pdf_ocr_stream(file, prompt_type, base_size, image_size, crop_mode)
elif file_ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']:
# 对于图片文件,直接使用路径
yield from process_image_ocr_stream(file_path, prompt_type, base_size, image_size, crop_mode)
else:
yield "不支持的文件格式,请上传图片(jpg/png/bmp等)或PDF文件", "", "不支持的文件格式", None
# 创建Gradio界面
with gr.Blocks(title="DeepSeek OCR", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🔍 DeepSeek OCR 文档识别")
gr.Markdown("支持上传图片或PDF文件进行OCR识别,实时流式输出,支持Markdown渲染")
with gr.Row():
with gr.Column(scale=1):
# 文件上传
file_input = gr.File(
label="📁 上传图片或PDF文件",
file_types=[".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp", ".pdf"]
)
# OCR参数设置
prompt_type = gr.Radio(
choices=["Free OCR", "Convert to Markdown"],
value="Convert to Markdown",
label="🎯 OCR模式"
)
with gr.Accordion("⚙️ 高级设置", open=False):
base_size = gr.Slider(
minimum=512,
maximum=1280,
value=1024,
step=64,
label="Base Size (基础尺寸)"
)
image_size = gr.Slider(
minimum=512,
maximum=1280,
value=640,
step=64,
label="Image Size (图像尺寸)"
)
crop_mode = gr.Checkbox(
value=True,
label="Crop Mode (裁剪模式)"
)
gr.Markdown("""
**模型尺寸预设参考:**
- Tiny: base_size=512, image_size=512, crop_mode=False
- Small: base_size=640, image_size=640, crop_mode=False
- Base: base_size=1024, image_size=1024, crop_mode=False
- Large: base_size=1280, image_size=1280, crop_mode=False
- Gundam (推荐): base_size=1024, image_size=640, crop_mode=True
""")
# 执行按钮
submit_btn = gr.Button("🚀 开始识别", variant="primary", size="lg")
stop_btn = gr.Button("⏹️ 停止", variant="stop", size="lg")
with gr.Column(scale=2):
# 结果显示
with gr.Tabs():
with gr.Tab("📝 Markdown渲染"):
output_markdown = gr.Markdown(
label="Markdown渲染结果",
value="等待识别结果..."
)
with gr.Tab("📄 纯文本"):
output_clean = gr.Textbox(
label="OCR识别结果(已清理)",
lines=30,
max_lines=50,
show_copy_button=True
)
with gr.Tab("🔧 原始输出"):
output_raw = gr.Textbox(
label="原始OCR输出(包含标注信息)",
lines=30,
max_lines=50,
show_copy_button=True
)
with gr.Tab("🖼️ 提取的图片"):
output_gallery = gr.Gallery(
label="文档中识别并提取的图片",
columns=3,
rows=2,
height="auto",
object_fit="contain",
show_download_button=True
)
# 示例
gr.Markdown("### 📚 示例文件")
# 获取当前脚本所在目录
script_dir = os.path.dirname(os.path.abspath(__file__))
example_image = os.path.join(script_dir, "示例图片.png")
example_pdf = os.path.join(script_dir, "DeepSeek_OCR_paper-p1.pdf")
# 只添加存在的示例
examples_list = []
if os.path.exists(example_image):
examples_list.append([example_image, "Convert to Markdown", 1024, 640, True])
if os.path.exists(example_pdf):
examples_list.append([example_pdf, "Convert to Markdown", 1024, 640, True])
if examples_list:
gr.Examples(
examples=examples_list,
inputs=[file_input, prompt_type, base_size, image_size, crop_mode],
label="点击示例快速体验"
)
gr.Markdown("### 💡 使用提示")
gr.Markdown("""
- 识别过程中会**实时流式显示**结果
- **Markdown渲染**标签页可以看到格式化后的效果
- **纯文本**标签页可以复制文本内容
- **原始输出**标签页包含模型的原始标注信息
- **提取的图片**标签页显示文档中识别到的所有图片,支持下载
- 支持PDF多页文档,会逐页识别并显示进度
""")
# 绑定事件
submit_event = submit_btn.click(
fn=process_file_stream,
inputs=[file_input, prompt_type, base_size, image_size, crop_mode],
outputs=[output_clean, output_raw, output_markdown, output_gallery]
)
# 停止按钮(取消当前任务)
stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[submit_event])
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)