arahrooh commited on
Commit
2437458
·
1 Parent(s): bdad35b

Fix: Use text_generation API directly for better reliability

Browse files
Files changed (2) hide show
  1. app.py +26 -49
  2. requirements.txt +1 -1
app.py CHANGED
@@ -193,39 +193,27 @@ class InferenceAPIBot:
193
  def generate_answer(self, prompt: str, **kwargs) -> str:
194
  """Generate answer using Inference API"""
195
  try:
196
- # Use text generation API (more reliable than chat.completions)
197
- # The InferenceClient supports both formats, but text_generation is more stable
198
  max_tokens = kwargs.get('max_new_tokens', 512)
199
  temperature = kwargs.get('temperature', 0.2)
200
  top_p = kwargs.get('top_p', 0.9)
201
 
202
- # Try chat.completions first (newer API)
203
- try:
204
- messages = [{"role": "user", "content": prompt}]
205
- completion = self.client.chat.completions.create(
206
- model=self.current_model,
207
- messages=messages,
208
- max_tokens=max_tokens,
209
- temperature=temperature,
210
- top_p=top_p,
211
- )
212
- answer = completion.choices[0].message.content
213
- return answer
214
- except (AttributeError, TypeError) as e:
215
- # Fallback to text generation API if chat.completions not available
216
- logger.warning(f"chat.completions not available, using text_generation: {e}")
217
- response = self.client.text_generation(
218
- prompt,
219
- model=self.current_model,
220
- max_new_tokens=max_tokens,
221
- temperature=temperature,
222
- top_p=top_p,
223
- return_full_text=False,
224
- )
225
- return response
226
  except Exception as e:
227
  logger.error(f"Error calling Inference API: {e}", exc_info=True)
228
- return f"Error generating answer: {str(e)}"
 
 
229
 
230
  def enhance_readability(self, answer: str, target_level: str = "middle_school") -> Tuple[str, float]:
231
  """Enhance readability using Inference API"""
@@ -279,31 +267,20 @@ class InferenceAPIBot:
279
  {"role": "user", "content": user_message}
280
  ]
281
 
282
- # Call Inference API
283
  max_tokens = 512 if target_level in ["college", "doctoral"] else 384
284
  temperature = 0.4 if target_level in ["college", "doctoral"] else 0.3
285
 
286
- try:
287
- # Try chat.completions first
288
- completion = self.client.chat.completions.create(
289
- model=self.current_model,
290
- messages=messages,
291
- max_tokens=max_tokens,
292
- temperature=temperature,
293
- )
294
- enhanced_answer = completion.choices[0].message.content
295
- except (AttributeError, TypeError) as e:
296
- # Fallback to text generation
297
- logger.warning(f"chat.completions not available for readability, using text_generation: {e}")
298
- # Combine system and user messages for text generation
299
- combined_prompt = f"{system_message}\n\n{user_message}"
300
- enhanced_answer = self.client.text_generation(
301
- combined_prompt,
302
- model=self.current_model,
303
- max_new_tokens=max_tokens,
304
- temperature=temperature,
305
- return_full_text=False,
306
- )
307
  # Clean the answer (same as bot.py)
308
  cleaned = self.bot._clean_readability_answer(enhanced_answer, target_level)
309
 
 
193
  def generate_answer(self, prompt: str, **kwargs) -> str:
194
  """Generate answer using Inference API"""
195
  try:
 
 
196
  max_tokens = kwargs.get('max_new_tokens', 512)
197
  temperature = kwargs.get('temperature', 0.2)
198
  top_p = kwargs.get('top_p', 0.9)
199
 
200
+ # Use text_generation API directly (more reliable and widely supported)
201
+ logger.info(f"Calling Inference API for model: {self.current_model}")
202
+ response = self.client.text_generation(
203
+ prompt,
204
+ model=self.current_model,
205
+ max_new_tokens=max_tokens,
206
+ temperature=temperature,
207
+ top_p=top_p,
208
+ return_full_text=False,
209
+ )
210
+ logger.info(f"Inference API response received (length: {len(response) if response else 0})")
211
+ return response
 
 
 
 
 
 
 
 
 
 
 
 
212
  except Exception as e:
213
  logger.error(f"Error calling Inference API: {e}", exc_info=True)
214
+ import traceback
215
+ logger.error(f"Traceback: {traceback.format_exc()}")
216
+ return f"Error generating answer: {str(e)}. Please check the logs for details."
217
 
218
  def enhance_readability(self, answer: str, target_level: str = "middle_school") -> Tuple[str, float]:
219
  """Enhance readability using Inference API"""
 
267
  {"role": "user", "content": user_message}
268
  ]
269
 
270
+ # Call Inference API using text_generation (more reliable)
271
  max_tokens = 512 if target_level in ["college", "doctoral"] else 384
272
  temperature = 0.4 if target_level in ["college", "doctoral"] else 0.3
273
 
274
+ # Combine system and user messages for text generation
275
+ combined_prompt = f"{system_message}\n\n{user_message}"
276
+ logger.info(f"Enhancing readability for {target_level} level")
277
+ enhanced_answer = self.client.text_generation(
278
+ combined_prompt,
279
+ model=self.current_model,
280
+ max_new_tokens=max_tokens,
281
+ temperature=temperature,
282
+ return_full_text=False,
283
+ )
 
 
 
 
 
 
 
 
 
 
 
284
  # Clean the answer (same as bot.py)
285
  cleaned = self.bot._clean_readability_answer(enhanced_answer, target_level)
286
 
requirements.txt CHANGED
@@ -6,7 +6,7 @@
6
  # Core ML/AI Framework
7
  torch>=2.0.0 # PyTorch for model inference
8
  transformers>=4.30.0 # Hugging Face transformers
9
- huggingface_hub>=0.20.0 # Hugging Face Hub API (for Inference API)
10
  accelerate>=0.20.0 # Model loading optimization
11
  safetensors>=0.3.0 # Safe model loading
12
 
 
6
  # Core ML/AI Framework
7
  torch>=2.0.0 # PyTorch for model inference
8
  transformers>=4.30.0 # Hugging Face transformers
9
+ huggingface_hub>=0.23.0 # Hugging Face Hub API (for Inference API)
10
  accelerate>=0.20.0 # Model loading optimization
11
  safetensors>=0.3.0 # Safe model loading
12