💬 Commit message: Update 2026-01-24 19:22:34, 7 files, 1285 lines
📁 Files changed: 7 📝 Lines changed: 1285 • build.py • check.py • civit.md • main.py • pyproject.toml • sft_get.py • uv.lock
This commit is contained in:
+306
@@ -0,0 +1,306 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
sft-get: Read safetensor metadata and fetch CivitAI model information.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from rich.console import Console
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
DownloadColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TaskProgressColumn,
|
||||
TextColumn,
|
||||
TimeRemainingColumn,
|
||||
TransferSpeedColumn,
|
||||
)
|
||||
from rich.table import Table
|
||||
|
||||
console = Console()
|
||||
|
||||
CIVITAI_API_BASE = "https://civitai.com/api/v1"
|
||||
|
||||
|
||||
def read_safetensor_metadata(file_path: Path) -> dict[str, Any]:
|
||||
"""Read metadata from a safetensor file header."""
|
||||
with file_path.open("rb") as f:
|
||||
# First 8 bytes are the header size (little-endian u64)
|
||||
header_size_bytes = f.read(8)
|
||||
if len(header_size_bytes) < 8:
|
||||
raise ValueError("Invalid safetensor file: too short")
|
||||
|
||||
header_size = struct.unpack("<Q", header_size_bytes)[0]
|
||||
|
||||
if header_size > 100_000_000: # 100MB sanity check
|
||||
raise ValueError(f"Invalid header size: {header_size}")
|
||||
|
||||
header_bytes = f.read(header_size)
|
||||
if len(header_bytes) < header_size:
|
||||
raise ValueError("Invalid safetensor file: header truncated")
|
||||
|
||||
header: dict[str, Any] = json.loads(header_bytes.decode("utf-8"))
|
||||
|
||||
# Extract __metadata__ if present
|
||||
metadata: dict[str, Any] = header.get("__metadata__", {})
|
||||
|
||||
# Count tensors (keys that aren't __metadata__)
|
||||
tensor_count = sum(1 for k in header if k != "__metadata__")
|
||||
|
||||
return {
|
||||
"metadata": metadata,
|
||||
"tensor_count": tensor_count,
|
||||
"header_size": header_size,
|
||||
}
|
||||
|
||||
|
||||
def compute_sha256(file_path: Path) -> str:
|
||||
"""Compute SHA256 hash of a file with progress display."""
|
||||
file_size = file_path.stat().st_size
|
||||
sha256 = hashlib.sha256()
|
||||
chunk_size = 1024 * 1024 * 8 # 8MB chunks
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
DownloadColumn(),
|
||||
TransferSpeedColumn(),
|
||||
TimeRemainingColumn(),
|
||||
console=console,
|
||||
) as progress:
|
||||
task = progress.add_task(f"[cyan]Hashing {file_path.name}...", total=file_size)
|
||||
|
||||
with file_path.open("rb") as f:
|
||||
while chunk := f.read(chunk_size):
|
||||
sha256.update(chunk)
|
||||
progress.update(task, advance=len(chunk))
|
||||
|
||||
return sha256.hexdigest().upper()
|
||||
|
||||
|
||||
def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None = None) -> dict[str, Any] | None:
|
||||
"""Fetch model information from CivitAI by SHA256 hash."""
|
||||
url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}"
|
||||
headers: dict[str, str] = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
console=console,
|
||||
transient=True,
|
||||
) as progress:
|
||||
progress.add_task("[cyan]Fetching from CivitAI...", total=None)
|
||||
|
||||
try:
|
||||
response = httpx.get(url, headers=headers, timeout=30.0)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
response.raise_for_status()
|
||||
result: dict[str, Any] = response.json()
|
||||
return result
|
||||
except httpx.HTTPStatusError as e:
|
||||
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
||||
return None
|
||||
except httpx.RequestError as e:
|
||||
console.print(f"[red]Request error: {e}[/red]")
|
||||
return None
|
||||
|
||||
|
||||
def _display_file_info(file_path: Path, local_metadata: dict[str, Any], sha256_hash: str) -> None:
|
||||
"""Display file information table."""
|
||||
file_table = Table(title="File Information", show_header=True, header_style="bold magenta")
|
||||
file_table.add_column("Property", style="cyan")
|
||||
file_table.add_column("Value", style="green")
|
||||
|
||||
file_table.add_row("File", str(file_path.name))
|
||||
file_table.add_row("Path", str(file_path.parent))
|
||||
file_table.add_row("Size", f"{file_path.stat().st_size / (1024**3):.2f} GB")
|
||||
file_table.add_row("SHA256", sha256_hash)
|
||||
file_table.add_row("Header Size", f"{local_metadata['header_size']:,} bytes")
|
||||
file_table.add_row("Tensor Count", str(local_metadata["tensor_count"]))
|
||||
|
||||
console.print()
|
||||
console.print(file_table)
|
||||
|
||||
|
||||
def _display_local_metadata(local_metadata: dict[str, Any]) -> None:
|
||||
"""Display local safetensor metadata table."""
|
||||
if local_metadata["metadata"]:
|
||||
meta_table = Table(
|
||||
title="Safetensor Metadata", show_header=True, header_style="bold magenta"
|
||||
)
|
||||
meta_table.add_column("Key", style="cyan")
|
||||
meta_table.add_column("Value", style="green", max_width=80)
|
||||
|
||||
for key, value in sorted(local_metadata["metadata"].items()):
|
||||
display_value = str(value)
|
||||
if len(display_value) > 200:
|
||||
display_value = display_value[:200] + "..."
|
||||
meta_table.add_row(key, display_value)
|
||||
|
||||
console.print()
|
||||
console.print(meta_table)
|
||||
else:
|
||||
console.print()
|
||||
console.print("[yellow]No embedded metadata found in safetensor file.[/yellow]")
|
||||
|
||||
|
||||
def _display_civitai_data(civitai_data: dict[str, Any] | None) -> None:
|
||||
"""Display CivitAI model information table."""
|
||||
if not civitai_data:
|
||||
console.print()
|
||||
console.print("[yellow]Model not found on CivitAI.[/yellow]")
|
||||
return
|
||||
|
||||
civit_table = Table(
|
||||
title="CivitAI Model Information", show_header=True, header_style="bold magenta"
|
||||
)
|
||||
civit_table.add_column("Property", style="cyan")
|
||||
civit_table.add_column("Value", style="green", max_width=80)
|
||||
|
||||
civit_table.add_row("Model ID", str(civitai_data.get("modelId", "N/A")))
|
||||
civit_table.add_row("Version ID", str(civitai_data.get("id", "N/A")))
|
||||
civit_table.add_row("Version Name", str(civitai_data.get("name", "N/A")))
|
||||
civit_table.add_row("Base Model", str(civitai_data.get("baseModel", "N/A")))
|
||||
civit_table.add_row("Created At", str(civitai_data.get("createdAt", "N/A")))
|
||||
|
||||
# Trained words
|
||||
trained_words: list[str] = civitai_data.get("trainedWords", [])
|
||||
if trained_words:
|
||||
civit_table.add_row("Trigger Words", ", ".join(trained_words))
|
||||
|
||||
# Download URL
|
||||
download_url = str(civitai_data.get("downloadUrl", "N/A"))
|
||||
civit_table.add_row("Download URL", download_url)
|
||||
|
||||
# File info from CivitAI
|
||||
files: list[dict[str, Any]] = civitai_data.get("files", [])
|
||||
for f in files:
|
||||
if f.get("primary"):
|
||||
civit_table.add_row("Primary File", str(f.get("name", "N/A")))
|
||||
civit_table.add_row(
|
||||
"File Size (CivitAI)",
|
||||
f"{f.get('sizeKB', 0) / 1024:.2f} MB",
|
||||
)
|
||||
meta: dict[str, Any] = f.get("metadata", {})
|
||||
if meta:
|
||||
civit_table.add_row("Format", str(meta.get("format", "N/A")))
|
||||
civit_table.add_row("Precision", str(meta.get("fp", "N/A")))
|
||||
civit_table.add_row("Size Type", str(meta.get("size", "N/A")))
|
||||
|
||||
console.print()
|
||||
console.print(civit_table)
|
||||
|
||||
# Model page link
|
||||
model_id = civitai_data.get("modelId")
|
||||
if model_id:
|
||||
console.print()
|
||||
console.print(
|
||||
f"[bold blue]View on CivitAI:[/bold blue] https://civitai.com/models/{model_id}"
|
||||
)
|
||||
|
||||
|
||||
def display_results(
|
||||
file_path: Path,
|
||||
local_metadata: dict[str, Any],
|
||||
sha256_hash: str,
|
||||
civitai_data: dict[str, Any] | None,
|
||||
) -> None:
|
||||
"""Display results in rich tables."""
|
||||
_display_file_info(file_path, local_metadata, sha256_hash)
|
||||
_display_local_metadata(local_metadata)
|
||||
_display_civitai_data(civitai_data)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Main entry point."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Read safetensor metadata and fetch CivitAI model information.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"file",
|
||||
type=Path,
|
||||
help="Path to the safetensor file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="CivitAI API key for authenticated requests",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-civitai",
|
||||
action="store_true",
|
||||
help="Skip CivitAI API lookup",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--json",
|
||||
action="store_true",
|
||||
dest="json_output",
|
||||
help="Output results as JSON",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
file_path: Path = args.file.resolve()
|
||||
|
||||
if not file_path.exists():
|
||||
console.print(f"[red]Error: File not found: {file_path}[/red]")
|
||||
return 1
|
||||
|
||||
if file_path.suffix.lower() not in (".safetensors", ".sft"):
|
||||
console.print("[yellow]Warning: File does not have .safetensors extension[/yellow]")
|
||||
|
||||
try:
|
||||
# Read local metadata
|
||||
console.print(f"[bold]Reading safetensor file:[/bold] {file_path.name}")
|
||||
local_metadata = read_safetensor_metadata(file_path)
|
||||
|
||||
# Compute SHA256
|
||||
sha256_hash = compute_sha256(file_path)
|
||||
|
||||
# Fetch from CivitAI
|
||||
civitai_data = None
|
||||
if not args.skip_civitai:
|
||||
civitai_data = fetch_civitai_by_hash(sha256_hash, args.api_key)
|
||||
|
||||
if args.json_output:
|
||||
output = {
|
||||
"file": str(file_path),
|
||||
"sha256": sha256_hash,
|
||||
"header_size": local_metadata["header_size"],
|
||||
"tensor_count": local_metadata["tensor_count"],
|
||||
"metadata": local_metadata["metadata"],
|
||||
"civitai": civitai_data,
|
||||
}
|
||||
console.print_json(data=output)
|
||||
else:
|
||||
display_results(file_path, local_metadata, sha256_hash, civitai_data)
|
||||
|
||||
return 0
|
||||
|
||||
except ValueError as e:
|
||||
console.print(f"[red]Error reading safetensor: {e}[/red]")
|
||||
return 1
|
||||
except Exception as e:
|
||||
console.print(f"[red]Unexpected error: {e}[/red]")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user