- Changed the configuration format to use a single 'models' list with entries containing 'url' and 'type'. - Updated validation logic to ensure 'models' entries are correctly structured. - Modified download logic to check for existing directories before downloading. - Revised README to reflect new configuration format and usage instructions. Signed-off-by: Daniel Henry <iamdanhenry@gmail.com>
85 lines
3.0 KiB
Python
85 lines
3.0 KiB
Python
"""Load and validate YAML config; build ComfyUI paths."""
|
|
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import yaml
|
|
|
|
REQUIRED_KEYS = {"comfyui_base_dir"}
|
|
TOKEN_KEYS = {"huggingface_token", "civitai_token"}
|
|
|
|
|
|
def _is_safe_type(t: str) -> bool:
|
|
"""Reject type values that could escape models/ (path traversal)."""
|
|
if not t or not isinstance(t, str):
|
|
return False
|
|
if ".." in t or "/" in t or "\\" in t:
|
|
return False
|
|
return True
|
|
|
|
|
|
def load_config(path: str | Path) -> dict[str, Any]:
|
|
"""Load YAML config from path. Raises FileNotFoundError or yaml error."""
|
|
p = Path(path)
|
|
if not p.exists():
|
|
raise FileNotFoundError(f"Config not found: {p}")
|
|
with open(p, encoding="utf-8") as f:
|
|
data = yaml.safe_load(f)
|
|
if not isinstance(data, dict):
|
|
raise ValueError("Config must be a YAML object")
|
|
return data
|
|
|
|
|
|
def validate_config(data: dict[str, Any]) -> None:
|
|
"""Validate required keys and models list. Raises ValueError on failure."""
|
|
if "comfyui_base_dir" not in data:
|
|
raise ValueError("Config must contain 'comfyui_base_dir'")
|
|
base = data["comfyui_base_dir"]
|
|
if not base or not isinstance(base, str):
|
|
raise ValueError("comfyui_base_dir must be a non-empty string")
|
|
if "models" not in data:
|
|
data["models"] = []
|
|
models = data["models"]
|
|
if not isinstance(models, list):
|
|
raise ValueError("'models' must be a list")
|
|
for i, item in enumerate(models):
|
|
if not isinstance(item, dict):
|
|
raise ValueError(f"models[{i}] must be an object with url and type")
|
|
url = item.get("url")
|
|
if not isinstance(url, str) or not url.strip():
|
|
raise ValueError(f"models[{i}] must have a non-empty 'url' string")
|
|
t = item.get("type")
|
|
if not isinstance(t, str) or not t.strip():
|
|
raise ValueError(f"models[{i}] must have a non-empty 'type' string")
|
|
if not _is_safe_type(t):
|
|
raise ValueError(f"models[{i}] 'type' must not contain /, \\, or ..")
|
|
|
|
|
|
def get_model_tasks(
|
|
data: dict[str, Any], only_types: list[str] | None = None
|
|
) -> list[tuple[str, str]]:
|
|
"""Return list of (model_type, url) for all configured model URLs."""
|
|
models = data.get("models")
|
|
if not isinstance(models, list):
|
|
return []
|
|
tasks: list[tuple[str, str]] = []
|
|
for item in models:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
url = item.get("url")
|
|
t = item.get("type")
|
|
if not isinstance(url, str) or not url.strip():
|
|
continue
|
|
if not isinstance(t, str) or not t.strip() or not _is_safe_type(t):
|
|
continue
|
|
if only_types is not None and t not in only_types:
|
|
continue
|
|
tasks.append((t.strip(), url.strip()))
|
|
return tasks
|
|
|
|
|
|
def get_download_dir(data: dict[str, Any], model_type: str) -> Path:
|
|
"""Return the absolute directory path for a model type under ComfyUI base."""
|
|
base = Path(data["comfyui_base_dir"]).expanduser().resolve()
|
|
return base / "models" / model_type
|