Files
Daniel Henry 93e53ad838 Refactor model configuration structure and update README
- Changed the configuration format to use a single 'models' list with entries containing 'url' and 'type'.
- Updated validation logic to ensure 'models' entries are correctly structured.
- Modified download logic to check for existing directories before downloading.
- Revised README to reflect new configuration format and usage instructions.

Signed-off-by: Daniel Henry <iamdanhenry@gmail.com>
2026-01-31 15:09:12 -06:00

192 lines
6.4 KiB
Python

"""URL detection, auth, streaming download with progress, save to ComfyUI paths.
Idempotency: Same URL + same model type always writes to the same path
(comfyui_base_dir/models/<subdir>/<filename>). Re-run is safe: existing files
are skipped (no re-download, no overwrite).
"""
import re
import tempfile
from pathlib import Path
from urllib.parse import parse_qs, urlencode, urlparse
import requests
from tqdm import tqdm
from .config import get_download_dir
CIVITAI_HOST = "civitai.com"
HUGGINGFACE_HOST = "huggingface.co"
CHUNK_SIZE = 1024 * 1024 # 1 MiB
def detect_source(url: str) -> str:
"""Return 'civitai' or 'huggingface'. Raises ValueError for unsupported URLs."""
parsed = urlparse(url)
netloc = (parsed.netloc or "").lower()
if CIVITAI_HOST in netloc:
return "civitai"
if HUGGINGFACE_HOST in netloc:
return "huggingface"
raise ValueError(f"Unsupported URL host: {url}")
def _ensure_civitai_token(url: str, token: str | None) -> str:
"""Append or replace token in Civitai URL query."""
if not token:
return url
parsed = urlparse(url)
qs = parse_qs(parsed.query, keep_blank_values=True)
qs["token"] = [token]
new_query = urlencode(qs, doseq=True)
return parsed._replace(query=new_query).geturl()
def _filename_from_content_disposition(headers: requests.structures.CaseInsensitiveDict) -> str | None:
"""Extract filename from Content-Disposition header if present."""
cd = headers.get("Content-Disposition")
if not cd:
return None
# filename="..."; or filename*=UTF-8''...
m = re.search(r'filename\*?=(?:UTF-8\'\')?["\']?([^";\n]+)["\']?', cd, re.I)
if m:
name = m.group(1).strip()
if name:
return name
return None
def _filename_from_url_path(url: str) -> str:
"""Last path segment of URL, or 'model' if empty."""
path = urlparse(url).path.rstrip("/")
name = path.split("/")[-1] if path else ""
return name or "model"
def _civitai_model_id(url: str) -> str:
"""Extract model version id from Civitai API URL (e.g. .../models/1523317?...)."""
parsed = urlparse(url)
parts = parsed.path.strip("/").split("/")
# .../api/download/models/1523317
if len(parts) >= 4 and parts[-2] == "models":
return parts[-1]
return "unknown"
def resolve_filename(
url: str, source: str, response_headers: requests.structures.CaseInsensitiveDict | None
) -> str:
"""Choose output filename: Content-Disposition > URL path > sensible default."""
if response_headers:
cd_name = _filename_from_content_disposition(response_headers)
if cd_name:
return cd_name
path_name = _filename_from_url_path(url)
if path_name:
return path_name
if source == "civitai":
return f"model_{_civitai_model_id(url)}.safetensors"
return "model"
def download_one(
url: str,
dest_dir: Path,
config: dict,
*,
dry_run: bool = False,
) -> bool:
"""
Download one URL into dest_dir. Use config for tokens.
Same URL + same dest_dir always yields the same output path. Skips if file already exists.
Returns True on success or skip, False on failure. On dry_run, only prints and returns True.
"""
try:
source = detect_source(url)
except ValueError as e:
print(f" Skip: {e}")
return False
dest_dir = Path(dest_dir)
hf_token = config.get("huggingface_token") or ""
civitai_token = config.get("civitai_token") or ""
if source == "huggingface" and not hf_token:
print(" Skip: Hugging Face token not set (huggingface_token in config)")
return False
if source == "civitai" and not civitai_token:
print(" Skip: Civitai token not set (civitai_token in config)")
return False
# HEAD to get Content-Length and optionally Content-Disposition (Civitai may redirect on GET)
headers = {}
if source == "huggingface":
headers["Authorization"] = f"Bearer {hf_token}"
get_url = _ensure_civitai_token(url, civitai_token) if source == "civitai" else url
if dry_run:
filename_guess = _filename_from_url_path(url) or f"model_{_civitai_model_id(url)}.safetensors"
print(f" Would download to: {dest_dir / filename_guess}")
return True
if not dest_dir.is_dir():
print(f" Skip: Directory does not exist: {dest_dir}. Create models/<type> under your ComfyUI base.")
return False
# Single GET with stream (Civitai redirects to S3 and often doesn't support HEAD)
try:
get_resp = requests.get(get_url, headers=headers, stream=True, allow_redirects=True, timeout=60)
get_resp.raise_for_status()
except requests.RequestException as e:
print(f" Request failed: {e}")
return False
# Use original url for filename so same config URL always yields same path (idempotency)
filename = resolve_filename(url, source, get_resp.headers)
dest_path = dest_dir / filename
if dest_path.exists():
get_resp.close()
print(f" Skip (already exists): {filename}")
return True
total = int(get_resp.headers.get("Content-Length", 0) or 0)
tmp_path = None
try:
with tempfile.NamedTemporaryFile(dir=dest_dir, delete=False, prefix=".download_") as tmp:
tmp_path = Path(tmp.name)
with open(tmp_path, "wb") as f:
with tqdm(total=total, unit="B", unit_scale=True, unit_divisor=1024, desc=filename) as pbar:
for chunk in get_resp.iter_content(chunk_size=CHUNK_SIZE):
if chunk:
f.write(chunk)
pbar.update(len(chunk))
tmp_path.replace(dest_path)
return True
except OSError as e:
print(f" Write failed: {e}")
if tmp_path is not None and tmp_path.exists():
try:
tmp_path.unlink()
except OSError:
pass
return False
def run_downloads(
config: dict,
tasks: list[tuple[str, str]],
*,
dry_run: bool = False,
) -> tuple[int, int]:
"""Run download for each (model_type, url). Returns (success_count, fail_count)."""
ok, fail = 0, 0
for model_type, url in tasks:
dest_dir = get_download_dir(config, model_type)
if dry_run:
print(f"[{model_type}] {url}")
if download_one(url, dest_dir, config, dry_run=dry_run):
ok += 1
else:
fail += 1
return ok, fail