| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180 |
- import asyncio
- import json
- import logging
- import time
- from uuid import uuid4
- from contextlib import suppress
- from fastapi import FastAPI, HTTPException, Request
- from fastapi.exceptions import RequestValidationError
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.responses import JSONResponse
- from starlette.responses import Response
- from sqlalchemy import text
- from sqlalchemy.exc import IntegrityError
- from backend.app.core.config import settings
- from backend.app.core.response import build_envelope
- from backend.app.db.session import engine, SessionLocal
- from backend.app.db.base import Base
- from backend.app import models # noqa: F401
- from backend.app.services.seed import seed_data
- from backend.app.services.center_sync import run_center_sync_loop
- from backend.app.routers import auth, users, roles, campuses, departments, shifts, schedule, duty, statistics, adjust_logs, sync, ws
- app = FastAPI(title=settings.app_name)
- app.add_middleware(
- CORSMiddleware,
- allow_origins=[
- "http://localhost:5173",
- "http://127.0.0.1:5173"
- ],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- logger = logging.getLogger("uvicorn.error")
- sync_stop_event = asyncio.Event()
- sync_task: asyncio.Task | None = None
- def _clone_headers(response, trace_id: str) -> dict:
- headers = dict(response.headers)
- headers.pop("content-length", None)
- headers["X-Trace-Id"] = trace_id
- return headers
- @app.middleware("http")
- async def add_trace_and_wrap(request: Request, call_next):
- trace_id = request.headers.get("x-trace-id") or request.headers.get("traceId") or uuid4().hex
- request.state.trace_id = trace_id
- response = await call_next(request)
- content_type = response.headers.get("content-type", "")
- if response.status_code >= 400 or "application/json" not in content_type:
- response.headers["X-Trace-Id"] = trace_id
- return response
- body = b""
- async for chunk in response.body_iterator:
- body += chunk
- headers = _clone_headers(response, trace_id)
- if not body:
- return JSONResponse(build_envelope(None, trace_id), status_code=response.status_code, headers=headers)
- try:
- data = json.loads(body)
- except json.JSONDecodeError:
- return Response(
- content=body,
- status_code=response.status_code,
- headers=headers,
- media_type=response.media_type,
- )
- if isinstance(data, dict) and {"code", "message", "data"}.issubset(data.keys()):
- data.setdefault("traceId", trace_id)
- data.setdefault("ts", int(time.time() * 1000))
- return JSONResponse(data, status_code=response.status_code, headers=headers)
- return JSONResponse(build_envelope(data, trace_id), status_code=response.status_code, headers=headers)
- def get_trace_id(request: Request) -> str:
- trace_id = getattr(request.state, "trace_id", None)
- return trace_id or uuid4().hex
- @app.exception_handler(RequestValidationError)
- async def validation_exception_handler(request: Request, exc: RequestValidationError):
- field_errors = []
- for err in exc.errors():
- loc = [str(item) for item in err.get("loc", []) if item not in ("body", "query", "path")]
- field = ".".join(loc) if loc else ".".join([str(item) for item in err.get("loc", [])])
- field_errors.append({"field": field, "reason": err.get("msg")})
- trace_id = get_trace_id(request)
- payload = build_envelope(
- None,
- trace_id,
- code="VALIDATION_ERROR",
- message="参数校验失败",
- details={"fieldErrors": field_errors},
- )
- return JSONResponse(payload, status_code=400)
- @app.exception_handler(HTTPException)
- async def http_exception_handler(request: Request, exc: HTTPException):
- code_map = {
- 400: "VALIDATION_ERROR",
- 401: "UNAUTHORIZED",
- 403: "FORBIDDEN",
- 404: "NOT_FOUND",
- 409: "CONFLICT",
- 429: "RATE_LIMITED",
- }
- trace_id = get_trace_id(request)
- code = code_map.get(exc.status_code, "INTERNAL_ERROR")
- message = exc.detail if isinstance(exc.detail, str) else "请求失败"
- payload = build_envelope(None, trace_id, code=code, message=message)
- return JSONResponse(payload, status_code=exc.status_code)
- @app.exception_handler(IntegrityError)
- async def integrity_exception_handler(request: Request, exc: IntegrityError):
- trace_id = get_trace_id(request)
- logger.exception("IntegrityError traceId=%s", trace_id)
- payload = build_envelope(None, trace_id, code="CONFLICT", message="数据冲突")
- return JSONResponse(payload, status_code=409)
- @app.exception_handler(Exception)
- async def general_exception_handler(request: Request, exc: Exception):
- trace_id = get_trace_id(request)
- logger.exception("Unhandled error traceId=%s", trace_id)
- payload = build_envelope(None, trace_id, code="INTERNAL_ERROR", message="系统异常,请稍后再试")
- return JSONResponse(payload, status_code=500)
- @app.on_event("startup")
- async def on_startup():
- global sync_task
- async with engine.begin() as conn:
- # 依赖 pgcrypto 扩展提供 gen_random_uuid()
- # 若数据库账号无权限创建扩展,会在这里直接失败并提示,避免后续插入时报错更隐蔽
- await conn.execute(text('CREATE EXTENSION IF NOT EXISTS "pgcrypto";'))
- await conn.run_sync(Base.metadata.create_all)
- async with SessionLocal() as session:
- await seed_data(session)
- sync_task = asyncio.create_task(run_center_sync_loop(sync_stop_event))
- @app.on_event("shutdown")
- async def on_shutdown():
- sync_stop_event.set()
- if sync_task:
- sync_task.cancel()
- with suppress(asyncio.CancelledError):
- await sync_task
- @app.get("/health")
- async def health():
- return {"status": "ok"}
- app.include_router(auth.router)
- app.include_router(users.router)
- app.include_router(roles.router)
- app.include_router(campuses.router)
- app.include_router(departments.router)
- app.include_router(shifts.router)
- app.include_router(schedule.router)
- app.include_router(duty.router)
- app.include_router(statistics.router)
- app.include_router(adjust_logs.router)
- app.include_router(sync.router)
- app.include_router(ws.router)
|