Add Hugging Face Hub integration for safetensor files

- Add `tsr hf search` to search HF models with safetensor files
- Add `tsr hf get` to view model info and list safetensor files
- Add `tsr hf files` to list safetensor files in a model
- Add `tsr hf dl` to download safetensor files from HF

Uses official huggingface_hub library for API access.
Only safetensor files are supported (enforced at search and download).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Adam Ladachowski
2026-02-15 19:27:23 +01:00
parent 29d96e2a00
commit eb151dac8d
6 changed files with 634 additions and 1 deletions
+128
View File
@@ -37,10 +37,19 @@ from tensors.display import (
_format_size,
display_civitai_data,
display_file_info,
display_hf_model_info,
display_hf_search_results,
display_local_metadata,
display_model_info,
display_search_results,
)
from tensors.hf import (
download_all_safetensors,
download_hf_safetensor,
get_hf_model,
list_safetensor_files,
search_hf_models,
)
from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_metadata
# Key masking threshold
@@ -720,6 +729,124 @@ def db_stats(
console.print(table)
# =============================================================================
# Hugging Face Commands
# =============================================================================
hf_app = typer.Typer(name="hf", help="Hugging Face Hub commands for safetensor files.")
app.add_typer(hf_app)
@hf_app.command("search")
def hf_search(
query: Annotated[str | None, typer.Argument(help="Search query")] = None,
author: Annotated[str | None, typer.Option("-a", "--author", help="Filter by author/org")] = None,
pipeline: Annotated[str | None, typer.Option("-p", "--pipeline", help="Pipeline tag (text-to-image, etc.)")] = None,
sort: Annotated[str | None, typer.Option("-s", "--sort", help="Sort by (downloads, likes, created_at)")] = None,
limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 25,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Search Hugging Face for models with safetensor files."""
results = search_hf_models(
query=query,
author=author,
pipeline_tag=pipeline,
sort=sort,
limit=limit,
console=console,
)
if json_output:
console.print_json(data=results)
return
display_hf_search_results(results, console)
@hf_app.command("get")
def hf_get(
model_id: Annotated[str, typer.Argument(help="Model ID (e.g., stabilityai/stable-diffusion-xl-base-1.0)")],
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Get Hugging Face model info and list safetensor files."""
model = get_hf_model(model_id, console=console)
if not model:
raise typer.Exit(1)
if json_output:
console.print_json(data=model)
return
display_hf_model_info(model, console)
@hf_app.command("files")
def hf_files(
model_id: Annotated[str, typer.Argument(help="Model ID")],
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""List safetensor files in a Hugging Face model."""
files = list_safetensor_files(model_id, console=console)
if json_output:
console.print_json(data=files)
return
if not files:
console.print("[yellow]No safetensor files found.[/yellow]")
return
console.print(f"[bold]Safetensor files in {model_id}:[/bold]")
for i, f in enumerate(files, 1):
console.print(f" {i}. {f}")
@hf_app.command("dl")
def hf_download(
model_id: Annotated[str, typer.Argument(help="Model ID (e.g., stabilityai/stable-diffusion-xl-base-1.0)")],
filename: Annotated[str | None, typer.Option("-f", "--file", help="Specific file to download")] = None,
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None,
all_files: Annotated[bool, typer.Option("--all", "-a", help="Download all safetensor files")] = False,
) -> None:
"""Download safetensor files from Hugging Face.
Examples:
tsr hf dl stabilityai/stable-diffusion-xl-base-1.0 -f sd_xl_base_1.0.safetensors
tsr hf dl author/model --all
"""
output_dir = output or Path.cwd()
if all_files:
downloaded = download_all_safetensors(model_id, output_dir, console=console)
if downloaded:
console.print(f"[green]Downloaded {len(downloaded)} files[/green]")
else:
console.print("[red]No files downloaded[/red]")
raise typer.Exit(1)
return
if not filename:
# List files and prompt or show help
files = list_safetensor_files(model_id, console=console)
if not files:
console.print("[red]No safetensor files found in model[/red]")
raise typer.Exit(1)
if len(files) == 1:
filename = files[0]
console.print(f"[dim]Downloading only safetensor file: {filename}[/dim]")
else:
console.print("[yellow]Multiple safetensor files found. Specify one with -f or use --all:[/yellow]")
for i, f in enumerate(files, 1):
console.print(f" {i}. {f}")
raise typer.Exit(1)
result = download_hf_safetensor(model_id, filename, output_dir, console=console)
if not result:
raise typer.Exit(1)
def main() -> int:
"""Main entry point."""
# Handle legacy invocation: tsr <file.safetensors> -> tsr info <file>
@@ -732,6 +859,7 @@ def main() -> int:
"config",
"serve",
"db",
"hf",
)
if len(sys.argv) > 1 and not sys.argv[1].startswith("-"):
arg = sys.argv[1]
+132
View File
@@ -329,3 +329,135 @@ def display_search_results(results: dict[str, Any], console: Console) -> None:
total = metadata.get("totalItems", len(items))
console.print(f"\n[dim]Showing {len(items)} of {total:,} results[/dim]")
console.print("[dim]Use 'tsr get <id>' to view details or 'tsr dl -m <id>' to download[/dim]")
# =============================================================================
# Hugging Face Display Functions
# =============================================================================
def _format_bytes(size_bytes: int) -> str:
"""Format size in bytes to human-readable string."""
if size_bytes < KB:
return f"{size_bytes} B"
if size_bytes < KB * KB:
return f"{size_bytes / KB:.1f} KB"
if size_bytes < KB * KB * KB:
return f"{size_bytes / KB / KB:.1f} MB"
return f"{size_bytes / KB / KB / KB:.2f} GB"
def _build_hf_search_table(console: Console) -> Table:
"""Build Hugging Face search results table."""
id_width = 40
dls_width = 8
likes_width = 6
files_width = 5
terminal_width = console.size.width
fixed_width = id_width + dls_width + likes_width + files_width
overhead = 17
author_width = max(15, (terminal_width - fixed_width - overhead) // 2)
table = Table(show_header=True, header_style="bold magenta")
table.add_column("Model ID", style="cyan", width=id_width, no_wrap=True, overflow="ellipsis")
table.add_column("Author", style="yellow", width=author_width, no_wrap=True, overflow="ellipsis")
table.add_column("DLs", justify="right", width=dls_width, no_wrap=True)
table.add_column("Likes", justify="right", width=likes_width, no_wrap=True)
table.add_column("Files", justify="right", width=files_width, no_wrap=True)
return table
def display_hf_search_results(models: list[dict[str, Any]], console: Console) -> None:
"""Display Hugging Face search results in a table."""
if not models:
console.print("[yellow]No results found.[/yellow]")
return
table = _build_hf_search_table(console)
for model in models:
model_id = model.get("id", "N/A")
author = model.get("author", model_id.split("/")[0] if "/" in model_id else "N/A")
downloads = _format_count(model.get("downloads", 0))
likes = _format_count(model.get("likes", 0))
safetensor_files = model.get("_safetensor_files", [])
files_count = str(len(safetensor_files))
table.add_row(model_id, author, downloads, likes, files_count)
console.print()
console.print(table)
console.print(f"\n[dim]Showing {len(models)} models with safetensor files[/dim]")
console.print("[dim]Use 'tsr hf get <model_id>' to view details or 'tsr hf dl <model_id>' to download[/dim]")
def _build_hf_model_table(console: Console) -> Table:
"""Build Hugging Face model info table."""
prop_width = 12
terminal_width = console.size.width
overhead = 7
value_width = max(40, terminal_width - prop_width - overhead)
table = Table(title="Hugging Face Model", show_header=True, header_style="bold magenta")
table.add_column("Property", style="cyan", width=prop_width, no_wrap=True)
table.add_column("Value", style="green", width=value_width, no_wrap=True, overflow="ellipsis")
return table
def display_hf_model_info(model: dict[str, Any], console: Console) -> None:
"""Display Hugging Face model information."""
if not model:
console.print("[yellow]Model not found.[/yellow]")
return
table = _build_hf_model_table(console)
model_id = model.get("id", "N/A")
table.add_row("Model ID", model_id)
table.add_row("Author", model.get("author", "N/A"))
table.add_row("Downloads", f"{model.get('downloads', 0):,}")
table.add_row("Likes", f"{model.get('likes', 0):,}")
# Handle datetime or string values
created = model.get("created_at") or model.get("createdAt")
if created:
created_str = created.strftime("%Y-%m-%d") if hasattr(created, "strftime") else str(created)[:10]
table.add_row("Created", created_str)
updated = model.get("last_modified") or model.get("lastModified")
if updated:
updated_str = updated.strftime("%Y-%m-%d") if hasattr(updated, "strftime") else str(updated)[:10]
table.add_row("Updated", updated_str)
tags = model.get("tags", [])
if tags:
table.add_row("Tags", ", ".join(tags[:MAX_TAGS_DISPLAY]) + ("..." if len(tags) > MAX_TAGS_DISPLAY else ""))
pipeline = model.get("pipeline_tag")
if pipeline:
table.add_row("Pipeline", pipeline)
console.print()
console.print(table)
# Display safetensor files
safetensor_files = model.get("_safetensor_files", [])
if safetensor_files:
files_table = Table(title="Safetensor Files", show_header=True, header_style="bold magenta")
files_table.add_column("#", style="dim", width=3, justify="right")
files_table.add_column("Filename", style="cyan", no_wrap=True, overflow="ellipsis")
files_table.add_column("Size", style="green", justify="right", width=10)
for i, f in enumerate(safetensor_files, 1):
filename = f.get("rfilename", "N/A")
size = _format_bytes(f.get("size", 0)) if f.get("size") else "N/A"
files_table.add_row(str(i), filename, size)
console.print()
console.print(files_table)
console.print()
console.print(f"[bold blue]View on HuggingFace:[/bold blue] https://huggingface.co/{model_id}")
+246
View File
@@ -0,0 +1,246 @@
"""Hugging Face Hub integration for safetensor files."""
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, Any
from huggingface_hub import HfApi, hf_hub_download, list_repo_files
from huggingface_hub.errors import RepositoryNotFoundError
if TYPE_CHECKING:
from rich.console import Console
# Shared API instance
_api: HfApi | None = None
def _get_api(token: str | None = None) -> HfApi:
"""Get or create HfApi instance."""
global _api # noqa: PLW0603
if _api is None:
_api = HfApi(token=token)
return _api
def search_hf_models(
query: str | None = None,
*,
author: str | None = None,
tags: list[str] | None = None,
pipeline_tag: str | None = None,
sort: str | None = None,
limit: int = 25,
token: str | None = None,
console: Console | None = None,
) -> list[dict[str, Any]]:
"""Search Hugging Face models with safetensor files.
Args:
query: Search query string
author: Filter by author/organization
tags: Additional tags to filter by
pipeline_tag: Pipeline type (text-generation, text-to-image, etc.)
sort: Sort field (downloads, likes, created_at, trending_score)
limit: Maximum results
token: HuggingFace API token
console: Rich console for output
Returns:
List of model info dictionaries with safetensor files
"""
api = _get_api(token)
# Build filter list - always include safetensors
filters = ["safetensors"]
if tags:
filters.extend(tags)
try:
models = api.list_models(
search=query,
author=author,
filter=filters,
pipeline_tag=pipeline_tag,
sort=sort or "downloads",
limit=limit,
expand=["siblings", "downloads", "likes", "author", "lastModified", "createdAt", "tags"],
)
results = []
for model in models:
model_dict = model.__dict__.copy()
# Get safetensor files from siblings
siblings = getattr(model, "siblings", None) or []
safetensor_files = [
{"rfilename": s.rfilename, "size": getattr(s, "size", None)}
for s in siblings
if s.rfilename.endswith(".safetensors")
]
if safetensor_files:
model_dict["_safetensor_files"] = safetensor_files
results.append(model_dict)
return results
except Exception as e:
if console:
console.print(f"[red]Error searching models: {e}[/red]")
return []
def get_hf_model(
model_id: str,
token: str | None = None,
console: Console | None = None,
) -> dict[str, Any] | None:
"""Get detailed model information from Hugging Face.
Args:
model_id: Model ID (e.g., "stabilityai/stable-diffusion-xl-base-1.0")
token: HuggingFace API token
console: Rich console for output
Returns:
Model info dictionary or None if not found
"""
api = _get_api(token)
try:
model = api.model_info(model_id, files_metadata=True)
model_dict = model.__dict__.copy()
# Get safetensor files
siblings = getattr(model, "siblings", None) or []
safetensor_files = [
{"rfilename": s.rfilename, "size": getattr(s, "size", None)}
for s in siblings
if s.rfilename.endswith(".safetensors")
]
model_dict["_safetensor_files"] = safetensor_files
return model_dict
except RepositoryNotFoundError:
if console:
console.print(f"[red]Model not found: {model_id}[/red]")
return None
except Exception as e:
if console:
console.print(f"[red]Error fetching model: {e}[/red]")
return None
def list_safetensor_files(
model_id: str,
token: str | None = None,
console: Console | None = None,
) -> list[str]:
"""List all safetensor files in a Hugging Face model.
Args:
model_id: Model ID
token: HuggingFace API token
console: Rich console for output
Returns:
List of safetensor filenames
"""
try:
files = list_repo_files(model_id, token=token)
return [f for f in files if f.endswith(".safetensors")]
except RepositoryNotFoundError:
if console:
console.print(f"[red]Model not found: {model_id}[/red]")
return []
except Exception as e:
if console:
console.print(f"[red]Error listing files: {e}[/red]")
return []
def download_hf_safetensor(
model_id: str,
filename: str,
output_dir: Path,
token: str | None = None,
console: Console | None = None,
*,
resume: bool = True,
) -> Path | None:
"""Download a safetensor file from Hugging Face.
Args:
model_id: Model ID (e.g., "stabilityai/stable-diffusion-xl-base-1.0")
filename: File name within the model repo
output_dir: Directory to save the file
token: HuggingFace API token
console: Rich console for progress output
resume: Whether to resume partial downloads
Returns:
Path to downloaded file, or None on failure
"""
if not filename.endswith(".safetensors"):
if console:
console.print("[red]Only .safetensors files are supported[/red]")
return None
try:
# hf_hub_download handles caching and resume automatically
downloaded_path = hf_hub_download(
repo_id=model_id,
filename=filename,
local_dir=output_dir,
token=token,
force_download=not resume,
)
if console:
console.print(f"[green]Downloaded: {downloaded_path}[/green]")
return Path(downloaded_path)
except RepositoryNotFoundError:
if console:
console.print(f"[red]Model not found: {model_id}[/red]")
return None
except Exception as e:
if console:
console.print(f"[red]Download failed: {e}[/red]")
return None
def download_all_safetensors(
model_id: str,
output_dir: Path,
token: str | None = None,
console: Console | None = None,
) -> list[Path]:
"""Download all safetensor files from a model.
Args:
model_id: Model ID
output_dir: Directory to save files
token: HuggingFace API token
console: Rich console for output
Returns:
List of downloaded file paths
"""
files = list_safetensor_files(model_id, token, console)
if not files:
return []
downloaded = []
for filename in files:
if console:
console.print(f"[dim]Downloading {filename}...[/dim]")
path = download_hf_safetensor(model_id, filename, output_dir, token, console)
if path:
downloaded.append(path)
return downloaded