190 lines
6.3 KiB
Python
190 lines
6.3 KiB
Python
"""URL detection, auth, streaming download with progress, save to ComfyUI paths.
|
|
|
|
Idempotency: Same URL + same model type always writes to the same path
|
|
(comfyui_base_dir/models/<subdir>/<filename>). Re-run is safe: existing files
|
|
are skipped (no re-download, no overwrite).
|
|
"""
|
|
|
|
import re
|
|
import tempfile
|
|
from pathlib import Path
|
|
from urllib.parse import parse_qs, urlencode, urlparse
|
|
|
|
import requests
|
|
from tqdm import tqdm
|
|
|
|
from .config import get_download_dir
|
|
|
|
CIVITAI_HOST = "civitai.com"
|
|
HUGGINGFACE_HOST = "huggingface.co"
|
|
CHUNK_SIZE = 1024 * 1024 # 1 MiB
|
|
|
|
|
|
def detect_source(url: str) -> str:
|
|
"""Return 'civitai' or 'huggingface'. Raises ValueError for unsupported URLs."""
|
|
parsed = urlparse(url)
|
|
netloc = (parsed.netloc or "").lower()
|
|
if CIVITAI_HOST in netloc:
|
|
return "civitai"
|
|
if HUGGINGFACE_HOST in netloc:
|
|
return "huggingface"
|
|
raise ValueError(f"Unsupported URL host: {url}")
|
|
|
|
|
|
def _ensure_civitai_token(url: str, token: str | None) -> str:
|
|
"""Append or replace token in Civitai URL query."""
|
|
if not token:
|
|
return url
|
|
parsed = urlparse(url)
|
|
qs = parse_qs(parsed.query, keep_blank_values=True)
|
|
qs["token"] = [token]
|
|
new_query = urlencode(qs, doseq=True)
|
|
return parsed._replace(query=new_query).geturl()
|
|
|
|
|
|
def _filename_from_content_disposition(headers: requests.structures.CaseInsensitiveDict) -> str | None:
|
|
"""Extract filename from Content-Disposition header if present."""
|
|
cd = headers.get("Content-Disposition")
|
|
if not cd:
|
|
return None
|
|
# filename="..."; or filename*=UTF-8''...
|
|
m = re.search(r'filename\*?=(?:UTF-8\'\')?["\']?([^";\n]+)["\']?', cd, re.I)
|
|
if m:
|
|
name = m.group(1).strip()
|
|
if name:
|
|
return name
|
|
return None
|
|
|
|
|
|
def _filename_from_url_path(url: str) -> str:
|
|
"""Last path segment of URL, or 'model' if empty."""
|
|
path = urlparse(url).path.rstrip("/")
|
|
name = path.split("/")[-1] if path else ""
|
|
return name or "model"
|
|
|
|
|
|
def _civitai_model_id(url: str) -> str:
|
|
"""Extract model version id from Civitai API URL (e.g. .../models/1523317?...)."""
|
|
parsed = urlparse(url)
|
|
parts = parsed.path.strip("/").split("/")
|
|
# .../api/download/models/1523317
|
|
if len(parts) >= 4 and parts[-2] == "models":
|
|
return parts[-1]
|
|
return "unknown"
|
|
|
|
|
|
def resolve_filename(
|
|
url: str, source: str, response_headers: requests.structures.CaseInsensitiveDict | None
|
|
) -> str:
|
|
"""Choose output filename: Content-Disposition > URL path > sensible default."""
|
|
if response_headers:
|
|
cd_name = _filename_from_content_disposition(response_headers)
|
|
if cd_name:
|
|
return cd_name
|
|
path_name = _filename_from_url_path(url)
|
|
if path_name:
|
|
return path_name
|
|
if source == "civitai":
|
|
return f"model_{_civitai_model_id(url)}.safetensors"
|
|
return "model"
|
|
|
|
|
|
def download_one(
|
|
url: str,
|
|
dest_dir: Path,
|
|
config: dict,
|
|
*,
|
|
dry_run: bool = False,
|
|
) -> bool:
|
|
"""
|
|
Download one URL into dest_dir. Use config for tokens.
|
|
Same URL + same dest_dir always yields the same output path. Skips if file already exists.
|
|
Returns True on success or skip, False on failure. On dry_run, only prints and returns True.
|
|
"""
|
|
try:
|
|
source = detect_source(url)
|
|
except ValueError as e:
|
|
print(f" Skip: {e}")
|
|
return False
|
|
|
|
dest_dir = Path(dest_dir)
|
|
hf_token = config.get("huggingface_token") or ""
|
|
civitai_token = config.get("civitai_token") or ""
|
|
|
|
if source == "huggingface" and not hf_token:
|
|
print(" Skip: Hugging Face token not set (huggingface_token in config)")
|
|
return False
|
|
if source == "civitai" and not civitai_token:
|
|
print(" Skip: Civitai token not set (civitai_token in config)")
|
|
return False
|
|
|
|
# HEAD to get Content-Length and optionally Content-Disposition (Civitai may redirect on GET)
|
|
headers = {}
|
|
if source == "huggingface":
|
|
headers["Authorization"] = f"Bearer {hf_token}"
|
|
get_url = _ensure_civitai_token(url, civitai_token) if source == "civitai" else url
|
|
|
|
if dry_run:
|
|
filename_guess = _filename_from_url_path(url) or f"model_{_civitai_model_id(url)}.safetensors"
|
|
print(f" Would download to: {dest_dir / filename_guess}")
|
|
return True
|
|
|
|
dest_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Single GET with stream (Civitai redirects to S3 and often doesn't support HEAD)
|
|
try:
|
|
get_resp = requests.get(get_url, headers=headers, stream=True, allow_redirects=True, timeout=60)
|
|
get_resp.raise_for_status()
|
|
except requests.RequestException as e:
|
|
print(f" Request failed: {e}")
|
|
return False
|
|
|
|
# Use original url for filename so same config URL always yields same path (idempotency)
|
|
filename = resolve_filename(url, source, get_resp.headers)
|
|
dest_path = dest_dir / filename
|
|
if dest_path.exists():
|
|
get_resp.close()
|
|
print(f" Skip (already exists): {filename}")
|
|
return True
|
|
|
|
total = int(get_resp.headers.get("Content-Length", 0) or 0)
|
|
tmp_path = None
|
|
try:
|
|
with tempfile.NamedTemporaryFile(dir=dest_dir, delete=False, prefix=".download_") as tmp:
|
|
tmp_path = Path(tmp.name)
|
|
with open(tmp_path, "wb") as f:
|
|
with tqdm(total=total, unit="B", unit_scale=True, unit_divisor=1024, desc=filename) as pbar:
|
|
for chunk in get_resp.iter_content(chunk_size=CHUNK_SIZE):
|
|
if chunk:
|
|
f.write(chunk)
|
|
pbar.update(len(chunk))
|
|
tmp_path.replace(dest_path)
|
|
return True
|
|
except OSError as e:
|
|
print(f" Write failed: {e}")
|
|
if tmp_path is not None and tmp_path.exists():
|
|
try:
|
|
tmp_path.unlink()
|
|
except OSError:
|
|
pass
|
|
return False
|
|
|
|
|
|
def run_downloads(
|
|
config: dict,
|
|
tasks: list[tuple[str, str]],
|
|
*,
|
|
dry_run: bool = False,
|
|
) -> tuple[int, int]:
|
|
"""Run download for each (model_type, url). Returns (success_count, fail_count)."""
|
|
ok, fail = 0, 0
|
|
for model_type, url in tasks:
|
|
dest_dir = get_download_dir(config, model_type)
|
|
if dry_run:
|
|
print(f"[{model_type}] {url}")
|
|
if download_one(url, dest_dir, config, dry_run=dry_run):
|
|
ok += 1
|
|
else:
|
|
fail += 1
|
|
return ok, fail
|