Back to Blog
MLOps·FastAPIMLOpsModel Serving

FastAPI for ML Model Serving: A Production Engineering Guide

How to build a production-ready ML inference service with FastAPI — covering async workers, model lifecycle management, caching strategies, health checks, and Docker deployment.

Rishabh Bhartiya9 min read
FastAPI for ML Model Serving: A Production Engineering Guide

FastAPI has become the default choice for ML model serving — and for good reason. It's fast, async-native, auto-documents itself, and has excellent Pydantic integration for request validation. But most FastAPI ML tutorials show a toy example. Production is more involved.

The Production FastAPI ML Stack

  • FastAPI — async HTTP framework, request validation, OpenAPI docs
  • Uvicorn + Gunicorn — ASGI server with multiple worker processes
  • Redis — prediction caching and job queue for async workloads
  • Docker — containerized deployment for environment consistency
  • Prometheus + Grafana — latency tracking, throughput, error rates

Model Lifecycle: Load Once, Serve Many


from contextlib import asynccontextmanager
from fastapi import FastAPI
import joblib

models = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
    # Load on startup — once per worker process
    models["classifier"] = joblib.load("models/classifier_v2.pkl")
    models["scaler"]     = joblib.load("models/scaler_v2.pkl")
    print("Models loaded successfully")
    yield
    # Cleanup on shutdown
    models.clear()

app = FastAPI(lifespan=lifespan, title="ML Inference API")

Request Validation with Pydantic


from pydantic import BaseModel, validator
from typing import List

class PredictionRequest(BaseModel):
    features: List[float]
    model_version: str = "v2"
    
    @validator("features")
    def validate_features(cls, v):
        if len(v) != 12:
            raise ValueError(f"Expected 12 features, got {len(v)}")
        if any(f != f for f in v):   # NaN check
            raise ValueError("Features contain NaN values")
        return v

class PredictionResponse(BaseModel):
    prediction: float
    confidence: float
    model_version: str
    latency_ms: float

Async Inference: Don't Block the Event Loop


import asyncio, time
import numpy as np
from fastapi import FastAPI

@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
    start = time.perf_counter()
    
    features = np.array(request.features).reshape(1, -1)
    
    # CPU-bound inference — run in thread pool to avoid blocking
    loop = asyncio.get_event_loop()
    scaled = await loop.run_in_executor(
        None, models["scaler"].transform, features
    )
    prediction = await loop.run_in_executor(
        None, models["classifier"].predict_proba, scaled
    )
    
    latency_ms = (time.perf_counter() - start) * 1000
    
    return PredictionResponse(
        prediction=float(prediction[0][1]),
        confidence=float(np.max(prediction[0])),
        model_version=request.model_version,
        latency_ms=latency_ms
    )

Redis Caching for Repeated Predictions


import redis, hashlib, json

cache = redis.Redis(host="redis", port=6379, decode_responses=True)
CACHE_TTL = 3600  # 1 hour

def get_cache_key(features: list, version: str) -> str:
    payload = json.dumps({"features": features, "version": version}, sort_keys=True)
    return f"pred:{hashlib.sha256(payload.encode()).hexdigest()[:16]}"

@app.post("/predict", response_model=PredictionResponse)
async def predict_with_cache(request: PredictionRequest):
    cache_key = get_cache_key(request.features, request.model_version)
    
    cached = cache.get(cache_key)
    if cached:
        result = json.loads(cached)
        result["cached"] = True
        return PredictionResponse(**result)
    
    # Run inference (as above)
    result = await run_inference(request)
    cache.setex(cache_key, CACHE_TTL, json.dumps(result.dict()))
    return result

Health Checks and Readiness Probes


@app.get("/health")
async def health():
    return {"status": "healthy", "models_loaded": list(models.keys())}

@app.get("/ready")
async def ready():
    if not models:
        from fastapi import HTTPException
        raise HTTPException(status_code=503, detail="Models not loaded")
    return {"status": "ready"}

Docker: Multi-Stage Build for Small Images


# Stage 1: build dependencies
FROM python:3.11-slim AS builder
WORKDIR /app
COPY requirements.txt .
RUN pip install --user --no-cache-dir -r requirements.txt

# Stage 2: production image
FROM python:3.11-slim
WORKDIR /app
COPY --from=builder /root/.local /root/.local
COPY . .

ENV PATH=/root/.local/bin:$PATH
ENV PYTHONUNBUFFERED=1

EXPOSE 8000
CMD ["gunicorn", "main:app", "-w", "4", "-k", "uvicorn.workers.UvicornWorker", 
     "--bind", "0.0.0.0:8000", "--timeout", "120"]

Production Checklist

  • Models loaded once at startup via lifespan, not per-request
  • CPU-bound inference runs in thread pool executors
  • Input validation with Pydantic catches bad data before inference
  • Redis caching for high-cardinality repeated requests
  • /health and /ready endpoints for Kubernetes probes
  • Multi-stage Docker build keeps image under 500MB
  • Prometheus middleware tracks p50/p95/p99 latencies

Tags

FastAPIMLOpsModel ServingDockerRedisProduction ML

Author

Rishabh Bhartiya

ML Engineer · NatrajX

Related Posts

All posts