from __future__ import annotations import json from typing import Protocol from anthropic import AsyncAnthropic from openai import AsyncOpenAI from app.config import LLMSettings from app.models import EmailData from app.prompts import SYSTEM_PROMPT class LLMAdapter(Protocol): async def classify(self, email: EmailData) -> str: ... class OpenAICompatibleAdapter: def __init__(self, settings: LLMSettings): self.settings = settings self.client = AsyncOpenAI( base_url=settings.base_url, api_key=settings.api_key, timeout=settings.timeout_seconds, max_retries=0, ) async def classify(self, email: EmailData) -> str: response = await self.client.chat.completions.create( model=self.settings.model, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": f"Subject: {email.subject}\nBody: {email.body}"}, ], temperature=self.settings.temperature, response_format={"type": "json_object"}, ) return response.choices[0].message.content or "" class AnthropicCompatibleAdapter: def __init__(self, settings: LLMSettings): self.settings = settings self.client = AsyncAnthropic( base_url=settings.base_url, api_key=settings.api_key, timeout=settings.timeout_seconds, max_retries=0, ) async def classify(self, email: EmailData) -> str: response = await self.client.messages.create( model=self.settings.model, max_tokens=500, temperature=self.settings.temperature, system=SYSTEM_PROMPT, messages=[ {"role": "user", "content": f"Subject: {email.subject}\nBody: {email.body}"}, ], ) chunks: list[str] = [] for block in response.content: text = getattr(block, "text", None) if text: chunks.append(text) return "\n".join(chunks) def build_adapter(settings: LLMSettings) -> LLMAdapter: if settings.provider == "anthropic": return AnthropicCompatibleAdapter(settings) return OpenAICompatibleAdapter(settings) def coerce_json_text(raw: str) -> str: text = raw.strip() if not text: return text if text.startswith("```"): lines = text.splitlines() if len(lines) >= 3 and lines[0].startswith("```") and lines[-1].startswith("```"): text = "\n".join(lines[1:-1]).strip() if text.lower().startswith("json\n"): text = text[5:].strip() start = text.find("{") end = text.rfind("}") if start != -1 and end != -1 and end >= start: candidate = text[start : end + 1] json.loads(candidate) return candidate json.loads(text) return text