Trouter-Library AlexGall commited on
Commit
144a17b
·
verified ·
1 Parent(s): eb06522

Create evaluate.py (#2)

Browse files

- Create evaluate.py (2d6c171cd36860900f9d1e8cc9f2f70bd5c422be)


Co-authored-by: Alex Gall <[email protected]>

Files changed (1) hide show
  1. evaluate.py +519 -0
evaluate.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helion-OSC Evaluation Script
3
+ Comprehensive evaluation suite for code generation and mathematical reasoning
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import torch
9
+ import logging
10
+ import numpy as np
11
+ from typing import List, Dict, Any, Optional, Tuple
12
+ from dataclasses import dataclass, field
13
+ from tqdm import tqdm
14
+ import subprocess
15
+ import tempfile
16
+ import signal
17
+ from contextlib import contextmanager
18
+ import multiprocessing as mp
19
+ from transformers import AutoTokenizer, AutoModelForCausalLM
20
+ from datasets import load_dataset
21
+ import re
22
+
23
+ logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ @dataclass
28
+ class EvaluationConfig:
29
+ """Configuration for evaluation"""
30
+ model_name: str = "DeepXR/Helion-OSC"
31
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
32
+ batch_size: int = 4
33
+ max_length: int = 2048
34
+ temperature: float = 0.7
35
+ top_p: float = 0.95
36
+ num_samples: int = 1
37
+ timeout: int = 5 # seconds for code execution
38
+ output_dir: str = "./evaluation_results"
39
+
40
+
41
+ class TimeoutException(Exception):
42
+ """Exception raised when code execution times out"""
43
+ pass
44
+
45
+
46
+ @contextmanager
47
+ def time_limit(seconds):
48
+ """Context manager for timing out code execution"""
49
+ def signal_handler(signum, frame):
50
+ raise TimeoutException("Code execution timed out")
51
+
52
+ signal.signal(signal.SIGALRM, signal_handler)
53
+ signal.alarm(seconds)
54
+ try:
55
+ yield
56
+ finally:
57
+ signal.alarm(0)
58
+
59
+
60
+ class CodeExecutor:
61
+ """Safe code execution environment"""
62
+
63
+ @staticmethod
64
+ def execute_python(code: str, timeout: int = 5) -> Tuple[bool, str]:
65
+ """
66
+ Execute Python code safely
67
+
68
+ Args:
69
+ code: Python code to execute
70
+ timeout: Timeout in seconds
71
+
72
+ Returns:
73
+ Tuple of (success, output/error)
74
+ """
75
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
76
+ f.write(code)
77
+ temp_file = f.name
78
+
79
+ try:
80
+ result = subprocess.run(
81
+ ['python', temp_file],
82
+ capture_output=True,
83
+ text=True,
84
+ timeout=timeout
85
+ )
86
+
87
+ os.unlink(temp_file)
88
+
89
+ if result.returncode == 0:
90
+ return True, result.stdout
91
+ else:
92
+ return False, result.stderr
93
+
94
+ except subprocess.TimeoutExpired:
95
+ os.unlink(temp_file)
96
+ return False, "Execution timed out"
97
+ except Exception as e:
98
+ if os.path.exists(temp_file):
99
+ os.unlink(temp_file)
100
+ return False, str(e)
101
+
102
+ @staticmethod
103
+ def check_syntax(code: str, language: str = "python") -> Tuple[bool, str]:
104
+ """
105
+ Check code syntax without execution
106
+
107
+ Args:
108
+ code: Code to check
109
+ language: Programming language
110
+
111
+ Returns:
112
+ Tuple of (is_valid, error_message)
113
+ """
114
+ if language.lower() == "python":
115
+ try:
116
+ compile(code, '<string>', 'exec')
117
+ return True, ""
118
+ except SyntaxError as e:
119
+ return False, str(e)
120
+
121
+ return True, "Syntax checking not implemented for this language"
122
+
123
+
124
+ class HumanEvalEvaluator:
125
+ """Evaluator for HumanEval benchmark"""
126
+
127
+ def __init__(self, config: EvaluationConfig):
128
+ self.config = config
129
+ self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
130
+ self.model = AutoModelForCausalLM.from_pretrained(
131
+ config.model_name,
132
+ torch_dtype=torch.bfloat16 if config.device == "cuda" else torch.float32,
133
+ device_map="auto" if config.device == "cuda" else None
134
+ )
135
+ if config.device == "cpu":
136
+ self.model = self.model.to(config.device)
137
+ self.model.eval()
138
+ self.executor = CodeExecutor()
139
+
140
+ def load_humaneval(self) -> List[Dict]:
141
+ """Load HumanEval dataset"""
142
+ logger.info("Loading HumanEval dataset...")
143
+ dataset = load_dataset("openai_humaneval", split="test")
144
+ return list(dataset)
145
+
146
+ def generate_solution(self, prompt: str) -> str:
147
+ """Generate code solution for a prompt"""
148
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.config.device)
149
+
150
+ with torch.no_grad():
151
+ outputs = self.model.generate(
152
+ **inputs,
153
+ max_length=self.config.max_length,
154
+ temperature=self.config.temperature,
155
+ top_p=self.config.top_p,
156
+ do_sample=True,
157
+ pad_token_id=self.tokenizer.eos_token_id
158
+ )
159
+
160
+ generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
161
+ # Extract only the new generation
162
+ solution = generated[len(prompt):].strip()
163
+ return solution
164
+
165
+ def test_solution(self, solution: str, test_code: str) -> bool:
166
+ """Test a solution against test cases"""
167
+ full_code = solution + "\n" + test_code
168
+ success, output = self.executor.execute_python(full_code, self.config.timeout)
169
+ return success
170
+
171
+ def evaluate(self) -> Dict[str, float]:
172
+ """Run HumanEval evaluation"""
173
+ logger.info("Starting HumanEval evaluation...")
174
+
175
+ problems = self.load_humaneval()
176
+ results = {
177
+ "total": len(problems),
178
+ "passed": 0,
179
+ "failed": 0,
180
+ "syntax_errors": 0,
181
+ "runtime_errors": 0,
182
+ "timeouts": 0
183
+ }
184
+
185
+ for problem in tqdm(problems, desc="Evaluating HumanEval"):
186
+ prompt = problem["prompt"]
187
+ test = problem["test"]
188
+ entry_point = problem["entry_point"]
189
+
190
+ # Generate solution
191
+ solution = self.generate_solution(prompt)
192
+
193
+ # Check syntax
194
+ is_valid, error = self.executor.check_syntax(solution)
195
+ if not is_valid:
196
+ results["syntax_errors"] += 1
197
+ results["failed"] += 1
198
+ continue
199
+
200
+ # Test solution
201
+ try:
202
+ if self.test_solution(solution, test):
203
+ results["passed"] += 1
204
+ else:
205
+ results["failed"] += 1
206
+ results["runtime_errors"] += 1
207
+ except TimeoutException:
208
+ results["failed"] += 1
209
+ results["timeouts"] += 1
210
+
211
+ # Calculate pass@1
212
+ results["pass@1"] = results["passed"] / results["total"]
213
+
214
+ logger.info(f"HumanEval Results: {results}")
215
+ return results
216
+
217
+
218
+ class MBPPEvaluator:
219
+ """Evaluator for MBPP (Mostly Basic Python Problems) benchmark"""
220
+
221
+ def __init__(self, config: EvaluationConfig):
222
+ self.config = config
223
+ self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
224
+ self.model = AutoModelForCausalLM.from_pretrained(
225
+ config.model_name,
226
+ torch_dtype=torch.bfloat16 if config.device == "cuda" else torch.float32,
227
+ device_map="auto" if config.device == "cuda" else None
228
+ )
229
+ if config.device == "cpu":
230
+ self.model = self.model.to(config.device)
231
+ self.model.eval()
232
+ self.executor = CodeExecutor()
233
+
234
+ def load_mbpp(self) -> List[Dict]:
235
+ """Load MBPP dataset"""
236
+ logger.info("Loading MBPP dataset...")
237
+ dataset = load_dataset("mbpp", split="test")
238
+ return list(dataset)
239
+
240
+ def generate_solution(self, prompt: str) -> str:
241
+ """Generate code solution"""
242
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.config.device)
243
+
244
+ with torch.no_grad():
245
+ outputs = self.model.generate(
246
+ **inputs,
247
+ max_length=self.config.max_length,
248
+ temperature=self.config.temperature,
249
+ top_p=self.config.top_p,
250
+ do_sample=True,
251
+ pad_token_id=self.tokenizer.eos_token_id
252
+ )
253
+
254
+ generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
255
+ solution = generated[len(prompt):].strip()
256
+ return solution
257
+
258
+ def evaluate(self) -> Dict[str, float]:
259
+ """Run MBPP evaluation"""
260
+ logger.info("Starting MBPP evaluation...")
261
+
262
+ problems = self.load_mbpp()
263
+ results = {
264
+ "total": len(problems),
265
+ "passed": 0,
266
+ "failed": 0
267
+ }
268
+
269
+ for problem in tqdm(problems, desc="Evaluating MBPP"):
270
+ prompt = problem["text"]
271
+ test_cases = problem["test_list"]
272
+
273
+ # Generate solution
274
+ solution = self.generate_solution(prompt)
275
+
276
+ # Test against all test cases
277
+ all_passed = True
278
+ for test in test_cases:
279
+ test_code = solution + "\n" + test
280
+ success, _ = self.executor.execute_python(test_code, self.config.timeout)
281
+ if not success:
282
+ all_passed = False
283
+ break
284
+
285
+ if all_passed:
286
+ results["passed"] += 1
287
+ else:
288
+ results["failed"] += 1
289
+
290
+ results["pass@1"] = results["passed"] / results["total"]
291
+
292
+ logger.info(f"MBPP Results: {results}")
293
+ return results
294
+
295
+
296
+ class GSM8KEvaluator:
297
+ """Evaluator for GSM8K mathematical reasoning benchmark"""
298
+
299
+ def __init__(self, config: EvaluationConfig):
300
+ self.config = config
301
+ self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
302
+ self.model = AutoModelForCausalLM.from_pretrained(
303
+ config.model_name,
304
+ torch_dtype=torch.bfloat16 if config.device == "cuda" else torch.float32,
305
+ device_map="auto" if config.device == "cuda" else None
306
+ )
307
+ if config.device == "cpu":
308
+ self.model = self.model.to(config.device)
309
+ self.model.eval()
310
+
311
+ def load_gsm8k(self) -> List[Dict]:
312
+ """Load GSM8K dataset"""
313
+ logger.info("Loading GSM8K dataset...")
314
+ dataset = load_dataset("gsm8k", "main", split="test")
315
+ return list(dataset)
316
+
317
+ def extract_answer(self, text: str) -> Optional[float]:
318
+ """Extract numerical answer from text"""
319
+ # Look for patterns like "#### 42" or "The answer is 42"
320
+ patterns = [
321
+ r'####\s*(-?\d+\.?\d*)',
322
+ r'answer is\s*(-?\d+\.?\d*)',
323
+ r'equals?\s*(-?\d+\.?\d*)',
324
+ r'=\s*(-?\d+\.?\d*)',
325
+ r'\$?\s*(-?\d+\.?\d*)\s*$'
326
+ ]
327
+
328
+ for pattern in patterns:
329
+ match = re.search(pattern, text, re.IGNORECASE)
330
+ if match:
331
+ try:
332
+ return float(match.group(1))
333
+ except:
334
+ continue
335
+
336
+ return None
337
+
338
+ def generate_solution(self, problem: str) -> str:
339
+ """Generate solution for math problem"""
340
+ prompt = f"Problem: {problem}\n\nLet's solve this step by step:\n"
341
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.config.device)
342
+
343
+ with torch.no_grad():
344
+ outputs = self.model.generate(
345
+ **inputs,
346
+ max_length=self.config.max_length,
347
+ temperature=0.3,
348
+ top_p=0.9,
349
+ do_sample=False,
350
+ pad_token_id=self.tokenizer.eos_token_id
351
+ )
352
+
353
+ generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
354
+ return generated
355
+
356
+ def evaluate(self) -> Dict[str, float]:
357
+ """Run GSM8K evaluation"""
358
+ logger.info("Starting GSM8K evaluation...")
359
+
360
+ problems = self.load_gsm8k()
361
+ results = {
362
+ "total": len(problems),
363
+ "correct": 0,
364
+ "incorrect": 0,
365
+ "no_answer": 0
366
+ }
367
+
368
+ for problem in tqdm(problems, desc="Evaluating GSM8K"):
369
+ question = problem["question"]
370
+ correct_answer_text = problem["answer"]
371
+
372
+ # Extract correct answer
373
+ correct_answer = self.extract_answer(correct_answer_text)
374
+ if correct_answer is None:
375
+ continue
376
+
377
+ # Generate solution
378
+ solution = self.generate_solution(question)
379
+
380
+ # Extract predicted answer
381
+ predicted_answer = self.extract_answer(solution)
382
+
383
+ if predicted_answer is None:
384
+ results["no_answer"] += 1
385
+ results["incorrect"] += 1
386
+ elif abs(predicted_answer - correct_answer) < 1e-5:
387
+ results["correct"] += 1
388
+ else:
389
+ results["incorrect"] += 1
390
+
391
+ results["accuracy"] = results["correct"] / results["total"]
392
+
393
+ logger.info(f"GSM8K Results: {results}")
394
+ return results
395
+
396
+
397
+ class ComprehensiveEvaluator:
398
+ """Run comprehensive evaluation across all benchmarks"""
399
+
400
+ def __init__(self, config: EvaluationConfig):
401
+ self.config = config
402
+ os.makedirs(config.output_dir, exist_ok=True)
403
+
404
+ def run_all_evaluations(self) -> Dict[str, Any]:
405
+ """Run all evaluation benchmarks"""
406
+ logger.info("Starting comprehensive evaluation...")
407
+
408
+ all_results = {}
409
+
410
+ # HumanEval
411
+ try:
412
+ logger.info("\n" + "="*80)
413
+ logger.info("Running HumanEval Evaluation")
414
+ logger.info("="*80)
415
+ humaneval_evaluator = HumanEvalEvaluator(self.config)
416
+ all_results["humaneval"] = humaneval_evaluator.evaluate()
417
+ except Exception as e:
418
+ logger.error(f"HumanEval evaluation failed: {e}")
419
+ all_results["humaneval"] = {"error": str(e)}
420
+
421
+ # MBPP
422
+ try:
423
+ logger.info("\n" + "="*80)
424
+ logger.info("Running MBPP Evaluation")
425
+ logger.info("="*80)
426
+ mbpp_evaluator = MBPPEvaluator(self.config)
427
+ all_results["mbpp"] = mbpp_evaluator.evaluate()
428
+ except Exception as e:
429
+ logger.error(f"MBPP evaluation failed: {e}")
430
+ all_results["mbpp"] = {"error": str(e)}
431
+
432
+ # GSM8K
433
+ try:
434
+ logger.info("\n" + "="*80)
435
+ logger.info("Running GSM8K Evaluation")
436
+ logger.info("="*80)
437
+ gsm8k_evaluator = GSM8KEvaluator(self.config)
438
+ all_results["gsm8k"] = gsm8k_evaluator.evaluate()
439
+ except Exception as e:
440
+ logger.error(f"GSM8K evaluation failed: {e}")
441
+ all_results["gsm8k"] = {"error": str(e)}
442
+
443
+ # Save results
444
+ self.save_results(all_results)
445
+
446
+ # Print summary
447
+ self.print_summary(all_results)
448
+
449
+ return all_results
450
+
451
+ def save_results(self, results: Dict[str, Any]):
452
+ """Save evaluation results to file"""
453
+ output_file = os.path.join(self.config.output_dir, "evaluation_results.json")
454
+ with open(output_file, 'w') as f:
455
+ json.dump(results, f, indent=2)
456
+ logger.info(f"Results saved to {output_file}")
457
+
458
+ def print_summary(self, results: Dict[str, Any]):
459
+ """Print evaluation summary"""
460
+ logger.info("\n" + "="*80)
461
+ logger.info("EVALUATION SUMMARY")
462
+ logger.info("="*80)
463
+
464
+ if "humaneval" in results and "pass@1" in results["humaneval"]:
465
+ logger.info(f"HumanEval Pass@1: {results['humaneval']['pass@1']:.3f}")
466
+
467
+ if "mbpp" in results and "pass@1" in results["mbpp"]:
468
+ logger.info(f"MBPP Pass@1: {results['mbpp']['pass@1']:.3f}")
469
+
470
+ if "gsm8k" in results and "accuracy" in results["gsm8k"]:
471
+ logger.info(f"GSM8K Accuracy: {results['gsm8k']['accuracy']:.3f}")
472
+
473
+ logger.info("="*80)
474
+
475
+
476
+ def main():
477
+ """Main evaluation script"""
478
+ import argparse
479
+
480
+ parser = argparse.ArgumentParser(description="Evaluate Helion-OSC model")
481
+ parser.add_argument("--model_name", type=str, default="DeepXR/Helion-OSC")
482
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
483
+ parser.add_argument("--batch_size", type=int, default=4)
484
+ parser.add_argument("--max_length", type=int, default=2048)
485
+ parser.add_argument("--temperature", type=float, default=0.7)
486
+ parser.add_argument("--top_p", type=float, default=0.95)
487
+ parser.add_argument("--timeout", type=int, default=5)
488
+ parser.add_argument("--output_dir", type=str, default="./evaluation_results")
489
+ parser.add_argument("--benchmark", type=str, choices=["all", "humaneval", "mbpp", "gsm8k"], default="all")
490
+
491
+ args = parser.parse_args()
492
+
493
+ config = EvaluationConfig(
494
+ model_name=args.model_name,
495
+ device=args.device,
496
+ batch_size=args.batch_size,
497
+ max_length=args.max_length,
498
+ temperature=args.temperature,
499
+ top_p=args.top_p,
500
+ timeout=args.timeout,
501
+ output_dir=args.output_dir
502
+ )
503
+
504
+ if args.benchmark == "all":
505
+ evaluator = ComprehensiveEvaluator(config)
506
+ evaluator.run_all_evaluations()
507
+ elif args.benchmark == "humaneval":
508
+ evaluator = HumanEvalEvaluator(config)
509
+ evaluator.evaluate()
510
+ elif args.benchmark == "mbpp":
511
+ evaluator = MBPPEvaluator(config)
512
+ evaluator.evaluate()
513
+ elif args.benchmark == "gsm8k":
514
+ evaluator = GSM8KEvaluator(config)
515
+ evaluator.evaluate()
516
+
517
+
518
+ if __name__ == "__main__":
519
+ main()