Spaces:
Runtime error
Runtime error
| import json | |
| import re | |
| import time | |
| from datetime import datetime | |
| import gradio as gr | |
| import chat_client | |
| CHAT_URL = "wss://chat.petals.dev/api/v2/generate" | |
| #CHAT_URL='ws://localhost:8000/api/v2/generate' | |
| EMPTY_STATE = { | |
| "generate": False, | |
| "model": None, | |
| "client": None, | |
| "history": [], | |
| } | |
| def generate(state, prompt, model, context, output, *args): | |
| # Save that we're in generating loop | |
| state["generate"] = True | |
| try: | |
| yield from _generate(state, prompt, model, context, output, *args) | |
| except (json.decoder.JSONDecodeError, BrokenPipeError): | |
| # Broken session, try to renew | |
| # TODO This is a bit fragile because of recursive call... | |
| print("Retrying session...") | |
| context = output | |
| output = "" | |
| yield from generate(state, prompt, model, context, output, *args) | |
| finally: | |
| state["generate"] = False | |
| def _generate( | |
| state, | |
| prompt, | |
| model, | |
| context, | |
| output, | |
| endseq, | |
| max_length, | |
| do_sample, | |
| top_k, | |
| top_p, | |
| temperature, | |
| ): | |
| start = time.time() | |
| cnt = 0 # Tokens generated | |
| def stats(): | |
| # Produces inline stats for generation speed | |
| if cnt == 0: | |
| return "\u2026 | ? sec/t" | |
| if cnt > time.time() - start: | |
| items_per_sec = cnt / (time.time() - start) | |
| return f" | {items_per_sec:.1f} t/sec" | |
| sec_per_item = (time.time() - start) / cnt | |
| return f" | {sec_per_item:.1f} sec/t" | |
| eos = "</s>\n" if "bloomz" in model else "\n\n" | |
| if state["model"] != model and output: | |
| # If the connection is resumed, output is truncated in generate(). | |
| # So this executes when user change model. | |
| context = output | |
| output = "" | |
| # Update widgets even before we get the first response | |
| print("prompt", prompt) | |
| yield state, state["history"] + [[prompt, stats()]], "", output | |
| if ( | |
| state["model"] != model | |
| or state["client"] == None | |
| or state["client"].is_session() == False | |
| ): | |
| try: | |
| state["client"] = chat_client.ModelClient(CHAT_URL) | |
| state["client"].open_session(model, max_length) | |
| state["model"] = model | |
| except Exception as e: | |
| print(datetime.now(), str(e)[-500:]) | |
| raise gr.Error(str(e)[-500:]) | |
| else: | |
| context = "" | |
| client = state["client"] | |
| context += eos | |
| # Fix eventual eos token mismatch and add eos token to context and prompt | |
| if "bloomz" in model: | |
| context = context.replace("\n\n", eos) | |
| prompt2 = prompt.replace("\n\n", eos) + "</s>\n" | |
| else: | |
| context = context.replace("</s>", eos) | |
| context = re.sub(r"\n\n+", "\n\n", context) | |
| prompt2 = prompt.replace("</s>", eos) + "\n\n" | |
| prompt2 = f"{context}Human: {prompt2}AI:" | |
| # Translate checkbox items to actual sequences | |
| seq = [] | |
| for s in endseq: | |
| if s == "Human:": | |
| seq.append("Human:") | |
| if s == "AI:": | |
| seq.append("AI:") | |
| if s == "\\n": | |
| seq.append("\n") | |
| elif s == "</s>": | |
| seq.append("</s>") | |
| elif s == "? (question mark)": | |
| seq.append("?") | |
| elif s == ". (dot)": | |
| seq.append(".") | |
| # only top_k or top_p can be set | |
| if top_k == 0: | |
| top_k = None | |
| if top_p == 0: | |
| top_p = None | |
| if top_p and top_k: | |
| top_k = None | |
| if temperature == 0: | |
| temperature = 1.0 | |
| output += prompt2 | |
| orig_history = state["history"] | |
| new_line = "" | |
| try: | |
| for out in client.generate( | |
| prompt2, | |
| max_new_tokens=1, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| stop_sequences=seq, | |
| ): | |
| if not state["generate"]: | |
| client.close_session() | |
| yield state, [], "", "" | |
| # Stopping generation | |
| return | |
| cnt += 1 | |
| new_line += out | |
| # Detect end sequences and finish the generation | |
| # prematurely if found. | |
| for s in seq: | |
| spl = new_line.split(s) | |
| new_line = spl[0] | |
| if len(spl) > 1: | |
| state["history"] = orig_history + [[prompt, new_line]] | |
| output += new_line | |
| yield state, state["history"], "", output | |
| # Stopping generation | |
| return | |
| # Keep original history untouched as we're adding just | |
| # a chunks at one moment. | |
| state["history"] = orig_history + [[prompt, new_line + stats()]] | |
| yield state, state["history"], "", output | |
| # Avoid throwing an exception by generate() | |
| # to prevent UI errors. | |
| if cnt >= max_length - 6: # FIXME Bulgarian constant | |
| break | |
| # Final line w/o statistics | |
| yield state, state["history"], "", output | |
| except (json.decoder.JSONDecodeError, BrokenPipeError): | |
| # Session was interrupted | |
| # Handled in upstream func | |
| client.close_session() | |
| state["client"] = None | |
| state["model"] = None | |
| print("Broken session!") | |
| raise | |
| except Exception as e: | |
| client.close_session() | |
| state["client"] = None | |
| state["model"] = None | |
| print(datetime.now(), str(e)[-500:]) | |
| raise gr.Error(str(e)[-500:]) | |
| def reset(state): | |
| """Resets the session and clears the chat window.""" | |
| state.update(EMPTY_STATE) | |
| return state, [], "" | |
| # --------------------------------------------------------- | |
| # Defining Gradio layout | |
| with gr.Blocks() as iface_chat: | |
| gr.Markdown("""**Let's talk to AI in a chat!**""") | |
| with gr.Row(): | |
| model = gr.Radio( | |
| ["stabilityai/StableBeluga2", "meta-llama/Llama-2-70b-chat-hf", "bigscience/bloomz"], value="stabilityai/StableBeluga2", label="Use model" | |
| ) | |
| # Additional ending sequence, at which generation shoud stop | |
| endseq = gr.CheckboxGroup( | |
| ["Human:", "AI:", "\\n", "</s>", "? (question mark)", ". (dot)"], | |
| value=["Human:", "AI:", "</s>"], | |
| label="Extra end sequences", | |
| ) | |
| # Maximum length of inference session | |
| max_length = gr.Radio( | |
| [64, 128, 256, 512, 1024, 2048], | |
| value=1024, | |
| interactive=True, | |
| label="Max length", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Switch between sampling and greedy generation | |
| do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample") | |
| context = gr.Textbox( | |
| lines=3, | |
| label="Initial context:", | |
| interactive=True, | |
| value="A Human talks to a powerful AI that follows " | |
| "the Human's instructions.\n" | |
| "AI is talkative, friendly, positive and provides " | |
| "detailed answers to any question.</s>\n" | |
| "Human: Hi!</s>\n" | |
| "AI: How can I help you?", | |
| ) | |
| # Only one of top_k and top_p can be set. Requires "do_sample=True" to work. | |
| top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k") | |
| top_p = gr.Number(value=0.9, precision=2, interactive=True, label="top_p") | |
| # TODO num_beams | |
| # Generation temperature | |
| temperature = gr.Number( | |
| value=0.75, precision=2, interactive=True, label="Temperature" | |
| ) | |
| chat = gr.Chatbot(label="Chat window") | |
| prompt = gr.Textbox( | |
| show_label=False, label="Prompt", placeholder="Prompt Here and press Enter..." | |
| ).style(container=False) | |
| with gr.Row(): | |
| button_generate = gr.Button("Generate") | |
| button_reset = gr.Button("Reset session") | |
| with gr.Accordion("Raw prompt log", open=False): | |
| output = gr.Textbox(lines=3, show_label=False).style(container=False) | |
| # Chat history | |
| state = gr.State(EMPTY_STATE) | |
| # Define button actions | |
| inputs = [ | |
| state, | |
| prompt, | |
| model, | |
| context, | |
| output, | |
| endseq, | |
| max_length, | |
| do_sample, | |
| top_k, | |
| top_p, | |
| temperature, | |
| ] | |
| outputs = [state, chat, prompt, output] | |
| prompt.submit(generate, inputs=inputs, outputs=outputs) | |
| button_generate.click(generate, inputs=inputs, outputs=outputs) | |
| button_reset.click(reset, inputs=[state], outputs=[state, chat, output]) | |
| examples = gr.Examples( | |
| inputs=[context, prompt, model, do_sample, top_k, top_p, temperature], | |
| examples=[ | |
| [ | |
| "Human talks to a powerful AI that follows the Human's instructions. " | |
| "AI is a smart, talkative, friendly, honest, helpful, harmless assistant to Human. " | |
| "AI has instant access to an online encyclopedia containing all the facts about the world " | |
| "and answers any question in detail. AI never says common misconceptions, " | |
| "outdated information, lies, fiction, myths, jokes, or memes.</s>\n" | |
| "AI: Hi! How can I help you?</s>\n", | |
| "Could you remind me please who was Neil Armstrong?", | |
| "stabilityai/StableBeluga2", | |
| True, | |
| 0, | |
| 0.9, | |
| 0.75, | |
| ], | |
| [ | |
| "Human mluví s mocnou, inteligentní a vševědoucí AI, která plní instrukce od Human. " | |
| "AI je výřečná, přátelská, pozitivní a poskytuje detailní odpovědi na jakoukoliv otázku.</s>\n" | |
| "Human: Ahoj!</s>\n" | |
| "AI: Ahoj! Jak ti mohu pomoci?", | |
| "Můžeš mi prosím připomenout, kdo byl Neil Armstrong?", | |
| "stabilityai/StableBeluga2", | |
| True, | |
| 0, | |
| 0.9, | |
| 0.75, | |
| ], | |
| ], | |
| ) | |