FractalAIR commited on
Commit
e30ac3f
Β·
verified Β·
1 Parent(s): f14967b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -238
app.py CHANGED
@@ -1,296 +1,226 @@
1
- # ---------------------------------------------------------------
2
- # Fathom-R1-14B ZeroGPU chat-demo (Gradio Blocks)
3
- # ---------------------------------------------------------------
4
-
5
  import gradio as gr
6
  import spaces
7
- import torch, re, uuid, tiktoken
8
- from transformers import (AutoModelForCausalLM,
9
- AutoTokenizer,
10
- TextIteratorStreamer)
11
  from threading import Thread
12
-
13
- # ────────────────────────────────────────────────────────────────
14
- # 1. Load the model on the single GPU supplied by ZeroGPU
15
- # (4-bit to stay well below the 24 GB VRAM of an A10G)
16
- # ────────────────────────────────────────────────────────────────
17
- model_name = "FractalAIResearch/Fathom-R1-14B"
18
-
19
- try:
20
- # 1-line 4-bit loading (needs bitsandbytes, already in HF Space image)
21
- model = AutoModelForCausalLM.from_pretrained(
22
- model_name,
23
- device_map="auto",
24
- load_in_4bit=True,
25
- trust_remote_code=True
26
- )
27
- except RuntimeError:
28
- # fallback to fp16 if 4-bit isn’t available
29
- model = AutoModelForCausalLM.from_pretrained(
30
- model_name,
31
- torch_dtype=torch.float16,
32
- device_map="auto",
33
- trust_remote_code=True
34
- )
35
-
36
- tokenizer = AutoTokenizer.from_pretrained(model_name)
37
- device = next(model.parameters()).device # usually cuda:0
38
-
39
-
40
- # ────────────────────────────────────────────────────────────────
41
- # 2. Helpers
42
- # ────────────────────────────────────────────────────────────────
43
- def format_math(text: str) -> str:
44
- "Replace [...]/\\(...\\) with $$...$$ for nicer math rendering"
45
- text = re.sub(r"\[(.*?)\]", r"$$\1$$", text, flags=re.DOTALL)
46
- return text.replace(r"\(", "$").replace(r"\)", "$")
47
-
48
-
49
- def generate_conversation_id() -> str:
50
- return str(uuid.uuid4())[:8]
51
-
52
-
53
- # tiktoken – we just keep it to count tokens during streaming
54
- enc = tiktoken.encoding_for_model("gpt-3.5-turbo")
55
-
56
-
57
- # Build a prompt that Fathom-R1 understands
58
- BOS, SEP, EOS = "<|im_start|>", "<|im_sep|>", "<|im_end|>"
59
-
60
- system_message = (
61
- "Your role as an assistant involves thoroughly exploring questions "
62
- "through a systematic thinking process before providing the final "
63
- "precise and accurate solutions. …" # same text you used before
64
  )
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- def build_prompt(history, user_msg: str) -> str:
68
- prompt = f"{BOS}system{SEP}{system_message}{EOS}"
69
- for m in history:
70
- role = m["role"]
71
- prompt += f"{BOS}{role}{SEP}{m['content']}{EOS}"
72
- prompt += f"{BOS}user{SEP}{user_msg}{EOS}{BOS}assistant{SEP}"
73
- return prompt
74
-
75
 
76
- # ────────────────────────────────────────────────────────────────
77
- # 3. Generation (runs on the GPU for 60 s max per call)
78
- # ────────────────────────────────────────────────────────────────
79
- @spaces.GPU(duration=60)
80
- def generate_response(user_message,
81
- max_tokens,
82
- temperature,
83
- top_p,
84
- history_state):
85
  """
86
- Takes exactly the same signature the rest of the UI expects:
87
- returns (visible_chatbot, history_state)
 
 
88
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  if not user_message.strip():
90
  return history_state, history_state
91
 
92
- prompt = build_prompt(history_state, user_message)
93
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
94
 
95
- streamer = TextIteratorStreamer(tokenizer,
96
- skip_prompt=True,
97
- skip_special_tokens=True)
98
 
99
- gen_kwargs = dict(
100
  input_ids=inputs["input_ids"],
101
  attention_mask=inputs["attention_mask"],
102
  max_new_tokens=int(max_tokens),
 
103
  temperature=float(temperature),
 
104
  top_p=float(top_p),
105
- do_sample=True,
106
- eos_token_id=tokenizer.eos_token_id,
107
- pad_token_id=tokenizer.eos_token_id,
108
- streamer=streamer
109
  )
110
 
111
- # run generate in a background thread – lets us stream tokens
112
- Thread(target=model.generate, kwargs=gen_kwargs).start()
113
 
114
  assistant_response = ""
115
  new_history = history_state + [
116
  {"role": "user", "content": user_message},
117
- {"role": "assistant", "content": ""}
118
  ]
119
 
120
- # live-stream tokens to the UI
121
- tokens_seen = 0
122
- token_budget = int(max_tokens)
123
-
124
- for new_tok in streamer:
125
- assistant_response += new_tok
126
- tokens_seen += len(enc.encode(new_tok))
127
- new_history[-1]["content"] = format_math(assistant_response.strip())
 
128
  yield new_history, new_history
129
- if tokens_seen >= token_budget:
130
- break
131
 
132
- # final return
133
  yield new_history, new_history
134
 
135
-
136
- # ────────────────────────────────────────────────────────────────
137
- # 4. Demo UI – identical to your current one
138
- # ────────────────────────────────────────────────────────────────
139
  example_messages = {
140
- "IIT-JEE 2024 Mathematics": (
141
- "A student appears for a quiz consisting of only true-false type "
142
- "questions and answers all the questions. …"
143
- ),
144
- "IIT-JEE 2025 Physics": (
145
- "A person sitting inside an elevator performs a weighing experiment …"
146
- ),
147
- "Goldman Sachs Interview Puzzle": (
148
- "Four friends need to cross a dangerous bridge at night …"
149
- ),
150
- "IIT-JEE 2025 Mathematics": (
151
- "Let S be the set of all seven-digit numbers that can be formed …"
152
- )
153
  }
154
 
 
 
 
155
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
156
- # session-scoped states
157
- conversations_state = gr.State({})
158
- current_convo_id = gr.State(generate_conversation_id())
159
- history_state = gr.State([])
160
-
161
- # Header
162
- gr.HTML(
163
  """
164
- <div style="display:flex;align-items:center;gap:16px;margin-bottom:1em">
165
- <div style="background-color:black;padding:6px;border-radius:8px">
166
- <img src="https://framerusercontent.com/images/j0KjQQyrUfkFw4NwSaxQOLAoBU.png"
167
- style="height:48px">
168
- </div>
169
- <h1 style="margin:0;">Fathom R1 14B Chatbot</h1>
170
- </div>
171
  """
172
  )
173
 
174
- # Sidebar
175
- with gr.Sidebar():
176
- gr.Markdown("## Conversations")
177
- conversation_selector = gr.Radio(choices=[], label="Select Conversation", interactive=True)
178
- new_convo_button = gr.Button("New Conversation βž•")
179
 
180
  with gr.Row():
 
181
  with gr.Column(scale=1):
182
- # intro text
183
- gr.Markdown(
184
- """
185
- Welcome to the Fathom R1 14B Chatbot, developed by **Fractal AI Research**!
186
- This model excels at reasoning tasks in mathematics and science …
187
-
188
- Once you close this demo window, all currently saved conversations will be lost.
189
- """
190
- )
191
-
192
- # Settings
193
  gr.Markdown("### Settings")
194
- max_tokens_slider = gr.Slider(6144, 32768, step=1024, value=16384, label="Max Tokens")
195
- with gr.Accordion("Advanced Settings", open=True):
196
- temperature_slider = gr.Slider(0.1, 2.0, value=0.6, label="Temperature")
197
- top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
198
-
199
- gr.Markdown(
200
- """
201
- We sincerely acknowledge [VIDraft](https://huggingface.co/VIDraft) …
202
- """
203
  )
204
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  with gr.Column(scale=4):
206
- chatbot = gr.Chatbot(label="Chat", type="messages", height=520)
207
  with gr.Row():
208
- user_input = gr.Textbox(label="User Input",
209
- placeholder="Type your question here…",
210
- lines=3, scale=8)
211
- with gr.Column():
212
- submit_button = gr.Button("Send", variant="primary", scale=1)
213
- clear_button = gr.Button("Clear", scale=1)
214
-
215
- # examples
216
  gr.Markdown("**Try these examples:**")
217
  with gr.Row():
218
- example1_button = gr.Button("IIT-JEE 2025 Mathematics")
219
- example2_button = gr.Button("IIT-JEE 2025 Physics")
220
- example3_button = gr.Button("Goldman Sachs Interview Puzzle")
221
- example4_button = gr.Button("IIT-JEE 2024 Mathematics")
222
-
223
- # ───────── conversation-management helpers ──────────────────
224
- def update_conversation_list(conversations):
225
- return [conversations[cid]["title"] for cid in conversations]
226
-
227
- def start_new_conversation(conversations):
228
- new_id = generate_conversation_id()
229
- conversations[new_id] = {"title": f"New Conversation {new_id}", "messages": []}
230
- return new_id, [], gr.update(choices=update_conversation_list(conversations),
231
- value=conversations[new_id]["title"]), conversations
232
-
233
- def load_conversation(selected_title, conversations):
234
- for cid, convo in conversations.items():
235
- if convo["title"] == selected_title:
236
- return cid, convo["messages"], convo["messages"]
237
- return current_convo_id.value, history_state.value, history_state.value
238
 
239
- # main β€œsend” wrapper: keeps conversations dict in sync
240
- def send_message(user_message, max_tokens, temperature, top_p,
241
- convo_id, history, conversations):
242
- if convo_id not in conversations:
243
- title = " ".join(user_message.strip().split()[:5])
244
- conversations[convo_id] = {"title": title, "messages": history}
245
- if conversations[convo_id]["title"].startswith("New Conversation"):
246
- conversations[convo_id]["title"] = " ".join(user_message.strip().split()[:5])
247
-
248
- # call the streamer generator and forward its yields
249
- for updated_history, new_history in generate_response(
250
- user_message, max_tokens, temperature, top_p, history):
251
- conversations[convo_id]["messages"] = new_history
252
- yield (updated_history, new_history,
253
- gr.update(choices=update_conversation_list(conversations),
254
- value=conversations[convo_id]["title"]),
255
- conversations)
256
-
257
- # ───────── UI β†’ functions wiring ────────────────────────────
258
  submit_button.click(
259
- fn=send_message,
260
- inputs=[user_input, max_tokens_slider, temperature_slider, top_p_slider,
261
- current_convo_id, history_state, conversations_state],
262
- outputs=[chatbot, history_state, conversation_selector, conversations_state],
263
- concurrency_limit=16
 
 
 
 
 
 
264
  ).then(
265
  fn=lambda: gr.update(value=""),
266
  inputs=None,
267
- outputs=user_input
268
  )
269
 
270
- clear_button.click(fn=lambda: ([], []), inputs=None,
271
- outputs=[chatbot, history_state])
272
-
273
- new_convo_button.click(fn=start_new_conversation,
274
- inputs=[conversations_state],
275
- outputs=[current_convo_id, history_state,
276
- conversation_selector, conversations_state])
277
-
278
- conversation_selector.change(fn=load_conversation,
279
- inputs=[conversation_selector, conversations_state],
280
- outputs=[current_convo_id, history_state, chatbot])
281
 
282
- # example buttons
283
- example1_button.click(lambda: gr.update(value=example_messages["IIT-JEE 2025 Mathematics"]),
284
- None, user_input)
285
- example2_button.click(lambda: gr.update(value=example_messages["IIT-JEE 2025 Physics"]),
286
- None, user_input)
287
- example3_button.click(lambda: gr.update(value=example_messages["Goldman Sachs Interview Puzzle"]),
288
- None, user_input)
289
- example4_button.click(lambda: gr.update(value=example_messages["IIT-JEE 2024 Mathematics"]),
290
- None, user_input)
 
 
 
 
 
 
291
 
292
- # ────────────────────────────────────────────────────────────────
293
- # 5. Launch
294
- # ────────────────────────────────────────────────────────────────
295
  if __name__ == "__main__":
296
- demo.queue().launch(share=True, ssr_mode=False)
 
1
+ # app.py – Gradio chatbot for FractalAIResearch/Fathom-R1-14B
2
+ # ---------------------------------------------------------------------
 
 
3
  import gradio as gr
4
  import spaces
5
+ import torch
 
 
 
6
  from threading import Thread
7
+ from transformers import (
8
+ AutoModelForCausalLM,
9
+ AutoTokenizer,
10
+ TextIteratorStreamer,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  )
12
 
13
+ # ---------------------------------------------------------------------
14
+ # 1. Model & tokenizer
15
+ # ---------------------------------------------------------------------
16
+ MODEL_NAME = "FractalAIResearch/Fathom-R1-14B"
17
+
18
+ print("⏳ Loading model … (this may take a couple of minutes)")
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ MODEL_NAME,
21
+ device_map="auto", # dispatch across any available device(s)
22
+ trust_remote_code=True, # Fathom uses custom modelling code
23
+ low_cpu_mem_usage=True,
24
+ )
25
+ tokenizer = AutoTokenizer.from_pretrained(
26
+ MODEL_NAME,
27
+ trust_remote_code=True,
28
+ )
29
 
30
+ print("βœ… Model loaded")
 
 
 
 
 
 
 
31
 
32
+ # ---------------------------------------------------------------------
33
+ # 2. Helper: build a prompt with the tokenizer’s chat_template
34
+ # ---------------------------------------------------------------------
35
+ def build_chat_prompt(history, user_message, system_message):
 
 
 
 
 
36
  """
37
+ history : list[dict(role, content)]
38
+ user_message : str
39
+ system_message : str
40
+ returns a single prompt string (not tokenised)
41
  """
42
+ msgs = []
43
+ if system_message:
44
+ msgs.append({"role": "system", "content": system_message})
45
+ msgs.extend(history)
46
+ msgs.append({"role": "user", "content": user_message})
47
+
48
+ return tokenizer.apply_chat_template(
49
+ msgs,
50
+ tokenize=False, # return pure text
51
+ add_generation_prompt=True,
52
+ )
53
+
54
+ # ---------------------------------------------------------------------
55
+ # 3. Generation endpoint
56
+ # ---------------------------------------------------------------------
57
+ @spaces.GPU(duration=60) # short GPU reservation if available
58
+ def generate_response(
59
+ user_message,
60
+ max_tokens,
61
+ temperature,
62
+ top_k,
63
+ top_p,
64
+ repetition_penalty,
65
+ history_state,
66
+ ):
67
+ # Empty input β†’ nothing to do
68
  if not user_message.strip():
69
  return history_state, history_state
70
 
71
+ # System prompt (kept from your Phi-4 version)
72
+ system_message = (
73
+ "Your role as an assistant involves thoroughly exploring questions through a "
74
+ "systematic thinking process before providing the final precise and accurate "
75
+ "solutions. Please structure your response into two main sections: "
76
+ "<think> … </think> and Solution."
77
+ )
78
+
79
+ prompt = build_chat_prompt(history_state, user_message, system_message)
80
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
81
 
82
+ # Stream tokens as they come
83
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
 
84
 
85
+ generation_kwargs = dict(
86
  input_ids=inputs["input_ids"],
87
  attention_mask=inputs["attention_mask"],
88
  max_new_tokens=int(max_tokens),
89
+ do_sample=True,
90
  temperature=float(temperature),
91
+ top_k=int(top_k),
92
  top_p=float(top_p),
93
+ repetition_penalty=float(repetition_penalty),
94
+ streamer=streamer,
 
 
95
  )
96
 
97
+ # Run generate in a background thread so the UI stays responsive
98
+ Thread(target=model.generate, kwargs=generation_kwargs).start()
99
 
100
  assistant_response = ""
101
  new_history = history_state + [
102
  {"role": "user", "content": user_message},
103
+ {"role": "assistant", "content": ""},
104
  ]
105
 
106
+ for token in streamer:
107
+ # strip any stray special tokens the model may output
108
+ cleaned = (
109
+ token.replace("<|im_start|>", "")
110
+ .replace("<|im_end|>", "")
111
+ .replace("<|im_sep|>", "")
112
+ )
113
+ assistant_response += cleaned
114
+ new_history[-1]["content"] = assistant_response.strip()
115
  yield new_history, new_history
 
 
116
 
 
117
  yield new_history, new_history
118
 
119
+ # ---------------------------------------------------------------------
120
+ # 4. Example questions (unchanged)
121
+ # ---------------------------------------------------------------------
 
122
  example_messages = {
123
+ "Math reasoning": "If a rectangular prism has a length of 6 cm, a width of 4 cm, and a height of 5 cm, what is the length of the longest line segment that can be drawn from one vertex to another?",
124
+ "Logic puzzle": "Four people (Alex, Blake, Casey, and Dana) each have a different favorite color (red, blue, green, yellow) and a different favorite fruit (apple, banana, cherry, date). Given the following clues: 1) The person who likes red doesn't like dates. 2) Alex likes yellow. 3) The person who likes blue likes cherries. 4) Blake doesn't like apples or bananas. 5) Casey doesn't like yellow or green. Who likes what color and what fruit?",
125
+ "Physics problem": "A ball is thrown upward with an initial velocity of 15 m/s from a height of 2 meters above the ground. Assuming the acceleration due to gravity is 9.8 m/sΒ², determine: 1) The maximum height the ball reaches. 2) The total time the ball is in the air before hitting the ground. 3) The velocity with which the ball hits the ground.",
 
 
 
 
 
 
 
 
 
 
126
  }
127
 
128
+ # ---------------------------------------------------------------------
129
+ # 5. Gradio UI (identical to the original, just lower default max_tokens)
130
+ # ---------------------------------------------------------------------
131
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
132
+ gr.Markdown(
 
 
 
 
 
 
133
  """
134
+ # Fathom-R1-14B Chatbot
135
+ The model excels at multi-step reasoning in mathematics, logic, and science.
136
+
137
+ It returns two sections:\n
138
+ 1. **<think>** – detailed chain-of-thought (reasoning)\n
139
+ 2. **Solution** – concise, final answer
 
140
  """
141
  )
142
 
143
+ history_state = gr.State([])
 
 
 
 
144
 
145
  with gr.Row():
146
+ # Settings panel
147
  with gr.Column(scale=1):
 
 
 
 
 
 
 
 
 
 
 
148
  gr.Markdown("### Settings")
149
+ max_tokens_slider = gr.Slider(
150
+ minimum=64, maximum=4096, step=256, value=1024, label="Max Tokens"
 
 
 
 
 
 
 
151
  )
152
+ with gr.Accordion("Advanced Settings", open=False):
153
+ temperature_slider = gr.Slider(
154
+ minimum=0.1, maximum=2.0, value=0.8, label="Temperature"
155
+ )
156
+ top_k_slider = gr.Slider(
157
+ minimum=1, maximum=100, step=1, value=50, label="Top-k"
158
+ )
159
+ top_p_slider = gr.Slider(
160
+ minimum=0.1, maximum=1.0, value=0.95, label="Top-p"
161
+ )
162
+ repetition_penalty_slider = gr.Slider(
163
+ minimum=1.0, maximum=2.0, value=1.0, label="Repetition Penalty"
164
+ )
165
+
166
+ # Chat area
167
  with gr.Column(scale=4):
168
+ chatbot = gr.Chatbot(label="Chat", type="messages")
169
  with gr.Row():
170
+ user_input = gr.Textbox(
171
+ label="Your message", placeholder="Type your message here…", scale=3
172
+ )
173
+ submit_button = gr.Button("Send", variant="primary", scale=1)
174
+ clear_button = gr.Button("Clear", scale=1)
 
 
 
175
  gr.Markdown("**Try these examples:**")
176
  with gr.Row():
177
+ example1_button = gr.Button("Math reasoning")
178
+ example2_button = gr.Button("Logic puzzle")
179
+ example3_button = gr.Button("Physics problem")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
+ # Button wiring
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  submit_button.click(
183
+ fn=generate_response,
184
+ inputs=[
185
+ user_input,
186
+ max_tokens_slider,
187
+ temperature_slider,
188
+ top_k_slider,
189
+ top_p_slider,
190
+ repetition_penalty_slider,
191
+ history_state,
192
+ ],
193
+ outputs=[chatbot, history_state],
194
  ).then(
195
  fn=lambda: gr.update(value=""),
196
  inputs=None,
197
+ outputs=user_input,
198
  )
199
 
200
+ clear_button.click(
201
+ fn=lambda: ([], []),
202
+ inputs=None,
203
+ outputs=[chatbot, history_state],
204
+ )
 
 
 
 
 
 
205
 
206
+ example1_button.click(
207
+ fn=lambda: gr.update(value=example_messages["Math reasoning"]),
208
+ inputs=None,
209
+ outputs=user_input,
210
+ )
211
+ example2_button.click(
212
+ fn=lambda: gr.update(value=example_messages["Logic puzzle"]),
213
+ inputs=None,
214
+ outputs=user_input,
215
+ )
216
+ example3_button.click(
217
+ fn=lambda: gr.update(value=example_messages["Physics problem"]),
218
+ inputs=None,
219
+ outputs=user_input,
220
+ )
221
 
222
+ # ---------------------------------------------------------------------
223
+ # 6. Launch
224
+ # ---------------------------------------------------------------------
225
  if __name__ == "__main__":
226
+ demo.launch(ssr_mode=False)