Spaces:
Runtime error
Runtime error
File size: 4,073 Bytes
26c0321 4f5bc6b b1cb14b 4f5bc6b 0959374 070f825 4f5bc6b 070f825 4f5bc6b b1cb14b f203e5b 0959374 b1cb14b ebb582a b1cb14b ebb582a b1cb14b 4f5bc6b 26c0321 4f5bc6b b1cb14b 4f5bc6b ebb582a 0959374 4f5bc6b 0959374 4f5bc6b 8fb934f 070f825 8fb934f 0959374 8fb934f 070f825 8fb934f 0959374 8fb934f 070f825 f203e5b 8fb934f 26c0321 070f825 8fb934f 4f5bc6b 8fb934f 4f5bc6b 0959374 070f825 4f5bc6b 708af3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import os
os.environ["OMP_NUM_THREADS"] = "1"
import gradio as gr
import torch
# 【关键修改】不再用 AutoTokenizer,改用 GPT2Tokenizer
from transformers import GPT2Tokenizer, SiglipImageProcessor, AutoModelForCausalLM
from PIL import Image
import re
# --- 模型配置 ---
model_id = "starvector/starvector-1b-im2svg"
print(f"正在加载模型: {model_id} ...")
# 1. 加载图片处理器
try:
image_processor = SiglipImageProcessor.from_pretrained(model_id, trust_remote_code=True)
except:
from transformers import AutoImageProcessor
image_processor = AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)
# 2. 加载 Tokenizer
print("正在加载 Tokenizer (强制使用 GPT2 Python 版)...")
# 【核心修复】
# StarVector-1B 基于 StarCoder (GPT2架构)。
# 直接使用 GPT2Tokenizer 可以避开 AutoTokenizer 强制读取 tokenizer.json 导致的报错。
# 这个类只读取 vocab.json,兼容性最好。
tokenizer = GPT2Tokenizer.from_pretrained(model_id, trust_remote_code=True)
# 确保 pad_token 存在 (GPT2 默认没有 pad_token,StarCoder 通常用 eos 作为 pad)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 3. 加载模型
print("正在加载主模型 (FP16)...")
# 1B 模型很小,直接加载到 CPU 没问题,不会 OOM
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
trust_remote_code=True
# 不使用 device_map,防止 Meta Tensor 报错
)
print("正在移动模型到 GPU...")
model.to("cuda")
print("模型加载完成!")
# --- 推理函数 ---
def convert_image_to_svg(image):
if image is None:
return "请上传图片", ""
try:
# 预处理图片
image_inputs = image_processor(image, return_tensors="pt")
# 确保数据在 GPU 上且精度匹配
pixel_values = image_inputs.pixel_values.to("cuda").to(torch.float16)
# 预处理文本
text_prompt = "Generate SVG code"
text_inputs = tokenizer(text_prompt, return_tensors="pt")
input_ids = text_inputs.input_ids.to("cuda")
attention_mask = text_inputs.attention_mask.to("cuda")
# 生成
with torch.no_grad():
generated_ids = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
max_new_tokens=2048,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
# 解码
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
svg_code = extract_svg_content(generated_text)
return svg_code, svg_code
except Exception as e:
import traceback
traceback.print_exc()
return f"生成出错: {str(e)}", ""
def extract_svg_content(text):
pattern = r"(<svg.*?</svg>)"
match = re.search(pattern, text, re.DOTALL)
if match:
return match.group(1)
if "```xml" in text:
return text.split("```xml")[1].split("```")[0]
if "```svg" in text:
return text.split("```svg")[1].split("```")[0]
if "xmlns=" in text and "<svg" in text:
start = text.find("<svg")
end = text.rfind("</svg>") + 6
if start != -1 and end != -1:
return text[start:end]
return text
# --- 界面 ---
with gr.Blocks(title="StarVector-1B") as demo:
gr.Markdown("# StarVector-1B Image to SVG")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="上传图片")
run_btn = gr.Button("生成 SVG", variant="primary")
with gr.Column():
html_output = gr.HTML(label="SVG 预览")
code_output = gr.Code(language="xml", label="SVG 代码")
run_btn.click(
fn=convert_image_to_svg,
inputs=[input_image],
outputs=[code_output, html_output]
)
demo.launch(server_name="0.0.0.0", server_port=7860) |