Files
email-classifier/app/config.py

87 lines
3.0 KiB
Python

from __future__ import annotations
import os
from functools import lru_cache
from pathlib import Path
from typing import Any, Literal
import yaml
from pydantic import BaseModel
Provider = Literal["openai", "anthropic"]
DEFAULT_CONFIG_PATHS = ["config.yml", "config.yaml", "/config/config.yml", "/config/config.yaml"]
class LLMSettings(BaseModel):
provider: Provider = "openai"
api_key: str = "none"
model: str = "qwen2.5-7b-instruct.q4_k_m"
base_url: str = "http://ollama.internal.henryhosted.com:9292/v1"
temperature: float = 0.1
timeout_seconds: float = 60
max_retries: int = 3
def _load_yaml_config() -> dict[str, Any]:
explicit = os.getenv("EMAIL_CLASSIFIER_CONFIG") or os.getenv("APP_CONFIG_FILE")
candidates = [explicit] if explicit else DEFAULT_CONFIG_PATHS
for candidate in candidates:
if not candidate:
continue
path = Path(candidate)
if not path.exists() or not path.is_file():
continue
data = yaml.safe_load(path.read_text()) or {}
if not isinstance(data, dict):
raise ValueError(f"Config file must contain a mapping/object: {path}")
llm = data.get("llm", data)
if not isinstance(llm, dict):
raise ValueError(f"LLM config must be a mapping/object: {path}")
return llm
return {}
def _env_or_yaml(env_name: str, yaml_data: dict[str, Any], yaml_key: str, default: Any) -> Any:
value = os.getenv(env_name)
if value is not None:
return value
if yaml_key in yaml_data and yaml_data[yaml_key] is not None:
return yaml_data[yaml_key]
return default
@lru_cache(maxsize=1)
def get_settings() -> LLMSettings:
yaml_data = _load_yaml_config()
return LLMSettings(
provider=_env_or_yaml("LLM_PROVIDER", yaml_data, "provider", "openai"),
api_key=_env_or_yaml("LLM_API_KEY", yaml_data, "api_key", "none"),
model=_env_or_yaml("LLM_MODEL", yaml_data, "model", "qwen2.5-7b-instruct.q4_k_m"),
base_url=_env_or_yaml("LLM_BASE_URL", yaml_data, "base_url", "http://ollama.internal.henryhosted.com:9292/v1"),
temperature=float(_env_or_yaml("LLM_TEMPERATURE", yaml_data, "temperature", 0.1)),
timeout_seconds=float(_env_or_yaml("LLM_TIMEOUT_SECONDS", yaml_data, "timeout_seconds", 60)),
max_retries=int(_env_or_yaml("LLM_MAX_RETRIES", yaml_data, "max_retries", 3)),
)
def get_request_settings(
provider: str | None = None,
model: str | None = None,
base_url: str | None = None,
api_key: str | None = None,
temperature: float | None = None,
) -> LLMSettings:
base = get_settings()
data = base.model_dump()
if provider is not None:
data["provider"] = provider
if model is not None:
data["model"] = model
if base_url is not None:
data["base_url"] = base_url
if api_key is not None:
data["api_key"] = api_key
if temperature is not None:
data["temperature"] = temperature
return LLMSettings(**data)