Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import T5ForConditionalGeneration, T5TokenizerFast | |
| import nltk | |
| from nltk import tokenize | |
| nltk.download('punkt') | |
| checkpoint = "yhavinga/t5-base-dutch" | |
| tokenizer = T5TokenizerFast.from_pretrained(checkpoint) | |
| tokenizer.sep_token = '<sep>' | |
| tokenizer.add_tokens(['<sep>']) | |
| hfmodel = T5ForConditionalGeneration.from_pretrained("Michelvh/t5-end2end-questions-generation-dutch") | |
| def hf_run_model(input_string, **generator_args): | |
| generator_args = { | |
| "max_length": 256, | |
| "num_beams": 4, | |
| "length_penalty": 1.5, | |
| "no_repeat_ngram_size": 3, | |
| "early_stopping": True, | |
| "num_return_sequences": 1, | |
| } | |
| input_string = input_string + " </s>" | |
| input_ids = tokenizer.encode(input_string, return_tensors="pt") | |
| res = hfmodel.generate(input_ids, **generator_args) | |
| output = tokenizer.batch_decode(res, skip_special_tokens=True) | |
| output = [item.split("<sep>") for item in output] | |
| return output | |
| def chunk_text(text, framesize=5): | |
| sentences = tokenize.sent_tokenize(text) | |
| frames = [] | |
| lastindex = len(sentences) - framesize + 1 | |
| for index in range(lastindex): | |
| frames.append(" ".join(sentences[index:index+framesize])) | |
| return frames | |
| def flatten(l): | |
| return [item for sublist in l for item in sublist] | |
| def run_model_with_frames(text, framesize=4, overlap=3, progress=gr.Progress()): | |
| if overlap > framesize: | |
| return "Overlap should be smaller than batch size" | |
| frames = create_frames(text, framesize, overlap) | |
| counter = 0 | |
| total_steps = len(frames) | |
| progress((counter, total_steps), desc="Starting...") | |
| result = set() | |
| for frame in frames: | |
| questions = flatten(hf_run_model(frame)) | |
| for question in questions: | |
| result.add(ensure_questionmark(question.strip())) | |
| counter += 1 | |
| progress((counter, total_steps), desc="Generating...") | |
| output_string = "" | |
| for entry in result: | |
| output_string += entry | |
| output_string += "\n" | |
| progress((counter, total_steps), desc="Done") | |
| return output_string | |
| def create_frames(text, framesize=4, overlap=3): | |
| sentences = tokenize.sent_tokenize(text) | |
| frames = [] | |
| stepsize = framesize - overlap | |
| index = 0 | |
| sentenceslength = len(sentences) | |
| while index < sentenceslength: | |
| endindex = index + framesize | |
| if endindex >= sentenceslength: | |
| frame = " ".join(sentences[-framesize:]) | |
| index = sentenceslength | |
| else: | |
| frame = " ".join(sentences[index:endindex]) | |
| index += stepsize | |
| frames.append(frame) | |
| return frames | |
| def ensure_questionmark(question): | |
| if question.endswith("?"): | |
| return question | |
| return question + "?" | |
| description = """ | |
| # Dutch question generator | |
| Input some Dutch text and click the button to generate some questions! | |
| The model is currently set up to generate as many questions, but this | |
| can take a couple of minutes so have some patience ;) | |
| The optimal text lenght is probably around 8-10 lines. Longer text | |
| will obviously take longer. Please keep in mind that this is a work in | |
| progress and might still be a little bit buggy.""" | |
| with gr.Blocks() as iface: | |
| gr.Markdown(description) | |
| context = gr.Textbox(label="Input text") | |
| frame_size = gr.Number(value=5, label="Batch size", info="Size of the subparts that are used to generate questions. Increase to speed up the generation", precision=0) | |
| overlap = gr.Number(value=4, label="Overlap", info="Overlap between batches. Should be bigger than batch size. Decrease to speed up generation", precision=0) | |
| questions = gr.Textbox(label="Questions") | |
| generate_btn = gr.Button("Generate questions") | |
| generate_btn.click(fn=run_model_with_frames, inputs=[context, frame_size, overlap], outputs=questions, api_name="generate_questions") | |
| #iface = gr.Interface(fn=run_model_with_frames, inputs="text", outputs="text") | |
| iface.launch() |