import asyncio import json import logging import time from uuid import uuid4 from contextlib import suppress from fastapi import FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from starlette.responses import Response from sqlalchemy import text from sqlalchemy.exc import IntegrityError from backend.app.core.config import settings from backend.app.core.response import build_envelope from backend.app.db.session import engine, SessionLocal from backend.app.db.base import Base from backend.app import models # noqa: F401 from backend.app.services.seed import seed_data from backend.app.services.center_sync import run_center_sync_loop from backend.app.routers import auth, users, roles, campuses, departments, shifts, schedule, duty, statistics, adjust_logs, sync, ws app = FastAPI(title=settings.app_name) app.add_middleware( CORSMiddleware, allow_origins=[ "http://localhost:5173", "http://127.0.0.1:5173" ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) logger = logging.getLogger("uvicorn.error") sync_stop_event = asyncio.Event() sync_task: asyncio.Task | None = None def _clone_headers(response, trace_id: str) -> dict: headers = dict(response.headers) headers.pop("content-length", None) headers["X-Trace-Id"] = trace_id return headers @app.middleware("http") async def add_trace_and_wrap(request: Request, call_next): trace_id = request.headers.get("x-trace-id") or request.headers.get("traceId") or uuid4().hex request.state.trace_id = trace_id response = await call_next(request) content_type = response.headers.get("content-type", "") if response.status_code >= 400 or "application/json" not in content_type: response.headers["X-Trace-Id"] = trace_id return response body = b"" async for chunk in response.body_iterator: body += chunk headers = _clone_headers(response, trace_id) if not body: return JSONResponse(build_envelope(None, trace_id), status_code=response.status_code, headers=headers) try: data = json.loads(body) except json.JSONDecodeError: return Response( content=body, status_code=response.status_code, headers=headers, media_type=response.media_type, ) if isinstance(data, dict) and {"code", "message", "data"}.issubset(data.keys()): data.setdefault("traceId", trace_id) data.setdefault("ts", int(time.time() * 1000)) return JSONResponse(data, status_code=response.status_code, headers=headers) return JSONResponse(build_envelope(data, trace_id), status_code=response.status_code, headers=headers) def get_trace_id(request: Request) -> str: trace_id = getattr(request.state, "trace_id", None) return trace_id or uuid4().hex @app.exception_handler(RequestValidationError) async def validation_exception_handler(request: Request, exc: RequestValidationError): field_errors = [] for err in exc.errors(): loc = [str(item) for item in err.get("loc", []) if item not in ("body", "query", "path")] field = ".".join(loc) if loc else ".".join([str(item) for item in err.get("loc", [])]) field_errors.append({"field": field, "reason": err.get("msg")}) trace_id = get_trace_id(request) payload = build_envelope( None, trace_id, code="VALIDATION_ERROR", message="参数校验失败", details={"fieldErrors": field_errors}, ) return JSONResponse(payload, status_code=400) @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException): code_map = { 400: "VALIDATION_ERROR", 401: "UNAUTHORIZED", 403: "FORBIDDEN", 404: "NOT_FOUND", 409: "CONFLICT", 429: "RATE_LIMITED", } trace_id = get_trace_id(request) code = code_map.get(exc.status_code, "INTERNAL_ERROR") message = exc.detail if isinstance(exc.detail, str) else "请求失败" payload = build_envelope(None, trace_id, code=code, message=message) return JSONResponse(payload, status_code=exc.status_code) @app.exception_handler(IntegrityError) async def integrity_exception_handler(request: Request, exc: IntegrityError): trace_id = get_trace_id(request) logger.exception("IntegrityError traceId=%s", trace_id) payload = build_envelope(None, trace_id, code="CONFLICT", message="数据冲突") return JSONResponse(payload, status_code=409) @app.exception_handler(Exception) async def general_exception_handler(request: Request, exc: Exception): trace_id = get_trace_id(request) logger.exception("Unhandled error traceId=%s", trace_id) payload = build_envelope(None, trace_id, code="INTERNAL_ERROR", message="系统异常,请稍后再试") return JSONResponse(payload, status_code=500) @app.on_event("startup") async def on_startup(): global sync_task async with engine.begin() as conn: # 依赖 pgcrypto 扩展提供 gen_random_uuid() # 若数据库账号无权限创建扩展,会在这里直接失败并提示,避免后续插入时报错更隐蔽 await conn.execute(text('CREATE EXTENSION IF NOT EXISTS "pgcrypto";')) await conn.run_sync(Base.metadata.create_all) async with SessionLocal() as session: await seed_data(session) sync_task = asyncio.create_task(run_center_sync_loop(sync_stop_event)) @app.on_event("shutdown") async def on_shutdown(): sync_stop_event.set() if sync_task: sync_task.cancel() with suppress(asyncio.CancelledError): await sync_task @app.get("/health") async def health(): return {"status": "ok"} app.include_router(auth.router) app.include_router(users.router) app.include_router(roles.router) app.include_router(campuses.router) app.include_router(departments.router) app.include_router(shifts.router) app.include_router(schedule.router) app.include_router(duty.router) app.include_router(statistics.router) app.include_router(adjust_logs.router) app.include_router(sync.router) app.include_router(ws.router)