File size: 4,073 Bytes
26c0321
 
 
4f5bc6b
 
b1cb14b
 
4f5bc6b
 
 
0959374
070f825
4f5bc6b
070f825
4f5bc6b
b1cb14b
f203e5b
 
0959374
 
 
 
 
b1cb14b
ebb582a
b1cb14b
 
 
 
 
 
 
 
ebb582a
 
 
b1cb14b
4f5bc6b
 
26c0321
4f5bc6b
b1cb14b
4f5bc6b
 
ebb582a
0959374
 
4f5bc6b
 
0959374
4f5bc6b
 
 
 
8fb934f
070f825
8fb934f
0959374
 
8fb934f
070f825
8fb934f
 
0959374
 
8fb934f
 
 
 
 
 
 
070f825
f203e5b
8fb934f
 
 
 
 
 
 
 
 
 
26c0321
070f825
8fb934f
4f5bc6b
 
 
 
 
 
 
 
 
 
 
 
8fb934f
 
 
 
 
 
4f5bc6b
 
0959374
 
070f825
4f5bc6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
708af3a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
os.environ["OMP_NUM_THREADS"] = "1"

import gradio as gr
import torch
# 【关键修改】不再用 AutoTokenizer,改用 GPT2Tokenizer
from transformers import GPT2Tokenizer, SiglipImageProcessor, AutoModelForCausalLM
from PIL import Image
import re

# --- 模型配置 ---
model_id = "starvector/starvector-1b-im2svg"

print(f"正在加载模型: {model_id} ...")

# 1. 加载图片处理器
try:
    image_processor = SiglipImageProcessor.from_pretrained(model_id, trust_remote_code=True)
except:
    from transformers import AutoImageProcessor
    image_processor = AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)

# 2. 加载 Tokenizer
print("正在加载 Tokenizer (强制使用 GPT2 Python 版)...")
# 【核心修复】
# StarVector-1B 基于 StarCoder (GPT2架构)。
# 直接使用 GPT2Tokenizer 可以避开 AutoTokenizer 强制读取 tokenizer.json 导致的报错。
# 这个类只读取 vocab.json,兼容性最好。
tokenizer = GPT2Tokenizer.from_pretrained(model_id, trust_remote_code=True)

# 确保 pad_token 存在 (GPT2 默认没有 pad_token,StarCoder 通常用 eos 作为 pad)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# 3. 加载模型
print("正在加载主模型 (FP16)...")
# 1B 模型很小,直接加载到 CPU 没问题,不会 OOM
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    trust_remote_code=True
    # 不使用 device_map,防止 Meta Tensor 报错
)

print("正在移动模型到 GPU...")
model.to("cuda")

print("模型加载完成!")

# --- 推理函数 ---
def convert_image_to_svg(image):
    if image is None:
        return "请上传图片", ""
    
    try:
        # 预处理图片
        image_inputs = image_processor(image, return_tensors="pt")
        # 确保数据在 GPU 上且精度匹配
        pixel_values = image_inputs.pixel_values.to("cuda").to(torch.float16)
        
        # 预处理文本
        text_prompt = "Generate SVG code"
        text_inputs = tokenizer(text_prompt, return_tensors="pt")
        input_ids = text_inputs.input_ids.to("cuda")
        attention_mask = text_inputs.attention_mask.to("cuda")

        # 生成
        with torch.no_grad():
            generated_ids = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                pixel_values=pixel_values,
                max_new_tokens=2048, 
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id
            )

        # 解码
        generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        svg_code = extract_svg_content(generated_text)
        
        return svg_code, svg_code

    except Exception as e:
        import traceback
        traceback.print_exc()
        return f"生成出错: {str(e)}", ""

def extract_svg_content(text):
    pattern = r"(<svg.*?</svg>)"
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group(1)
    
    if "```xml" in text:
        return text.split("```xml")[1].split("```")[0]
    if "```svg" in text:
        return text.split("```svg")[1].split("```")[0]
        
    if "xmlns=" in text and "<svg" in text:
        start = text.find("<svg")
        end = text.rfind("</svg>") + 6
        if start != -1 and end != -1:
            return text[start:end]

    return text

# --- 界面 ---
with gr.Blocks(title="StarVector-1B") as demo:
    gr.Markdown("# StarVector-1B Image to SVG")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="上传图片")
            run_btn = gr.Button("生成 SVG", variant="primary")
        
        with gr.Column():
            html_output = gr.HTML(label="SVG 预览")
            code_output = gr.Code(language="xml", label="SVG 代码")

    run_btn.click(
        fn=convert_image_to_svg,
        inputs=[input_image],
        outputs=[code_output, html_output]
    )

demo.launch(server_name="0.0.0.0", server_port=7860)