87 lines
3.0 KiB
Python
87 lines
3.0 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
from functools import lru_cache
|
|
from pathlib import Path
|
|
from typing import Any, Literal
|
|
|
|
import yaml
|
|
from pydantic import BaseModel
|
|
|
|
Provider = Literal["openai", "anthropic"]
|
|
DEFAULT_CONFIG_PATHS = ["config.yml", "config.yaml", "/config/config.yml", "/config/config.yaml"]
|
|
|
|
|
|
class LLMSettings(BaseModel):
|
|
provider: Provider = "openai"
|
|
api_key: str = "none"
|
|
model: str = "qwen2.5-7b-instruct.q4_k_m"
|
|
base_url: str = "http://ollama.internal.henryhosted.com:9292/v1"
|
|
temperature: float = 0.1
|
|
timeout_seconds: float = 60
|
|
max_retries: int = 3
|
|
|
|
|
|
def _load_yaml_config() -> dict[str, Any]:
|
|
explicit = os.getenv("EMAIL_CLASSIFIER_CONFIG") or os.getenv("APP_CONFIG_FILE")
|
|
candidates = [explicit] if explicit else DEFAULT_CONFIG_PATHS
|
|
for candidate in candidates:
|
|
if not candidate:
|
|
continue
|
|
path = Path(candidate)
|
|
if not path.exists() or not path.is_file():
|
|
continue
|
|
data = yaml.safe_load(path.read_text()) or {}
|
|
if not isinstance(data, dict):
|
|
raise ValueError(f"Config file must contain a mapping/object: {path}")
|
|
llm = data.get("llm", data)
|
|
if not isinstance(llm, dict):
|
|
raise ValueError(f"LLM config must be a mapping/object: {path}")
|
|
return llm
|
|
return {}
|
|
|
|
|
|
def _env_or_yaml(env_name: str, yaml_data: dict[str, Any], yaml_key: str, default: Any) -> Any:
|
|
value = os.getenv(env_name)
|
|
if value is not None:
|
|
return value
|
|
if yaml_key in yaml_data and yaml_data[yaml_key] is not None:
|
|
return yaml_data[yaml_key]
|
|
return default
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_settings() -> LLMSettings:
|
|
yaml_data = _load_yaml_config()
|
|
return LLMSettings(
|
|
provider=_env_or_yaml("LLM_PROVIDER", yaml_data, "provider", "openai"),
|
|
api_key=_env_or_yaml("LLM_API_KEY", yaml_data, "api_key", "none"),
|
|
model=_env_or_yaml("LLM_MODEL", yaml_data, "model", "qwen2.5-7b-instruct.q4_k_m"),
|
|
base_url=_env_or_yaml("LLM_BASE_URL", yaml_data, "base_url", "http://ollama.internal.henryhosted.com:9292/v1"),
|
|
temperature=float(_env_or_yaml("LLM_TEMPERATURE", yaml_data, "temperature", 0.1)),
|
|
timeout_seconds=float(_env_or_yaml("LLM_TIMEOUT_SECONDS", yaml_data, "timeout_seconds", 60)),
|
|
max_retries=int(_env_or_yaml("LLM_MAX_RETRIES", yaml_data, "max_retries", 3)),
|
|
)
|
|
|
|
|
|
def get_request_settings(
|
|
provider: str | None = None,
|
|
model: str | None = None,
|
|
base_url: str | None = None,
|
|
api_key: str | None = None,
|
|
temperature: float | None = None,
|
|
) -> LLMSettings:
|
|
base = get_settings()
|
|
data = base.model_dump()
|
|
if provider is not None:
|
|
data["provider"] = provider
|
|
if model is not None:
|
|
data["model"] = model
|
|
if base_url is not None:
|
|
data["base_url"] = base_url
|
|
if api_key is not None:
|
|
data["api_key"] = api_key
|
|
if temperature is not None:
|
|
data["temperature"] = temperature
|
|
return LLMSettings(**data)
|