|
|
""" |
|
|
Helion-OSC API Server |
|
|
FastAPI-based REST API for serving Helion-OSC model |
|
|
""" |
|
|
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import StreamingResponse |
|
|
from pydantic import BaseModel, Field |
|
|
from typing import Optional, List, Dict, Any, AsyncGenerator |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
|
|
from threading import Thread |
|
|
import uvicorn |
|
|
import logging |
|
|
import time |
|
|
import json |
|
|
from queue import Queue |
|
|
import asyncio |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Helion-OSC API", |
|
|
description="REST API for Helion-OSC Code Generation Model", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
device = None |
|
|
|
|
|
|
|
|
class GenerationRequest(BaseModel): |
|
|
"""Request model for text generation""" |
|
|
prompt: str = Field(..., description="Input prompt for generation") |
|
|
max_length: int = Field(2048, ge=1, le=16384, description="Maximum length of generation") |
|
|
temperature: float = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature") |
|
|
top_p: float = Field(0.95, ge=0.0, le=1.0, description="Nucleus sampling parameter") |
|
|
top_k: int = Field(50, ge=0, le=200, description="Top-k sampling parameter") |
|
|
repetition_penalty: float = Field(1.05, ge=1.0, le=2.0, description="Repetition penalty") |
|
|
do_sample: bool = Field(True, description="Whether to use sampling") |
|
|
num_return_sequences: int = Field(1, ge=1, le=10, description="Number of sequences to generate") |
|
|
stop_sequences: Optional[List[str]] = Field(None, description="Stop generation at these sequences") |
|
|
stream: bool = Field(False, description="Stream the response") |
|
|
task_type: Optional[str] = Field("code_generation", description="Task type for optimized parameters") |
|
|
|
|
|
|
|
|
class GenerationResponse(BaseModel): |
|
|
"""Response model for text generation""" |
|
|
generated_text: str |
|
|
prompt: str |
|
|
model: str |
|
|
generation_time: float |
|
|
tokens_generated: int |
|
|
|
|
|
|
|
|
class ModelInfo(BaseModel): |
|
|
"""Model information""" |
|
|
model_name: str |
|
|
model_type: str |
|
|
vocabulary_size: int |
|
|
hidden_size: int |
|
|
num_layers: int |
|
|
device: str |
|
|
dtype: str |
|
|
max_position_embeddings: int |
|
|
|
|
|
|
|
|
class HealthResponse(BaseModel): |
|
|
"""Health check response""" |
|
|
status: str |
|
|
model_loaded: bool |
|
|
device: str |
|
|
timestamp: float |
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def load_model(): |
|
|
"""Load model on startup""" |
|
|
global model, tokenizer, device |
|
|
|
|
|
logger.info("Loading Helion-OSC model...") |
|
|
|
|
|
model_name = "DeepXR/Helion-OSC" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
try: |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_name, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32, |
|
|
device_map="auto" if device == "cuda" else None, |
|
|
trust_remote_code=True, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
|
|
|
if device == "cpu": |
|
|
model = model.to(device) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
logger.info(f"Model loaded successfully on {device}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load model: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
@app.get("/", response_model=Dict[str, str]) |
|
|
async def root(): |
|
|
"""Root endpoint""" |
|
|
return { |
|
|
"message": "Helion-OSC API Server", |
|
|
"version": "1.0.0", |
|
|
"documentation": "/docs" |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/health", response_model=HealthResponse) |
|
|
async def health_check(): |
|
|
"""Health check endpoint""" |
|
|
return HealthResponse( |
|
|
status="healthy" if model is not None else "unhealthy", |
|
|
model_loaded=model is not None, |
|
|
device=device, |
|
|
timestamp=time.time() |
|
|
) |
|
|
|
|
|
|
|
|
@app.get("/info", response_model=ModelInfo) |
|
|
async def model_info(): |
|
|
"""Get model information""" |
|
|
if model is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
config = model.config |
|
|
|
|
|
return ModelInfo( |
|
|
model_name="DeepXR/Helion-OSC", |
|
|
model_type=config.model_type, |
|
|
vocabulary_size=config.vocab_size, |
|
|
hidden_size=config.hidden_size, |
|
|
num_layers=config.num_hidden_layers, |
|
|
device=device, |
|
|
dtype=str(next(model.parameters()).dtype), |
|
|
max_position_embeddings=config.max_position_embeddings |
|
|
) |
|
|
|
|
|
|
|
|
@app.post("/generate", response_model=GenerationResponse) |
|
|
async def generate(request: GenerationRequest): |
|
|
"""Generate text based on prompt""" |
|
|
if model is None or tokenizer is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
if request.stream: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail="Use /generate/stream endpoint for streaming responses" |
|
|
) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
try: |
|
|
|
|
|
inputs = tokenizer(request.prompt, return_tensors="pt").to(device) |
|
|
input_length = inputs.input_ids.shape[1] |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_length=request.max_length, |
|
|
temperature=request.temperature, |
|
|
top_p=request.top_p, |
|
|
top_k=request.top_k, |
|
|
repetition_penalty=request.repetition_penalty, |
|
|
do_sample=request.do_sample, |
|
|
num_return_sequences=request.num_return_sequences, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
generated_text = generated_text[len(request.prompt):].strip() |
|
|
|
|
|
generation_time = time.time() - start_time |
|
|
tokens_generated = outputs.shape[1] - input_length |
|
|
|
|
|
return GenerationResponse( |
|
|
generated_text=generated_text, |
|
|
prompt=request.prompt, |
|
|
model="DeepXR/Helion-OSC", |
|
|
generation_time=generation_time, |
|
|
tokens_generated=tokens_generated |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Generation error: {e}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
@app.post("/generate/stream") |
|
|
async def generate_stream(request: GenerationRequest): |
|
|
"""Generate text with streaming response""" |
|
|
if model is None or tokenizer is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
async def stream_generator() -> AsyncGenerator[str, None]: |
|
|
try: |
|
|
|
|
|
inputs = tokenizer(request.prompt, return_tensors="pt").to(device) |
|
|
|
|
|
|
|
|
streamer = TextIteratorStreamer( |
|
|
tokenizer, |
|
|
skip_prompt=True, |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
|
|
|
generation_kwargs = { |
|
|
**inputs, |
|
|
"max_length": request.max_length, |
|
|
"temperature": request.temperature, |
|
|
"top_p": request.top_p, |
|
|
"top_k": request.top_k, |
|
|
"repetition_penalty": request.repetition_penalty, |
|
|
"do_sample": request.do_sample, |
|
|
"pad_token_id": tokenizer.pad_token_id, |
|
|
"eos_token_id": tokenizer.eos_token_id, |
|
|
"streamer": streamer |
|
|
} |
|
|
|
|
|
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
|
thread.start() |
|
|
|
|
|
|
|
|
for text in streamer: |
|
|
yield f"data: {json.dumps({'text': text})}\n\n" |
|
|
await asyncio.sleep(0) |
|
|
|
|
|
yield f"data: {json.dumps({'done': True})}\n\n" |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Streaming error: {e}") |
|
|
yield f"data: {json.dumps({'error': str(e)})}\n\n" |
|
|
|
|
|
return StreamingResponse( |
|
|
stream_generator(), |
|
|
media_type="text/event-stream" |
|
|
) |
|
|
|
|
|
|
|
|
@app.post("/code/complete") |
|
|
async def code_complete( |
|
|
code: str, |
|
|
language: Optional[str] = "python", |
|
|
max_length: int = 1024 |
|
|
): |
|
|
"""Code completion endpoint""" |
|
|
if model is None or tokenizer is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
request = GenerationRequest( |
|
|
prompt=code, |
|
|
max_length=max_length, |
|
|
temperature=0.6, |
|
|
top_p=0.92, |
|
|
do_sample=True, |
|
|
task_type="code_completion" |
|
|
) |
|
|
|
|
|
return await generate(request) |
|
|
|
|
|
|
|
|
@app.post("/code/explain") |
|
|
async def code_explain(code: str, language: Optional[str] = "python"): |
|
|
"""Code explanation endpoint""" |
|
|
if model is None or tokenizer is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
prompt = f"Explain the following {language} code in detail:\n\n```{language}\n{code}\n```\n\nExplanation:" |
|
|
|
|
|
request = GenerationRequest( |
|
|
prompt=prompt, |
|
|
max_length=2048, |
|
|
temperature=0.6, |
|
|
top_p=0.9, |
|
|
do_sample=True, |
|
|
task_type="code_explanation" |
|
|
) |
|
|
|
|
|
return await generate(request) |
|
|
|
|
|
|
|
|
@app.post("/code/debug") |
|
|
async def code_debug( |
|
|
code: str, |
|
|
error_message: Optional[str] = None, |
|
|
language: Optional[str] = "python" |
|
|
): |
|
|
"""Code debugging endpoint""" |
|
|
if model is None or tokenizer is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
prompt = f"Debug the following {language} code:\n\n```{language}\n{code}\n```" |
|
|
if error_message: |
|
|
prompt += f"\n\nError message: {error_message}" |
|
|
prompt += "\n\nProvide a detailed analysis and fixed code:" |
|
|
|
|
|
request = GenerationRequest( |
|
|
prompt=prompt, |
|
|
max_length=2048, |
|
|
temperature=0.4, |
|
|
top_p=0.88, |
|
|
do_sample=False, |
|
|
task_type="debugging" |
|
|
) |
|
|
|
|
|
return await generate(request) |
|
|
|
|
|
|
|
|
@app.post("/math/solve") |
|
|
async def math_solve(problem: str): |
|
|
"""Mathematical problem solving endpoint""" |
|
|
if model is None or tokenizer is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
prompt = f"Solve the following mathematical problem step by step:\n\n{problem}\n\nSolution:" |
|
|
|
|
|
request = GenerationRequest( |
|
|
prompt=prompt, |
|
|
max_length=2048, |
|
|
temperature=0.3, |
|
|
top_p=0.9, |
|
|
do_sample=False, |
|
|
task_type="mathematical_reasoning" |
|
|
) |
|
|
|
|
|
return await generate(request) |
|
|
|
|
|
|
|
|
@app.post("/algorithm/design") |
|
|
async def algorithm_design( |
|
|
problem: str, |
|
|
include_complexity: bool = True |
|
|
): |
|
|
"""Algorithm design endpoint""" |
|
|
if model is None or tokenizer is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
prompt = f"Design an efficient algorithm for the following problem:\n\n{problem}" |
|
|
if include_complexity: |
|
|
prompt += "\n\nInclude time and space complexity analysis." |
|
|
|
|
|
request = GenerationRequest( |
|
|
prompt=prompt, |
|
|
max_length=3072, |
|
|
temperature=0.5, |
|
|
top_p=0.93, |
|
|
do_sample=True, |
|
|
task_type="algorithm_design" |
|
|
) |
|
|
|
|
|
return await generate(request) |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Run the API server""" |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Helion-OSC API Server") |
|
|
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") |
|
|
parser.add_argument("--port", type=int, default=8000, help="Port to bind to") |
|
|
parser.add_argument("--reload", action="store_true", help="Enable auto-reload") |
|
|
parser.add_argument("--workers", type=int, default=1, help="Number of worker processes") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
logger.info(f"Starting Helion-OSC API Server on {args.host}:{args.port}") |
|
|
|
|
|
uvicorn.run( |
|
|
"api_server:app", |
|
|
host=args.host, |
|
|
port=args.port, |
|
|
reload=args.reload, |
|
|
workers=args.workers |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |