""" Helion-OSC API Server FastAPI-based REST API for serving Helion-OSC model """ from fastapi import FastAPI, HTTPException, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field from typing import Optional, List, Dict, Any, AsyncGenerator import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from threading import Thread import uvicorn import logging import time import json from queue import Queue import asyncio logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI( title="Helion-OSC API", description="REST API for Helion-OSC Code Generation Model", version="1.0.0" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global model variables model = None tokenizer = None device = None class GenerationRequest(BaseModel): """Request model for text generation""" prompt: str = Field(..., description="Input prompt for generation") max_length: int = Field(2048, ge=1, le=16384, description="Maximum length of generation") temperature: float = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature") top_p: float = Field(0.95, ge=0.0, le=1.0, description="Nucleus sampling parameter") top_k: int = Field(50, ge=0, le=200, description="Top-k sampling parameter") repetition_penalty: float = Field(1.05, ge=1.0, le=2.0, description="Repetition penalty") do_sample: bool = Field(True, description="Whether to use sampling") num_return_sequences: int = Field(1, ge=1, le=10, description="Number of sequences to generate") stop_sequences: Optional[List[str]] = Field(None, description="Stop generation at these sequences") stream: bool = Field(False, description="Stream the response") task_type: Optional[str] = Field("code_generation", description="Task type for optimized parameters") class GenerationResponse(BaseModel): """Response model for text generation""" generated_text: str prompt: str model: str generation_time: float tokens_generated: int class ModelInfo(BaseModel): """Model information""" model_name: str model_type: str vocabulary_size: int hidden_size: int num_layers: int device: str dtype: str max_position_embeddings: int class HealthResponse(BaseModel): """Health check response""" status: str model_loaded: bool device: str timestamp: float @app.on_event("startup") async def load_model(): """Load model on startup""" global model, tokenizer, device logger.info("Loading Helion-OSC model...") model_name = "DeepXR/Helion-OSC" device = "cuda" if torch.cuda.is_available() else "cpu" try: # Load tokenizer tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load model model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32, device_map="auto" if device == "cuda" else None, trust_remote_code=True, low_cpu_mem_usage=True ) if device == "cpu": model = model.to(device) model.eval() logger.info(f"Model loaded successfully on {device}") except Exception as e: logger.error(f"Failed to load model: {e}") raise @app.get("/", response_model=Dict[str, str]) async def root(): """Root endpoint""" return { "message": "Helion-OSC API Server", "version": "1.0.0", "documentation": "/docs" } @app.get("/health", response_model=HealthResponse) async def health_check(): """Health check endpoint""" return HealthResponse( status="healthy" if model is not None else "unhealthy", model_loaded=model is not None, device=device, timestamp=time.time() ) @app.get("/info", response_model=ModelInfo) async def model_info(): """Get model information""" if model is None: raise HTTPException(status_code=503, detail="Model not loaded") config = model.config return ModelInfo( model_name="DeepXR/Helion-OSC", model_type=config.model_type, vocabulary_size=config.vocab_size, hidden_size=config.hidden_size, num_layers=config.num_hidden_layers, device=device, dtype=str(next(model.parameters()).dtype), max_position_embeddings=config.max_position_embeddings ) @app.post("/generate", response_model=GenerationResponse) async def generate(request: GenerationRequest): """Generate text based on prompt""" if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model not loaded") if request.stream: raise HTTPException( status_code=400, detail="Use /generate/stream endpoint for streaming responses" ) start_time = time.time() try: # Tokenize input inputs = tokenizer(request.prompt, return_tensors="pt").to(device) input_length = inputs.input_ids.shape[1] # Generate with torch.no_grad(): outputs = model.generate( **inputs, max_length=request.max_length, temperature=request.temperature, top_p=request.top_p, top_k=request.top_k, repetition_penalty=request.repetition_penalty, do_sample=request.do_sample, num_return_sequences=request.num_return_sequences, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) # Decode output generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Remove prompt from output generated_text = generated_text[len(request.prompt):].strip() generation_time = time.time() - start_time tokens_generated = outputs.shape[1] - input_length return GenerationResponse( generated_text=generated_text, prompt=request.prompt, model="DeepXR/Helion-OSC", generation_time=generation_time, tokens_generated=tokens_generated ) except Exception as e: logger.error(f"Generation error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/generate/stream") async def generate_stream(request: GenerationRequest): """Generate text with streaming response""" if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model not loaded") async def stream_generator() -> AsyncGenerator[str, None]: try: # Tokenize input inputs = tokenizer(request.prompt, return_tensors="pt").to(device) # Setup streamer streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) # Generation kwargs generation_kwargs = { **inputs, "max_length": request.max_length, "temperature": request.temperature, "top_p": request.top_p, "top_k": request.top_k, "repetition_penalty": request.repetition_penalty, "do_sample": request.do_sample, "pad_token_id": tokenizer.pad_token_id, "eos_token_id": tokenizer.eos_token_id, "streamer": streamer } # Start generation in separate thread thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Stream tokens for text in streamer: yield f"data: {json.dumps({'text': text})}\n\n" await asyncio.sleep(0) # Allow other tasks to run yield f"data: {json.dumps({'done': True})}\n\n" except Exception as e: logger.error(f"Streaming error: {e}") yield f"data: {json.dumps({'error': str(e)})}\n\n" return StreamingResponse( stream_generator(), media_type="text/event-stream" ) @app.post("/code/complete") async def code_complete( code: str, language: Optional[str] = "python", max_length: int = 1024 ): """Code completion endpoint""" if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model not loaded") request = GenerationRequest( prompt=code, max_length=max_length, temperature=0.6, top_p=0.92, do_sample=True, task_type="code_completion" ) return await generate(request) @app.post("/code/explain") async def code_explain(code: str, language: Optional[str] = "python"): """Code explanation endpoint""" if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model not loaded") prompt = f"Explain the following {language} code in detail:\n\n```{language}\n{code}\n```\n\nExplanation:" request = GenerationRequest( prompt=prompt, max_length=2048, temperature=0.6, top_p=0.9, do_sample=True, task_type="code_explanation" ) return await generate(request) @app.post("/code/debug") async def code_debug( code: str, error_message: Optional[str] = None, language: Optional[str] = "python" ): """Code debugging endpoint""" if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model not loaded") prompt = f"Debug the following {language} code:\n\n```{language}\n{code}\n```" if error_message: prompt += f"\n\nError message: {error_message}" prompt += "\n\nProvide a detailed analysis and fixed code:" request = GenerationRequest( prompt=prompt, max_length=2048, temperature=0.4, top_p=0.88, do_sample=False, task_type="debugging" ) return await generate(request) @app.post("/math/solve") async def math_solve(problem: str): """Mathematical problem solving endpoint""" if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model not loaded") prompt = f"Solve the following mathematical problem step by step:\n\n{problem}\n\nSolution:" request = GenerationRequest( prompt=prompt, max_length=2048, temperature=0.3, top_p=0.9, do_sample=False, task_type="mathematical_reasoning" ) return await generate(request) @app.post("/algorithm/design") async def algorithm_design( problem: str, include_complexity: bool = True ): """Algorithm design endpoint""" if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model not loaded") prompt = f"Design an efficient algorithm for the following problem:\n\n{problem}" if include_complexity: prompt += "\n\nInclude time and space complexity analysis." request = GenerationRequest( prompt=prompt, max_length=3072, temperature=0.5, top_p=0.93, do_sample=True, task_type="algorithm_design" ) return await generate(request) def main(): """Run the API server""" import argparse parser = argparse.ArgumentParser(description="Helion-OSC API Server") parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") parser.add_argument("--port", type=int, default=8000, help="Port to bind to") parser.add_argument("--reload", action="store_true", help="Enable auto-reload") parser.add_argument("--workers", type=int, default=1, help="Number of worker processes") args = parser.parse_args() logger.info(f"Starting Helion-OSC API Server on {args.host}:{args.port}") uvicorn.run( "api_server:app", host=args.host, port=args.port, reload=args.reload, workers=args.workers ) if __name__ == "__main__": main()