Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| from transformers import pipeline, set_seed | |
| from transformers import AutoTokenizer | |
| from transformers import GPT2LMHeadModel | |
| from mtranslate import translate | |
| import random | |
| import meta | |
| from normalizer import normalize | |
| from utils import ( | |
| remote_css, | |
| local_css, | |
| load_json | |
| ) | |
| EXAMPLES = load_json("examples.json") | |
| CK = "متن" | |
| QK = "پرسش" | |
| AK = "پاسخ" | |
| class TextGeneration: | |
| def __init__(self): | |
| self.debug = False | |
| self.dummy_output = "مخلوطی از ایتالیایی و انگلیسی" | |
| self.tokenizer = None | |
| self.model = None | |
| self.model_name_or_path = "m3hrdadfi/gpt2-persian-qa" | |
| self.length_margin = 100 | |
| set_seed(42) | |
| def load(self): | |
| if not self.debug: | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) | |
| self.model = GPT2LMHeadModel.from_pretrained(self.model_name_or_path) | |
| def generate(self, prompt, generation_kwargs): | |
| if not self.debug: | |
| input_ids = self.tokenizer([prompt], return_tensors="pt")["input_ids"] | |
| max_length = len(input_ids[0]) + self.length_margin | |
| generation_kwargs["max_length"] = max_length | |
| generated = self.model.generate( | |
| input_ids, | |
| **generation_kwargs, | |
| )[0] | |
| answer = self.tokenizer.decode(generated, skip_special_tokens=True) | |
| found = answer.find(f"{AK}: ") | |
| if not found: | |
| return "" | |
| answer = [a.strip() for a in answer[found:].split(f"{AK}: ") if a.strip()] | |
| answer = answer[0] if len(answer) > 0 else "" | |
| return answer | |
| return self.dummy_output | |
| def load_text_generator(): | |
| generator = TextGeneration() | |
| generator.load() | |
| return generator | |
| def main(): | |
| st.set_page_config( | |
| page_title="GPT2 QA - Persian", | |
| page_icon="⁉️", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| remote_css("https://cdn.jsdelivr.net/gh/rastikerdar/vazir-font/dist/font-face.css") | |
| local_css("assets/rtl.css") | |
| generator = load_text_generator() | |
| st.sidebar.markdown(meta.SIDEBAR_INFO) | |
| num_beams = st.sidebar.slider( | |
| label='Number of Beam', | |
| help="Number of beams for beam search", | |
| min_value=4, | |
| max_value=15, | |
| value=5, | |
| step=1 | |
| ) | |
| repetition_penalty = st.sidebar.slider( | |
| label='Repetition Penalty', | |
| help="The parameter for repetition penalty", | |
| min_value=1.0, | |
| max_value=10.0, | |
| value=1.0, | |
| step=0.1 | |
| ) | |
| length_penalty = st.sidebar.slider( | |
| label='Length Penalty', | |
| help="Exponential penalty to the length", | |
| min_value=1.0, | |
| max_value=10.0, | |
| value=1.0, | |
| step=0.1 | |
| ) | |
| early_stopping = st.sidebar.selectbox( | |
| label='Early Stopping ?', | |
| options=(True, False), | |
| help="Whether to stop the beam search when at least num_beams sentences are finished per batch or not", | |
| ) | |
| translated = st.sidebar.selectbox( | |
| label='Translation ?', | |
| options=(True, False), | |
| help="Will translate the result in English", | |
| ) | |
| generation_kwargs = { | |
| "num_beams": num_beams, | |
| "early_stopping": early_stopping, | |
| "repetition_penalty": repetition_penalty, | |
| "length_penalty": length_penalty, | |
| } | |
| st.markdown(meta.HEADER_INFO) | |
| prompts = [e["title"] for e in EXAMPLES] + ["Custom"] | |
| prompt = st.selectbox('Examples', prompts, index=len(prompts) - 1) | |
| if prompt == "Custom": | |
| prompt_box = { | |
| "context": meta.C_PROMPT_BOX, | |
| "question": meta.Q_PROMPT_BOX, | |
| "answer": meta.A_PROMPT_BOX, | |
| } | |
| else: | |
| prompt_box = next(e for e in EXAMPLES if e["title"] == prompt) | |
| context = st.text_area("Enter context", prompt_box["context"], height=250) | |
| question = st.text_area("Enter question", prompt_box["question"], height=100) | |
| answer = "پاسخ درست: " + prompt_box["answer"] | |
| st.markdown( | |
| f'<p class="rtl rtl-box">' | |
| f'{answer}' | |
| f'<p>', | |
| unsafe_allow_html=True | |
| ) | |
| if translated: | |
| translated_answer = translate(answer, "en", "fa") | |
| st.markdown( | |
| f'<p class="ltr">' | |
| f'{translated_answer}' | |
| f'<p>', | |
| unsafe_allow_html=True | |
| ) | |
| generation_kwargs_ph = st.empty() | |
| if st.button("Find the answer 🔎 "): | |
| with st.spinner(text="Searching ..."): | |
| generation_kwargs_ph.markdown(", ".join([f"`{k}`: {v}" for k, v in generation_kwargs.items()])) | |
| context = normalize(context) | |
| question = normalize(question) | |
| if context and question: | |
| text = f"{context} {QK}: {question} {AK}:" | |
| generated_answer = generator.generate(text, generation_kwargs) | |
| generated_answer = f"{AK}: {generated_answer}".strip() | |
| context = f"{CK}: {context}".strip() | |
| question = f"{QK}: {question}".strip() | |
| st.markdown( | |
| f'<p class="rtl rtl-box">' | |
| f'<span class="result-text">{context}<span><br/><br/>' | |
| f'<span class="result-text">{question}<span><br/><br/>' | |
| f'<span class="result-text generated-text">{generated_answer} </span>' | |
| f'</p>', | |
| unsafe_allow_html=True | |
| ) | |
| if translated: | |
| translated_context = translate(context, "en", "fa") | |
| translated_question = translate(question, "en", "fa") | |
| translated_generated_answer = translate(generated_answer, "en", "fa") | |
| st.markdown( | |
| f'<p class="ltr ltr-box">' | |
| f'<span class="result-text">{translated_context}<span><br/><br/>' | |
| f'<span class="result-text">{translated_question}<span><br/><br/>' | |
| f'<span class="result-text generated-text">{translated_generated_answer}</span>' | |
| f'</p>', | |
| unsafe_allow_html=True | |
| ) | |
| if __name__ == '__main__': | |
| main() | |