Spaces:
Runtime error
Runtime error
| import argparse | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import DonutProcessor, VisionEncoderDecoderModel | |
| task_prompt = f"<s_cord-v2>" # λͺ¨λΈμκ² "μ§κΈ μνν νμ€ν¬ μ’ λ₯λ₯Ό μλ €μ£Όλ ννΈ" μν : "μ΄λ―Έμ§ μμμ 무μμ μ½μ΄μΌ νλμ§"λ₯Ό prompt ννλ‘ κ°μ΄λ ν΄μ€ | |
| # pretrained_path = "gwkrsrch/donut-cord-v2-menu-sample-demo" | |
| pretrained_path = "SoccerData/Industry-AI" | |
| processor = DonutProcessor.from_pretrained(pretrained_path) | |
| pretrained_model = VisionEncoderDecoderModel.from_pretrained(pretrained_path) | |
| pretrained_model.half() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| pretrained_model.to(device) | |
| pretrained_model.eval() | |
| import re | |
| def token2json(tokens, is_inner_value=False): | |
| """ | |
| Convert a (generated) token seuqnce into an ordered JSON format | |
| """ | |
| output = dict() | |
| while tokens: | |
| start_token = re.search(r"<s_(.*?)>", tokens, re.IGNORECASE) | |
| if start_token is None: | |
| break | |
| key = start_token.group(1) | |
| end_token = re.search(fr"</s_{key}>", tokens, re.IGNORECASE) | |
| start_token = start_token.group() | |
| if end_token is None: | |
| tokens = tokens.replace(start_token, "") | |
| else: | |
| end_token = end_token.group() | |
| start_token_escaped = re.escape(start_token) | |
| end_token_escaped = re.escape(end_token) | |
| content = re.search(f"{start_token_escaped}(.*?){end_token_escaped}", tokens, re.IGNORECASE) | |
| if content is not None: | |
| content = content.group(1).strip() | |
| if r"<s_" in content and r"</s_" in content: # non-leaf node | |
| value = token2json(content, is_inner_value=True) | |
| if value: | |
| if len(value) == 1: | |
| value = value[0] | |
| output[key] = value | |
| else: # leaf nodes | |
| output[key] = [] | |
| for leaf in content.split(r"<sep/>"): | |
| leaf = leaf.strip() | |
| output[key].append(leaf) | |
| if len(output[key]) == 1: | |
| output[key] = output[key][0] | |
| tokens = tokens[tokens.find(end_token) + len(end_token) :].strip() | |
| if tokens[:6] == r"<sep/>": # non-leaf nodes | |
| return [output] + token2json(tokens[6:], is_inner_value=True) | |
| if len(output): | |
| return [output] if is_inner_value else output | |
| else: | |
| return [] if is_inner_value else {"text_sequence": tokens} | |
| def demo_process(input_img): | |
| global pretrained_model, task_prompt, device | |
| input_img = Image.fromarray(input_img) | |
| pixel_values = processor(input_img, return_tensors="pt").pixel_values.half().to(device) | |
| decoder_input_ids = torch.full((1, 1), pretrained_model.config.decoder_start_token_id, device=device) | |
| outputs = pretrained_model.generate(pixel_values, | |
| decoder_input_ids=decoder_input_ids, | |
| max_length=pretrained_model.config.decoder.max_length, | |
| early_stopping=True, | |
| pad_token_id=processor.tokenizer.pad_token_id, | |
| eos_token_id=processor.tokenizer.eos_token_id, | |
| use_cache=True, | |
| num_beams=1, | |
| bad_words_ids=[[processor.tokenizer.unk_token_id]], | |
| return_dict_in_generate=True,) | |
| predictions = [] | |
| for seq in processor.tokenizer.batch_decode(outputs.sequences): | |
| seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") | |
| seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token | |
| predictions.append(seq) | |
| return token2json(predictions[0]) | |
| demo = gr.Interface( | |
| fn=demo_process, | |
| inputs="image", | |
| outputs="json", | |
| title=f"Donut π© demonstration", | |
| ) | |
| demo.launch(debug=True) |