diff --git a/sft_get.py b/sft_get.py index 79c472c..3e233ad 100644 --- a/sft_get.py +++ b/sft_get.py @@ -8,6 +8,7 @@ from __future__ import annotations import argparse import hashlib import json +import re import struct import sys from pathlib import Path @@ -30,6 +31,7 @@ from rich.table import Table console = Console() CIVITAI_API_BASE = "https://civitai.com/api/v1" +CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models" def read_safetensor_metadata(file_path: Path) -> dict[str, Any]: @@ -90,6 +92,30 @@ def compute_sha256(file_path: Path) -> str: return sha256.hexdigest().upper() +def fetch_civitai_model_version( + version_id: int, api_key: str | None = None +) -> dict[str, Any] | None: + """Fetch model version information from CivitAI by version ID.""" + url = f"{CIVITAI_API_BASE}/model-versions/{version_id}" + headers: dict[str, str] = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + try: + response = httpx.get(url, headers=headers, timeout=30.0) + if response.status_code == 404: + return None + response.raise_for_status() + result: dict[str, Any] = response.json() + return result + except httpx.HTTPStatusError as e: + console.print(f"[red]API error: {e.response.status_code}[/red]") + return None + except httpx.RequestError as e: + console.print(f"[red]Request error: {e}[/red]") + return None + + def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None = None) -> dict[str, Any] | None: """Fetch model information from CivitAI by SHA256 hash.""" url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}" @@ -120,6 +146,93 @@ def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None = None) -> dict[ return None +def download_model( + version_id: int, + dest_path: Path, + api_key: str | None = None, + resume: bool = True, +) -> bool: + """Download a model from CivitAI by version ID with resume support. + + Returns True on success, False on failure. + """ + url = f"{CIVITAI_DOWNLOAD_BASE}/{version_id}" + params: dict[str, str] = {} + if api_key: + params["token"] = api_key + + headers: dict[str, str] = {} + mode = "wb" + initial_size = 0 + + # Check for existing partial download + if resume and dest_path.exists(): + initial_size = dest_path.stat().st_size + headers["Range"] = f"bytes={initial_size}-" + mode = "ab" + console.print(f"[cyan]Resuming download from {initial_size / (1024**2):.1f} MB[/cyan]") + + try: + with httpx.stream( + "GET", + url, + params=params, + headers=headers, + follow_redirects=True, + timeout=httpx.Timeout(30.0, read=None), # No read timeout for large files + ) as response: + # Handle 416 Range Not Satisfiable (file already complete) + if response.status_code == 416: + console.print("[green]File already fully downloaded.[/green]") + return True + + response.raise_for_status() + + # Get total size from Content-Length or Content-Range + content_length = response.headers.get("content-length") + total_size = int(content_length) + initial_size if content_length else 0 + + # Get filename from Content-Disposition if available + content_disp = response.headers.get("content-disposition", "") + if "filename=" in content_disp: + match = re.search(r'filename="?([^";\n]+)"?', content_disp) + if match and dest_path.is_dir(): + dest_path = dest_path / match.group(1) + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + DownloadColumn(), + TransferSpeedColumn(), + TimeRemainingColumn(), + console=console, + ) as progress: + task = progress.add_task( + f"[cyan]Downloading {dest_path.name}...", + total=total_size if total_size > 0 else None, + completed=initial_size, + ) + + with dest_path.open(mode) as f: + for chunk in response.iter_bytes(1024 * 1024): # 1MB chunks + f.write(chunk) + progress.update(task, advance=len(chunk)) + + console.print(f"[green]Downloaded:[/green] {dest_path}") + return True + + except httpx.HTTPStatusError as e: + console.print(f"[red]Download error: HTTP {e.response.status_code}[/red]") + if e.response.status_code == 401: + console.print("[yellow]Hint: This model may require an API key.[/yellow]") + return False + except httpx.RequestError as e: + console.print(f"[red]Download error: {e}[/red]") + return False + + def _display_file_info(file_path: Path, local_metadata: dict[str, Any], sha256_hash: str) -> None: """Display file information table.""" file_table = Table(title="File Information", show_header=True, header_style="bold magenta") @@ -264,43 +377,8 @@ def save_metadata( return json_path, sha_path -def main() -> int: - """Main entry point.""" - parser = argparse.ArgumentParser( - description="Read safetensor metadata and fetch CivitAI model information.", - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - parser.add_argument( - "file", - type=Path, - help="Path to the safetensor file", - ) - parser.add_argument( - "--api-key", - type=str, - default=None, - help="CivitAI API key for authenticated requests", - ) - parser.add_argument( - "--skip-civitai", - action="store_true", - help="Skip CivitAI API lookup", - ) - parser.add_argument( - "--json", - action="store_true", - dest="json_output", - help="Output results as JSON", - ) - parser.add_argument( - "--save-to", - type=Path, - metavar="DIR", - help="Save metadata JSON and SHA256 hash to the specified directory", - ) - - args = parser.parse_args() - +def cmd_info(args: argparse.Namespace) -> int: + """Handle the info subcommand (default behavior).""" file_path: Path = args.file.resolve() if not file_path.exists(): @@ -362,5 +440,179 @@ def main() -> int: return 1 +def _resolve_version_id( + version_id: int | None, sha256_hash: str | None, api_key: str | None +) -> int | None: + """Resolve version ID from hash if needed.""" + if version_id: + return version_id + if sha256_hash: + console.print(f"[cyan]Looking up model by hash: {sha256_hash[:16]}...[/cyan]") + civitai_data = fetch_civitai_by_hash(sha256_hash.upper(), api_key) + if not civitai_data: + console.print("[red]Error: Model not found on CivitAI for this hash.[/red]") + return None + vid = civitai_data.get("id") + if vid: + console.print(f"[green]Found model version:[/green] {civitai_data.get('name', 'N/A')}") + else: + console.print("[red]Error: Could not determine version ID from CivitAI response.[/red]") + return vid + return None + + +def cmd_download(args: argparse.Namespace) -> int: + """Handle the download subcommand.""" + api_key: str | None = args.api_key + output_dir: Path = args.output.resolve() if args.output else Path.cwd() + + if not output_dir.exists(): + console.print(f"[red]Error: Output directory not found: {output_dir}[/red]") + return 1 + if not output_dir.is_dir(): + console.print(f"[red]Error: Not a directory: {output_dir}[/red]") + return 1 + + # Resolve version ID from hash if needed + version_id = _resolve_version_id(args.version_id, args.hash, api_key) + if not version_id: + if not args.version_id and not args.hash: + console.print("[red]Error: Must specify --version-id or --hash[/red]") + return 1 + + # Fetch version info to get filename + console.print(f"[cyan]Fetching model info for version {version_id}...[/cyan]") + version_info = fetch_civitai_model_version(version_id, api_key) + + if not version_info: + console.print("[red]Error: Could not fetch model version info.[/red]") + return 1 + + # Find primary file or first file + files: list[dict[str, Any]] = version_info.get("files", []) + primary_file = next((f for f in files if f.get("primary")), files[0] if files else None) + + if not primary_file: + console.print("[red]Error: No files found for this model version.[/red]") + return 1 + + filename = primary_file.get("name", f"model-{version_id}.safetensors") + dest_path = output_dir / filename + + # Display model info + model_table = Table(title="Model Download", show_header=True, header_style="bold magenta") + model_table.add_column("Property", style="cyan") + model_table.add_column("Value", style="green") + model_table.add_row("Version", version_info.get("name", "N/A")) + model_table.add_row("Base Model", version_info.get("baseModel", "N/A")) + model_table.add_row("File", filename) + model_table.add_row("Size", f"{primary_file.get('sizeKB', 0) / 1024:.2f} MB") + model_table.add_row("Destination", str(dest_path)) + console.print() + console.print(model_table) + console.print() + + # Download + success = download_model(version_id, dest_path, api_key, resume=not args.no_resume) + return 0 if success else 1 + + +def main() -> int: + """Main entry point.""" + parser = argparse.ArgumentParser( + description="Read safetensor metadata and download CivitAI models.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + subparsers = parser.add_subparsers(dest="command", help="Commands") + + # Info command (default) + info_parser = subparsers.add_parser( + "info", + help="Read safetensor metadata and fetch CivitAI info (default)", + ) + info_parser.add_argument( + "file", + type=Path, + help="Path to the safetensor file", + ) + info_parser.add_argument( + "--api-key", + type=str, + default=None, + help="CivitAI API key for authenticated requests", + ) + info_parser.add_argument( + "--skip-civitai", + action="store_true", + help="Skip CivitAI API lookup", + ) + info_parser.add_argument( + "--json", + action="store_true", + dest="json_output", + help="Output results as JSON", + ) + info_parser.add_argument( + "--save-to", + type=Path, + metavar="DIR", + help="Save metadata JSON and SHA256 hash to the specified directory", + ) + info_parser.set_defaults(func=cmd_info) + + # Download command + dl_parser = subparsers.add_parser( + "download", + aliases=["dl"], + help="Download a model from CivitAI", + ) + dl_parser.add_argument( + "--version-id", + "-v", + type=int, + help="CivitAI model version ID to download", + ) + dl_parser.add_argument( + "--hash", + "-H", + type=str, + help="SHA256 hash to look up and download", + ) + dl_parser.add_argument( + "--api-key", + type=str, + default=None, + help="CivitAI API key for authenticated requests", + ) + dl_parser.add_argument( + "--output", + "-o", + type=Path, + help="Output directory (default: current directory)", + ) + dl_parser.add_argument( + "--no-resume", + action="store_true", + help="Don't resume partial downloads, start fresh", + ) + dl_parser.set_defaults(func=cmd_download) + + # Parse and handle default command + args = parser.parse_args() + + # If no command specified and file argument given, assume 'info' command + if args.command is None: + # Check if there's a positional argument (file path) + if len(sys.argv) > 1 and not sys.argv[1].startswith("-"): + # Re-parse with 'info' prepended + args = parser.parse_args(["info", *sys.argv[1:]]) + else: + parser.print_help() + return 0 + + result: int = args.func(args) + return result + + if __name__ == "__main__": sys.exit(main())