alexshen1979's picture
Update app.py
b1cb14b verified
import os
os.environ["OMP_NUM_THREADS"] = "1"
import gradio as gr
import torch
# 【关键修改】不再用 AutoTokenizer,改用 GPT2Tokenizer
from transformers import GPT2Tokenizer, SiglipImageProcessor, AutoModelForCausalLM
from PIL import Image
import re
# --- 模型配置 ---
model_id = "starvector/starvector-1b-im2svg"
print(f"正在加载模型: {model_id} ...")
# 1. 加载图片处理器
try:
image_processor = SiglipImageProcessor.from_pretrained(model_id, trust_remote_code=True)
except:
from transformers import AutoImageProcessor
image_processor = AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)
# 2. 加载 Tokenizer
print("正在加载 Tokenizer (强制使用 GPT2 Python 版)...")
# 【核心修复】
# StarVector-1B 基于 StarCoder (GPT2架构)。
# 直接使用 GPT2Tokenizer 可以避开 AutoTokenizer 强制读取 tokenizer.json 导致的报错。
# 这个类只读取 vocab.json,兼容性最好。
tokenizer = GPT2Tokenizer.from_pretrained(model_id, trust_remote_code=True)
# 确保 pad_token 存在 (GPT2 默认没有 pad_token,StarCoder 通常用 eos 作为 pad)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 3. 加载模型
print("正在加载主模型 (FP16)...")
# 1B 模型很小,直接加载到 CPU 没问题,不会 OOM
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
trust_remote_code=True
# 不使用 device_map,防止 Meta Tensor 报错
)
print("正在移动模型到 GPU...")
model.to("cuda")
print("模型加载完成!")
# --- 推理函数 ---
def convert_image_to_svg(image):
if image is None:
return "请上传图片", ""
try:
# 预处理图片
image_inputs = image_processor(image, return_tensors="pt")
# 确保数据在 GPU 上且精度匹配
pixel_values = image_inputs.pixel_values.to("cuda").to(torch.float16)
# 预处理文本
text_prompt = "Generate SVG code"
text_inputs = tokenizer(text_prompt, return_tensors="pt")
input_ids = text_inputs.input_ids.to("cuda")
attention_mask = text_inputs.attention_mask.to("cuda")
# 生成
with torch.no_grad():
generated_ids = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
max_new_tokens=2048,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
# 解码
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
svg_code = extract_svg_content(generated_text)
return svg_code, svg_code
except Exception as e:
import traceback
traceback.print_exc()
return f"生成出错: {str(e)}", ""
def extract_svg_content(text):
pattern = r"(<svg.*?</svg>)"
match = re.search(pattern, text, re.DOTALL)
if match:
return match.group(1)
if "```xml" in text:
return text.split("```xml")[1].split("```")[0]
if "```svg" in text:
return text.split("```svg")[1].split("```")[0]
if "xmlns=" in text and "<svg" in text:
start = text.find("<svg")
end = text.rfind("</svg>") + 6
if start != -1 and end != -1:
return text[start:end]
return text
# --- 界面 ---
with gr.Blocks(title="StarVector-1B") as demo:
gr.Markdown("# StarVector-1B Image to SVG")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="上传图片")
run_btn = gr.Button("生成 SVG", variant="primary")
with gr.Column():
html_output = gr.HTML(label="SVG 预览")
code_output = gr.Code(language="xml", label="SVG 代码")
run_btn.click(
fn=convert_image_to_svg,
inputs=[input_image],
outputs=[code_output, html_output]
)
demo.launch(server_name="0.0.0.0", server_port=7860)