import os os.environ["OMP_NUM_THREADS"] = "1" import gradio as gr import torch # 【关键修改】不再用 AutoTokenizer,改用 GPT2Tokenizer from transformers import GPT2Tokenizer, SiglipImageProcessor, AutoModelForCausalLM from PIL import Image import re # --- 模型配置 --- model_id = "starvector/starvector-1b-im2svg" print(f"正在加载模型: {model_id} ...") # 1. 加载图片处理器 try: image_processor = SiglipImageProcessor.from_pretrained(model_id, trust_remote_code=True) except: from transformers import AutoImageProcessor image_processor = AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True) # 2. 加载 Tokenizer print("正在加载 Tokenizer (强制使用 GPT2 Python 版)...") # 【核心修复】 # StarVector-1B 基于 StarCoder (GPT2架构)。 # 直接使用 GPT2Tokenizer 可以避开 AutoTokenizer 强制读取 tokenizer.json 导致的报错。 # 这个类只读取 vocab.json,兼容性最好。 tokenizer = GPT2Tokenizer.from_pretrained(model_id, trust_remote_code=True) # 确保 pad_token 存在 (GPT2 默认没有 pad_token,StarCoder 通常用 eos 作为 pad) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # 3. 加载模型 print("正在加载主模型 (FP16)...") # 1B 模型很小,直接加载到 CPU 没问题,不会 OOM model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, trust_remote_code=True # 不使用 device_map,防止 Meta Tensor 报错 ) print("正在移动模型到 GPU...") model.to("cuda") print("模型加载完成!") # --- 推理函数 --- def convert_image_to_svg(image): if image is None: return "请上传图片", "" try: # 预处理图片 image_inputs = image_processor(image, return_tensors="pt") # 确保数据在 GPU 上且精度匹配 pixel_values = image_inputs.pixel_values.to("cuda").to(torch.float16) # 预处理文本 text_prompt = "Generate SVG code" text_inputs = tokenizer(text_prompt, return_tensors="pt") input_ids = text_inputs.input_ids.to("cuda") attention_mask = text_inputs.attention_mask.to("cuda") # 生成 with torch.no_grad(): generated_ids = model.generate( input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, max_new_tokens=2048, do_sample=False, pad_token_id=tokenizer.eos_token_id ) # 解码 generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) svg_code = extract_svg_content(generated_text) return svg_code, svg_code except Exception as e: import traceback traceback.print_exc() return f"生成出错: {str(e)}", "" def extract_svg_content(text): pattern = r"()" match = re.search(pattern, text, re.DOTALL) if match: return match.group(1) if "```xml" in text: return text.split("```xml")[1].split("```")[0] if "```svg" in text: return text.split("```svg")[1].split("```")[0] if "xmlns=" in text and "") + 6 if start != -1 and end != -1: return text[start:end] return text # --- 界面 --- with gr.Blocks(title="StarVector-1B") as demo: gr.Markdown("# StarVector-1B Image to SVG") with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="上传图片") run_btn = gr.Button("生成 SVG", variant="primary") with gr.Column(): html_output = gr.HTML(label="SVG 预览") code_output = gr.Code(language="xml", label="SVG 代码") run_btn.click( fn=convert_image_to_svg, inputs=[input_image], outputs=[code_output, html_output] ) demo.launch(server_name="0.0.0.0", server_port=7860)