diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..6664b3c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,41 @@ + [project] + name = "notebook-tools" + version = "0.1.0" + description = "FastAPI service to OCR Paperless-ngx PDFs via llama.cpp" + readme = "README.md" + requires-python = ">=3.11" + dependencies = [ + "fastapi>=0.115", + "uvicorn[standard]>=0.30", + "httpx>=0.27", + "pydantic>=2.7", + "pydantic-settings>=2.3", + "pymupdf>=1.24", + "pillow>=10.4", + "img2pdf>=0.5", + "tenacity>=8.3", + ] + + [project.optional-dependencies] + test = [ + "pytest>=8.2", + "pytest-asyncio>=0.24", + "respx>=0.21", + ] + + [tool.pytest.ini_options] + asyncio_mode = "auto" + testpaths = ["tests"] + + [tool.ruff] + line-length = 100 + + [tool.ruff.lint] + select = ["E", "F", "I", "UP", "B"] + + [build-system] + requires = ["hatchling"] + build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/notebook_tools"] diff --git a/src/notebook_tools/__init__.py b/src/notebook_tools/__init__.py new file mode 100644 index 0000000..1c866ef --- /dev/null +++ b/src/notebook_tools/__init__.py @@ -0,0 +1,6 @@ +"""notebook_tools package. + +This repository is intentionally written to be easy to read/modify if you're new to Python. +Most modules include short docstrings and type hints, and we keep functions small. +""" + diff --git a/src/notebook_tools/api.py b/src/notebook_tools/api.py new file mode 100644 index 0000000..f1e9570 --- /dev/null +++ b/src/notebook_tools/api.py @@ -0,0 +1,74 @@ +"""FastAPI application entrypoint. + +This file is intentionally small: +- Routes call into a job manager. +- The job manager calls the pipeline. + +Keeping the web layer thin makes the business logic easier to test and maintain. +""" + +from __future__ import annotations + +import logging + +from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException + +from notebook_tools.jobs import JobManager, get_job_manager +from notebook_tools.logging_utils import configure_logging +from notebook_tools.models import JobStartRequest, JobStatusResponse +from notebook_tools.settings import Settings, get_settings + +app = FastAPI(title="notebook-tools", version="0.1.0") +logger = logging.getLogger("notebook_tools.api") + + +@app.on_event("startup") +async def _startup() -> None: + # Load settings once at startup so we fail fast if env vars are missing. + settings = get_settings() + configure_logging(level=settings.log_level) + logger.info("Service starting up") + + +@app.get("/health") +async def health() -> dict[str, str]: + return {"status": "ok"} + + +@app.post("/jobs/paperless/{document_id}", response_model=JobStatusResponse) +async def start_job_for_paperless_document( + document_id: int, + req: JobStartRequest, + background: BackgroundTasks, + settings: Settings = Depends(get_settings), + manager: JobManager = Depends(get_job_manager), +) -> JobStatusResponse: + """Start an OCR job for an existing Paperless document id.""" + + if document_id <= 0: + raise HTTPException(status_code=422, detail="document_id must be a positive integer") + + job = manager.create_job(document_id=document_id, notebook_id=req.notebook_id) + logger.info( + "Job created job_id=%s paperless_document_id=%s notebook_id=%s", + job.job_id, + document_id, + req.notebook_id, + ) + background.add_task( + manager.run_job, + job_id=job.job_id, + settings=settings, + ocr_prompt_override=req.ocr_prompt, + title_prefix=req.title_prefix, + ) + return job + + +@app.get("/jobs/{job_id}", response_model=JobStatusResponse) +async def get_job(job_id: str, manager: JobManager = Depends(get_job_manager)) -> JobStatusResponse: + job = manager.get_job(job_id) + if not job: + raise HTTPException(status_code=404, detail="job not found") + return job + diff --git a/src/notebook_tools/jobs.py b/src/notebook_tools/jobs.py new file mode 100644 index 0000000..994289c --- /dev/null +++ b/src/notebook_tools/jobs.py @@ -0,0 +1,122 @@ +"""In-memory job tracking. + +We store job status in memory because it's the simplest way to get started. +Trade-off: +- If the server restarts, in-progress jobs are lost. + +If you later want durability, we can swap this for SQLite without changing the API shape. +""" + +from __future__ import annotations + +import asyncio +import traceback +import uuid +from dataclasses import dataclass, field + +from notebook_tools.models import JobStatusResponse +from notebook_tools.pipeline import run_pipeline_for_paperless_document +from notebook_tools.settings import Settings + + +@dataclass +class _JobRecord: + document_id: int + notebook_id: str + status: JobStatusResponse + lock: asyncio.Lock = field(default_factory=asyncio.Lock) + + +class JobManager: + def __init__(self) -> None: + self._jobs: dict[str, _JobRecord] = {} + + def create_job(self, *, document_id: int, notebook_id: str) -> JobStatusResponse: + job_id = uuid.uuid4().hex + status = JobStatusResponse(job_id=job_id, state="queued", completed_pages=0) + self._jobs[job_id] = _JobRecord( + document_id=document_id, + notebook_id=notebook_id, + status=status, + ) + return status + + def get_job(self, job_id: str) -> JobStatusResponse | None: + rec = self._jobs.get(job_id) + return rec.status if rec else None + + async def _set_status(self, job_id: str, updater) -> None: + rec = self._jobs[job_id] + async with rec.lock: + updater(rec.status) + + async def _mark_running(self, job_id: str) -> None: + def _update(s: JobStatusResponse) -> None: + s.state = "running" + s.message = None + + await self._set_status(job_id, _update) + + async def _update_progress(self, job_id: str, *, completed: int, total: int) -> None: + def _update(s: JobStatusResponse) -> None: + s.completed_pages = completed + s.total_pages = total + + await self._set_status(job_id, _update) + + async def _mark_succeeded(self, job_id: str, *, created_document_ids: list[int]) -> None: + def _update(s: JobStatusResponse) -> None: + s.state = "succeeded" + s.created_document_ids = created_document_ids + s.message = "done" + + await self._set_status(job_id, _update) + + async def _mark_failed(self, job_id: str, *, error_message: str, traceback_text: str) -> None: + def _update(s: JobStatusResponse) -> None: + s.state = "failed" + s.message = error_message + s.errors.append(error_message) + s.debug["traceback"] = traceback_text + + await self._set_status(job_id, _update) + + async def run_job( + self, + *, + job_id: str, + settings: Settings, + ocr_prompt_override: str | None, + title_prefix: str | None, + ) -> None: + rec = self._jobs[job_id] + + await self._mark_running(job_id) + + try: + result = await run_pipeline_for_paperless_document( + settings=settings, + paperless_document_id=rec.document_id, + notebook_id=rec.notebook_id, + job_id=job_id, + on_progress=lambda completed, total: self._update_progress( + job_id, completed=completed, total=total + ), + ocr_prompt_override=ocr_prompt_override, + title_prefix=title_prefix, + ) + + await self._mark_succeeded(job_id, created_document_ids=result["created_document_ids"]) + except Exception as e: # noqa: BLE001 - we want to catch and report job failures + tb = traceback.format_exc() + await self._mark_failed(job_id, error_message=str(e), traceback_text=tb) + + +_manager = JobManager() + + +def get_job_manager() -> JobManager: + """FastAPI dependency for a singleton in-memory manager.""" + + return _manager + diff --git a/src/notebook_tools/llama_client.py b/src/notebook_tools/llama_client.py new file mode 100644 index 0000000..81d3fc8 --- /dev/null +++ b/src/notebook_tools/llama_client.py @@ -0,0 +1,101 @@ +"""llama.cpp OCR client (OpenAI-compatible chat/completions). + +Your working prototype uses: +POST /v1/chat/completions +with: +- a strict OCR instruction (text) +- an inline base64 image_url: data:image/jpeg;base64,... + +We wrap that here so the rest of the app can simply call `ocr_jpeg(...)`. +""" + +from __future__ import annotations + +import base64 +import logging +from typing import Any + +import httpx +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential_jitter + +logger = logging.getLogger("notebook_tools.llama") + +class LlamaError(RuntimeError): + """Raised for non-2xx responses or unexpected payloads from llama.""" + + +DEFAULT_OCR_PROMPT = ( + "You are a highly accurate OCR system. Extract all handwritten text from this image exactly as it appears. " + "Preserve line breaks. Do not add any commentary. Output ONLY the transcribed text." +) + + +def _raise_for_status(resp: httpx.Response) -> None: + if 200 <= resp.status_code < 300: + return + raise LlamaError(f"llama API {resp.status_code}: {resp.text}") + + +class LlamaClient: + def __init__( + self, + *, + base_url: str, + model: str, + timeout_s: float = 120.0, + temperature: float = 0.0, + max_tokens: int = 1024, + ) -> None: + self._base_url = base_url.rstrip("/") + self._model = model + self._timeout = httpx.Timeout(timeout_s) + self._temperature = temperature + self._max_tokens = max_tokens + + def _url(self, path: str) -> str: + return f"{self._base_url}{path}" + + @retry( + retry=retry_if_exception_type((httpx.TimeoutException, httpx.NetworkError)), + wait=wait_exponential_jitter(initial=0.5, max=8.0), + stop=stop_after_attempt(3), + reraise=True, + ) + async def ocr_jpeg(self, *, jpeg_bytes: bytes, prompt: str | None = None) -> str: + """Return extracted text for a single JPEG image.""" + + prompt_text = prompt or DEFAULT_OCR_PROMPT + logger.info("OCR request image_bytes=%s max_tokens=%s temperature=%s", len(jpeg_bytes), self._max_tokens, self._temperature) + + b64 = base64.b64encode(jpeg_bytes).decode("utf-8") + payload: dict[str, Any] = { + "model": self._model, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt_text}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64}"}}, + ], + } + ], + "temperature": self._temperature, + "max_tokens": self._max_tokens, + } + + async with httpx.AsyncClient(timeout=self._timeout) as client: + resp = await client.post(self._url("/v1/chat/completions"), json=payload) + _raise_for_status(resp) + data = resp.json() + + try: + text = data["choices"][0]["message"]["content"] + except Exception as e: # noqa: BLE001 + raise LlamaError(f"Unexpected llama response shape: {data}") from e + + if not isinstance(text, str): + raise LlamaError(f"Expected OCR text to be a string, got: {type(text)}") + + # We intentionally do NOT log the OCR text itself by default because it can include sensitive content. + return text.strip() + diff --git a/src/notebook_tools/logging_utils.py b/src/notebook_tools/logging_utils.py new file mode 100644 index 0000000..0b34306 --- /dev/null +++ b/src/notebook_tools/logging_utils.py @@ -0,0 +1,41 @@ +"""Logging setup for the service. + +We use the standard library `logging` module so it works everywhere: +- local dev +- Docker +- systemd +- Kubernetes + +Uvicorn has its own logging, but application logs should still be configured so +our modules emit useful messages too. +""" + +from __future__ import annotations + +import logging +import os + + +def configure_logging(*, level: str) -> None: + """Configure root logging once. + + We keep the format readable in terminals but also structured enough for log collectors. + """ + + level_norm = (level or "INFO").upper() + # If logging was already configured (common with uvicorn), basicConfig will be a no-op. + logging.basicConfig( + level=getattr(logging, level_norm, logging.INFO), + format="%(asctime)s %(levelname)s %(name)s - %(message)s", + ) + + # Keep noisy libraries at INFO unless you explicitly crank LOG_LEVEL to DEBUG. + if level_norm != "DEBUG": + logging.getLogger("httpx").setLevel(logging.WARNING) + logging.getLogger("httpcore").setLevel(logging.WARNING) + + # Helpful banner once. + logging.getLogger("notebook_tools").info( + "Logging configured (level=%s, pid=%s)", level_norm, os.getpid() + ) + diff --git a/src/notebook_tools/models.py b/src/notebook_tools/models.py new file mode 100644 index 0000000..aeb1c07 --- /dev/null +++ b/src/notebook_tools/models.py @@ -0,0 +1,46 @@ +"""Pydantic models for request/response payloads. + +Why bother with models? +- They validate incoming JSON (so we can give clear errors early). +- They produce automatic OpenAPI docs in FastAPI. +- They double as "living documentation" for how to call the API. +""" + +from __future__ import annotations + +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +class JobStartRequest(BaseModel): + """Parameters for starting an OCR job.""" + + notebook_id: str = Field(..., description="Your logical notebook identifier to store in custom fields.") + # If you want to tweak OCR prompt later, we allow overriding it per job. + ocr_prompt: str | None = Field( + None, + description="Optional override for the OCR prompt sent to llama.cpp.", + ) + # Optional: metadata you might want to set on each uploaded per-page document. + title_prefix: str | None = Field( + None, + description="Optional prefix for per-page document titles (e.g. 'Notebook 123').", + ) + + +JobState = Literal["queued", "running", "failed", "succeeded"] + + +class JobStatusResponse(BaseModel): + job_id: str + state: JobState + message: str | None = None + # Progress + total_pages: int | None = None + completed_pages: int = 0 + # Output + created_document_ids: list[int] = Field(default_factory=list) + errors: list[str] = Field(default_factory=list) + debug: dict[str, Any] = Field(default_factory=dict) + diff --git a/src/notebook_tools/pdf_utils.py b/src/notebook_tools/pdf_utils.py new file mode 100644 index 0000000..796bbdb --- /dev/null +++ b/src/notebook_tools/pdf_utils.py @@ -0,0 +1,50 @@ +"""PDF/image helpers for the OCR pipeline.""" + +from __future__ import annotations + +import io + +import fitz # PyMuPDF +import img2pdf +from PIL import Image + + +def render_pdf_to_jpegs(*, pdf_bytes: bytes, dpi: int) -> list[bytes]: + """Render each PDF page to a JPEG (one JPEG per page). + + Why render to images at all? + - llama vision models take images as input. + - PDFs can contain vector text, scans, rotations, etc. Rendering normalizes that. + + DPI notes: + - Higher DPI improves handwriting legibility but increases latency and payload sizes. + - 200 DPI is a reasonable starting point for notebook pages. + """ + + doc = fitz.open(stream=pdf_bytes, filetype="pdf") + try: + zoom = dpi / 72.0 # PDF points are 72 DPI; scale to requested DPI. + mat = fitz.Matrix(zoom, zoom) + + out: list[bytes] = [] + for page in doc: + pix = page.get_pixmap(matrix=mat, alpha=False) + # pix.samples is raw bytes; easiest is to build a PIL Image and then re-encode as JPEG. + img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) + + buf = io.BytesIO() + # quality=90 keeps text crisp without huge files; optimize reduces size a bit. + img.save(buf, format="JPEG", quality=90, optimize=True) + out.append(buf.getvalue()) + + return out + finally: + doc.close() + + +def jpeg_to_pdf_bytes(*, jpeg_bytes: bytes) -> bytes: + """Convert a single JPEG image to a single-page PDF.""" + + # img2pdf expects the image data (bytes) and returns PDF bytes. + return img2pdf.convert(jpeg_bytes) + diff --git a/tests/test_llama_payload.py b/tests/test_llama_payload.py new file mode 100644 index 0000000..ac00bbc --- /dev/null +++ b/tests/test_llama_payload.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import base64 +import json + +import respx +from httpx import Response + +from notebook_tools.llama_client import LlamaClient + + +@respx.mock +async def test_llama_client_parses_openai_style_response() -> None: + # Arrange: mock llama endpoint + route = respx.post("http://llama.local/v1/chat/completions").mock( + return_value=Response(200, json={"choices": [{"message": {"content": "Hello\nWorld"}}]}) + ) + + client = LlamaClient(base_url="http://llama.local", model="m") + + # Act + out = await client.ocr_jpeg(jpeg_bytes=b"\xff\xd8\xff\xe0fakejpeg") + + # Assert + assert out == "Hello\nWorld" + assert route.called + + # Optional: sanity-check that we really sent a base64 data URL + sent = json.loads(route.calls[0].request.content.decode("utf-8")) + url = sent["messages"][0]["content"][1]["image_url"]["url"] + assert url.startswith("data:image/jpeg;base64,") + b64 = url.split(",", 1)[1] + # If this fails, our payload construction changed. + base64.b64decode(b64) +