3
model_downloader/__init__.py
Normal file
3
model_downloader/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Download ComfyUI models from Hugging Face and Civitai into configured folders."""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
65
model_downloader/__main__.py
Normal file
65
model_downloader/__main__.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""CLI: parse args, load config, run downloads."""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from .config import get_model_tasks, load_config, validate_config
|
||||
from .download import run_downloads
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Download ComfyUI models from Hugging Face and Civitai into configured folders."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=Path,
|
||||
default=Path("config.yaml"),
|
||||
help="Path to YAML config (default: config.yaml in cwd)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Only print what would be downloaded and where; no writes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--only",
|
||||
nargs="+",
|
||||
metavar="TYPE",
|
||||
help="Only process these model types (e.g. diffusion_models loras)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
config = load_config(args.config)
|
||||
except FileNotFoundError as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"Error loading config: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
try:
|
||||
validate_config(config)
|
||||
except ValueError as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
tasks = get_model_tasks(config, only_types=args.only)
|
||||
if not tasks:
|
||||
print("No model URLs to download. Add URLs under diffusion_models, text_encoders, vaes, upscale_models, or loras in your config.")
|
||||
return 0
|
||||
|
||||
if args.dry_run:
|
||||
print("Dry run – no files will be written.\n")
|
||||
|
||||
ok, fail = run_downloads(config, tasks, dry_run=args.dry_run)
|
||||
|
||||
if not args.dry_run:
|
||||
print(f"\nDone: {ok} succeeded, {fail} failed.")
|
||||
return 1 if fail else 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
74
model_downloader/config.py
Normal file
74
model_downloader/config.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Load and validate YAML config; build ComfyUI paths."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
# Config keys for model lists → ComfyUI subdir under models/
|
||||
MODEL_TYPE_SUBDIRS = {
|
||||
"diffusion_models": "diffusion_models",
|
||||
"text_encoders": "text_encoders",
|
||||
"vaes": "vae",
|
||||
"upscale_models": "upscale_models",
|
||||
"loras": "loras",
|
||||
}
|
||||
|
||||
REQUIRED_KEYS = {"comfyui_base_dir"}
|
||||
TOKEN_KEYS = {"huggingface_token", "civitai_token"}
|
||||
|
||||
|
||||
def load_config(path: str | Path) -> dict[str, Any]:
|
||||
"""Load YAML config from path. Raises FileNotFoundError or yaml error."""
|
||||
p = Path(path)
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"Config not found: {p}")
|
||||
with open(p, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("Config must be a YAML object")
|
||||
return data
|
||||
|
||||
|
||||
def validate_config(data: dict[str, Any]) -> None:
|
||||
"""Validate required keys and model list types. Raises ValueError on failure."""
|
||||
if "comfyui_base_dir" not in data:
|
||||
raise ValueError("Config must contain 'comfyui_base_dir'")
|
||||
base = data["comfyui_base_dir"]
|
||||
if not base or not isinstance(base, str):
|
||||
raise ValueError("comfyui_base_dir must be a non-empty string")
|
||||
for key in MODEL_TYPE_SUBDIRS:
|
||||
val = data.get(key)
|
||||
if val is None:
|
||||
data[key] = []
|
||||
elif not isinstance(val, list):
|
||||
raise ValueError(f"'{key}' must be a list of URL strings")
|
||||
else:
|
||||
for i, item in enumerate(val):
|
||||
if not isinstance(item, str):
|
||||
raise ValueError(f"'{key}[{i}]' must be a string")
|
||||
|
||||
|
||||
def get_model_tasks(
|
||||
data: dict[str, Any], only_types: list[str] | None = None
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Return list of (model_type, url) for all configured model URLs."""
|
||||
types = only_types if only_types is not None else list(MODEL_TYPE_SUBDIRS)
|
||||
tasks: list[tuple[str, str]] = []
|
||||
for model_type in types:
|
||||
if model_type not in MODEL_TYPE_SUBDIRS:
|
||||
continue
|
||||
urls = data.get(model_type)
|
||||
if not isinstance(urls, list):
|
||||
continue
|
||||
for url in urls:
|
||||
if isinstance(url, str) and url.strip():
|
||||
tasks.append((model_type, url.strip()))
|
||||
return tasks
|
||||
|
||||
|
||||
def get_download_dir(data: dict[str, Any], model_type: str) -> Path:
|
||||
"""Return the absolute directory path for a model type under ComfyUI base."""
|
||||
base = Path(data["comfyui_base_dir"]).expanduser().resolve()
|
||||
subdir = MODEL_TYPE_SUBDIRS[model_type]
|
||||
return base / "models" / subdir
|
||||
189
model_downloader/download.py
Normal file
189
model_downloader/download.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""URL detection, auth, streaming download with progress, save to ComfyUI paths.
|
||||
|
||||
Idempotency: Same URL + same model type always writes to the same path
|
||||
(comfyui_base_dir/models/<subdir>/<filename>). Re-run is safe: existing files
|
||||
are skipped (no re-download, no overwrite).
|
||||
"""
|
||||
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from urllib.parse import parse_qs, urlencode, urlparse
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
from .config import get_download_dir
|
||||
|
||||
CIVITAI_HOST = "civitai.com"
|
||||
HUGGINGFACE_HOST = "huggingface.co"
|
||||
CHUNK_SIZE = 1024 * 1024 # 1 MiB
|
||||
|
||||
|
||||
def detect_source(url: str) -> str:
|
||||
"""Return 'civitai' or 'huggingface'. Raises ValueError for unsupported URLs."""
|
||||
parsed = urlparse(url)
|
||||
netloc = (parsed.netloc or "").lower()
|
||||
if CIVITAI_HOST in netloc:
|
||||
return "civitai"
|
||||
if HUGGINGFACE_HOST in netloc:
|
||||
return "huggingface"
|
||||
raise ValueError(f"Unsupported URL host: {url}")
|
||||
|
||||
|
||||
def _ensure_civitai_token(url: str, token: str | None) -> str:
|
||||
"""Append or replace token in Civitai URL query."""
|
||||
if not token:
|
||||
return url
|
||||
parsed = urlparse(url)
|
||||
qs = parse_qs(parsed.query, keep_blank_values=True)
|
||||
qs["token"] = [token]
|
||||
new_query = urlencode(qs, doseq=True)
|
||||
return parsed._replace(query=new_query).geturl()
|
||||
|
||||
|
||||
def _filename_from_content_disposition(headers: requests.structures.CaseInsensitiveDict) -> str | None:
|
||||
"""Extract filename from Content-Disposition header if present."""
|
||||
cd = headers.get("Content-Disposition")
|
||||
if not cd:
|
||||
return None
|
||||
# filename="..."; or filename*=UTF-8''...
|
||||
m = re.search(r'filename\*?=(?:UTF-8\'\')?["\']?([^";\n]+)["\']?', cd, re.I)
|
||||
if m:
|
||||
name = m.group(1).strip()
|
||||
if name:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def _filename_from_url_path(url: str) -> str:
|
||||
"""Last path segment of URL, or 'model' if empty."""
|
||||
path = urlparse(url).path.rstrip("/")
|
||||
name = path.split("/")[-1] if path else ""
|
||||
return name or "model"
|
||||
|
||||
|
||||
def _civitai_model_id(url: str) -> str:
|
||||
"""Extract model version id from Civitai API URL (e.g. .../models/1523317?...)."""
|
||||
parsed = urlparse(url)
|
||||
parts = parsed.path.strip("/").split("/")
|
||||
# .../api/download/models/1523317
|
||||
if len(parts) >= 4 and parts[-2] == "models":
|
||||
return parts[-1]
|
||||
return "unknown"
|
||||
|
||||
|
||||
def resolve_filename(
|
||||
url: str, source: str, response_headers: requests.structures.CaseInsensitiveDict | None
|
||||
) -> str:
|
||||
"""Choose output filename: Content-Disposition > URL path > sensible default."""
|
||||
if response_headers:
|
||||
cd_name = _filename_from_content_disposition(response_headers)
|
||||
if cd_name:
|
||||
return cd_name
|
||||
path_name = _filename_from_url_path(url)
|
||||
if path_name:
|
||||
return path_name
|
||||
if source == "civitai":
|
||||
return f"model_{_civitai_model_id(url)}.safetensors"
|
||||
return "model"
|
||||
|
||||
|
||||
def download_one(
|
||||
url: str,
|
||||
dest_dir: Path,
|
||||
config: dict,
|
||||
*,
|
||||
dry_run: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Download one URL into dest_dir. Use config for tokens.
|
||||
Same URL + same dest_dir always yields the same output path. Skips if file already exists.
|
||||
Returns True on success or skip, False on failure. On dry_run, only prints and returns True.
|
||||
"""
|
||||
try:
|
||||
source = detect_source(url)
|
||||
except ValueError as e:
|
||||
print(f" Skip: {e}")
|
||||
return False
|
||||
|
||||
dest_dir = Path(dest_dir)
|
||||
hf_token = config.get("huggingface_token") or ""
|
||||
civitai_token = config.get("civitai_token") or ""
|
||||
|
||||
if source == "huggingface" and not hf_token:
|
||||
print(" Skip: Hugging Face token not set (huggingface_token in config)")
|
||||
return False
|
||||
if source == "civitai" and not civitai_token:
|
||||
print(" Skip: Civitai token not set (civitai_token in config)")
|
||||
return False
|
||||
|
||||
# HEAD to get Content-Length and optionally Content-Disposition (Civitai may redirect on GET)
|
||||
headers = {}
|
||||
if source == "huggingface":
|
||||
headers["Authorization"] = f"Bearer {hf_token}"
|
||||
get_url = _ensure_civitai_token(url, civitai_token) if source == "civitai" else url
|
||||
|
||||
if dry_run:
|
||||
filename_guess = _filename_from_url_path(url) or f"model_{_civitai_model_id(url)}.safetensors"
|
||||
print(f" Would download to: {dest_dir / filename_guess}")
|
||||
return True
|
||||
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Single GET with stream (Civitai redirects to S3 and often doesn't support HEAD)
|
||||
try:
|
||||
get_resp = requests.get(get_url, headers=headers, stream=True, allow_redirects=True, timeout=60)
|
||||
get_resp.raise_for_status()
|
||||
except requests.RequestException as e:
|
||||
print(f" Request failed: {e}")
|
||||
return False
|
||||
|
||||
# Use original url for filename so same config URL always yields same path (idempotency)
|
||||
filename = resolve_filename(url, source, get_resp.headers)
|
||||
dest_path = dest_dir / filename
|
||||
if dest_path.exists():
|
||||
get_resp.close()
|
||||
print(f" Skip (already exists): {filename}")
|
||||
return True
|
||||
|
||||
total = int(get_resp.headers.get("Content-Length", 0) or 0)
|
||||
tmp_path = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(dir=dest_dir, delete=False, prefix=".download_") as tmp:
|
||||
tmp_path = Path(tmp.name)
|
||||
with open(tmp_path, "wb") as f:
|
||||
with tqdm(total=total, unit="B", unit_scale=True, unit_divisor=1024, desc=filename) as pbar:
|
||||
for chunk in get_resp.iter_content(chunk_size=CHUNK_SIZE):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
pbar.update(len(chunk))
|
||||
tmp_path.replace(dest_path)
|
||||
return True
|
||||
except OSError as e:
|
||||
print(f" Write failed: {e}")
|
||||
if tmp_path is not None and tmp_path.exists():
|
||||
try:
|
||||
tmp_path.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def run_downloads(
|
||||
config: dict,
|
||||
tasks: list[tuple[str, str]],
|
||||
*,
|
||||
dry_run: bool = False,
|
||||
) -> tuple[int, int]:
|
||||
"""Run download for each (model_type, url). Returns (success_count, fail_count)."""
|
||||
ok, fail = 0, 0
|
||||
for model_type, url in tasks:
|
||||
dest_dir = get_download_dir(config, model_type)
|
||||
if dry_run:
|
||||
print(f"[{model_type}] {url}")
|
||||
if download_one(url, dest_dir, config, dry_run=dry_run):
|
||||
ok += 1
|
||||
else:
|
||||
fail += 1
|
||||
return ok, fail
|
||||
Reference in New Issue
Block a user