from typing import Optional from uuid import UUID from datetime import time from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy import and_, extract, func, or_, select from sqlalchemy.ext.asyncio import AsyncSession from backend.app.core.dependencies import require_any_permissions, require_permissions from backend.app.db.session import get_db from backend.app.models import Shift from backend.app.schemas.pagination import PaginatedResponse from backend.app.schemas.shift import ShiftCreate, ShiftResponse, ShiftUpdate router = APIRouter(prefix="/shifts", tags=["shifts"]) @router.get( "", response_model=list[ShiftResponse] | PaginatedResponse[ShiftResponse], dependencies=[Depends(require_any_permissions(["shifts.view", "schedule.view"]))], ) async def list_shifts( db: AsyncSession = Depends(get_db), keyword: Optional[str] = Query(default=None), status: Optional[str] = Query(default=None), shift_type: Optional[str] = Query(default=None, alias="type"), page: Optional[int] = Query(default=None, ge=1), page_size: Optional[int] = Query(default=None, ge=1, le=200, alias="pageSize"), size: Optional[int] = Query(default=None, ge=1, le=200), ): query = select(Shift) if keyword: like = f"%{keyword.strip()}%" query = query.where(or_(Shift.name.ilike(like), Shift.remark.ilike(like))) if status == "enabled": query = query.where(Shift.enabled.is_(True)) elif status == "disabled": query = query.where(Shift.enabled.is_(False)) if shift_type == "cross": query = query.where(Shift.end_time < Shift.start_time) elif shift_type == "night": query = query.where( or_( Shift.end_time < Shift.start_time, extract("hour", Shift.start_time) >= 18, extract("hour", Shift.start_time) < 6, ) ) elif shift_type == "day": query = query.where( and_( Shift.end_time >= Shift.start_time, extract("hour", Shift.start_time) >= 6, extract("hour", Shift.start_time) < 18, ) ) if page is None and page_size is None and size is None: result = await db.execute(query.order_by(Shift.name)) return result.scalars().all() page_value = page or 1 size_value = page_size or size or 10 total = await db.scalar(select(func.count()).select_from(query.subquery())) result = await db.execute( query.order_by(Shift.name) .offset((page_value - 1) * size_value) .limit(size_value) ) return { "items": result.scalars().all(), "total": total or 0, "page": page_value, "pageSize": size_value, } @router.post("", response_model=ShiftResponse, dependencies=[Depends(require_permissions(["shifts.create"]))]) async def create_shift(payload: ShiftCreate, db: AsyncSession = Depends(get_db)): start_time = time.fromisoformat(payload.start_time) end_time = time.fromisoformat(payload.end_time) shift = Shift( name=payload.name, start_time=start_time, end_time=end_time, enabled=payload.enabled, remark=payload.remark ) db.add(shift) await db.commit() await db.refresh(shift) return shift @router.put("/{shift_id}", response_model=ShiftResponse, dependencies=[Depends(require_permissions(["shifts.edit"]))]) async def update_shift(shift_id: UUID, payload: ShiftUpdate, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Shift).where(Shift.id == shift_id)) shift = result.scalar_one_or_none() if not shift: raise HTTPException(status_code=404, detail="班次不存在") update_data = payload.model_dump(exclude_unset=True) if "start_time" in update_data and update_data["start_time"]: update_data["start_time"] = time.fromisoformat(update_data["start_time"]) if "end_time" in update_data and update_data["end_time"]: update_data["end_time"] = time.fromisoformat(update_data["end_time"]) for key, value in update_data.items(): setattr(shift, key, value) await db.commit() await db.refresh(shift) return shift @router.patch("/{shift_id}/toggle", response_model=ShiftResponse, dependencies=[Depends(require_permissions(["shifts.edit"]))]) async def toggle_shift(shift_id: UUID, enabled: bool, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Shift).where(Shift.id == shift_id)) shift = result.scalar_one_or_none() if not shift: raise HTTPException(status_code=404, detail="班次不存在") shift.enabled = enabled await db.commit() await db.refresh(shift) return shift