shifts.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. from typing import Optional
  2. from uuid import UUID
  3. from datetime import time
  4. from fastapi import APIRouter, Depends, HTTPException, Query
  5. from sqlalchemy import and_, extract, func, or_, select
  6. from sqlalchemy.ext.asyncio import AsyncSession
  7. from backend.app.core.dependencies import require_any_permissions, require_permissions
  8. from backend.app.db.session import get_db
  9. from backend.app.models import Shift
  10. from backend.app.schemas.pagination import PaginatedResponse
  11. from backend.app.schemas.shift import ShiftCreate, ShiftResponse, ShiftUpdate
  12. router = APIRouter(prefix="/shifts", tags=["shifts"])
  13. @router.get(
  14. "",
  15. response_model=list[ShiftResponse] | PaginatedResponse[ShiftResponse],
  16. dependencies=[Depends(require_any_permissions(["shifts.view", "schedule.view"]))],
  17. )
  18. async def list_shifts(
  19. db: AsyncSession = Depends(get_db),
  20. keyword: Optional[str] = Query(default=None),
  21. status: Optional[str] = Query(default=None),
  22. shift_type: Optional[str] = Query(default=None, alias="type"),
  23. page: Optional[int] = Query(default=None, ge=1),
  24. page_size: Optional[int] = Query(default=None, ge=1, le=200, alias="pageSize"),
  25. size: Optional[int] = Query(default=None, ge=1, le=200),
  26. ):
  27. query = select(Shift)
  28. if keyword:
  29. like = f"%{keyword.strip()}%"
  30. query = query.where(or_(Shift.name.ilike(like), Shift.remark.ilike(like)))
  31. if status == "enabled":
  32. query = query.where(Shift.enabled.is_(True))
  33. elif status == "disabled":
  34. query = query.where(Shift.enabled.is_(False))
  35. if shift_type == "cross":
  36. query = query.where(Shift.end_time < Shift.start_time)
  37. elif shift_type == "night":
  38. query = query.where(
  39. or_(
  40. Shift.end_time < Shift.start_time,
  41. extract("hour", Shift.start_time) >= 18,
  42. extract("hour", Shift.start_time) < 6,
  43. )
  44. )
  45. elif shift_type == "day":
  46. query = query.where(
  47. and_(
  48. Shift.end_time >= Shift.start_time,
  49. extract("hour", Shift.start_time) >= 6,
  50. extract("hour", Shift.start_time) < 18,
  51. )
  52. )
  53. if page is None and page_size is None and size is None:
  54. result = await db.execute(query.order_by(Shift.name))
  55. return result.scalars().all()
  56. page_value = page or 1
  57. size_value = page_size or size or 10
  58. total = await db.scalar(select(func.count()).select_from(query.subquery()))
  59. result = await db.execute(
  60. query.order_by(Shift.name)
  61. .offset((page_value - 1) * size_value)
  62. .limit(size_value)
  63. )
  64. return {
  65. "items": result.scalars().all(),
  66. "total": total or 0,
  67. "page": page_value,
  68. "pageSize": size_value,
  69. }
  70. @router.post("", response_model=ShiftResponse, dependencies=[Depends(require_permissions(["shifts.create"]))])
  71. async def create_shift(payload: ShiftCreate, db: AsyncSession = Depends(get_db)):
  72. start_time = time.fromisoformat(payload.start_time)
  73. end_time = time.fromisoformat(payload.end_time)
  74. shift = Shift(
  75. name=payload.name,
  76. start_time=start_time,
  77. end_time=end_time,
  78. enabled=payload.enabled,
  79. remark=payload.remark
  80. )
  81. db.add(shift)
  82. await db.commit()
  83. await db.refresh(shift)
  84. return shift
  85. @router.put("/{shift_id}", response_model=ShiftResponse, dependencies=[Depends(require_permissions(["shifts.edit"]))])
  86. async def update_shift(shift_id: UUID, payload: ShiftUpdate, db: AsyncSession = Depends(get_db)):
  87. result = await db.execute(select(Shift).where(Shift.id == shift_id))
  88. shift = result.scalar_one_or_none()
  89. if not shift:
  90. raise HTTPException(status_code=404, detail="班次不存在")
  91. update_data = payload.model_dump(exclude_unset=True)
  92. if "start_time" in update_data and update_data["start_time"]:
  93. update_data["start_time"] = time.fromisoformat(update_data["start_time"])
  94. if "end_time" in update_data and update_data["end_time"]:
  95. update_data["end_time"] = time.fromisoformat(update_data["end_time"])
  96. for key, value in update_data.items():
  97. setattr(shift, key, value)
  98. await db.commit()
  99. await db.refresh(shift)
  100. return shift
  101. @router.patch("/{shift_id}/toggle", response_model=ShiftResponse, dependencies=[Depends(require_permissions(["shifts.edit"]))])
  102. async def toggle_shift(shift_id: UUID, enabled: bool, db: AsyncSession = Depends(get_db)):
  103. result = await db.execute(select(Shift).where(Shift.id == shift_id))
  104. shift = result.scalar_one_or_none()
  105. if not shift:
  106. raise HTTPException(status_code=404, detail="班次不存在")
  107. shift.enabled = enabled
  108. await db.commit()
  109. await db.refresh(shift)
  110. return shift