import gradio as gr
from transformers import AutoModel, AutoTokenizer
import torch
import os
import tempfile
from pathlib import Path
import re
from PIL import Image
import fitz # PyMuPDF
import io
import sys
import threading
import time
from queue import Queue
# os.environ["CUDA_VISIBLE_DEVICES"] = '0'
# 初始化模型
print("Loading model...")
model_name = 'deepseek-ai/DeepSeek-OCR'
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(
model_name,
_attn_implementation='flash_attention_2',
trust_remote_code=True,
use_safetensors=True
)
model = model.eval() # .cuda().to(torch.bfloat16)
print("Model loaded successfully!")
class StreamCapture:
"""捕获并流式输出标准输出"""
def __init__(self):
self.queue = Queue()
self.captured_text = []
self.original_stdout = None
self.is_capturing = False
def write(self, text):
"""捕获write调用"""
if text and text.strip():
self.captured_text.append(text)
self.queue.put(text)
# 同时输出到终端
if self.original_stdout:
self.original_stdout.write(text)
def flush(self):
"""flush方法"""
if self.original_stdout:
self.original_stdout.flush()
def start_capture(self):
"""开始捕获"""
self.original_stdout = sys.stdout
sys.stdout = self
self.is_capturing = True
self.captured_text = []
def stop_capture(self):
"""停止捕获"""
if self.is_capturing:
sys.stdout = self.original_stdout
self.is_capturing = False
return ''.join(self.captured_text)
def pdf_to_images(pdf_path, dpi=144):
"""将PDF转换为图片列表"""
images = []
pdf_document = fitz.open(pdf_path)
zoom = dpi / 72.0
matrix = fitz.Matrix(zoom, zoom)
for page_num in range(pdf_document.page_count):
page = pdf_document[page_num]
pixmap = page.get_pixmap(matrix=matrix, alpha=False)
Image.MAX_IMAGE_PIXELS = None
img_data = pixmap.tobytes("png")
img = Image.open(io.BytesIO(img_data))
images.append(img)
pdf_document.close()
return images
def extract_image_refs(text):
"""提取图片引用标签"""
pattern = r'(<\|ref\|>image<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
matches = re.findall(pattern, text, re.DOTALL)
return matches
def extract_coordinates(det_text):
"""从det标签中提取坐标"""
try:
coords_list = eval(det_text)
return coords_list
except:
return None
def crop_images_from_text(original_image, ocr_text, output_dir, page_idx=0):
"""根据OCR文本中的坐标裁切图片"""
if original_image is None or ocr_text is None:
print("crop_images_from_text: 输入为空")
return []
print(f"crop_images_from_text: 开始处理页面 {page_idx}")
print(f"输出目录: {output_dir}")
# 提取图片引用
image_refs = extract_image_refs(ocr_text)
print(f"找到 {len(image_refs)} 个图片引用")
if not image_refs:
return []
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
# 加载原始图片
if isinstance(original_image, str):
img = Image.open(original_image)
print(f"从文件加载图片: {original_image}")
else:
img = original_image
print("使用PIL Image对象")
image_width, image_height = img.size
print(f"图片尺寸: {image_width} x {image_height}")
cropped_images = []
for idx, (full_match, coords_str) in enumerate(image_refs):
print(f"\n处理图片引用 {idx+1}: {coords_str}")
coords_list = extract_coordinates(coords_str)
if coords_list is None:
print(" 坐标解析失败")
continue
print(f" 解析出 {len(coords_list)} 个坐标框")
# 处理每个坐标框
for coord_idx, coord in enumerate(coords_list):
try:
x1, y1, x2, y2 = coord
print(f" 框 {coord_idx+1}: 原始坐标 ({x1}, {y1}) -> ({x2}, {y2})")
# 将归一化坐标转换为实际像素坐标
px1 = int(x1 / 999 * image_width)
py1 = int(y1 / 999 * image_height)
px2 = int(x2 / 999 * image_width)
py2 = int(y2 / 999 * image_height)
print(f" 框 {coord_idx+1}: 像素坐标 ({px1}, {py1}) -> ({px2}, {py2})")
# 裁切图片
cropped = img.crop((px1, py1, px2, py2))
print(f" 裁切后尺寸: {cropped.size}")
# 保存图片
img_filename = f"page{page_idx}_img{len(cropped_images)}.jpg"
img_path = os.path.join(output_dir, img_filename)
cropped.save(img_path, quality=95)
print(f" 保存到: {img_path}")
cropped_images.append(img_path)
except Exception as e:
print(f" 裁切图片出错: {e}")
import traceback
traceback.print_exc()
continue
print(f"\n总共裁切了 {len(cropped_images)} 张图片")
return cropped_images
def clean_ocr_output(text, image_refs_replacement=None):
"""清理OCR输出文本"""
if text is None:
return ""
# 移除或格式化调试信息
# 匹配开头的调试块:=====================\nBASE: ...\nPATCHES: ...\n=====================
text = re.sub(r'={5,}\s*BASE:\s+.*?\s+PATCHES:\s+.*?\s+={5,}\s*', '', text, flags=re.DOTALL)
# 匹配结尾的调试块:==================================================\nimage size: ...\n...==================================================
text = re.sub(r'={10,}\s*image size:.*?={10,}\s*', '', text, flags=re.DOTALL)
# 移除grounding标签
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
matches = re.findall(pattern, text, re.DOTALL)
# 如果提供了图片引用替换,则替换image标签
img_idx = 0
for match in matches:
if '<|ref|>image<|/ref|>' in match[0]:
if image_refs_replacement:
# 用实际的图片链接替换
text = text.replace(match[0], image_refs_replacement.format(img_idx), 1)
img_idx += 1
else:
text = text.replace(match[0], '')
else:
text = text.replace(match[0], '')
# 清理特殊字符
text = text.replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:')
text = text.replace('\n\n\n\n', '\n\n').replace('\n\n\n', '\n\n')
# 移除结束标记
if '<|end▁of▁sentence|>' in text:
text = text.replace('<|end▁of▁sentence|>', '')
return text.strip()
def create_markdown_with_images(ocr_text, image_paths):
"""创建包含图片链接的Markdown文本"""
if ocr_text is None:
return ""
result_text = ocr_text
# 移除或格式化调试信息
# 匹配开头的调试块:=====================\nBASE: ...\nPATCHES: ...\n=====================
result_text = re.sub(r'={5,}\s*BASE:\s+.*?\s+PATCHES:\s+.*?\s+={5,}\s*', '', result_text, flags=re.DOTALL)
# 匹配结尾的调试块:==================================================\nimage size: ...\n...==================================================
result_text = re.sub(r'={10,}\s*image size:.*?={10,}\s*', '', result_text, flags=re.DOTALL)
# 移除grounding标签,但保留图片位置
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
matches = re.findall(pattern, result_text, re.DOTALL)
img_idx = 0
for match in matches:
if '<|ref|>image<|/ref|>' in match[0]:
# 用Markdown图片语法替换
if image_paths and img_idx < len(image_paths):
img_path = image_paths[img_idx]
# Gradio需要使用HTML img标签来显示本地图片
# 使用base64编码嵌入图片
try:
import base64
with open(img_path, 'rb') as f:
img_data = f.read()
img_base64 = base64.b64encode(img_data).decode('utf-8')
# 检测图片格式
img_ext = os.path.splitext(img_path)[1].lower()
if img_ext == '.png':
mime_type = 'image/png'
elif img_ext in ['.jpg', '.jpeg']:
mime_type = 'image/jpeg'
else:
mime_type = 'image/jpeg'
# 使用base64嵌入的img标签
img_markdown = f'\n\n
\n\n'
print(f"图片{img_idx+1}已转换为base64嵌入 (大小: {len(img_base64)} bytes)")
except Exception as e:
print(f"转换图片{img_idx+1}为base64失败: {e}")
img_markdown = f'\n\n[图片{img_idx+1}: {os.path.basename(img_path)}]\n\n'
result_text = result_text.replace(match[0], img_markdown, 1)
img_idx += 1
else:
result_text = result_text.replace(match[0], '\n[图片]\n', 1)
else:
# 移除其他标注
result_text = result_text.replace(match[0], '')
# 清理特殊字符
result_text = result_text.replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:')
result_text = result_text.replace('\n\n\n\n', '\n\n').replace('\n\n\n', '\n\n')
# 移除结束标记
if '<|end▁of▁sentence|>' in result_text:
result_text = result_text.replace('<|end▁of▁sentence|>', '')
return result_text.strip()
def create_pdf_markdown_with_images(raw_results_list, all_image_paths):
"""为PDF创建包含图片的Markdown"""
if not raw_results_list:
return ""
# 合并所有原始结果
combined_text = "\n".join(raw_results_list)
# 使用create_markdown_with_images处理
return create_markdown_with_images(combined_text, all_image_paths)
def process_image_ocr_stream(image, prompt_type, base_size, image_size, crop_mode):
"""处理单张图片的OCR - 生成器版本,支持流式输出"""
if image is None:
yield "请上传图片", "", "请上传图片", None
return
# 设置prompt
if prompt_type == "Free OCR":
prompt = "\nFree OCR. "
else: # Convert to Markdown
prompt = "\n<|grounding|>Convert the document to markdown. "
# 创建持久化的输出目录
script_dir = os.path.dirname(os.path.abspath(__file__))
output_base_dir = os.path.join(script_dir, "gradio_outputs")
session_dir = os.path.join(output_base_dir, f"session_{int(time.time() * 1000)}")
os.makedirs(session_dir, exist_ok=True)
try:
tmpdir = session_dir
# 保存上传的图片
temp_image_path = os.path.join(tmpdir, "input_image.png")
if isinstance(image, str):
# 如果是文件路径
temp_image_path = image
else:
# 如果是PIL Image或numpy array
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
image.save(temp_image_path)
yield "正在处理图片...", "", "正在处理图片...", None
# 创建流式捕获对象
stream_capture = StreamCapture()
# 在新线程中运行模型推理
result_container = {'result': None, 'error': None}
def run_inference():
try:
stream_capture.start_capture()
result = model.infer(
tokenizer,
prompt=prompt,
image_file=temp_image_path,
output_path=tmpdir,
base_size=int(base_size),
image_size=int(image_size),
crop_mode=crop_mode,
save_results=False,
test_compress=True
)
result_container['result'] = result
except Exception as e:
result_container['error'] = e
finally:
stream_capture.stop_capture()
# 启动推理线程
inference_thread = threading.Thread(target=run_inference)
inference_thread.start()
# 流式输出
accumulated_text = ""
last_update = time.time()
while inference_thread.is_alive() or not stream_capture.queue.empty():
try:
# 尝试从队列获取新文本
text = stream_capture.queue.get(timeout=0.1)
accumulated_text += text
# 清理后的文本
cleaned_text = clean_ocr_output(accumulated_text)
# 每0.1秒更新一次界面(流式阶段暂不显示图片)
current_time = time.time()
if current_time - last_update >= 0.1:
yield cleaned_text, accumulated_text, cleaned_text + "\n\n*识别中...*", None
last_update = current_time
except:
# 队列为空,继续等待
time.sleep(0.05)
# 等待线程结束
inference_thread.join()
# 检查是否有错误
if result_container['error']:
import traceback
error_msg = f"OCR处理出错: {str(result_container['error'])}\n\n{traceback.format_exc()}"
yield error_msg, "", error_msg, None
return
# 获取最终结果
final_captured = stream_capture.stop_capture()
# 使用返回值或捕获的输出
if result_container['result'] is not None and result_container['result'] != "":
final_result = result_container['result']
else:
final_result = final_captured
if not final_result:
yield "OCR处理完成,但模型未返回结果。", "", "OCR处理完成,但模型未返回结果。", None
return
# 先输出识别完成的提示,图片正在处理中
cleaned_temp = clean_ocr_output(final_result)
yield cleaned_temp, final_result, cleaned_temp + "\n\n*正在提取图片...*", None
# 裁切图片
images_dir = os.path.join(tmpdir, "extracted_images")
cropped_image_paths = crop_images_from_text(temp_image_path, final_result, images_dir, page_idx=0)
print(f"提取的图片数量: {len(cropped_image_paths)}")
print(f"图片路径: {cropped_image_paths}")
# 清理输出(纯文本)
cleaned_result = clean_ocr_output(final_result)
# 为Markdown创建带图片链接的版本(图片已经裁切完成)
markdown_result = create_markdown_with_images(final_result, cropped_image_paths)
print(f"Markdown结果前500字符:\n{markdown_result[:500]}")
# 创建图片画廊
gallery_images = cropped_image_paths if cropped_image_paths else None
# 最终输出,包含图片
yield cleaned_result, final_result, markdown_result, gallery_images
except Exception as e:
import traceback
error_msg = f"OCR处理出错: {str(e)}\n\n详细错误:\n{traceback.format_exc()}"
yield error_msg, "", error_msg, None
def process_pdf_ocr_stream(pdf_file, prompt_type, base_size, image_size, crop_mode):
"""处理PDF文件的OCR - 生成器版本,支持流式输出"""
if pdf_file is None:
yield "请上传PDF文件", "", "请上传PDF文件", None
return
# 设置prompt
if prompt_type == "Free OCR":
prompt = "\nFree OCR. "
else: # Convert to Markdown
prompt = "\n<|grounding|>Convert the document to markdown. "
# 创建持久化的输出目录
script_dir = os.path.dirname(os.path.abspath(__file__))
output_base_dir = os.path.join(script_dir, "gradio_outputs")
session_dir = os.path.join(output_base_dir, f"session_{int(time.time() * 1000)}")
os.makedirs(session_dir, exist_ok=True)
try:
tmpdir = session_dir
# 将PDF转换为图片
yield "正在加载PDF...", "", "正在加载PDF...", None
images = pdf_to_images(pdf_file.name)
yield f"PDF加载完成,共 {len(images)} 页,开始识别...", "", f"PDF加载完成,共 {len(images)} 页", None
all_results = []
all_raw_results = []
all_cleaned_results = []
all_cropped_images = [] # 存储所有裁切的图片
images_base_dir = os.path.join(tmpdir, "extracted_images")
for idx, img in enumerate(images):
temp_image_path = os.path.join(tmpdir, f"page_{idx}.png")
img.save(temp_image_path)
page_header = f"\n{'='*50}\n第 {idx+1}/{len(images)} 页\n{'='*50}\n"
# 更新进度
progress_msg = "\n".join(all_cleaned_results) + page_header + "正在识别..."
yield progress_msg, "", progress_msg, all_cropped_images if all_cropped_images else None
# 创建流式捕获对象
stream_capture = StreamCapture()
# 在新线程中运行模型推理
result_container = {'result': None, 'error': None}
def run_inference():
try:
stream_capture.start_capture()
result = model.infer(
tokenizer,
prompt=prompt,
image_file=temp_image_path,
output_path=tmpdir,
base_size=int(base_size),
image_size=int(image_size),
crop_mode=crop_mode,
save_results=False,
test_compress=True
)
result_container['result'] = result
except Exception as e:
result_container['error'] = e
finally:
stream_capture.stop_capture()
# 启动推理线程
inference_thread = threading.Thread(target=run_inference)
inference_thread.start()
# 流式输出当前页
page_text = ""
last_update = time.time()
while inference_thread.is_alive() or not stream_capture.queue.empty():
try:
text = stream_capture.queue.get(timeout=0.1)
page_text += text
# 清理后的文本
cleaned_page = clean_ocr_output(page_text)
# 更新界面
current_time = time.time()
if current_time - last_update >= 0.1:
current_display = "\n".join(all_cleaned_results) + page_header + cleaned_page
yield current_display, "", current_display, all_cropped_images if all_cropped_images else None
last_update = current_time
except:
time.sleep(0.05)
# 等待线程结束
inference_thread.join()
# 获取最终结果
final_captured = stream_capture.stop_capture()
if result_container['result'] is not None and result_container['result'] != "":
final_result = result_container['result']
else:
final_result = final_captured
if not final_result:
final_result = f"[第 {idx+1} 页处理失败]"
# 裁切当前页的图片
page_cropped_images = crop_images_from_text(temp_image_path, final_result, images_base_dir, page_idx=idx)
all_cropped_images.extend(page_cropped_images)
print(f"第 {idx+1} 页提取图片数: {len(page_cropped_images)}")
# 清理输出
cleaned_result = clean_ocr_output(final_result)
all_cleaned_results.append(page_header + cleaned_result)
all_raw_results.append(f"{page_header}(原始输出)\n{final_result}")
final_clean = "\n".join(all_cleaned_results)
final_raw = "\n".join(all_raw_results)
print(f"总共提取图片数: {len(all_cropped_images)}")
print(f"图片路径: {all_cropped_images}")
# 创建包含图片的Markdown版本
# 对于PDF,我们需要重新构建带图片的Markdown
final_markdown = create_pdf_markdown_with_images(all_raw_results, all_cropped_images)
yield final_clean, final_raw, final_markdown, all_cropped_images if all_cropped_images else None
except Exception as e:
import traceback
error_msg = f"PDF处理出错: {str(e)}\n{traceback.format_exc()}"
yield error_msg, "", error_msg, None
def process_file_stream(file, prompt_type, base_size, image_size, crop_mode):
"""统一处理图片或PDF文件 - 生成器版本"""
if file is None:
yield "请上传文件", "", "请上传文件", None
return
# 判断文件类型
file_path = file.name if hasattr(file, 'name') else file
file_ext = Path(file_path).suffix.lower()
if file_ext == '.pdf':
yield from process_pdf_ocr_stream(file, prompt_type, base_size, image_size, crop_mode)
elif file_ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']:
# 对于图片文件,直接使用路径
yield from process_image_ocr_stream(file_path, prompt_type, base_size, image_size, crop_mode)
else:
yield "不支持的文件格式,请上传图片(jpg/png/bmp等)或PDF文件", "", "不支持的文件格式", None
# 创建Gradio界面
with gr.Blocks(title="DeepSeek OCR", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🔍 DeepSeek OCR 文档识别")
gr.Markdown("支持上传图片或PDF文件进行OCR识别,实时流式输出,支持Markdown渲染")
with gr.Row():
with gr.Column(scale=1):
# 文件上传
file_input = gr.File(
label="📁 上传图片或PDF文件",
file_types=[".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp", ".pdf"]
)
# OCR参数设置
prompt_type = gr.Radio(
choices=["Free OCR", "Convert to Markdown"],
value="Convert to Markdown",
label="🎯 OCR模式"
)
with gr.Accordion("⚙️ 高级设置", open=False):
base_size = gr.Slider(
minimum=512,
maximum=1280,
value=1024,
step=64,
label="Base Size (基础尺寸)"
)
image_size = gr.Slider(
minimum=512,
maximum=1280,
value=640,
step=64,
label="Image Size (图像尺寸)"
)
crop_mode = gr.Checkbox(
value=True,
label="Crop Mode (裁剪模式)"
)
gr.Markdown("""
**模型尺寸预设参考:**
- Tiny: base_size=512, image_size=512, crop_mode=False
- Small: base_size=640, image_size=640, crop_mode=False
- Base: base_size=1024, image_size=1024, crop_mode=False
- Large: base_size=1280, image_size=1280, crop_mode=False
- Gundam (推荐): base_size=1024, image_size=640, crop_mode=True
""")
# 执行按钮
submit_btn = gr.Button("🚀 开始识别", variant="primary", size="lg")
stop_btn = gr.Button("⏹️ 停止", variant="stop", size="lg")
with gr.Column(scale=2):
# 结果显示
with gr.Tabs():
with gr.Tab("📝 Markdown渲染"):
output_markdown = gr.Markdown(
label="Markdown渲染结果",
value="等待识别结果..."
)
with gr.Tab("📄 纯文本"):
output_clean = gr.Textbox(
label="OCR识别结果(已清理)",
lines=30,
max_lines=50,
show_copy_button=True
)
with gr.Tab("🔧 原始输出"):
output_raw = gr.Textbox(
label="原始OCR输出(包含标注信息)",
lines=30,
max_lines=50,
show_copy_button=True
)
with gr.Tab("🖼️ 提取的图片"):
output_gallery = gr.Gallery(
label="文档中识别并提取的图片",
columns=3,
rows=2,
height="auto",
object_fit="contain",
show_download_button=True
)
# 示例
gr.Markdown("### 📚 示例文件")
# 获取当前脚本所在目录
script_dir = os.path.dirname(os.path.abspath(__file__))
example_image = os.path.join(script_dir, "示例图片.png")
example_pdf = os.path.join(script_dir, "DeepSeek_OCR_paper-p1.pdf")
# 只添加存在的示例
examples_list = []
if os.path.exists(example_image):
examples_list.append([example_image, "Convert to Markdown", 1024, 640, True])
if os.path.exists(example_pdf):
examples_list.append([example_pdf, "Convert to Markdown", 1024, 640, True])
if examples_list:
gr.Examples(
examples=examples_list,
inputs=[file_input, prompt_type, base_size, image_size, crop_mode],
label="点击示例快速体验"
)
gr.Markdown("### 💡 使用提示")
gr.Markdown("""
- 识别过程中会**实时流式显示**结果
- **Markdown渲染**标签页可以看到格式化后的效果
- **纯文本**标签页可以复制文本内容
- **原始输出**标签页包含模型的原始标注信息
- **提取的图片**标签页显示文档中识别到的所有图片,支持下载
- 支持PDF多页文档,会逐页识别并显示进度
""")
# 绑定事件
submit_event = submit_btn.click(
fn=process_file_stream,
inputs=[file_input, prompt_type, base_size, image_size, crop_mode],
outputs=[output_clean, output_raw, output_markdown, output_gallery]
)
# 停止按钮(取消当前任务)
stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[submit_event])
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)