| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 |
- from fastapi import Depends, HTTPException, status
- from fastapi.security import OAuth2PasswordBearer
- from jose import JWTError, jwt
- from sqlalchemy.ext.asyncio import AsyncSession
- from sqlalchemy import select
- from backend.app.core.config import settings
- from backend.app.db.session import get_db
- from backend.app.models import User, RolePermission
- from backend.app.schemas.auth import TokenPayload
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
- async def get_current_user_by_token(token: str, db: AsyncSession) -> User:
- try:
- payload = jwt.decode(token, settings.jwt_secret, algorithms=[settings.jwt_algorithm])
- token_data = TokenPayload(**payload)
- if token_data.type != "access":
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效凭证")
- if token_data.ver is None:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效凭证")
- except JWTError as exc:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效凭证") from exc
- result = await db.execute(select(User).where(User.id == token_data.sub))
- user = result.scalar_one_or_none()
- if not user:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在")
- if user.token_version != token_data.ver:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="凭证已失效")
- return user
- async def get_current_user(
- token: str = Depends(oauth2_scheme),
- db: AsyncSession = Depends(get_db),
- ) -> User:
- return await get_current_user_by_token(token, db)
- def require_permissions(required: list[str]):
- async def checker(
- current_user: User = Depends(get_current_user),
- db: AsyncSession = Depends(get_db),
- ) -> User:
- if not required:
- return current_user
- result = await db.execute(
- select(RolePermission.permission_code).where(RolePermission.role_id == current_user.role_id)
- )
- user_permissions = {row[0] for row in result.fetchall()}
- missing = [code for code in required if code not in user_permissions]
- if missing:
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权限操作")
- return current_user
- return checker
- def require_any_permissions(required: list[str]):
- async def checker(
- current_user: User = Depends(get_current_user),
- db: AsyncSession = Depends(get_db),
- ) -> User:
- if not required:
- return current_user
- result = await db.execute(
- select(RolePermission.permission_code).where(RolePermission.role_id == current_user.role_id)
- )
- user_permissions = {row[0] for row in result.fetchall()}
- if not user_permissions.intersection(required):
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权限操作")
- return current_user
- return checker
|