Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,296 +1,226 @@
|
|
| 1 |
-
#
|
| 2 |
-
#
|
| 3 |
-
# ---------------------------------------------------------------
|
| 4 |
-
|
| 5 |
import gradio as gr
|
| 6 |
import spaces
|
| 7 |
-
import torch
|
| 8 |
-
from transformers import (AutoModelForCausalLM,
|
| 9 |
-
AutoTokenizer,
|
| 10 |
-
TextIteratorStreamer)
|
| 11 |
from threading import Thread
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 17 |
-
model_name = "FractalAIResearch/Fathom-R1-14B"
|
| 18 |
-
|
| 19 |
-
try:
|
| 20 |
-
# 1-line 4-bit loading (needs bitsandbytes, already in HF Space image)
|
| 21 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 22 |
-
model_name,
|
| 23 |
-
device_map="auto",
|
| 24 |
-
load_in_4bit=True,
|
| 25 |
-
trust_remote_code=True
|
| 26 |
-
)
|
| 27 |
-
except RuntimeError:
|
| 28 |
-
# fallback to fp16 if 4-bit isnβt available
|
| 29 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 30 |
-
model_name,
|
| 31 |
-
torch_dtype=torch.float16,
|
| 32 |
-
device_map="auto",
|
| 33 |
-
trust_remote_code=True
|
| 34 |
-
)
|
| 35 |
-
|
| 36 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 37 |
-
device = next(model.parameters()).device # usually cuda:0
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 41 |
-
# 2. Helpers
|
| 42 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 43 |
-
def format_math(text: str) -> str:
|
| 44 |
-
"Replace [...]/\\(...\\) with $$...$$ for nicer math rendering"
|
| 45 |
-
text = re.sub(r"\[(.*?)\]", r"$$\1$$", text, flags=re.DOTALL)
|
| 46 |
-
return text.replace(r"\(", "$").replace(r"\)", "$")
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def generate_conversation_id() -> str:
|
| 50 |
-
return str(uuid.uuid4())[:8]
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
# tiktoken β we just keep it to count tokens during streaming
|
| 54 |
-
enc = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
# Build a prompt that Fathom-R1 understands
|
| 58 |
-
BOS, SEP, EOS = "<|im_start|>", "<|im_sep|>", "<|im_end|>"
|
| 59 |
-
|
| 60 |
-
system_message = (
|
| 61 |
-
"Your role as an assistant involves thoroughly exploring questions "
|
| 62 |
-
"through a systematic thinking process before providing the final "
|
| 63 |
-
"precise and accurate solutions. β¦" # same text you used before
|
| 64 |
)
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
-
|
| 68 |
-
prompt = f"{BOS}system{SEP}{system_message}{EOS}"
|
| 69 |
-
for m in history:
|
| 70 |
-
role = m["role"]
|
| 71 |
-
prompt += f"{BOS}{role}{SEP}{m['content']}{EOS}"
|
| 72 |
-
prompt += f"{BOS}user{SEP}{user_msg}{EOS}{BOS}assistant{SEP}"
|
| 73 |
-
return prompt
|
| 74 |
-
|
| 75 |
|
| 76 |
-
#
|
| 77 |
-
#
|
| 78 |
-
#
|
| 79 |
-
|
| 80 |
-
def generate_response(user_message,
|
| 81 |
-
max_tokens,
|
| 82 |
-
temperature,
|
| 83 |
-
top_p,
|
| 84 |
-
history_state):
|
| 85 |
"""
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
| 88 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
if not user_message.strip():
|
| 90 |
return history_state, history_state
|
| 91 |
|
| 92 |
-
prompt
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
skip_special_tokens=True)
|
| 98 |
|
| 99 |
-
|
| 100 |
input_ids=inputs["input_ids"],
|
| 101 |
attention_mask=inputs["attention_mask"],
|
| 102 |
max_new_tokens=int(max_tokens),
|
|
|
|
| 103 |
temperature=float(temperature),
|
|
|
|
| 104 |
top_p=float(top_p),
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
pad_token_id=tokenizer.eos_token_id,
|
| 108 |
-
streamer=streamer
|
| 109 |
)
|
| 110 |
|
| 111 |
-
#
|
| 112 |
-
Thread(target=model.generate, kwargs=
|
| 113 |
|
| 114 |
assistant_response = ""
|
| 115 |
new_history = history_state + [
|
| 116 |
{"role": "user", "content": user_message},
|
| 117 |
-
{"role": "assistant", "content": ""}
|
| 118 |
]
|
| 119 |
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
|
|
|
| 128 |
yield new_history, new_history
|
| 129 |
-
if tokens_seen >= token_budget:
|
| 130 |
-
break
|
| 131 |
|
| 132 |
-
# final return
|
| 133 |
yield new_history, new_history
|
| 134 |
|
| 135 |
-
|
| 136 |
-
#
|
| 137 |
-
#
|
| 138 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 139 |
example_messages = {
|
| 140 |
-
"
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
),
|
| 144 |
-
"IIT-JEE 2025 Physics": (
|
| 145 |
-
"A person sitting inside an elevator performs a weighing experiment β¦"
|
| 146 |
-
),
|
| 147 |
-
"Goldman Sachs Interview Puzzle": (
|
| 148 |
-
"Four friends need to cross a dangerous bridge at night β¦"
|
| 149 |
-
),
|
| 150 |
-
"IIT-JEE 2025 Mathematics": (
|
| 151 |
-
"Let S be the set of all seven-digit numbers that can be formed β¦"
|
| 152 |
-
)
|
| 153 |
}
|
| 154 |
|
|
|
|
|
|
|
|
|
|
| 155 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 156 |
-
|
| 157 |
-
conversations_state = gr.State({})
|
| 158 |
-
current_convo_id = gr.State(generate_conversation_id())
|
| 159 |
-
history_state = gr.State([])
|
| 160 |
-
|
| 161 |
-
# Header
|
| 162 |
-
gr.HTML(
|
| 163 |
"""
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
</div>
|
| 171 |
"""
|
| 172 |
)
|
| 173 |
|
| 174 |
-
|
| 175 |
-
with gr.Sidebar():
|
| 176 |
-
gr.Markdown("## Conversations")
|
| 177 |
-
conversation_selector = gr.Radio(choices=[], label="Select Conversation", interactive=True)
|
| 178 |
-
new_convo_button = gr.Button("New Conversation β")
|
| 179 |
|
| 180 |
with gr.Row():
|
|
|
|
| 181 |
with gr.Column(scale=1):
|
| 182 |
-
# intro text
|
| 183 |
-
gr.Markdown(
|
| 184 |
-
"""
|
| 185 |
-
Welcome to the Fathom R1 14B Chatbot, developed by **Fractal AI Research**!
|
| 186 |
-
This model excels at reasoning tasks in mathematics and science β¦
|
| 187 |
-
|
| 188 |
-
Once you close this demo window, all currently saved conversations will be lost.
|
| 189 |
-
"""
|
| 190 |
-
)
|
| 191 |
-
|
| 192 |
-
# Settings
|
| 193 |
gr.Markdown("### Settings")
|
| 194 |
-
max_tokens_slider = gr.Slider(
|
| 195 |
-
|
| 196 |
-
temperature_slider = gr.Slider(0.1, 2.0, value=0.6, label="Temperature")
|
| 197 |
-
top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
|
| 198 |
-
|
| 199 |
-
gr.Markdown(
|
| 200 |
-
"""
|
| 201 |
-
We sincerely acknowledge [VIDraft](https://huggingface.co/VIDraft) β¦
|
| 202 |
-
"""
|
| 203 |
)
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
with gr.Column(scale=4):
|
| 206 |
-
chatbot = gr.Chatbot(label="Chat", type="messages"
|
| 207 |
with gr.Row():
|
| 208 |
-
user_input = gr.Textbox(
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
clear_button = gr.Button("Clear", scale=1)
|
| 214 |
-
|
| 215 |
-
# examples
|
| 216 |
gr.Markdown("**Try these examples:**")
|
| 217 |
with gr.Row():
|
| 218 |
-
example1_button = gr.Button("
|
| 219 |
-
example2_button = gr.Button("
|
| 220 |
-
example3_button = gr.Button("
|
| 221 |
-
example4_button = gr.Button("IIT-JEE 2024 Mathematics")
|
| 222 |
-
|
| 223 |
-
# βββββββββ conversation-management helpers ββββββββββββββββββ
|
| 224 |
-
def update_conversation_list(conversations):
|
| 225 |
-
return [conversations[cid]["title"] for cid in conversations]
|
| 226 |
-
|
| 227 |
-
def start_new_conversation(conversations):
|
| 228 |
-
new_id = generate_conversation_id()
|
| 229 |
-
conversations[new_id] = {"title": f"New Conversation {new_id}", "messages": []}
|
| 230 |
-
return new_id, [], gr.update(choices=update_conversation_list(conversations),
|
| 231 |
-
value=conversations[new_id]["title"]), conversations
|
| 232 |
-
|
| 233 |
-
def load_conversation(selected_title, conversations):
|
| 234 |
-
for cid, convo in conversations.items():
|
| 235 |
-
if convo["title"] == selected_title:
|
| 236 |
-
return cid, convo["messages"], convo["messages"]
|
| 237 |
-
return current_convo_id.value, history_state.value, history_state.value
|
| 238 |
|
| 239 |
-
#
|
| 240 |
-
def send_message(user_message, max_tokens, temperature, top_p,
|
| 241 |
-
convo_id, history, conversations):
|
| 242 |
-
if convo_id not in conversations:
|
| 243 |
-
title = " ".join(user_message.strip().split()[:5])
|
| 244 |
-
conversations[convo_id] = {"title": title, "messages": history}
|
| 245 |
-
if conversations[convo_id]["title"].startswith("New Conversation"):
|
| 246 |
-
conversations[convo_id]["title"] = " ".join(user_message.strip().split()[:5])
|
| 247 |
-
|
| 248 |
-
# call the streamer generator and forward its yields
|
| 249 |
-
for updated_history, new_history in generate_response(
|
| 250 |
-
user_message, max_tokens, temperature, top_p, history):
|
| 251 |
-
conversations[convo_id]["messages"] = new_history
|
| 252 |
-
yield (updated_history, new_history,
|
| 253 |
-
gr.update(choices=update_conversation_list(conversations),
|
| 254 |
-
value=conversations[convo_id]["title"]),
|
| 255 |
-
conversations)
|
| 256 |
-
|
| 257 |
-
# βββββββββ UI β functions wiring ββββββββββββββββββββββββββββ
|
| 258 |
submit_button.click(
|
| 259 |
-
fn=
|
| 260 |
-
inputs=[
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
).then(
|
| 265 |
fn=lambda: gr.update(value=""),
|
| 266 |
inputs=None,
|
| 267 |
-
outputs=user_input
|
| 268 |
)
|
| 269 |
|
| 270 |
-
clear_button.click(
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
outputs=[current_convo_id, history_state,
|
| 276 |
-
conversation_selector, conversations_state])
|
| 277 |
-
|
| 278 |
-
conversation_selector.change(fn=load_conversation,
|
| 279 |
-
inputs=[conversation_selector, conversations_state],
|
| 280 |
-
outputs=[current_convo_id, history_state, chatbot])
|
| 281 |
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
-
#
|
| 293 |
-
#
|
| 294 |
-
#
|
| 295 |
if __name__ == "__main__":
|
| 296 |
-
demo.
|
|
|
|
| 1 |
+
# app.py β Gradio chatbot for FractalAIResearch/Fathom-R1-14B
|
| 2 |
+
# ---------------------------------------------------------------------
|
|
|
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
import spaces
|
| 5 |
+
import torch
|
|
|
|
|
|
|
|
|
|
| 6 |
from threading import Thread
|
| 7 |
+
from transformers import (
|
| 8 |
+
AutoModelForCausalLM,
|
| 9 |
+
AutoTokenizer,
|
| 10 |
+
TextIteratorStreamer,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
)
|
| 12 |
|
| 13 |
+
# ---------------------------------------------------------------------
|
| 14 |
+
# 1. Model & tokenizer
|
| 15 |
+
# ---------------------------------------------------------------------
|
| 16 |
+
MODEL_NAME = "FractalAIResearch/Fathom-R1-14B"
|
| 17 |
+
|
| 18 |
+
print("β³ Loading model β¦ (this may take a couple of minutes)")
|
| 19 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 20 |
+
MODEL_NAME,
|
| 21 |
+
device_map="auto", # dispatch across any available device(s)
|
| 22 |
+
trust_remote_code=True, # Fathom uses custom modelling code
|
| 23 |
+
low_cpu_mem_usage=True,
|
| 24 |
+
)
|
| 25 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 26 |
+
MODEL_NAME,
|
| 27 |
+
trust_remote_code=True,
|
| 28 |
+
)
|
| 29 |
|
| 30 |
+
print("β
Model loaded")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
# ---------------------------------------------------------------------
|
| 33 |
+
# 2. Helper: build a prompt with the tokenizerβs chat_template
|
| 34 |
+
# ---------------------------------------------------------------------
|
| 35 |
+
def build_chat_prompt(history, user_message, system_message):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
"""
|
| 37 |
+
history : list[dict(role, content)]
|
| 38 |
+
user_message : str
|
| 39 |
+
system_message : str
|
| 40 |
+
returns a single prompt string (not tokenised)
|
| 41 |
"""
|
| 42 |
+
msgs = []
|
| 43 |
+
if system_message:
|
| 44 |
+
msgs.append({"role": "system", "content": system_message})
|
| 45 |
+
msgs.extend(history)
|
| 46 |
+
msgs.append({"role": "user", "content": user_message})
|
| 47 |
+
|
| 48 |
+
return tokenizer.apply_chat_template(
|
| 49 |
+
msgs,
|
| 50 |
+
tokenize=False, # return pure text
|
| 51 |
+
add_generation_prompt=True,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# ---------------------------------------------------------------------
|
| 55 |
+
# 3. Generation endpoint
|
| 56 |
+
# ---------------------------------------------------------------------
|
| 57 |
+
@spaces.GPU(duration=60) # short GPU reservation if available
|
| 58 |
+
def generate_response(
|
| 59 |
+
user_message,
|
| 60 |
+
max_tokens,
|
| 61 |
+
temperature,
|
| 62 |
+
top_k,
|
| 63 |
+
top_p,
|
| 64 |
+
repetition_penalty,
|
| 65 |
+
history_state,
|
| 66 |
+
):
|
| 67 |
+
# Empty input β nothing to do
|
| 68 |
if not user_message.strip():
|
| 69 |
return history_state, history_state
|
| 70 |
|
| 71 |
+
# System prompt (kept from your Phi-4 version)
|
| 72 |
+
system_message = (
|
| 73 |
+
"Your role as an assistant involves thoroughly exploring questions through a "
|
| 74 |
+
"systematic thinking process before providing the final precise and accurate "
|
| 75 |
+
"solutions. Please structure your response into two main sections: "
|
| 76 |
+
"<think> β¦ </think> and Solution."
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
prompt = build_chat_prompt(history_state, user_message, system_message)
|
| 80 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 81 |
|
| 82 |
+
# Stream tokens as they come
|
| 83 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
|
|
|
|
| 84 |
|
| 85 |
+
generation_kwargs = dict(
|
| 86 |
input_ids=inputs["input_ids"],
|
| 87 |
attention_mask=inputs["attention_mask"],
|
| 88 |
max_new_tokens=int(max_tokens),
|
| 89 |
+
do_sample=True,
|
| 90 |
temperature=float(temperature),
|
| 91 |
+
top_k=int(top_k),
|
| 92 |
top_p=float(top_p),
|
| 93 |
+
repetition_penalty=float(repetition_penalty),
|
| 94 |
+
streamer=streamer,
|
|
|
|
|
|
|
| 95 |
)
|
| 96 |
|
| 97 |
+
# Run generate in a background thread so the UI stays responsive
|
| 98 |
+
Thread(target=model.generate, kwargs=generation_kwargs).start()
|
| 99 |
|
| 100 |
assistant_response = ""
|
| 101 |
new_history = history_state + [
|
| 102 |
{"role": "user", "content": user_message},
|
| 103 |
+
{"role": "assistant", "content": ""},
|
| 104 |
]
|
| 105 |
|
| 106 |
+
for token in streamer:
|
| 107 |
+
# strip any stray special tokens the model may output
|
| 108 |
+
cleaned = (
|
| 109 |
+
token.replace("<|im_start|>", "")
|
| 110 |
+
.replace("<|im_end|>", "")
|
| 111 |
+
.replace("<|im_sep|>", "")
|
| 112 |
+
)
|
| 113 |
+
assistant_response += cleaned
|
| 114 |
+
new_history[-1]["content"] = assistant_response.strip()
|
| 115 |
yield new_history, new_history
|
|
|
|
|
|
|
| 116 |
|
|
|
|
| 117 |
yield new_history, new_history
|
| 118 |
|
| 119 |
+
# ---------------------------------------------------------------------
|
| 120 |
+
# 4. Example questions (unchanged)
|
| 121 |
+
# ---------------------------------------------------------------------
|
|
|
|
| 122 |
example_messages = {
|
| 123 |
+
"Math reasoning": "If a rectangular prism has a length of 6 cm, a width of 4 cm, and a height of 5 cm, what is the length of the longest line segment that can be drawn from one vertex to another?",
|
| 124 |
+
"Logic puzzle": "Four people (Alex, Blake, Casey, and Dana) each have a different favorite color (red, blue, green, yellow) and a different favorite fruit (apple, banana, cherry, date). Given the following clues: 1) The person who likes red doesn't like dates. 2) Alex likes yellow. 3) The person who likes blue likes cherries. 4) Blake doesn't like apples or bananas. 5) Casey doesn't like yellow or green. Who likes what color and what fruit?",
|
| 125 |
+
"Physics problem": "A ball is thrown upward with an initial velocity of 15 m/s from a height of 2 meters above the ground. Assuming the acceleration due to gravity is 9.8 m/sΒ², determine: 1) The maximum height the ball reaches. 2) The total time the ball is in the air before hitting the ground. 3) The velocity with which the ball hits the ground.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
}
|
| 127 |
|
| 128 |
+
# ---------------------------------------------------------------------
|
| 129 |
+
# 5. Gradio UI (identical to the original, just lower default max_tokens)
|
| 130 |
+
# ---------------------------------------------------------------------
|
| 131 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 132 |
+
gr.Markdown(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
"""
|
| 134 |
+
# Fathom-R1-14B Chatbot
|
| 135 |
+
The model excels at multi-step reasoning in mathematics, logic, and science.
|
| 136 |
+
|
| 137 |
+
It returns two sections:\n
|
| 138 |
+
1. **<think>** β detailed chain-of-thought (reasoning)\n
|
| 139 |
+
2. **Solution** β concise, final answer
|
|
|
|
| 140 |
"""
|
| 141 |
)
|
| 142 |
|
| 143 |
+
history_state = gr.State([])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
with gr.Row():
|
| 146 |
+
# Settings panel
|
| 147 |
with gr.Column(scale=1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
gr.Markdown("### Settings")
|
| 149 |
+
max_tokens_slider = gr.Slider(
|
| 150 |
+
minimum=64, maximum=4096, step=256, value=1024, label="Max Tokens"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
)
|
| 152 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 153 |
+
temperature_slider = gr.Slider(
|
| 154 |
+
minimum=0.1, maximum=2.0, value=0.8, label="Temperature"
|
| 155 |
+
)
|
| 156 |
+
top_k_slider = gr.Slider(
|
| 157 |
+
minimum=1, maximum=100, step=1, value=50, label="Top-k"
|
| 158 |
+
)
|
| 159 |
+
top_p_slider = gr.Slider(
|
| 160 |
+
minimum=0.1, maximum=1.0, value=0.95, label="Top-p"
|
| 161 |
+
)
|
| 162 |
+
repetition_penalty_slider = gr.Slider(
|
| 163 |
+
minimum=1.0, maximum=2.0, value=1.0, label="Repetition Penalty"
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# Chat area
|
| 167 |
with gr.Column(scale=4):
|
| 168 |
+
chatbot = gr.Chatbot(label="Chat", type="messages")
|
| 169 |
with gr.Row():
|
| 170 |
+
user_input = gr.Textbox(
|
| 171 |
+
label="Your message", placeholder="Type your message hereβ¦", scale=3
|
| 172 |
+
)
|
| 173 |
+
submit_button = gr.Button("Send", variant="primary", scale=1)
|
| 174 |
+
clear_button = gr.Button("Clear", scale=1)
|
|
|
|
|
|
|
|
|
|
| 175 |
gr.Markdown("**Try these examples:**")
|
| 176 |
with gr.Row():
|
| 177 |
+
example1_button = gr.Button("Math reasoning")
|
| 178 |
+
example2_button = gr.Button("Logic puzzle")
|
| 179 |
+
example3_button = gr.Button("Physics problem")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
+
# Button wiring
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
submit_button.click(
|
| 183 |
+
fn=generate_response,
|
| 184 |
+
inputs=[
|
| 185 |
+
user_input,
|
| 186 |
+
max_tokens_slider,
|
| 187 |
+
temperature_slider,
|
| 188 |
+
top_k_slider,
|
| 189 |
+
top_p_slider,
|
| 190 |
+
repetition_penalty_slider,
|
| 191 |
+
history_state,
|
| 192 |
+
],
|
| 193 |
+
outputs=[chatbot, history_state],
|
| 194 |
).then(
|
| 195 |
fn=lambda: gr.update(value=""),
|
| 196 |
inputs=None,
|
| 197 |
+
outputs=user_input,
|
| 198 |
)
|
| 199 |
|
| 200 |
+
clear_button.click(
|
| 201 |
+
fn=lambda: ([], []),
|
| 202 |
+
inputs=None,
|
| 203 |
+
outputs=[chatbot, history_state],
|
| 204 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
+
example1_button.click(
|
| 207 |
+
fn=lambda: gr.update(value=example_messages["Math reasoning"]),
|
| 208 |
+
inputs=None,
|
| 209 |
+
outputs=user_input,
|
| 210 |
+
)
|
| 211 |
+
example2_button.click(
|
| 212 |
+
fn=lambda: gr.update(value=example_messages["Logic puzzle"]),
|
| 213 |
+
inputs=None,
|
| 214 |
+
outputs=user_input,
|
| 215 |
+
)
|
| 216 |
+
example3_button.click(
|
| 217 |
+
fn=lambda: gr.update(value=example_messages["Physics problem"]),
|
| 218 |
+
inputs=None,
|
| 219 |
+
outputs=user_input,
|
| 220 |
+
)
|
| 221 |
|
| 222 |
+
# ---------------------------------------------------------------------
|
| 223 |
+
# 6. Launch
|
| 224 |
+
# ---------------------------------------------------------------------
|
| 225 |
if __name__ == "__main__":
|
| 226 |
+
demo.launch(ssr_mode=False)
|