"""URL detection, auth, streaming download with progress, save to ComfyUI paths. Idempotency: Same URL + same model type always writes to the same path (comfyui_base_dir/models//). Re-run is safe: existing files are skipped (no re-download, no overwrite). """ import re import tempfile from pathlib import Path from urllib.parse import parse_qs, urlencode, urlparse import requests from tqdm import tqdm from .config import get_download_dir CIVITAI_HOST = "civitai.com" HUGGINGFACE_HOST = "huggingface.co" CHUNK_SIZE = 1024 * 1024 # 1 MiB def detect_source(url: str) -> str: """Return 'civitai' or 'huggingface'. Raises ValueError for unsupported URLs.""" parsed = urlparse(url) netloc = (parsed.netloc or "").lower() if CIVITAI_HOST in netloc: return "civitai" if HUGGINGFACE_HOST in netloc: return "huggingface" raise ValueError(f"Unsupported URL host: {url}") def _ensure_civitai_token(url: str, token: str | None) -> str: """Append or replace token in Civitai URL query.""" if not token: return url parsed = urlparse(url) qs = parse_qs(parsed.query, keep_blank_values=True) qs["token"] = [token] new_query = urlencode(qs, doseq=True) return parsed._replace(query=new_query).geturl() def _filename_from_content_disposition(headers: requests.structures.CaseInsensitiveDict) -> str | None: """Extract filename from Content-Disposition header if present.""" cd = headers.get("Content-Disposition") if not cd: return None # filename="..."; or filename*=UTF-8''... m = re.search(r'filename\*?=(?:UTF-8\'\')?["\']?([^";\n]+)["\']?', cd, re.I) if m: name = m.group(1).strip() if name: return name return None def _filename_from_url_path(url: str) -> str: """Last path segment of URL, or 'model' if empty.""" path = urlparse(url).path.rstrip("/") name = path.split("/")[-1] if path else "" return name or "model" def _civitai_model_id(url: str) -> str: """Extract model version id from Civitai API URL (e.g. .../models/1523317?...).""" parsed = urlparse(url) parts = parsed.path.strip("/").split("/") # .../api/download/models/1523317 if len(parts) >= 4 and parts[-2] == "models": return parts[-1] return "unknown" def resolve_filename( url: str, source: str, response_headers: requests.structures.CaseInsensitiveDict | None ) -> str: """Choose output filename: Content-Disposition > URL path > sensible default.""" if response_headers: cd_name = _filename_from_content_disposition(response_headers) if cd_name: return cd_name path_name = _filename_from_url_path(url) if path_name: return path_name if source == "civitai": return f"model_{_civitai_model_id(url)}.safetensors" return "model" def download_one( url: str, dest_dir: Path, config: dict, *, dry_run: bool = False, ) -> bool: """ Download one URL into dest_dir. Use config for tokens. Same URL + same dest_dir always yields the same output path. Skips if file already exists. Returns True on success or skip, False on failure. On dry_run, only prints and returns True. """ try: source = detect_source(url) except ValueError as e: print(f" Skip: {e}") return False dest_dir = Path(dest_dir) hf_token = config.get("huggingface_token") or "" civitai_token = config.get("civitai_token") or "" if source == "huggingface" and not hf_token: print(" Skip: Hugging Face token not set (huggingface_token in config)") return False if source == "civitai" and not civitai_token: print(" Skip: Civitai token not set (civitai_token in config)") return False # HEAD to get Content-Length and optionally Content-Disposition (Civitai may redirect on GET) headers = {} if source == "huggingface": headers["Authorization"] = f"Bearer {hf_token}" get_url = _ensure_civitai_token(url, civitai_token) if source == "civitai" else url if dry_run: filename_guess = _filename_from_url_path(url) or f"model_{_civitai_model_id(url)}.safetensors" print(f" Would download to: {dest_dir / filename_guess}") return True dest_dir.mkdir(parents=True, exist_ok=True) # Single GET with stream (Civitai redirects to S3 and often doesn't support HEAD) try: get_resp = requests.get(get_url, headers=headers, stream=True, allow_redirects=True, timeout=60) get_resp.raise_for_status() except requests.RequestException as e: print(f" Request failed: {e}") return False # Use original url for filename so same config URL always yields same path (idempotency) filename = resolve_filename(url, source, get_resp.headers) dest_path = dest_dir / filename if dest_path.exists(): get_resp.close() print(f" Skip (already exists): {filename}") return True total = int(get_resp.headers.get("Content-Length", 0) or 0) tmp_path = None try: with tempfile.NamedTemporaryFile(dir=dest_dir, delete=False, prefix=".download_") as tmp: tmp_path = Path(tmp.name) with open(tmp_path, "wb") as f: with tqdm(total=total, unit="B", unit_scale=True, unit_divisor=1024, desc=filename) as pbar: for chunk in get_resp.iter_content(chunk_size=CHUNK_SIZE): if chunk: f.write(chunk) pbar.update(len(chunk)) tmp_path.replace(dest_path) return True except OSError as e: print(f" Write failed: {e}") if tmp_path is not None and tmp_path.exists(): try: tmp_path.unlink() except OSError: pass return False def run_downloads( config: dict, tasks: list[tuple[str, str]], *, dry_run: bool = False, ) -> tuple[int, int]: """Run download for each (model_type, url). Returns (success_count, fail_count).""" ok, fail = 0, 0 for model_type, url in tasks: dest_dir = get_download_dir(config, model_type) if dry_run: print(f"[{model_type}] {url}") if download_one(url, dest_dir, config, dry_run=dry_run): ok += 1 else: fail += 1 return ok, fail