File size: 4,101 Bytes
a0b30e5
 
 
 
 
 
 
 
 
 
 
 
 
49fff35
a0b30e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import argparse
import gradio as gr
import torch
from PIL import Image
from transformers import DonutProcessor, VisionEncoderDecoderModel

task_prompt = f"<s_cord-v2>" # ๋ชจ๋ธ์—๊ฒŒ "์ง€๊ธˆ ์ˆ˜ํ–‰ํ•  ํƒœ์Šคํฌ ์ข…๋ฅ˜๋ฅผ ์•Œ๋ ค์ฃผ๋Š” ํžŒํŠธ" ์—ญํ• : "์ด๋ฏธ์ง€ ์•ˆ์—์„œ ๋ฌด์—‡์„ ์ฝ์–ด์•ผ ํ•˜๋Š”์ง€"๋ฅผ prompt ํ˜•ํƒœ๋กœ ๊ฐ€์ด๋“œ ํ•ด์คŒ
# pretrained_path = "gwkrsrch/donut-cord-v2-menu-sample-demo"
pretrained_path = "SoccerData/Industry-AI"

processor = DonutProcessor.from_pretrained(pretrained_path)
pretrained_model = VisionEncoderDecoderModel.from_pretrained(pretrained_path)
pretrained_model.half()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrained_model.to(device)
pretrained_model.eval()

import re

def token2json(tokens, is_inner_value=False):
    """
    Convert a (generated) token seuqnce into an ordered JSON format
    """
    output = dict()

    while tokens:
        start_token = re.search(r"<s_(.*?)>", tokens, re.IGNORECASE)
        if start_token is None:
            break
        key = start_token.group(1)
        end_token = re.search(fr"</s_{key}>", tokens, re.IGNORECASE)
        start_token = start_token.group()
        if end_token is None:
            tokens = tokens.replace(start_token, "")
        else:
            end_token = end_token.group()
            start_token_escaped = re.escape(start_token)
            end_token_escaped = re.escape(end_token)
            content = re.search(f"{start_token_escaped}(.*?){end_token_escaped}", tokens, re.IGNORECASE)
            if content is not None:
                content = content.group(1).strip()
                if r"<s_" in content and r"</s_" in content:  # non-leaf node
                    value = token2json(content, is_inner_value=True)
                    if value:
                        if len(value) == 1:
                            value = value[0]
                        output[key] = value
                else:  # leaf nodes
                    output[key] = []
                    for leaf in content.split(r"<sep/>"):
                        leaf = leaf.strip()
                        output[key].append(leaf)
                    if len(output[key]) == 1:
                        output[key] = output[key][0]

            tokens = tokens[tokens.find(end_token) + len(end_token) :].strip()
            if tokens[:6] == r"<sep/>":  # non-leaf nodes
                return [output] + token2json(tokens[6:], is_inner_value=True)

    if len(output):
        return [output] if is_inner_value else output
    else:
        return [] if is_inner_value else {"text_sequence": tokens}

def demo_process(input_img):
    global pretrained_model, task_prompt, device
    input_img = Image.fromarray(input_img)
    pixel_values = processor(input_img, return_tensors="pt").pixel_values.half().to(device)

    decoder_input_ids = torch.full((1, 1), pretrained_model.config.decoder_start_token_id, device=device)

    outputs = pretrained_model.generate(pixel_values,
                                decoder_input_ids=decoder_input_ids,
                                max_length=pretrained_model.config.decoder.max_length,
                                early_stopping=True,
                                pad_token_id=processor.tokenizer.pad_token_id,
                                eos_token_id=processor.tokenizer.eos_token_id,
                                use_cache=True,
                                num_beams=1,
                                bad_words_ids=[[processor.tokenizer.unk_token_id]],
                                return_dict_in_generate=True,)

    predictions = []
    for seq in processor.tokenizer.batch_decode(outputs.sequences):
        seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
        seq = re.sub(r"<.*?>", "", seq, count=1).strip()  # remove first task start token
        predictions.append(seq)
    return token2json(predictions[0])

demo = gr.Interface(
    fn=demo_process,
    inputs="image",
    outputs="json",
    title=f"Donut ๐Ÿฉ demonstration",
)
demo.launch(debug=True)