132 lines
4.6 KiB
Python
132 lines
4.6 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from contextlib import contextmanager
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from uuid import UUID, uuid4
|
|
|
|
import psycopg
|
|
from fastapi import HTTPException
|
|
from psycopg.rows import dict_row
|
|
|
|
from app.config import settings
|
|
|
|
MIGRATION_PATH = Path(__file__).parent / "migrations" / "001_initial.sql"
|
|
TERMINAL_STATUSES = {"completed", "failed", "cancelled"}
|
|
VALID_TRANSITIONS = {
|
|
"queued": {"dispatched", "cancelled"},
|
|
"dispatched": {"in_progress", "cancelled"},
|
|
"in_progress": {"blocked", "completed", "failed"},
|
|
"blocked": {"in_progress", "completed", "failed"},
|
|
"completed": set(),
|
|
"failed": set(),
|
|
"cancelled": set(),
|
|
}
|
|
|
|
|
|
def utcnow() -> datetime:
|
|
return datetime.now(timezone.utc).replace(microsecond=0)
|
|
|
|
|
|
@contextmanager
|
|
def get_conn():
|
|
with psycopg.connect(settings.database_url, row_factory=dict_row) as conn:
|
|
yield conn
|
|
|
|
|
|
def run_migrations() -> None:
|
|
with get_conn() as conn:
|
|
conn.execute(MIGRATION_PATH.read_text())
|
|
conn.commit()
|
|
|
|
|
|
def _parse_payload(value: Any) -> dict[str, Any] | None:
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, dict):
|
|
return value
|
|
return json.loads(value)
|
|
|
|
|
|
def _normalize_work_row(row: dict[str, Any]) -> dict[str, Any]:
|
|
row["payload"] = _parse_payload(row.get("payload"))
|
|
row.setdefault("dispatch_log", [])
|
|
return row
|
|
|
|
|
|
def fetch_project_or_404(conn: psycopg.Connection, project_id: str | UUID) -> dict[str, Any]:
|
|
row = conn.execute(
|
|
"SELECT id, name, external_ref, created_at, updated_at FROM projects WHERE id = %s",
|
|
(str(project_id),),
|
|
).fetchone()
|
|
if row is None:
|
|
raise HTTPException(status_code=404, detail="project not found")
|
|
return row
|
|
|
|
|
|
def fetch_work_or_404(conn: psycopg.Connection, work_id: str | UUID) -> dict[str, Any]:
|
|
row = conn.execute(
|
|
"""
|
|
SELECT id, project_id, type, description, payload, priority, status, assigned_agent,
|
|
created_by, created_at, updated_at, completed_at, outcome, notes
|
|
FROM work_items WHERE id = %s
|
|
""",
|
|
(str(work_id),),
|
|
).fetchone()
|
|
if row is None:
|
|
raise HTTPException(status_code=404, detail="work item not found")
|
|
return _normalize_work_row(row)
|
|
|
|
|
|
def fetch_dispatch_log(conn: psycopg.Connection, work_id: str | UUID) -> list[dict[str, Any]]:
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT id, work_item_id, dispatched_at, agent, completed_at, outcome
|
|
FROM dispatch_log WHERE work_item_id = %s ORDER BY dispatched_at ASC
|
|
""",
|
|
(str(work_id),),
|
|
).fetchall()
|
|
return rows
|
|
|
|
|
|
def validate_transition(current: str, new: str, assigned_agent: str | None, outcome: str | None, notes: str | None) -> None:
|
|
if new not in VALID_TRANSITIONS[current]:
|
|
raise HTTPException(status_code=400, detail=f"invalid status transition: {current} -> {new}")
|
|
if current == "queued" and new == "dispatched" and not assigned_agent:
|
|
raise HTTPException(status_code=400, detail="assigned_agent is required for queued -> dispatched")
|
|
if new == "blocked" and not (notes or "").strip():
|
|
raise HTTPException(status_code=400, detail="notes are required when status is blocked")
|
|
if new in TERMINAL_STATUSES and not outcome:
|
|
raise HTTPException(status_code=400, detail="outcome is required for terminal statuses")
|
|
if new == "completed" and outcome != "success":
|
|
raise HTTPException(status_code=400, detail="completed requires outcome=success")
|
|
if new == "failed" and outcome != "failed":
|
|
raise HTTPException(status_code=400, detail="failed requires outcome=failed")
|
|
if new == "cancelled" and outcome != "cancelled":
|
|
raise HTTPException(status_code=400, detail="cancelled requires outcome=cancelled")
|
|
|
|
|
|
def create_dispatch_log(conn: psycopg.Connection, work_item_id: str, agent: str, dispatched_at: datetime) -> None:
|
|
conn.execute(
|
|
"INSERT INTO dispatch_log (id, work_item_id, dispatched_at, agent) VALUES (%s, %s, %s, %s)",
|
|
(str(uuid4()), work_item_id, dispatched_at, agent),
|
|
)
|
|
|
|
|
|
def complete_dispatch_log(conn: psycopg.Connection, work_item_id: str, completed_at: datetime, outcome: str) -> None:
|
|
conn.execute(
|
|
"""
|
|
UPDATE dispatch_log
|
|
SET completed_at = %s, outcome = %s
|
|
WHERE id = (
|
|
SELECT id FROM dispatch_log
|
|
WHERE work_item_id = %s AND completed_at IS NULL
|
|
ORDER BY dispatched_at DESC
|
|
LIMIT 1
|
|
)
|
|
""",
|
|
(completed_at, outcome, work_item_id),
|
|
)
|