diff --git a/app/api.py b/app/api.py index de6d6a2..88d4a51 100644 --- a/app/api.py +++ b/app/api.py @@ -1,4 +1,5 @@ import hashlib +from collections import defaultdict from datetime import datetime from enum import StrEnum, auto from pathlib import Path @@ -258,6 +259,10 @@ class Flag(FlagCreate): device_id: str = Field(..., description="Device ID of the flagger") +class FlagsByTicketIdRequest(BaseModel): + ticket_ids: list[int] + + @api_v1_router.post("/flags") def create_flag(flag: FlagCreate, request: Request): """Create a flag for a product. @@ -339,6 +344,7 @@ def get_flag(flag_id: int): def _create_ticket(ticket: TicketCreate): + """Create a ticket.""" return TicketModel.create(**ticket.model_dump()) @@ -353,14 +359,28 @@ def create_ticket(ticket: TicketCreate) -> Ticket: return _create_ticket(ticket) +def _get_ticket(status: TicketStatus | None, type_: IssueType | None): + """Get tickets with optional filters.""" + query = TicketModel.select() + + if status is not None: + query = query.where(TicketModel.status == status) + + if type_ is not None: + query = query.where(TicketModel.type == type_) + + with db: + return list(query.dicts()) + + @api_v1_router.get("/tickets") -def get_tickets(): +def get_tickets(status: TicketStatus | None = None, type_: IssueType | None = None): """Get all tickets. - This function is used to get all tickets. + This function is used to get all tickets with status open. """ with db: - return {"tickets": list(TicketModel.select().dicts().iterator())} + return _get_ticket(status, type_) @api_v1_router.get("/tickets/{ticket_id}") @@ -376,21 +396,24 @@ def get_ticket(ticket_id: int): raise HTTPException(status_code=404, detail="Not found") -@api_v1_router.get("/tickets/{ticket_id}/flags") -def get_flags_by_ticket(ticket_id: int): - """Get all flags for a ticket by ID. +@api_v1_router.post("/flags/batch") +def get_flags_by_ticket_batch(flag_request: FlagsByTicketIdRequest): + """Get all flags for tickets by IDs. - This function is used to get all flags for a ticket by its ID. + This function is used to get all flags for tickets by there IDs. """ with db: - return { - "flags": list( - FlagModel.select() - .where(FlagModel.ticket_id == ticket_id) - .dicts() - .iterator() - ) - } + flags = list( + FlagModel.select() + .where(FlagModel.ticket_id.in_(flag_request.ticket_ids)) + .dicts() + ) + + ticket_id_to_flags = defaultdict(list) + for flag in flags: + ticket_id_to_flags[flag["ticket"]].append(flag) + + return {"ticket_id_to_flags": dict(ticket_id_to_flags)} @api_v1_router.put("/tickets/{ticket_id}/status")