Spaces:
Build error
Build error
| import argparse | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from donut import DonutModel | |
| def demo_process_vqa(input_img, question): | |
| global pretrained_model, task_prompt, task_name | |
| # input_img = Image.fromarray(input_img) | |
| user_prompt = task_prompt.replace("{user_input}", question) | |
| output = pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0] | |
| return output | |
| def demo_process(input_img): | |
| global pretrained_model, task_prompt, task_name,security_layer | |
| input_img = Image.fromarray(input_img) | |
| sec = security_layer.inference(image=input_img,prompt="<s_rvlcdip>")['predictions'][0] | |
| print(sec) | |
| if sec['class']=="invoice": | |
| output = pretrained_model.inference(image=input_img, prompt="<s_cord-v2>")["predictions"][0] | |
| return output | |
| return sec | |
| task_name="cord-v2" | |
| if "docvqa" == task_name: | |
| task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>" | |
| else: # rvlcdip, cord, ... | |
| task_prompt = f"<s_{task_name}>" | |
| security_layer = DonutModel.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip") | |
| pretrained_model = DonutModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2") | |
| if torch.cuda.is_available(): | |
| pretrained_model.half() | |
| security_layer.half() | |
| device = torch.device("cuda") | |
| pretrained_model.to(device) | |
| security_layer.to(device) | |
| else: | |
| pretrained_model.encoder.to(torch.bfloat16) | |
| security_layer.encoder.to(torch.bfloat16) | |
| pretrained_model.eval() | |
| security_layer.eval() | |
| demo = gr.Interface( | |
| fn=demo_process_vqa if task_name == "docvqa" else demo_process, | |
| inputs=["image", "text"] if task_name == "docvqa" else "image", | |
| outputs="json", | |
| title=f"Donut 🍩 demonstration for `{task_name}` task", | |
| concurrency_limit=10, | |
| description="Get invoice details if invoice" | |
| ) | |
| demo.queue(default_concurrency_limit=2,max_size=5) | |
| demo.launch(debug=True,share=True, inline=False) | |