💬 Commit message: Update 2026-01-26 06:35:11, 1 files, 326 lines

📁 Files changed: 1
📝 Lines changed: 326

  • sft_get.py
This commit is contained in:
Adam Ladachowski
2026-01-26 06:35:11 +01:00
parent 67bc60b3ad
commit 65d80c5eb5
+289 -37
View File
@@ -8,6 +8,7 @@ from __future__ import annotations
import argparse import argparse
import hashlib import hashlib
import json import json
import re
import struct import struct
import sys import sys
from pathlib import Path from pathlib import Path
@@ -30,6 +31,7 @@ from rich.table import Table
console = Console() console = Console()
CIVITAI_API_BASE = "https://civitai.com/api/v1" CIVITAI_API_BASE = "https://civitai.com/api/v1"
CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models"
def read_safetensor_metadata(file_path: Path) -> dict[str, Any]: def read_safetensor_metadata(file_path: Path) -> dict[str, Any]:
@@ -90,6 +92,30 @@ def compute_sha256(file_path: Path) -> str:
return sha256.hexdigest().upper() return sha256.hexdigest().upper()
def fetch_civitai_model_version(
version_id: int, api_key: str | None = None
) -> dict[str, Any] | None:
"""Fetch model version information from CivitAI by version ID."""
url = f"{CIVITAI_API_BASE}/model-versions/{version_id}"
headers: dict[str, str] = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
try:
response = httpx.get(url, headers=headers, timeout=30.0)
if response.status_code == 404:
return None
response.raise_for_status()
result: dict[str, Any] = response.json()
return result
except httpx.HTTPStatusError as e:
console.print(f"[red]API error: {e.response.status_code}[/red]")
return None
except httpx.RequestError as e:
console.print(f"[red]Request error: {e}[/red]")
return None
def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None = None) -> dict[str, Any] | None: def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None = None) -> dict[str, Any] | None:
"""Fetch model information from CivitAI by SHA256 hash.""" """Fetch model information from CivitAI by SHA256 hash."""
url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}" url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}"
@@ -120,6 +146,93 @@ def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None = None) -> dict[
return None return None
def download_model(
version_id: int,
dest_path: Path,
api_key: str | None = None,
resume: bool = True,
) -> bool:
"""Download a model from CivitAI by version ID with resume support.
Returns True on success, False on failure.
"""
url = f"{CIVITAI_DOWNLOAD_BASE}/{version_id}"
params: dict[str, str] = {}
if api_key:
params["token"] = api_key
headers: dict[str, str] = {}
mode = "wb"
initial_size = 0
# Check for existing partial download
if resume and dest_path.exists():
initial_size = dest_path.stat().st_size
headers["Range"] = f"bytes={initial_size}-"
mode = "ab"
console.print(f"[cyan]Resuming download from {initial_size / (1024**2):.1f} MB[/cyan]")
try:
with httpx.stream(
"GET",
url,
params=params,
headers=headers,
follow_redirects=True,
timeout=httpx.Timeout(30.0, read=None), # No read timeout for large files
) as response:
# Handle 416 Range Not Satisfiable (file already complete)
if response.status_code == 416:
console.print("[green]File already fully downloaded.[/green]")
return True
response.raise_for_status()
# Get total size from Content-Length or Content-Range
content_length = response.headers.get("content-length")
total_size = int(content_length) + initial_size if content_length else 0
# Get filename from Content-Disposition if available
content_disp = response.headers.get("content-disposition", "")
if "filename=" in content_disp:
match = re.search(r'filename="?([^";\n]+)"?', content_disp)
if match and dest_path.is_dir():
dest_path = dest_path / match.group(1)
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
DownloadColumn(),
TransferSpeedColumn(),
TimeRemainingColumn(),
console=console,
) as progress:
task = progress.add_task(
f"[cyan]Downloading {dest_path.name}...",
total=total_size if total_size > 0 else None,
completed=initial_size,
)
with dest_path.open(mode) as f:
for chunk in response.iter_bytes(1024 * 1024): # 1MB chunks
f.write(chunk)
progress.update(task, advance=len(chunk))
console.print(f"[green]Downloaded:[/green] {dest_path}")
return True
except httpx.HTTPStatusError as e:
console.print(f"[red]Download error: HTTP {e.response.status_code}[/red]")
if e.response.status_code == 401:
console.print("[yellow]Hint: This model may require an API key.[/yellow]")
return False
except httpx.RequestError as e:
console.print(f"[red]Download error: {e}[/red]")
return False
def _display_file_info(file_path: Path, local_metadata: dict[str, Any], sha256_hash: str) -> None: def _display_file_info(file_path: Path, local_metadata: dict[str, Any], sha256_hash: str) -> None:
"""Display file information table.""" """Display file information table."""
file_table = Table(title="File Information", show_header=True, header_style="bold magenta") file_table = Table(title="File Information", show_header=True, header_style="bold magenta")
@@ -264,43 +377,8 @@ def save_metadata(
return json_path, sha_path return json_path, sha_path
def main() -> int: def cmd_info(args: argparse.Namespace) -> int:
"""Main entry point.""" """Handle the info subcommand (default behavior)."""
parser = argparse.ArgumentParser(
description="Read safetensor metadata and fetch CivitAI model information.",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"file",
type=Path,
help="Path to the safetensor file",
)
parser.add_argument(
"--api-key",
type=str,
default=None,
help="CivitAI API key for authenticated requests",
)
parser.add_argument(
"--skip-civitai",
action="store_true",
help="Skip CivitAI API lookup",
)
parser.add_argument(
"--json",
action="store_true",
dest="json_output",
help="Output results as JSON",
)
parser.add_argument(
"--save-to",
type=Path,
metavar="DIR",
help="Save metadata JSON and SHA256 hash to the specified directory",
)
args = parser.parse_args()
file_path: Path = args.file.resolve() file_path: Path = args.file.resolve()
if not file_path.exists(): if not file_path.exists():
@@ -362,5 +440,179 @@ def main() -> int:
return 1 return 1
def _resolve_version_id(
version_id: int | None, sha256_hash: str | None, api_key: str | None
) -> int | None:
"""Resolve version ID from hash if needed."""
if version_id:
return version_id
if sha256_hash:
console.print(f"[cyan]Looking up model by hash: {sha256_hash[:16]}...[/cyan]")
civitai_data = fetch_civitai_by_hash(sha256_hash.upper(), api_key)
if not civitai_data:
console.print("[red]Error: Model not found on CivitAI for this hash.[/red]")
return None
vid = civitai_data.get("id")
if vid:
console.print(f"[green]Found model version:[/green] {civitai_data.get('name', 'N/A')}")
else:
console.print("[red]Error: Could not determine version ID from CivitAI response.[/red]")
return vid
return None
def cmd_download(args: argparse.Namespace) -> int:
"""Handle the download subcommand."""
api_key: str | None = args.api_key
output_dir: Path = args.output.resolve() if args.output else Path.cwd()
if not output_dir.exists():
console.print(f"[red]Error: Output directory not found: {output_dir}[/red]")
return 1
if not output_dir.is_dir():
console.print(f"[red]Error: Not a directory: {output_dir}[/red]")
return 1
# Resolve version ID from hash if needed
version_id = _resolve_version_id(args.version_id, args.hash, api_key)
if not version_id:
if not args.version_id and not args.hash:
console.print("[red]Error: Must specify --version-id or --hash[/red]")
return 1
# Fetch version info to get filename
console.print(f"[cyan]Fetching model info for version {version_id}...[/cyan]")
version_info = fetch_civitai_model_version(version_id, api_key)
if not version_info:
console.print("[red]Error: Could not fetch model version info.[/red]")
return 1
# Find primary file or first file
files: list[dict[str, Any]] = version_info.get("files", [])
primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
if not primary_file:
console.print("[red]Error: No files found for this model version.[/red]")
return 1
filename = primary_file.get("name", f"model-{version_id}.safetensors")
dest_path = output_dir / filename
# Display model info
model_table = Table(title="Model Download", show_header=True, header_style="bold magenta")
model_table.add_column("Property", style="cyan")
model_table.add_column("Value", style="green")
model_table.add_row("Version", version_info.get("name", "N/A"))
model_table.add_row("Base Model", version_info.get("baseModel", "N/A"))
model_table.add_row("File", filename)
model_table.add_row("Size", f"{primary_file.get('sizeKB', 0) / 1024:.2f} MB")
model_table.add_row("Destination", str(dest_path))
console.print()
console.print(model_table)
console.print()
# Download
success = download_model(version_id, dest_path, api_key, resume=not args.no_resume)
return 0 if success else 1
def main() -> int:
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Read safetensor metadata and download CivitAI models.",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
subparsers = parser.add_subparsers(dest="command", help="Commands")
# Info command (default)
info_parser = subparsers.add_parser(
"info",
help="Read safetensor metadata and fetch CivitAI info (default)",
)
info_parser.add_argument(
"file",
type=Path,
help="Path to the safetensor file",
)
info_parser.add_argument(
"--api-key",
type=str,
default=None,
help="CivitAI API key for authenticated requests",
)
info_parser.add_argument(
"--skip-civitai",
action="store_true",
help="Skip CivitAI API lookup",
)
info_parser.add_argument(
"--json",
action="store_true",
dest="json_output",
help="Output results as JSON",
)
info_parser.add_argument(
"--save-to",
type=Path,
metavar="DIR",
help="Save metadata JSON and SHA256 hash to the specified directory",
)
info_parser.set_defaults(func=cmd_info)
# Download command
dl_parser = subparsers.add_parser(
"download",
aliases=["dl"],
help="Download a model from CivitAI",
)
dl_parser.add_argument(
"--version-id",
"-v",
type=int,
help="CivitAI model version ID to download",
)
dl_parser.add_argument(
"--hash",
"-H",
type=str,
help="SHA256 hash to look up and download",
)
dl_parser.add_argument(
"--api-key",
type=str,
default=None,
help="CivitAI API key for authenticated requests",
)
dl_parser.add_argument(
"--output",
"-o",
type=Path,
help="Output directory (default: current directory)",
)
dl_parser.add_argument(
"--no-resume",
action="store_true",
help="Don't resume partial downloads, start fresh",
)
dl_parser.set_defaults(func=cmd_download)
# Parse and handle default command
args = parser.parse_args()
# If no command specified and file argument given, assume 'info' command
if args.command is None:
# Check if there's a positional argument (file path)
if len(sys.argv) > 1 and not sys.argv[1].startswith("-"):
# Re-parse with 'info' prepended
args = parser.parse_args(["info", *sys.argv[1:]])
else:
parser.print_help()
return 0
result: int = args.func(args)
return result
if __name__ == "__main__": if __name__ == "__main__":
sys.exit(main()) sys.exit(main())