Helion-OSC / api_server.py
Trouter-Library's picture
Create api_server.py (#3)
eb06522 verified
"""
Helion-OSC API Server
FastAPI-based REST API for serving Helion-OSC model
"""
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any, AsyncGenerator
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
import uvicorn
import logging
import time
import json
from queue import Queue
import asyncio
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="Helion-OSC API",
description="REST API for Helion-OSC Code Generation Model",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global model variables
model = None
tokenizer = None
device = None
class GenerationRequest(BaseModel):
"""Request model for text generation"""
prompt: str = Field(..., description="Input prompt for generation")
max_length: int = Field(2048, ge=1, le=16384, description="Maximum length of generation")
temperature: float = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature")
top_p: float = Field(0.95, ge=0.0, le=1.0, description="Nucleus sampling parameter")
top_k: int = Field(50, ge=0, le=200, description="Top-k sampling parameter")
repetition_penalty: float = Field(1.05, ge=1.0, le=2.0, description="Repetition penalty")
do_sample: bool = Field(True, description="Whether to use sampling")
num_return_sequences: int = Field(1, ge=1, le=10, description="Number of sequences to generate")
stop_sequences: Optional[List[str]] = Field(None, description="Stop generation at these sequences")
stream: bool = Field(False, description="Stream the response")
task_type: Optional[str] = Field("code_generation", description="Task type for optimized parameters")
class GenerationResponse(BaseModel):
"""Response model for text generation"""
generated_text: str
prompt: str
model: str
generation_time: float
tokens_generated: int
class ModelInfo(BaseModel):
"""Model information"""
model_name: str
model_type: str
vocabulary_size: int
hidden_size: int
num_layers: int
device: str
dtype: str
max_position_embeddings: int
class HealthResponse(BaseModel):
"""Health check response"""
status: str
model_loaded: bool
device: str
timestamp: float
@app.on_event("startup")
async def load_model():
"""Load model on startup"""
global model, tokenizer, device
logger.info("Loading Helion-OSC model...")
model_name = "DeepXR/Helion-OSC"
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None,
trust_remote_code=True,
low_cpu_mem_usage=True
)
if device == "cpu":
model = model.to(device)
model.eval()
logger.info(f"Model loaded successfully on {device}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
@app.get("/", response_model=Dict[str, str])
async def root():
"""Root endpoint"""
return {
"message": "Helion-OSC API Server",
"version": "1.0.0",
"documentation": "/docs"
}
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint"""
return HealthResponse(
status="healthy" if model is not None else "unhealthy",
model_loaded=model is not None,
device=device,
timestamp=time.time()
)
@app.get("/info", response_model=ModelInfo)
async def model_info():
"""Get model information"""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
config = model.config
return ModelInfo(
model_name="DeepXR/Helion-OSC",
model_type=config.model_type,
vocabulary_size=config.vocab_size,
hidden_size=config.hidden_size,
num_layers=config.num_hidden_layers,
device=device,
dtype=str(next(model.parameters()).dtype),
max_position_embeddings=config.max_position_embeddings
)
@app.post("/generate", response_model=GenerationResponse)
async def generate(request: GenerationRequest):
"""Generate text based on prompt"""
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model not loaded")
if request.stream:
raise HTTPException(
status_code=400,
detail="Use /generate/stream endpoint for streaming responses"
)
start_time = time.time()
try:
# Tokenize input
inputs = tokenizer(request.prompt, return_tensors="pt").to(device)
input_length = inputs.input_ids.shape[1]
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=request.max_length,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
repetition_penalty=request.repetition_penalty,
do_sample=request.do_sample,
num_return_sequences=request.num_return_sequences,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
# Decode output
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove prompt from output
generated_text = generated_text[len(request.prompt):].strip()
generation_time = time.time() - start_time
tokens_generated = outputs.shape[1] - input_length
return GenerationResponse(
generated_text=generated_text,
prompt=request.prompt,
model="DeepXR/Helion-OSC",
generation_time=generation_time,
tokens_generated=tokens_generated
)
except Exception as e:
logger.error(f"Generation error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate/stream")
async def generate_stream(request: GenerationRequest):
"""Generate text with streaming response"""
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model not loaded")
async def stream_generator() -> AsyncGenerator[str, None]:
try:
# Tokenize input
inputs = tokenizer(request.prompt, return_tensors="pt").to(device)
# Setup streamer
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
# Generation kwargs
generation_kwargs = {
**inputs,
"max_length": request.max_length,
"temperature": request.temperature,
"top_p": request.top_p,
"top_k": request.top_k,
"repetition_penalty": request.repetition_penalty,
"do_sample": request.do_sample,
"pad_token_id": tokenizer.pad_token_id,
"eos_token_id": tokenizer.eos_token_id,
"streamer": streamer
}
# Start generation in separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Stream tokens
for text in streamer:
yield f"data: {json.dumps({'text': text})}\n\n"
await asyncio.sleep(0) # Allow other tasks to run
yield f"data: {json.dumps({'done': True})}\n\n"
except Exception as e:
logger.error(f"Streaming error: {e}")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
return StreamingResponse(
stream_generator(),
media_type="text/event-stream"
)
@app.post("/code/complete")
async def code_complete(
code: str,
language: Optional[str] = "python",
max_length: int = 1024
):
"""Code completion endpoint"""
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model not loaded")
request = GenerationRequest(
prompt=code,
max_length=max_length,
temperature=0.6,
top_p=0.92,
do_sample=True,
task_type="code_completion"
)
return await generate(request)
@app.post("/code/explain")
async def code_explain(code: str, language: Optional[str] = "python"):
"""Code explanation endpoint"""
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model not loaded")
prompt = f"Explain the following {language} code in detail:\n\n```{language}\n{code}\n```\n\nExplanation:"
request = GenerationRequest(
prompt=prompt,
max_length=2048,
temperature=0.6,
top_p=0.9,
do_sample=True,
task_type="code_explanation"
)
return await generate(request)
@app.post("/code/debug")
async def code_debug(
code: str,
error_message: Optional[str] = None,
language: Optional[str] = "python"
):
"""Code debugging endpoint"""
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model not loaded")
prompt = f"Debug the following {language} code:\n\n```{language}\n{code}\n```"
if error_message:
prompt += f"\n\nError message: {error_message}"
prompt += "\n\nProvide a detailed analysis and fixed code:"
request = GenerationRequest(
prompt=prompt,
max_length=2048,
temperature=0.4,
top_p=0.88,
do_sample=False,
task_type="debugging"
)
return await generate(request)
@app.post("/math/solve")
async def math_solve(problem: str):
"""Mathematical problem solving endpoint"""
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model not loaded")
prompt = f"Solve the following mathematical problem step by step:\n\n{problem}\n\nSolution:"
request = GenerationRequest(
prompt=prompt,
max_length=2048,
temperature=0.3,
top_p=0.9,
do_sample=False,
task_type="mathematical_reasoning"
)
return await generate(request)
@app.post("/algorithm/design")
async def algorithm_design(
problem: str,
include_complexity: bool = True
):
"""Algorithm design endpoint"""
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model not loaded")
prompt = f"Design an efficient algorithm for the following problem:\n\n{problem}"
if include_complexity:
prompt += "\n\nInclude time and space complexity analysis."
request = GenerationRequest(
prompt=prompt,
max_length=3072,
temperature=0.5,
top_p=0.93,
do_sample=True,
task_type="algorithm_design"
)
return await generate(request)
def main():
"""Run the API server"""
import argparse
parser = argparse.ArgumentParser(description="Helion-OSC API Server")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
parser.add_argument("--reload", action="store_true", help="Enable auto-reload")
parser.add_argument("--workers", type=int, default=1, help="Number of worker processes")
args = parser.parse_args()
logger.info(f"Starting Helion-OSC API Server on {args.host}:{args.port}")
uvicorn.run(
"api_server:app",
host=args.host,
port=args.port,
reload=args.reload,
workers=args.workers
)
if __name__ == "__main__":
main()