Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| AI Coding Model Server | |
| FastAPI server that hosts the 5B parameter coding model | |
| """ | |
| import torch | |
| import spaces | |
| import uvicorn | |
| from fastapi import FastAPI, HTTPException, BackgroundTasks | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from typing import List, Dict, Any, Optional | |
| import logging | |
| import os | |
| import asyncio | |
| import threading | |
| from contextlib import asynccontextmanager | |
| # Import model components | |
| from models import CodeModel | |
| from utils import format_code_response, validate_code_syntax | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Global model instance | |
| code_model = None | |
| model_loading = False | |
| class ChatMessage(BaseModel): | |
| """Chat message model.""" | |
| message: str = Field(..., description="User's message") | |
| history: List[Dict[str, str]] = Field(default_factory=list, description="Chat history") | |
| language: str = Field(default="python", description="Target programming language") | |
| temperature: float = Field(default=0.7, ge=0.1, le=1.0, description="Generation temperature") | |
| class ChatResponse(BaseModel): | |
| """Chat response model.""" | |
| choices: List[Dict[str, Dict[str, str]]] = Field(..., description="Generated responses") | |
| history: List[Dict[str, str]] = Field(..., description="Updated chat history") | |
| usage: Optional[Dict[str, int]] = Field(None, description="Token usage information") | |
| class HealthResponse(BaseModel): | |
| """Health check response.""" | |
| status: str | |
| model_loaded: bool | |
| model_name: str | |
| device: str | |
| memory_usage: Optional[Dict[str, Any]] = None | |
| class ModelInfoResponse(BaseModel): | |
| """Model information response.""" | |
| model_name: str | |
| parameter_count: str | |
| max_length: int | |
| device: str | |
| is_loaded: bool | |
| vocab_size: int | |
| async def lifespan(app: FastAPI): | |
| """Application lifespan management.""" | |
| # Startup | |
| logger.info("Starting up AI Coding Model Server...") | |
| await load_model() | |
| yield | |
| # Shutdown | |
| logger.info("Shutting down server...") | |
| async def load_model(): | |
| """Load the model in background.""" | |
| global code_model, model_loading | |
| if code_model is not None or model_loading: | |
| return | |
| model_loading = True | |
| logger.info("Loading coding model...") | |
| try: | |
| # Load model in thread to avoid blocking | |
| loop = asyncio.get_event_loop() | |
| code_model = await loop.run_in_executor(None, CodeModel) | |
| if code_model.is_loaded: | |
| logger.info(f"β Model loaded successfully: {code_model.model_name}") | |
| else: | |
| logger.error("β Failed to load model") | |
| except Exception as e: | |
| logger.error(f"β Error loading model: {e}") | |
| code_model = None | |
| finally: | |
| model_loading = False | |
| def create_app() -> FastAPI: | |
| """Create and configure the FastAPI application.""" | |
| # Create FastAPI app with lifespan management | |
| app = FastAPI( | |
| title="AI Coding Model Server", | |
| description="FastAPI server hosting a 5B parameter coding model", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Configure appropriately for production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def root(): | |
| """Root endpoint.""" | |
| return { | |
| "message": "AI Coding Model Server", | |
| "version": "1.0.0", | |
| "status": "running" if code_model and code_model.is_loaded else "loading" | |
| } | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| if model_loading: | |
| return HealthResponse( | |
| status="loading", | |
| model_loaded=False, | |
| model_name="Loading...", | |
| device="unknown" | |
| ) | |
| if not code_model or not code_model.is_loaded: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| # Get memory usage if available | |
| memory_info = None | |
| if torch.cuda.is_available(): | |
| memory_info = { | |
| "allocated": torch.cuda.memory_allocated() / 1024**3, # GB | |
| "cached": torch.cuda.memory_reserved() / 1024**3, # GB | |
| "total": torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
| } | |
| return HealthResponse( | |
| status="healthy", | |
| model_loaded=True, | |
| model_name=code_model.model_name, | |
| device=code_model.device, | |
| memory_usage=memory_info | |
| ) | |
| async def model_info(): | |
| """Get detailed model information.""" | |
| if not code_model: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| info = code_model.get_model_info() | |
| return ModelInfoResponse(**info) | |
| async def chat(request: ChatMessage): | |
| """Main chat endpoint.""" | |
| if model_loading: | |
| raise HTTPException(status_code=503, detail="Model is still loading") | |
| if not code_model or not code_model.is_loaded: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| # Generate response using the model | |
| messages = request.history.copy() | |
| messages.append({"role": "user", "content": request.message}) | |
| response_text = code_model.generate( | |
| messages=messages, | |
| temperature=request.temperature, | |
| max_new_tokens=2048, | |
| language=request.language | |
| ) | |
| # Format the response | |
| formatted_response = format_code_response(response_text) | |
| # Update chat history | |
| new_history = request.history.copy() | |
| new_history.append({"role": "user", "content": request.message}) | |
| new_history.append({"role": "assistant", "content": formatted_response}) | |
| return ChatResponse( | |
| choices=[{"message": {"content": formatted_response}}], | |
| history=new_history | |
| ) | |
| except Exception as e: | |
| logger.error(f"Chat error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}") | |
| async def validate_code(request: Dict[str, Any]): | |
| """Validate code syntax.""" | |
| code = request.get("code", "") | |
| language = request.get("language", "python") | |
| if not code: | |
| raise HTTPException(status_code=400, detail="No code provided") | |
| validation_result = validate_code_syntax(code, language) | |
| return validation_result | |
| async def get_supported_languages(): | |
| """Get list of supported programming languages.""" | |
| return { | |
| "languages": [ | |
| "python", "javascript", "java", "cpp", "c", "go", "rust", | |
| "typescript", "php", "ruby", "swift", "kotlin", "sql", | |
| "html", "css", "bash", "powershell" | |
| ] | |
| } | |
| return app | |
| def run_server(host: str = "0.0.0.0", port: int = 8000, reload: bool = False): | |
| """Run the FastAPI server.""" | |
| app = create_app() | |
| console_info = f""" | |
| π AI Coding Model Server Starting... | |
| π Server Info: | |
| β’ Host: {host} | |
| β’ Port: {port} | |
| β’ Model: Loading... | |
| β’ Device: {'CUDA' if torch.cuda.is_available() else 'CPU'} | |
| π Endpoints: | |
| β’ Health: http://{host}:{port}/health | |
| β’ Model Info: http://{host}:{port}/model/info | |
| β’ Chat: http://{host}:{port}/api/chat | |
| β’ API Docs: http://{host}:{port}/docs | |
| π‘ Usage: | |
| β’ Terminal client: python terminal_chatbot.py | |
| β’ API calls: POST to /api/chat with chat messages | |
| β’ Check status: GET /health | |
| β‘ Server is ready! Press Ctrl+C to stop. | |
| """ | |
| print(console_info) | |
| # Run server | |
| uvicorn.run( | |
| "model_server:create_app", | |
| host=host, | |
| port=port, | |
| reload=reload, | |
| log_level="info", | |
| access_log=True | |
| ) | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="AI Coding Model Server") | |
| parser.add_argument("--host", default="0.0.0.0", help="Server host") | |
| parser.add_argument("--port", type=int, default=8000, help="Server port") | |
| parser.add_argument("--reload", action="store_true", help="Auto-reload on changes") | |
| args = parser.parse_args() | |
| run_server( | |
| host=args.host, | |
| port=args.port, | |
| reload=args.reload | |
| ) |