| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232 |
- import logging
- import secrets
- import httpx
- from fastapi import APIRouter, Depends, HTTPException, status
- from sqlalchemy import select
- from sqlalchemy.ext.asyncio import AsyncSession
- from backend.app.core.security import create_access_token, create_refresh_token, hash_password, verify_password
- from backend.app.core.config import settings
- from jose import jwt, JWTError
- from backend.app.db.session import get_db
- from backend.app.models import Role, RolePermission, User
- from backend.app.services.center_sync import (
- get_or_create_pending_role,
- resolve_department_by_external_id,
- resolve_role_by_external_ids,
- trigger_center_sync,
- upsert_sync_token,
- )
- from backend.app.schemas.auth import LoginRequest, RefreshRequest, SSOExchangeRequest, Token
- router = APIRouter(prefix="/auth", tags=["auth"])
- logger = logging.getLogger("uvicorn.error")
- def _parse_role_ids(claims: dict) -> list[str]:
- raw = claims.get("role_ids") or claims.get("roleIds") or claims.get("roleId") or claims.get("roles")
- if raw is None:
- return []
- if isinstance(raw, list):
- return [str(item).strip() for item in raw if item is not None and str(item).strip()]
- raw_str = str(raw).strip()
- if not raw_str:
- return []
- if "," in raw_str:
- return [item.strip() for item in raw_str.split(",") if item.strip()]
- return [raw_str]
- def _normalize_text(value) -> str | None:
- if value is None:
- return None
- text = str(value).strip()
- return text or None
- def _build_center_user_url() -> str:
- base = settings.center_base_url.rstrip("/")
- return f"{base}/gateway/centerSys/user/getUserByToken"
- async def _user_has_permission(db: AsyncSession, role_id, permission_code: str) -> bool:
- result = await db.execute(
- select(RolePermission.permission_code)
- .where(RolePermission.role_id == role_id, RolePermission.permission_code == permission_code)
- .limit(1)
- )
- return result.first() is not None
- @router.post("/login", response_model=Token)
- async def login(payload: LoginRequest, db: AsyncSession = Depends(get_db)):
- result = await db.execute(select(User).where(User.account == payload.account))
- user = result.scalar_one_or_none()
- if not user or not verify_password(payload.password, user.password_hash):
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="账号或密码错误")
- return Token(
- access_token=create_access_token(str(user.id), user.token_version),
- refresh_token=create_refresh_token(str(user.id), user.token_version),
- )
- @router.post("/sso/exchange", response_model=Token)
- async def exchange_sso_token(payload: SSOExchangeRequest, db: AsyncSession = Depends(get_db)):
- if not settings.sso_check_url:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="SSO未配置")
- if not settings.center_base_url:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="中台基础地址未配置")
- try:
- async with httpx.AsyncClient(timeout=settings.sso_timeout_seconds) as client:
- response = await client.post(settings.sso_check_url, headers={"token": payload.token})
- except httpx.RequestError as exc:
- raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="SSO服务不可用") from exc
- if response.status_code != status.HTTP_200_OK:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="SSO校验失败")
- try:
- data = response.json()
- except ValueError as exc:
- raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="SSO响应异常") from exc
- if data.get("status") != 200:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=data.get("msg") or "SSO校验失败")
- try:
- async with httpx.AsyncClient(timeout=settings.sso_timeout_seconds) as client:
- user_url = _build_center_user_url()
- user_response = await client.get(user_url, headers={"token": payload.token})
- if user_response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED:
- user_response = await client.post(user_url, headers={"token": payload.token})
- except httpx.RequestError as exc:
- raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="中台用户服务不可用") from exc
- if user_response.status_code != status.HTTP_200_OK:
- raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="中台用户信息获取失败")
- try:
- user_payload = user_response.json()
- except ValueError as exc:
- raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="中台用户信息返回异常") from exc
- if user_payload.get("status") != 200:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail=user_payload.get("msg") or "中台用户信息获取失败",
- )
- user_info = user_payload.get("data") or {}
- if not isinstance(user_info, dict):
- raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="中台用户信息格式异常")
- external_user_id = _normalize_text(user_info.get("userId") or user_info.get("id"))
- if not external_user_id:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="用户信息缺少用户标识")
- account = _normalize_text(user_info.get("empNo") or user_info.get("account")) or external_user_id
- name = (user_info.get("name") or account).strip()
- phone = _normalize_text(user_info.get("phoneNumber") or user_info.get("phone"))
- avatar = _normalize_text(user_info.get("avatarUrl"))
- title = _normalize_text(
- user_info.get("jobTitleName") or user_info.get("positionName") or user_info.get("professionalTitleName")
- )
- external_dept_id = _normalize_text(user_info.get("departmentId") or user_info.get("departmentCode"))
- role_id = _normalize_text(user_info.get("roleId") or user_info.get("roleCode"))
- tenant_id = _normalize_text(user_info.get("tenant_id") or user_info.get("tenantId"))
- tenant_name = _normalize_text(user_info.get("tenant_name") or user_info.get("tenantName"))
- hosp_id = _normalize_text(user_info.get("hospId") or user_info.get("hosp_id"))
- hosp_name = _normalize_text(user_info.get("hospName") or user_info.get("hosp_name"))
- resolved_role = None
- if role_id:
- resolved_role = await resolve_role_by_external_ids(db, [role_id])
- if not resolved_role and settings.seed_admin_account and account == settings.seed_admin_account:
- role_result = await db.execute(select(Role).where(Role.name == "管理员"))
- resolved_role = role_result.scalar_one_or_none()
- if not resolved_role:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="管理员角色未配置")
- pending_role = await get_or_create_pending_role(db)
- dept = await resolve_department_by_external_id(db, str(external_dept_id) if external_dept_id is not None else None)
- result = await db.execute(select(User).where(User.external_user_id == str(external_user_id)))
- user = result.scalar_one_or_none()
- if not user and account:
- result = await db.execute(select(User).where(User.account == account))
- user = result.scalar_one_or_none()
- if not user:
- role_id = resolved_role.id if resolved_role else pending_role.id
- user = User(
- name=name,
- account=account,
- phone=phone,
- title=title,
- avatar=avatar,
- role_id=role_id,
- status="active",
- password_hash=hash_password(secrets.token_urlsafe(24)),
- token_version=1,
- external_user_id=str(external_user_id),
- tenant_id=str(tenant_id) if tenant_id is not None else None,
- tenant_name=tenant_name,
- hosp_id=str(hosp_id) if hosp_id is not None else None,
- hosp_name=hosp_name,
- dept_id=dept.id if dept else None,
- campus_id=dept.campus_id if dept else None
- )
- db.add(user)
- else:
- user.name = name
- user.phone = phone
- user.title = title
- user.avatar = avatar
- user.external_user_id = str(external_user_id)
- user.tenant_id = str(tenant_id) if tenant_id is not None else user.tenant_id
- user.tenant_name = tenant_name or user.tenant_name
- user.hosp_id = str(hosp_id) if hosp_id is not None else user.hosp_id
- user.hosp_name = hosp_name or user.hosp_name
- if dept:
- user.dept_id = dept.id
- user.campus_id = dept.campus_id
- if resolved_role and (account == settings.seed_admin_account or user.role_id == pending_role.id):
- user.role_id = resolved_role.id
- await upsert_sync_token(db, payload.token)
- await db.commit()
- await db.refresh(user)
- try:
- if await _user_has_permission(db, user.role_id, "users.view"):
- await trigger_center_sync(payload.token, restart=False, wait=False)
- except Exception:
- logger.exception("登录后触发中台同步失败")
- return Token(
- access_token=create_access_token(str(user.id), user.token_version),
- refresh_token=create_refresh_token(str(user.id), user.token_version)
- )
- @router.post("/refresh", response_model=Token)
- async def refresh_token(payload: RefreshRequest, db: AsyncSession = Depends(get_db)):
- try:
- data = jwt.decode(payload.refresh_token, settings.jwt_secret, algorithms=[settings.jwt_algorithm])
- if data.get("type") != "refresh":
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效刷新令牌")
- user_id = data.get("sub")
- token_version = data.get("ver")
- if token_version is None:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效刷新令牌")
- except JWTError as exc:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效刷新令牌") from exc
- result = await db.execute(select(User).where(User.id == user_id))
- user = result.scalar_one_or_none()
- if not user:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在")
- if user.token_version != token_version:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="刷新令牌已失效")
- return Token(
- access_token=create_access_token(str(user.id), user.token_version),
- refresh_token=create_refresh_token(str(user.id), user.token_version)
- )
|