Delta-Vector commited on
Commit
53a4250
·
verified ·
1 Parent(s): 317906a

Upload runner.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. runner.py +493 -0
runner.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import argparse
4
+ import json
5
+ import time
6
+ import sys
7
+ from dataclasses import dataclass, field
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Optional
10
+
11
+ import requests
12
+
13
+
14
+ # ---------------------------
15
+ # Config data classes
16
+ # ---------------------------
17
+
18
+ @dataclass
19
+ class BackendConfig:
20
+ name: str
21
+ endpoint: str
22
+ concurrency: int = field(default=20)
23
+ timeout: int = field(default=60)
24
+ max_retries: int = field(default=5)
25
+ retry_delay_seconds: float = field(default=2.0)
26
+
27
+
28
+ @dataclass
29
+ class ModelConfig:
30
+ name: Optional[str]
31
+ backend: str
32
+ endpoint_model_name: str
33
+ system_prompt: Optional[str] = None
34
+
35
+
36
+ # ---------------------------
37
+ # Utilities
38
+ # ---------------------------
39
+
40
+ def fetch_default_model(endpoint: str) -> str:
41
+ headers = {
42
+ 'Content-Type': 'application/json',
43
+ }
44
+ url = f"{endpoint}/v1/models"
45
+ resp = requests.get(url, headers=headers, timeout=30)
46
+ resp.raise_for_status()
47
+ data = resp.json()
48
+ # Prefer first model id
49
+ if isinstance(data, dict) and 'data' in data and isinstance(data['data'], list) and data['data']:
50
+ first = data['data'][0]
51
+ model_id = first.get('id') or first.get('name')
52
+ if not model_id:
53
+ raise RuntimeError('Could not determine model id from /v1/models response')
54
+ return str(model_id)
55
+ raise RuntimeError('No models returned from /v1/models')
56
+
57
+
58
+ def resolve_backend_and_model(cli_endpoint: str) -> tuple[BackendConfig, ModelConfig]:
59
+ if not cli_endpoint:
60
+ raise ValueError('--endpoint is required')
61
+
62
+ model_id = fetch_default_model(cli_endpoint)
63
+
64
+ backend = BackendConfig(
65
+ name='vllm',
66
+ endpoint=cli_endpoint,
67
+ concurrency=20,
68
+ timeout=60,
69
+ max_retries=5,
70
+ retry_delay_seconds=2.0,
71
+ )
72
+
73
+ model = ModelConfig(
74
+ name=None,
75
+ backend='vllm',
76
+ endpoint_model_name=model_id,
77
+ system_prompt=None,
78
+ )
79
+
80
+ return backend, model
81
+
82
+
83
+ def post_chat_completion(backend: BackendConfig, model: ModelConfig, payload: Dict[str, Any]) -> Dict[str, Any]:
84
+ headers = {
85
+ 'Content-Type': 'application/json',
86
+ }
87
+
88
+ last_exc: Optional[Exception] = None
89
+ for attempt in range(backend.max_retries):
90
+ try:
91
+ response = requests.post(
92
+ f"{backend.endpoint}/v1/chat/completions",
93
+ json=payload,
94
+ headers=headers,
95
+ timeout=backend.timeout,
96
+ )
97
+ response.raise_for_status()
98
+ return response.json()
99
+ except requests.exceptions.RequestException as exc:
100
+ last_exc = exc
101
+ if attempt < backend.max_retries - 1:
102
+ time.sleep(backend.retry_delay_seconds)
103
+ else:
104
+ raise
105
+ if last_exc:
106
+ raise last_exc
107
+ raise RuntimeError('Request failed without exception')
108
+
109
+
110
+ # ---------------------------
111
+ # Resource loading
112
+ # ---------------------------
113
+
114
+ def read_json(path: Path) -> Dict[str, Any]:
115
+ with path.open('r', encoding='utf-8') as f:
116
+ return json.load(f)
117
+
118
+
119
+ def get_tool_calling_resource_paths() -> Dict[str, Path]:
120
+ base = Path(__file__).parent / 'resources'
121
+ return {
122
+ 'simple-tool-calling': base / 'simple-tool-calling.json',
123
+ 'image-generation': base / 'image-generation.json',
124
+ 'internet-search': base / 'internet-search.json',
125
+ 'weather': base / 'weather.json',
126
+ 'currency-conversion': base / 'currency-conversion.json',
127
+ }
128
+
129
+
130
+ # ---------------------------
131
+ # Evaluation logic
132
+ # ---------------------------
133
+
134
+ def validate_argument_types(arguments: Dict[str, Any], properties: Dict[str, Any]) -> tuple[bool, Optional[str]]:
135
+ """
136
+ Validate that argument types match the schema.
137
+ Returns (valid, error_message)
138
+ """
139
+ for field_name, value in arguments.items():
140
+ if field_name not in properties:
141
+ continue # Extra fields are okay
142
+
143
+ expected_type = properties[field_name].get('type')
144
+ if expected_type == 'string' and not isinstance(value, str):
145
+ return False, f"Field '{field_name}' should be string, got {type(value).__name__}"
146
+ elif expected_type == 'number' and not isinstance(value, (int, float)):
147
+ return False, f"Field '{field_name}' should be number, got {type(value).__name__}"
148
+ elif expected_type == 'boolean':
149
+ # Handle both boolean and string representations
150
+ if isinstance(value, bool):
151
+ continue # Valid boolean
152
+ elif isinstance(value, str) and value.lower() in ('true', 'false'):
153
+ # Accept string booleans but flag as warning (still pass)
154
+ continue
155
+ else:
156
+ return False, f"Field '{field_name}' should be boolean, got {type(value).__name__} with value {repr(value)}"
157
+ elif expected_type == 'integer' and not isinstance(value, int):
158
+ return False, f"Field '{field_name}' should be integer, got {type(value).__name__}"
159
+
160
+ return True, None
161
+
162
+
163
+ def validate_specific_test_case(func_name: str, arguments: Dict[str, Any], messages: List[Dict[str, Any]]) -> tuple[bool, Optional[str]]:
164
+ """
165
+ Add sanity checks for specific test cases to ensure the model actually understood the request.
166
+ Returns (valid, warning_message)
167
+ """
168
+ # Extract user query from messages
169
+ user_content = ""
170
+ for msg in messages:
171
+ if msg.get('role') == 'user':
172
+ user_content += msg.get('content', '') + " "
173
+ user_content = user_content.lower()
174
+
175
+ # Weather test - should extract city name
176
+ if func_name == 'GetWeather':
177
+ city = arguments.get('city', '').lower()
178
+ if not city:
179
+ return False, "Weather: 'city' field is empty"
180
+ # Check if it extracted something reasonable from "New York City"
181
+ if 'new york' in user_content and 'new york' not in city and 'nyc' not in city:
182
+ return False, f"Weather: Expected 'New York' in city, got '{arguments.get('city')}'"
183
+
184
+ # Internet search - should extract query
185
+ elif func_name == 'InternetSearch':
186
+ query = arguments.get('query', '').lower()
187
+ if not query:
188
+ return False, "Search: 'query' field is empty"
189
+ # Should mention AI or artificial intelligence
190
+ if ('artificial intelligence' in user_content or 'ai' in user_content) and \
191
+ ('ai' not in query and 'artificial intelligence' not in query):
192
+ return False, f"Search: Expected AI-related query, got '{arguments.get('query')}'"
193
+
194
+ # Currency conversion - should extract currencies and amount
195
+ elif func_name == 'CurrencyConverter':
196
+ from_curr = arguments.get('fromCurrency', '')
197
+ to_curr = arguments.get('toCurrency', '')
198
+ amount = arguments.get('amount')
199
+
200
+ if not from_curr or not to_curr:
201
+ return False, f"Currency: Missing currencies (from={from_curr}, to={to_curr})"
202
+ if amount is None:
203
+ return False, "Currency: 'amount' field is missing"
204
+ if amount <= 0:
205
+ return False, f"Currency: Invalid amount {amount}"
206
+
207
+ # Check if it matches the expected conversion
208
+ if 'usd' in user_content and 'eur' in user_content:
209
+ if from_curr.upper() != 'USD' or to_curr.upper() != 'EUR':
210
+ return False, f"Currency: Expected USD->EUR, got {from_curr}->{to_curr}"
211
+ if '100' in user_content and amount != 100:
212
+ return False, f"Currency: Expected amount 100, got {amount}"
213
+
214
+ # Image prompt enhancer
215
+ elif func_name == 'ImagePromptEnhancer':
216
+ enhanced = arguments.get('enhancedPrompt', '')
217
+ is_nsfw = arguments.get('isNSFW')
218
+
219
+ if not enhanced:
220
+ return False, "Image: 'enhancedPrompt' field is empty"
221
+ if is_nsfw is None:
222
+ return False, "Image: 'isNSFW' field is missing"
223
+
224
+ # Handle both boolean and string representations for isNSFW
225
+ if isinstance(is_nsfw, str):
226
+ if is_nsfw.lower() not in ('true', 'false'):
227
+ return False, f"Image: 'isNSFW' should be boolean or 'true'/'false' string, got '{is_nsfw}'"
228
+ elif not isinstance(is_nsfw, bool):
229
+ return False, f"Image: 'isNSFW' should be boolean, got {type(is_nsfw).__name__}"
230
+
231
+ # Check if enhanced prompt is actually enhanced (longer than original)
232
+ if len(enhanced) < 10:
233
+ return False, f"Image: Enhanced prompt too short ({len(enhanced)} chars)"
234
+
235
+ return True, None
236
+
237
+
238
+ def evaluate_tool_call_response(message: Dict[str, Any], expected_tool: Dict[str, Any], messages: List[Dict[str, Any]]) -> tuple[bool, Optional[str]]:
239
+ """
240
+ Evaluate if the response contains valid tool calls matching the expected tool schema.
241
+ Returns (passed, debug_info)
242
+ """
243
+ tool_calls = message.get('tool_calls', [])
244
+
245
+ if not tool_calls:
246
+ return False, "No tool_calls in response"
247
+
248
+ if not isinstance(tool_calls, list):
249
+ return False, f"tool_calls should be a list, got {type(tool_calls).__name__}"
250
+
251
+ # Get the expected function name and required parameters
252
+ expected_func = expected_tool.get('function', {})
253
+ expected_name = expected_func.get('name')
254
+ expected_params = expected_func.get('parameters', {})
255
+ required_fields = expected_params.get('required', [])
256
+ properties = expected_params.get('properties', {})
257
+
258
+ # Check if any tool call matches
259
+ for i, tool_call in enumerate(tool_calls):
260
+ if not isinstance(tool_call, dict):
261
+ return False, f"tool_call[{i}] should be a dict, got {type(tool_call).__name__}"
262
+
263
+ # Validate tool_call structure
264
+ if 'function' not in tool_call:
265
+ return False, f"tool_call[{i}] missing 'function' field"
266
+
267
+ func = tool_call.get('function', {})
268
+ func_name = func.get('name')
269
+
270
+ if not func_name:
271
+ return False, f"tool_call[{i}] missing function name"
272
+
273
+ # Check if function name matches
274
+ if func_name != expected_name:
275
+ continue
276
+
277
+ # Parse and validate arguments
278
+ try:
279
+ arguments_str = func.get('arguments', '{}')
280
+ if isinstance(arguments_str, str):
281
+ if not arguments_str.strip():
282
+ arguments_str = '{}'
283
+ arguments = json.loads(arguments_str)
284
+ else:
285
+ arguments = arguments_str
286
+
287
+ if not isinstance(arguments, dict):
288
+ return False, f"Arguments should be a dict, got {type(arguments).__name__}"
289
+
290
+ # Check all required fields are present
291
+ missing = [f for f in required_fields if f not in arguments]
292
+ if missing:
293
+ return False, f"Missing required fields: {missing}"
294
+
295
+ # Validate argument types
296
+ types_valid, type_error = validate_argument_types(arguments, properties)
297
+ if not types_valid:
298
+ return False, type_error
299
+
300
+ # Sanity check for specific test cases
301
+ sanity_valid, sanity_error = validate_specific_test_case(func_name, arguments, messages)
302
+ if not sanity_valid:
303
+ return False, sanity_error
304
+
305
+ # All checks passed!
306
+ return True, None
307
+
308
+ except json.JSONDecodeError as e:
309
+ return False, f"Invalid JSON in arguments: {e}"
310
+
311
+ return False, f"No matching tool call found. Expected: {expected_name}, Got: {[tc.get('function', {}).get('name') for tc in tool_calls]}"
312
+
313
+
314
+ def infer_expected_tool(resource: Dict[str, Any]) -> Optional[Dict[str, Any]]:
315
+ """Extract the first tool from the resource tools list."""
316
+ tools = resource.get('tools', [])
317
+ if tools and len(tools) > 0:
318
+ return tools[0]
319
+ return None
320
+
321
+
322
+ def build_payload(model: ModelConfig, resource: Dict[str, Any]) -> Dict[str, Any]:
323
+ messages = resource.get('messages', [])
324
+ if model.system_prompt:
325
+ sys_present = any(m.get('role') == 'system' for m in messages)
326
+ if not sys_present:
327
+ messages = [{'role': 'system', 'content': model.system_prompt}] + messages
328
+
329
+ payload: Dict[str, Any] = {
330
+ 'model': model.endpoint_model_name,
331
+ 'messages': messages,
332
+ 'stream': False,
333
+ 'max_tokens': 1024,
334
+ 'temperature': 0,
335
+ 'top_k': 1
336
+ }
337
+
338
+ # Use vllm native tool calling instead of response_format
339
+ if 'tools' in resource:
340
+ payload['tools'] = resource['tools']
341
+ # Force the model to call a tool
342
+ payload['tool_choice'] = 'auto'
343
+
344
+ return payload
345
+
346
+
347
+ def run_single_task(task_name: str, resource_path: Path, backend: BackendConfig, model: ModelConfig) -> Dict[str, Any]:
348
+ resource = read_json(resource_path)
349
+ payload = build_payload(model, resource)
350
+ response_data: Optional[str] = None
351
+ debug_info: Optional[str] = None
352
+
353
+ try:
354
+ result_json = post_chat_completion(backend, model, payload)
355
+ message = result_json.get('choices', [{}])[0].get('message', {})
356
+
357
+ # Get the expected tool from resource
358
+ expected_tool = infer_expected_tool(resource)
359
+ messages = resource.get('messages', [])
360
+
361
+ if expected_tool:
362
+ # Evaluate tool calling response with sanity checks
363
+ passed, debug_info = evaluate_tool_call_response(message, expected_tool, messages)
364
+ # Store the full message as JSON for debugging
365
+ response_data = json.dumps(message, indent=2)
366
+ else:
367
+ # No tools defined, can't evaluate
368
+ passed = False
369
+ debug_info = "No tools defined in resource"
370
+ response_data = json.dumps(message, indent=2)
371
+
372
+ except requests.exceptions.RequestException as exc:
373
+ err_text = None
374
+ try:
375
+ if hasattr(exc, 'response') and exc.response is not None:
376
+ err_text = exc.response.text
377
+ except Exception:
378
+ err_text = None
379
+ response_data = err_text or str(exc)
380
+ debug_info = "Request failed"
381
+ passed = False
382
+
383
+ # Match eval-suite per-task results structure
384
+ results = [{
385
+ 'response': response_data,
386
+ 'passed': passed,
387
+ 'debug_info': debug_info,
388
+ }]
389
+
390
+ return {
391
+ 'results': results,
392
+ 'passed': int(passed),
393
+ 'failed': int(not passed),
394
+ 'total': 1,
395
+ }
396
+
397
+
398
+ def format_suite_result(model: ModelConfig, task_name: str, run_result: Dict[str, Any]) -> Dict[str, Any]:
399
+ return {
400
+ 'model': {
401
+ 'name': model.name or None,
402
+ 'endpoint_model_name': model.endpoint_model_name,
403
+ },
404
+ 'benchmark': {
405
+ 'pathname': f"tool_calling/{task_name}",
406
+ 'safename': task_name,
407
+ 'version': 1,
408
+ 'task_results': run_result['results'],
409
+ },
410
+ 'passed': run_result['passed'],
411
+ 'failed': run_result['failed'],
412
+ 'total': run_result['total'],
413
+ }
414
+
415
+
416
+ def main() -> int:
417
+ parser = argparse.ArgumentParser(description='Run tool-calling tests against a vllm chat-completions endpoint using native tool calling.')
418
+ parser.add_argument('--endpoint', default='http://localhost:8002', help='Chat completions base URL (default: http://localhost:8002)')
419
+ parser.add_argument('--debug', action='store_true', help='Show debug details for failures')
420
+ parser.add_argument('--verbose', action='store_true', help='Show details for all tests (passed and failed)')
421
+ args = parser.parse_args()
422
+
423
+ backend, model = resolve_backend_and_model(args.endpoint)
424
+
425
+ resource_paths = get_tool_calling_resource_paths()
426
+ selected = list(resource_paths.items())
427
+
428
+ GREEN = '\x1b[32m'
429
+ RED = '\x1b[31m'
430
+ RESET = '\x1b[0m'
431
+
432
+ total_passed = 0
433
+ total_failed = 0
434
+
435
+ for name, path in selected:
436
+ run_result = run_single_task(name, path, backend, model)
437
+ suite_line = format_suite_result(model=model, task_name=name, run_result=run_result)
438
+
439
+ passed = suite_line['passed']
440
+ failed = suite_line['failed']
441
+ total_passed += passed
442
+ total_failed += failed
443
+
444
+ status = f"{GREEN}PASS{RESET}" if passed == 1 else f"{RED}FAIL{RESET}"
445
+ sys.stdout.write(f"[{status}] {name}\n")
446
+
447
+ # Show details based on flags
448
+ show_details = (failed and args.debug) or args.verbose
449
+ if show_details:
450
+ try:
451
+ task_result = suite_line['benchmark']['task_results'][0]
452
+ debug_info = task_result.get('debug_info')
453
+ response_str = task_result.get('response')
454
+
455
+ # Try to parse response and extract tool calls
456
+ if response_str:
457
+ try:
458
+ response = json.loads(response_str)
459
+ tool_calls = response.get('tool_calls', [])
460
+
461
+ if tool_calls:
462
+ for tc in tool_calls:
463
+ func = tc.get('function', {})
464
+ func_name = func.get('name', 'unknown')
465
+ args_str = func.get('arguments', '{}')
466
+ if isinstance(args_str, str):
467
+ tool_args = json.loads(args_str)
468
+ else:
469
+ tool_args = args_str
470
+ sys.stdout.write(f" Tool: {func_name}\n")
471
+ sys.stdout.write(f" Args: {json.dumps(tool_args, indent=4)}\n")
472
+ elif passed == 0:
473
+ sys.stdout.write(f" No tool calls found\n")
474
+ except json.JSONDecodeError:
475
+ pass
476
+
477
+ # Show debug info (especially for failures)
478
+ if debug_info:
479
+ sys.stdout.write(f" {'?' if failed else '?'} {debug_info}\n")
480
+ elif passed == 1:
481
+ sys.stdout.write(f" ? All sanity checks passed\n")
482
+
483
+ except Exception as e:
484
+ sys.stdout.write(f" Error printing details: {e}\n")
485
+
486
+ summary_color = GREEN if total_failed == 0 else RED
487
+ sys.stdout.write(f"\nSummary: {summary_color}{total_passed} passed, {total_failed} failed{RESET}\n")
488
+ sys.stdout.flush()
489
+ return 0
490
+
491
+
492
+ if __name__ == '__main__':
493
+ raise SystemExit(main())