main.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import asyncio
  2. import json
  3. import logging
  4. import time
  5. from uuid import uuid4
  6. from contextlib import suppress
  7. from fastapi import FastAPI, HTTPException, Request
  8. from fastapi.exceptions import RequestValidationError
  9. from fastapi.middleware.cors import CORSMiddleware
  10. from fastapi.responses import JSONResponse
  11. from starlette.responses import Response
  12. from sqlalchemy import text
  13. from sqlalchemy.exc import IntegrityError
  14. from backend.app.core.config import settings
  15. from backend.app.core.response import build_envelope
  16. from backend.app.db.session import engine, SessionLocal
  17. from backend.app.db.base import Base
  18. from backend.app import models # noqa: F401
  19. from backend.app.services.seed import seed_data
  20. from backend.app.services.center_sync import run_center_sync_loop
  21. from backend.app.routers import auth, users, roles, campuses, departments, shifts, schedule, duty, statistics, adjust_logs, sync, ws
  22. app = FastAPI(title=settings.app_name)
  23. app.add_middleware(
  24. CORSMiddleware,
  25. allow_origins=[
  26. "http://localhost:5173",
  27. "http://127.0.0.1:5173"
  28. ],
  29. allow_credentials=True,
  30. allow_methods=["*"],
  31. allow_headers=["*"],
  32. )
  33. logger = logging.getLogger("uvicorn.error")
  34. sync_stop_event = asyncio.Event()
  35. sync_task: asyncio.Task | None = None
  36. def _clone_headers(response, trace_id: str) -> dict:
  37. headers = dict(response.headers)
  38. headers.pop("content-length", None)
  39. headers["X-Trace-Id"] = trace_id
  40. return headers
  41. @app.middleware("http")
  42. async def add_trace_and_wrap(request: Request, call_next):
  43. trace_id = request.headers.get("x-trace-id") or request.headers.get("traceId") or uuid4().hex
  44. request.state.trace_id = trace_id
  45. response = await call_next(request)
  46. content_type = response.headers.get("content-type", "")
  47. if response.status_code >= 400 or "application/json" not in content_type:
  48. response.headers["X-Trace-Id"] = trace_id
  49. return response
  50. body = b""
  51. async for chunk in response.body_iterator:
  52. body += chunk
  53. headers = _clone_headers(response, trace_id)
  54. if not body:
  55. return JSONResponse(build_envelope(None, trace_id), status_code=response.status_code, headers=headers)
  56. try:
  57. data = json.loads(body)
  58. except json.JSONDecodeError:
  59. return Response(
  60. content=body,
  61. status_code=response.status_code,
  62. headers=headers,
  63. media_type=response.media_type,
  64. )
  65. if isinstance(data, dict) and {"code", "message", "data"}.issubset(data.keys()):
  66. data.setdefault("traceId", trace_id)
  67. data.setdefault("ts", int(time.time() * 1000))
  68. return JSONResponse(data, status_code=response.status_code, headers=headers)
  69. return JSONResponse(build_envelope(data, trace_id), status_code=response.status_code, headers=headers)
  70. def get_trace_id(request: Request) -> str:
  71. trace_id = getattr(request.state, "trace_id", None)
  72. return trace_id or uuid4().hex
  73. @app.exception_handler(RequestValidationError)
  74. async def validation_exception_handler(request: Request, exc: RequestValidationError):
  75. field_errors = []
  76. for err in exc.errors():
  77. loc = [str(item) for item in err.get("loc", []) if item not in ("body", "query", "path")]
  78. field = ".".join(loc) if loc else ".".join([str(item) for item in err.get("loc", [])])
  79. field_errors.append({"field": field, "reason": err.get("msg")})
  80. trace_id = get_trace_id(request)
  81. payload = build_envelope(
  82. None,
  83. trace_id,
  84. code="VALIDATION_ERROR",
  85. message="参数校验失败",
  86. details={"fieldErrors": field_errors},
  87. )
  88. return JSONResponse(payload, status_code=400)
  89. @app.exception_handler(HTTPException)
  90. async def http_exception_handler(request: Request, exc: HTTPException):
  91. code_map = {
  92. 400: "VALIDATION_ERROR",
  93. 401: "UNAUTHORIZED",
  94. 403: "FORBIDDEN",
  95. 404: "NOT_FOUND",
  96. 409: "CONFLICT",
  97. 429: "RATE_LIMITED",
  98. }
  99. trace_id = get_trace_id(request)
  100. code = code_map.get(exc.status_code, "INTERNAL_ERROR")
  101. message = exc.detail if isinstance(exc.detail, str) else "请求失败"
  102. payload = build_envelope(None, trace_id, code=code, message=message)
  103. return JSONResponse(payload, status_code=exc.status_code)
  104. @app.exception_handler(IntegrityError)
  105. async def integrity_exception_handler(request: Request, exc: IntegrityError):
  106. trace_id = get_trace_id(request)
  107. logger.exception("IntegrityError traceId=%s", trace_id)
  108. payload = build_envelope(None, trace_id, code="CONFLICT", message="数据冲突")
  109. return JSONResponse(payload, status_code=409)
  110. @app.exception_handler(Exception)
  111. async def general_exception_handler(request: Request, exc: Exception):
  112. trace_id = get_trace_id(request)
  113. logger.exception("Unhandled error traceId=%s", trace_id)
  114. payload = build_envelope(None, trace_id, code="INTERNAL_ERROR", message="系统异常,请稍后再试")
  115. return JSONResponse(payload, status_code=500)
  116. @app.on_event("startup")
  117. async def on_startup():
  118. global sync_task
  119. async with engine.begin() as conn:
  120. # 依赖 pgcrypto 扩展提供 gen_random_uuid()
  121. # 若数据库账号无权限创建扩展,会在这里直接失败并提示,避免后续插入时报错更隐蔽
  122. await conn.execute(text('CREATE EXTENSION IF NOT EXISTS "pgcrypto";'))
  123. await conn.run_sync(Base.metadata.create_all)
  124. async with SessionLocal() as session:
  125. await seed_data(session)
  126. sync_task = asyncio.create_task(run_center_sync_loop(sync_stop_event))
  127. @app.on_event("shutdown")
  128. async def on_shutdown():
  129. sync_stop_event.set()
  130. if sync_task:
  131. sync_task.cancel()
  132. with suppress(asyncio.CancelledError):
  133. await sync_task
  134. @app.get("/health")
  135. async def health():
  136. return {"status": "ok"}
  137. app.include_router(auth.router)
  138. app.include_router(users.router)
  139. app.include_router(roles.router)
  140. app.include_router(campuses.router)
  141. app.include_router(departments.router)
  142. app.include_router(shifts.router)
  143. app.include_router(schedule.router)
  144. app.include_router(duty.router)
  145. app.include_router(statistics.router)
  146. app.include_router(adjust_logs.router)
  147. app.include_router(sync.router)
  148. app.include_router(ws.router)