Spaces:
Sleeping
Sleeping
| from flask import Flask, request, jsonify, Response | |
| from functools import wraps | |
| import uuid | |
| import json | |
| from typing import List, Optional | |
| from pydantic import BaseModel, ValidationError | |
| from API_provider import API_Inference | |
| from core_logic import ( | |
| check_api_key_validity, | |
| update_request_count, | |
| get_rate_limit_status, | |
| get_subscription_status, | |
| get_available_models, | |
| get_model_info, | |
| ) | |
| app = Flask(__name__) | |
| class Message(BaseModel): | |
| role: str | |
| content: str | |
| class ChatCompletionRequest(BaseModel): | |
| model: str | |
| messages: List[Message] | |
| stream: Optional[bool] = False | |
| max_tokens: Optional[int] = 4000 | |
| temperature: Optional[float] = 0.5 | |
| top_p: Optional[float] = 0.95 | |
| def get_api_key(): | |
| auth_header = request.headers.get('Authorization') | |
| if not auth_header or not auth_header.startswith('Bearer '): | |
| return None | |
| return auth_header.split(' ')[1] | |
| def requires_api_key(func): | |
| def decorated(*args, **kwargs): | |
| api_key = get_api_key() | |
| if not api_key: | |
| return jsonify({'detail': 'Not authenticated'}), 401 | |
| kwargs['api_key'] = api_key | |
| return func(*args, **kwargs) | |
| return decorated | |
| def index(): | |
| return 'Hello, World!' | |
| def chat_completions(api_key): | |
| try: | |
| # Parse and validate request data | |
| try: | |
| data = request.get_json() | |
| chat_request = ChatCompletionRequest(**data) | |
| except ValidationError as e: | |
| return jsonify({'detail': e.errors()}), 400 | |
| # Check API key validity and rate limit | |
| is_valid, error_message = check_api_key_validity(api_key) | |
| if not is_valid: | |
| return jsonify({'detail': error_message}), 401 | |
| messages = [{"role": msg.role, "content": msg.content} for msg in chat_request.messages] | |
| # Get model info | |
| model_info = get_model_info(chat_request.model) | |
| if not model_info: | |
| return jsonify({'detail': 'Invalid model specified'}), 400 | |
| # Model mapping | |
| model_mapping = { | |
| "meta-llama-405b-turbo": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", | |
| "claude-3.5-sonnet": "claude-3-sonnet-20240229", | |
| } | |
| model_name = model_mapping.get(chat_request.model, chat_request.model) | |
| credits_reduction = { | |
| "gpt-4o": 1, | |
| "claude-3-sonnet-20240229": 1, | |
| "gemini-1.5-pro": 1, | |
| "gemini-1-5-flash": 1, | |
| "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": 1, | |
| "o1-mini": 2, | |
| "o1-preview": 3, | |
| }.get(model_name, 0) | |
| if chat_request.stream: | |
| def generate(): | |
| try: | |
| for chunk in API_Inference(messages, model=model_name, stream=True, | |
| max_tokens=chat_request.max_tokens, | |
| temperature=chat_request.temperature, | |
| top_p=chat_request.top_p): | |
| data = json.dumps({'choices': [{'delta': {'content': chunk}}]}) | |
| yield f"data: {data}\n\n" | |
| yield f"data: [DONE]\n\nCredits used: {credits_reduction}\n\n" | |
| update_request_count(api_key, credits_reduction) | |
| except Exception as e: | |
| yield f"data: [ERROR] {str(e)}\n\n" | |
| return Response(generate(), mimetype='text/event-stream') | |
| else: | |
| response = API_Inference(messages, model=model_name, stream=False, | |
| max_tokens=chat_request.max_tokens, | |
| temperature=chat_request.temperature, | |
| top_p=chat_request.top_p) | |
| update_request_count(api_key, credits_reduction) | |
| prompt_tokens = sum(len(msg['content'].split()) for msg in messages) | |
| completion_tokens = len(response.split()) | |
| total_tokens = prompt_tokens + completion_tokens | |
| return jsonify({ | |
| "id": f"chatcmpl-{str(uuid.uuid4())}", | |
| "object": "chat.completion", | |
| "created": int(uuid.uuid1().time // 1e7), | |
| "model": model_name, | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": response | |
| }, | |
| "finish_reason": "stop" | |
| } | |
| ], | |
| "usage": { | |
| "prompt_tokens": prompt_tokens, | |
| "completion_tokens": completion_tokens, | |
| "total_tokens": total_tokens | |
| }, | |
| "credits_used": credits_reduction | |
| }) | |
| except Exception as e: | |
| return jsonify({'detail': str(e)}), 500 | |
| def get_rate_limit_status_endpoint(api_key): | |
| is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False) | |
| if not is_valid: | |
| return jsonify({'detail': error_message}), 401 | |
| return jsonify(get_rate_limit_status(api_key)) | |
| def get_subscription_status_endpoint(api_key): | |
| is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False) | |
| if not is_valid: | |
| return jsonify({'detail': error_message}), 401 | |
| return jsonify(get_subscription_status(api_key)) | |
| def get_available_models_endpoint(api_key): | |
| is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False) | |
| if not is_valid: | |
| return jsonify({'detail': error_message}), 401 | |
| return jsonify({"data": [{"id": model} for model in get_available_models().values()]}) | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=8000) |