from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from jose import JWTError, jwt from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from backend.app.core.config import settings from backend.app.db.session import get_db from backend.app.models import User, RolePermission from backend.app.schemas.auth import TokenPayload oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login") async def get_current_user_by_token(token: str, db: AsyncSession) -> User: try: payload = jwt.decode(token, settings.jwt_secret, algorithms=[settings.jwt_algorithm]) token_data = TokenPayload(**payload) if token_data.type != "access": raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效凭证") if token_data.ver is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效凭证") except JWTError as exc: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效凭证") from exc result = await db.execute(select(User).where(User.id == token_data.sub)) user = result.scalar_one_or_none() if not user: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在") if user.token_version != token_data.ver: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="凭证已失效") return user async def get_current_user( token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(get_db), ) -> User: return await get_current_user_by_token(token, db) def require_permissions(required: list[str]): async def checker( current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ) -> User: if not required: return current_user result = await db.execute( select(RolePermission.permission_code).where(RolePermission.role_id == current_user.role_id) ) user_permissions = {row[0] for row in result.fetchall()} missing = [code for code in required if code not in user_permissions] if missing: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权限操作") return current_user return checker def require_any_permissions(required: list[str]): async def checker( current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ) -> User: if not required: return current_user result = await db.execute( select(RolePermission.permission_code).where(RolePermission.role_id == current_user.role_id) ) user_permissions = {row[0] for row in result.fetchall()} if not user_permissions.intersection(required): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权限操作") return current_user return checker