dependencies.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from fastapi import Depends, HTTPException, status
  2. from fastapi.security import OAuth2PasswordBearer
  3. from jose import JWTError, jwt
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from sqlalchemy import select
  6. from backend.app.core.config import settings
  7. from backend.app.db.session import get_db
  8. from backend.app.models import User, RolePermission
  9. from backend.app.schemas.auth import TokenPayload
  10. oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
  11. async def get_current_user_by_token(token: str, db: AsyncSession) -> User:
  12. try:
  13. payload = jwt.decode(token, settings.jwt_secret, algorithms=[settings.jwt_algorithm])
  14. token_data = TokenPayload(**payload)
  15. if token_data.type != "access":
  16. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效凭证")
  17. if token_data.ver is None:
  18. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效凭证")
  19. except JWTError as exc:
  20. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效凭证") from exc
  21. result = await db.execute(select(User).where(User.id == token_data.sub))
  22. user = result.scalar_one_or_none()
  23. if not user:
  24. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在")
  25. if user.token_version != token_data.ver:
  26. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="凭证已失效")
  27. return user
  28. async def get_current_user(
  29. token: str = Depends(oauth2_scheme),
  30. db: AsyncSession = Depends(get_db),
  31. ) -> User:
  32. return await get_current_user_by_token(token, db)
  33. def require_permissions(required: list[str]):
  34. async def checker(
  35. current_user: User = Depends(get_current_user),
  36. db: AsyncSession = Depends(get_db),
  37. ) -> User:
  38. if not required:
  39. return current_user
  40. result = await db.execute(
  41. select(RolePermission.permission_code).where(RolePermission.role_id == current_user.role_id)
  42. )
  43. user_permissions = {row[0] for row in result.fetchall()}
  44. missing = [code for code in required if code not in user_permissions]
  45. if missing:
  46. raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权限操作")
  47. return current_user
  48. return checker
  49. def require_any_permissions(required: list[str]):
  50. async def checker(
  51. current_user: User = Depends(get_current_user),
  52. db: AsyncSession = Depends(get_db),
  53. ) -> User:
  54. if not required:
  55. return current_user
  56. result = await db.execute(
  57. select(RolePermission.permission_code).where(RolePermission.role_id == current_user.role_id)
  58. )
  59. user_permissions = {row[0] for row in result.fetchall()}
  60. if not user_permissions.intersection(required):
  61. raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权限操作")
  62. return current_user
  63. return checker