auth.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. import logging
  2. import secrets
  3. import httpx
  4. from fastapi import APIRouter, Depends, HTTPException, status
  5. from sqlalchemy import select
  6. from sqlalchemy.ext.asyncio import AsyncSession
  7. from backend.app.core.security import create_access_token, create_refresh_token, hash_password, verify_password
  8. from backend.app.core.config import settings
  9. from jose import jwt, JWTError
  10. from backend.app.db.session import get_db
  11. from backend.app.models import Role, RolePermission, User
  12. from backend.app.services.center_sync import (
  13. get_or_create_pending_role,
  14. resolve_department_by_external_id,
  15. resolve_role_by_external_ids,
  16. trigger_center_sync,
  17. upsert_sync_token,
  18. )
  19. from backend.app.schemas.auth import LoginRequest, RefreshRequest, SSOExchangeRequest, Token
  20. router = APIRouter(prefix="/auth", tags=["auth"])
  21. logger = logging.getLogger("uvicorn.error")
  22. def _parse_role_ids(claims: dict) -> list[str]:
  23. raw = claims.get("role_ids") or claims.get("roleIds") or claims.get("roleId") or claims.get("roles")
  24. if raw is None:
  25. return []
  26. if isinstance(raw, list):
  27. return [str(item).strip() for item in raw if item is not None and str(item).strip()]
  28. raw_str = str(raw).strip()
  29. if not raw_str:
  30. return []
  31. if "," in raw_str:
  32. return [item.strip() for item in raw_str.split(",") if item.strip()]
  33. return [raw_str]
  34. def _normalize_text(value) -> str | None:
  35. if value is None:
  36. return None
  37. text = str(value).strip()
  38. return text or None
  39. def _build_center_user_url() -> str:
  40. base = settings.center_base_url.rstrip("/")
  41. return f"{base}/gateway/centerSys/user/getUserByToken"
  42. async def _user_has_permission(db: AsyncSession, role_id, permission_code: str) -> bool:
  43. result = await db.execute(
  44. select(RolePermission.permission_code)
  45. .where(RolePermission.role_id == role_id, RolePermission.permission_code == permission_code)
  46. .limit(1)
  47. )
  48. return result.first() is not None
  49. @router.post("/login", response_model=Token)
  50. async def login(payload: LoginRequest, db: AsyncSession = Depends(get_db)):
  51. result = await db.execute(select(User).where(User.account == payload.account))
  52. user = result.scalar_one_or_none()
  53. if not user or not verify_password(payload.password, user.password_hash):
  54. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="账号或密码错误")
  55. return Token(
  56. access_token=create_access_token(str(user.id), user.token_version),
  57. refresh_token=create_refresh_token(str(user.id), user.token_version),
  58. )
  59. @router.post("/sso/exchange", response_model=Token)
  60. async def exchange_sso_token(payload: SSOExchangeRequest, db: AsyncSession = Depends(get_db)):
  61. if not settings.sso_check_url:
  62. raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="SSO未配置")
  63. if not settings.center_base_url:
  64. raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="中台基础地址未配置")
  65. try:
  66. async with httpx.AsyncClient(timeout=settings.sso_timeout_seconds) as client:
  67. response = await client.post(settings.sso_check_url, headers={"token": payload.token})
  68. except httpx.RequestError as exc:
  69. raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="SSO服务不可用") from exc
  70. if response.status_code != status.HTTP_200_OK:
  71. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="SSO校验失败")
  72. try:
  73. data = response.json()
  74. except ValueError as exc:
  75. raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="SSO响应异常") from exc
  76. if data.get("status") != 200:
  77. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=data.get("msg") or "SSO校验失败")
  78. try:
  79. async with httpx.AsyncClient(timeout=settings.sso_timeout_seconds) as client:
  80. user_url = _build_center_user_url()
  81. user_response = await client.get(user_url, headers={"token": payload.token})
  82. if user_response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED:
  83. user_response = await client.post(user_url, headers={"token": payload.token})
  84. except httpx.RequestError as exc:
  85. raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="中台用户服务不可用") from exc
  86. if user_response.status_code != status.HTTP_200_OK:
  87. raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="中台用户信息获取失败")
  88. try:
  89. user_payload = user_response.json()
  90. except ValueError as exc:
  91. raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="中台用户信息返回异常") from exc
  92. if user_payload.get("status") != 200:
  93. raise HTTPException(
  94. status_code=status.HTTP_401_UNAUTHORIZED,
  95. detail=user_payload.get("msg") or "中台用户信息获取失败",
  96. )
  97. user_info = user_payload.get("data") or {}
  98. if not isinstance(user_info, dict):
  99. raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="中台用户信息格式异常")
  100. external_user_id = _normalize_text(user_info.get("userId") or user_info.get("id"))
  101. if not external_user_id:
  102. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户信息缺少用户标识")
  103. account = _normalize_text(user_info.get("empNo") or user_info.get("account")) or external_user_id
  104. name = (user_info.get("name") or account).strip()
  105. phone = _normalize_text(user_info.get("phoneNumber") or user_info.get("phone"))
  106. avatar = _normalize_text(user_info.get("avatarUrl"))
  107. title = _normalize_text(
  108. user_info.get("jobTitleName") or user_info.get("positionName") or user_info.get("professionalTitleName")
  109. )
  110. external_dept_id = _normalize_text(user_info.get("departmentId") or user_info.get("departmentCode"))
  111. role_id = _normalize_text(user_info.get("roleId") or user_info.get("roleCode"))
  112. tenant_id = _normalize_text(user_info.get("tenant_id") or user_info.get("tenantId"))
  113. tenant_name = _normalize_text(user_info.get("tenant_name") or user_info.get("tenantName"))
  114. hosp_id = _normalize_text(user_info.get("hospId") or user_info.get("hosp_id"))
  115. hosp_name = _normalize_text(user_info.get("hospName") or user_info.get("hosp_name"))
  116. resolved_role = None
  117. if role_id:
  118. resolved_role = await resolve_role_by_external_ids(db, [role_id])
  119. if not resolved_role and settings.seed_admin_account and account == settings.seed_admin_account:
  120. role_result = await db.execute(select(Role).where(Role.name == "管理员"))
  121. resolved_role = role_result.scalar_one_or_none()
  122. if not resolved_role:
  123. raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="管理员角色未配置")
  124. pending_role = await get_or_create_pending_role(db)
  125. dept = await resolve_department_by_external_id(db, str(external_dept_id) if external_dept_id is not None else None)
  126. result = await db.execute(select(User).where(User.external_user_id == str(external_user_id)))
  127. user = result.scalar_one_or_none()
  128. if not user and account:
  129. result = await db.execute(select(User).where(User.account == account))
  130. user = result.scalar_one_or_none()
  131. if not user:
  132. role_id = resolved_role.id if resolved_role else pending_role.id
  133. user = User(
  134. name=name,
  135. account=account,
  136. phone=phone,
  137. title=title,
  138. avatar=avatar,
  139. role_id=role_id,
  140. status="active",
  141. password_hash=hash_password(secrets.token_urlsafe(24)),
  142. token_version=1,
  143. external_user_id=str(external_user_id),
  144. tenant_id=str(tenant_id) if tenant_id is not None else None,
  145. tenant_name=tenant_name,
  146. hosp_id=str(hosp_id) if hosp_id is not None else None,
  147. hosp_name=hosp_name,
  148. dept_id=dept.id if dept else None,
  149. campus_id=dept.campus_id if dept else None
  150. )
  151. db.add(user)
  152. else:
  153. user.name = name
  154. user.phone = phone
  155. user.title = title
  156. user.avatar = avatar
  157. user.external_user_id = str(external_user_id)
  158. user.tenant_id = str(tenant_id) if tenant_id is not None else user.tenant_id
  159. user.tenant_name = tenant_name or user.tenant_name
  160. user.hosp_id = str(hosp_id) if hosp_id is not None else user.hosp_id
  161. user.hosp_name = hosp_name or user.hosp_name
  162. if dept:
  163. user.dept_id = dept.id
  164. user.campus_id = dept.campus_id
  165. if resolved_role and (account == settings.seed_admin_account or user.role_id == pending_role.id):
  166. user.role_id = resolved_role.id
  167. await upsert_sync_token(db, payload.token)
  168. await db.commit()
  169. await db.refresh(user)
  170. try:
  171. if await _user_has_permission(db, user.role_id, "users.view"):
  172. await trigger_center_sync(payload.token, restart=False, wait=False)
  173. except Exception:
  174. logger.exception("登录后触发中台同步失败")
  175. return Token(
  176. access_token=create_access_token(str(user.id), user.token_version),
  177. refresh_token=create_refresh_token(str(user.id), user.token_version)
  178. )
  179. @router.post("/refresh", response_model=Token)
  180. async def refresh_token(payload: RefreshRequest, db: AsyncSession = Depends(get_db)):
  181. try:
  182. data = jwt.decode(payload.refresh_token, settings.jwt_secret, algorithms=[settings.jwt_algorithm])
  183. if data.get("type") != "refresh":
  184. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效刷新令牌")
  185. user_id = data.get("sub")
  186. token_version = data.get("ver")
  187. if token_version is None:
  188. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效刷新令牌")
  189. except JWTError as exc:
  190. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效刷新令牌") from exc
  191. result = await db.execute(select(User).where(User.id == user_id))
  192. user = result.scalar_one_or_none()
  193. if not user:
  194. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在")
  195. if user.token_version != token_version:
  196. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="刷新令牌已失效")
  197. return Token(
  198. access_token=create_access_token(str(user.id), user.token_version),
  199. refresh_token=create_refresh_token(str(user.id), user.token_version)
  200. )