💬 Commit message: Update 2026-01-26 06:35:11, 1 files, 326 lines
📁 Files changed: 1 📝 Lines changed: 326 • sft_get.py
This commit is contained in:
+289
-37
@@ -8,6 +8,7 @@ from __future__ import annotations
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@@ -30,6 +31,7 @@ from rich.table import Table
|
||||
console = Console()
|
||||
|
||||
CIVITAI_API_BASE = "https://civitai.com/api/v1"
|
||||
CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models"
|
||||
|
||||
|
||||
def read_safetensor_metadata(file_path: Path) -> dict[str, Any]:
|
||||
@@ -90,6 +92,30 @@ def compute_sha256(file_path: Path) -> str:
|
||||
return sha256.hexdigest().upper()
|
||||
|
||||
|
||||
def fetch_civitai_model_version(
|
||||
version_id: int, api_key: str | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch model version information from CivitAI by version ID."""
|
||||
url = f"{CIVITAI_API_BASE}/model-versions/{version_id}"
|
||||
headers: dict[str, str] = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
try:
|
||||
response = httpx.get(url, headers=headers, timeout=30.0)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
response.raise_for_status()
|
||||
result: dict[str, Any] = response.json()
|
||||
return result
|
||||
except httpx.HTTPStatusError as e:
|
||||
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
||||
return None
|
||||
except httpx.RequestError as e:
|
||||
console.print(f"[red]Request error: {e}[/red]")
|
||||
return None
|
||||
|
||||
|
||||
def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None = None) -> dict[str, Any] | None:
|
||||
"""Fetch model information from CivitAI by SHA256 hash."""
|
||||
url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}"
|
||||
@@ -120,6 +146,93 @@ def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None = None) -> dict[
|
||||
return None
|
||||
|
||||
|
||||
def download_model(
|
||||
version_id: int,
|
||||
dest_path: Path,
|
||||
api_key: str | None = None,
|
||||
resume: bool = True,
|
||||
) -> bool:
|
||||
"""Download a model from CivitAI by version ID with resume support.
|
||||
|
||||
Returns True on success, False on failure.
|
||||
"""
|
||||
url = f"{CIVITAI_DOWNLOAD_BASE}/{version_id}"
|
||||
params: dict[str, str] = {}
|
||||
if api_key:
|
||||
params["token"] = api_key
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
mode = "wb"
|
||||
initial_size = 0
|
||||
|
||||
# Check for existing partial download
|
||||
if resume and dest_path.exists():
|
||||
initial_size = dest_path.stat().st_size
|
||||
headers["Range"] = f"bytes={initial_size}-"
|
||||
mode = "ab"
|
||||
console.print(f"[cyan]Resuming download from {initial_size / (1024**2):.1f} MB[/cyan]")
|
||||
|
||||
try:
|
||||
with httpx.stream(
|
||||
"GET",
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
follow_redirects=True,
|
||||
timeout=httpx.Timeout(30.0, read=None), # No read timeout for large files
|
||||
) as response:
|
||||
# Handle 416 Range Not Satisfiable (file already complete)
|
||||
if response.status_code == 416:
|
||||
console.print("[green]File already fully downloaded.[/green]")
|
||||
return True
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
# Get total size from Content-Length or Content-Range
|
||||
content_length = response.headers.get("content-length")
|
||||
total_size = int(content_length) + initial_size if content_length else 0
|
||||
|
||||
# Get filename from Content-Disposition if available
|
||||
content_disp = response.headers.get("content-disposition", "")
|
||||
if "filename=" in content_disp:
|
||||
match = re.search(r'filename="?([^";\n]+)"?', content_disp)
|
||||
if match and dest_path.is_dir():
|
||||
dest_path = dest_path / match.group(1)
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
DownloadColumn(),
|
||||
TransferSpeedColumn(),
|
||||
TimeRemainingColumn(),
|
||||
console=console,
|
||||
) as progress:
|
||||
task = progress.add_task(
|
||||
f"[cyan]Downloading {dest_path.name}...",
|
||||
total=total_size if total_size > 0 else None,
|
||||
completed=initial_size,
|
||||
)
|
||||
|
||||
with dest_path.open(mode) as f:
|
||||
for chunk in response.iter_bytes(1024 * 1024): # 1MB chunks
|
||||
f.write(chunk)
|
||||
progress.update(task, advance=len(chunk))
|
||||
|
||||
console.print(f"[green]Downloaded:[/green] {dest_path}")
|
||||
return True
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
console.print(f"[red]Download error: HTTP {e.response.status_code}[/red]")
|
||||
if e.response.status_code == 401:
|
||||
console.print("[yellow]Hint: This model may require an API key.[/yellow]")
|
||||
return False
|
||||
except httpx.RequestError as e:
|
||||
console.print(f"[red]Download error: {e}[/red]")
|
||||
return False
|
||||
|
||||
|
||||
def _display_file_info(file_path: Path, local_metadata: dict[str, Any], sha256_hash: str) -> None:
|
||||
"""Display file information table."""
|
||||
file_table = Table(title="File Information", show_header=True, header_style="bold magenta")
|
||||
@@ -264,43 +377,8 @@ def save_metadata(
|
||||
return json_path, sha_path
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Main entry point."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Read safetensor metadata and fetch CivitAI model information.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"file",
|
||||
type=Path,
|
||||
help="Path to the safetensor file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="CivitAI API key for authenticated requests",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-civitai",
|
||||
action="store_true",
|
||||
help="Skip CivitAI API lookup",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--json",
|
||||
action="store_true",
|
||||
dest="json_output",
|
||||
help="Output results as JSON",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-to",
|
||||
type=Path,
|
||||
metavar="DIR",
|
||||
help="Save metadata JSON and SHA256 hash to the specified directory",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def cmd_info(args: argparse.Namespace) -> int:
|
||||
"""Handle the info subcommand (default behavior)."""
|
||||
file_path: Path = args.file.resolve()
|
||||
|
||||
if not file_path.exists():
|
||||
@@ -362,5 +440,179 @@ def main() -> int:
|
||||
return 1
|
||||
|
||||
|
||||
def _resolve_version_id(
|
||||
version_id: int | None, sha256_hash: str | None, api_key: str | None
|
||||
) -> int | None:
|
||||
"""Resolve version ID from hash if needed."""
|
||||
if version_id:
|
||||
return version_id
|
||||
if sha256_hash:
|
||||
console.print(f"[cyan]Looking up model by hash: {sha256_hash[:16]}...[/cyan]")
|
||||
civitai_data = fetch_civitai_by_hash(sha256_hash.upper(), api_key)
|
||||
if not civitai_data:
|
||||
console.print("[red]Error: Model not found on CivitAI for this hash.[/red]")
|
||||
return None
|
||||
vid = civitai_data.get("id")
|
||||
if vid:
|
||||
console.print(f"[green]Found model version:[/green] {civitai_data.get('name', 'N/A')}")
|
||||
else:
|
||||
console.print("[red]Error: Could not determine version ID from CivitAI response.[/red]")
|
||||
return vid
|
||||
return None
|
||||
|
||||
|
||||
def cmd_download(args: argparse.Namespace) -> int:
|
||||
"""Handle the download subcommand."""
|
||||
api_key: str | None = args.api_key
|
||||
output_dir: Path = args.output.resolve() if args.output else Path.cwd()
|
||||
|
||||
if not output_dir.exists():
|
||||
console.print(f"[red]Error: Output directory not found: {output_dir}[/red]")
|
||||
return 1
|
||||
if not output_dir.is_dir():
|
||||
console.print(f"[red]Error: Not a directory: {output_dir}[/red]")
|
||||
return 1
|
||||
|
||||
# Resolve version ID from hash if needed
|
||||
version_id = _resolve_version_id(args.version_id, args.hash, api_key)
|
||||
if not version_id:
|
||||
if not args.version_id and not args.hash:
|
||||
console.print("[red]Error: Must specify --version-id or --hash[/red]")
|
||||
return 1
|
||||
|
||||
# Fetch version info to get filename
|
||||
console.print(f"[cyan]Fetching model info for version {version_id}...[/cyan]")
|
||||
version_info = fetch_civitai_model_version(version_id, api_key)
|
||||
|
||||
if not version_info:
|
||||
console.print("[red]Error: Could not fetch model version info.[/red]")
|
||||
return 1
|
||||
|
||||
# Find primary file or first file
|
||||
files: list[dict[str, Any]] = version_info.get("files", [])
|
||||
primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
|
||||
|
||||
if not primary_file:
|
||||
console.print("[red]Error: No files found for this model version.[/red]")
|
||||
return 1
|
||||
|
||||
filename = primary_file.get("name", f"model-{version_id}.safetensors")
|
||||
dest_path = output_dir / filename
|
||||
|
||||
# Display model info
|
||||
model_table = Table(title="Model Download", show_header=True, header_style="bold magenta")
|
||||
model_table.add_column("Property", style="cyan")
|
||||
model_table.add_column("Value", style="green")
|
||||
model_table.add_row("Version", version_info.get("name", "N/A"))
|
||||
model_table.add_row("Base Model", version_info.get("baseModel", "N/A"))
|
||||
model_table.add_row("File", filename)
|
||||
model_table.add_row("Size", f"{primary_file.get('sizeKB', 0) / 1024:.2f} MB")
|
||||
model_table.add_row("Destination", str(dest_path))
|
||||
console.print()
|
||||
console.print(model_table)
|
||||
console.print()
|
||||
|
||||
# Download
|
||||
success = download_model(version_id, dest_path, api_key, resume=not args.no_resume)
|
||||
return 0 if success else 1
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Main entry point."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Read safetensor metadata and download CivitAI models.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", help="Commands")
|
||||
|
||||
# Info command (default)
|
||||
info_parser = subparsers.add_parser(
|
||||
"info",
|
||||
help="Read safetensor metadata and fetch CivitAI info (default)",
|
||||
)
|
||||
info_parser.add_argument(
|
||||
"file",
|
||||
type=Path,
|
||||
help="Path to the safetensor file",
|
||||
)
|
||||
info_parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="CivitAI API key for authenticated requests",
|
||||
)
|
||||
info_parser.add_argument(
|
||||
"--skip-civitai",
|
||||
action="store_true",
|
||||
help="Skip CivitAI API lookup",
|
||||
)
|
||||
info_parser.add_argument(
|
||||
"--json",
|
||||
action="store_true",
|
||||
dest="json_output",
|
||||
help="Output results as JSON",
|
||||
)
|
||||
info_parser.add_argument(
|
||||
"--save-to",
|
||||
type=Path,
|
||||
metavar="DIR",
|
||||
help="Save metadata JSON and SHA256 hash to the specified directory",
|
||||
)
|
||||
info_parser.set_defaults(func=cmd_info)
|
||||
|
||||
# Download command
|
||||
dl_parser = subparsers.add_parser(
|
||||
"download",
|
||||
aliases=["dl"],
|
||||
help="Download a model from CivitAI",
|
||||
)
|
||||
dl_parser.add_argument(
|
||||
"--version-id",
|
||||
"-v",
|
||||
type=int,
|
||||
help="CivitAI model version ID to download",
|
||||
)
|
||||
dl_parser.add_argument(
|
||||
"--hash",
|
||||
"-H",
|
||||
type=str,
|
||||
help="SHA256 hash to look up and download",
|
||||
)
|
||||
dl_parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="CivitAI API key for authenticated requests",
|
||||
)
|
||||
dl_parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
type=Path,
|
||||
help="Output directory (default: current directory)",
|
||||
)
|
||||
dl_parser.add_argument(
|
||||
"--no-resume",
|
||||
action="store_true",
|
||||
help="Don't resume partial downloads, start fresh",
|
||||
)
|
||||
dl_parser.set_defaults(func=cmd_download)
|
||||
|
||||
# Parse and handle default command
|
||||
args = parser.parse_args()
|
||||
|
||||
# If no command specified and file argument given, assume 'info' command
|
||||
if args.command is None:
|
||||
# Check if there's a positional argument (file path)
|
||||
if len(sys.argv) > 1 and not sys.argv[1].startswith("-"):
|
||||
# Re-parse with 'info' prepended
|
||||
args = parser.parse_args(["info", *sys.argv[1:]])
|
||||
else:
|
||||
parser.print_help()
|
||||
return 0
|
||||
|
||||
result: int = args.func(args)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
Reference in New Issue
Block a user