Add Hugging Face Hub integration for safetensor files
- Add `tsr hf search` to search HF models with safetensor files - Add `tsr hf get` to view model info and list safetensor files - Add `tsr hf files` to list safetensor files in a model - Add `tsr hf dl` to download safetensor files from HF Uses official huggingface_hub library for API access. Only safetensor files are supported (enforced at search and download). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
+128
@@ -37,10 +37,19 @@ from tensors.display import (
|
||||
_format_size,
|
||||
display_civitai_data,
|
||||
display_file_info,
|
||||
display_hf_model_info,
|
||||
display_hf_search_results,
|
||||
display_local_metadata,
|
||||
display_model_info,
|
||||
display_search_results,
|
||||
)
|
||||
from tensors.hf import (
|
||||
download_all_safetensors,
|
||||
download_hf_safetensor,
|
||||
get_hf_model,
|
||||
list_safetensor_files,
|
||||
search_hf_models,
|
||||
)
|
||||
from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_metadata
|
||||
|
||||
# Key masking threshold
|
||||
@@ -720,6 +729,124 @@ def db_stats(
|
||||
console.print(table)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Hugging Face Commands
|
||||
# =============================================================================
|
||||
|
||||
hf_app = typer.Typer(name="hf", help="Hugging Face Hub commands for safetensor files.")
|
||||
app.add_typer(hf_app)
|
||||
|
||||
|
||||
@hf_app.command("search")
|
||||
def hf_search(
|
||||
query: Annotated[str | None, typer.Argument(help="Search query")] = None,
|
||||
author: Annotated[str | None, typer.Option("-a", "--author", help="Filter by author/org")] = None,
|
||||
pipeline: Annotated[str | None, typer.Option("-p", "--pipeline", help="Pipeline tag (text-to-image, etc.)")] = None,
|
||||
sort: Annotated[str | None, typer.Option("-s", "--sort", help="Sort by (downloads, likes, created_at)")] = None,
|
||||
limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 25,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
) -> None:
|
||||
"""Search Hugging Face for models with safetensor files."""
|
||||
results = search_hf_models(
|
||||
query=query,
|
||||
author=author,
|
||||
pipeline_tag=pipeline,
|
||||
sort=sort,
|
||||
limit=limit,
|
||||
console=console,
|
||||
)
|
||||
|
||||
if json_output:
|
||||
console.print_json(data=results)
|
||||
return
|
||||
|
||||
display_hf_search_results(results, console)
|
||||
|
||||
|
||||
@hf_app.command("get")
|
||||
def hf_get(
|
||||
model_id: Annotated[str, typer.Argument(help="Model ID (e.g., stabilityai/stable-diffusion-xl-base-1.0)")],
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
) -> None:
|
||||
"""Get Hugging Face model info and list safetensor files."""
|
||||
model = get_hf_model(model_id, console=console)
|
||||
|
||||
if not model:
|
||||
raise typer.Exit(1)
|
||||
|
||||
if json_output:
|
||||
console.print_json(data=model)
|
||||
return
|
||||
|
||||
display_hf_model_info(model, console)
|
||||
|
||||
|
||||
@hf_app.command("files")
|
||||
def hf_files(
|
||||
model_id: Annotated[str, typer.Argument(help="Model ID")],
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
) -> None:
|
||||
"""List safetensor files in a Hugging Face model."""
|
||||
files = list_safetensor_files(model_id, console=console)
|
||||
|
||||
if json_output:
|
||||
console.print_json(data=files)
|
||||
return
|
||||
|
||||
if not files:
|
||||
console.print("[yellow]No safetensor files found.[/yellow]")
|
||||
return
|
||||
|
||||
console.print(f"[bold]Safetensor files in {model_id}:[/bold]")
|
||||
for i, f in enumerate(files, 1):
|
||||
console.print(f" {i}. {f}")
|
||||
|
||||
|
||||
@hf_app.command("dl")
|
||||
def hf_download(
|
||||
model_id: Annotated[str, typer.Argument(help="Model ID (e.g., stabilityai/stable-diffusion-xl-base-1.0)")],
|
||||
filename: Annotated[str | None, typer.Option("-f", "--file", help="Specific file to download")] = None,
|
||||
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None,
|
||||
all_files: Annotated[bool, typer.Option("--all", "-a", help="Download all safetensor files")] = False,
|
||||
) -> None:
|
||||
"""Download safetensor files from Hugging Face.
|
||||
|
||||
Examples:
|
||||
tsr hf dl stabilityai/stable-diffusion-xl-base-1.0 -f sd_xl_base_1.0.safetensors
|
||||
tsr hf dl author/model --all
|
||||
"""
|
||||
output_dir = output or Path.cwd()
|
||||
|
||||
if all_files:
|
||||
downloaded = download_all_safetensors(model_id, output_dir, console=console)
|
||||
if downloaded:
|
||||
console.print(f"[green]Downloaded {len(downloaded)} files[/green]")
|
||||
else:
|
||||
console.print("[red]No files downloaded[/red]")
|
||||
raise typer.Exit(1)
|
||||
return
|
||||
|
||||
if not filename:
|
||||
# List files and prompt or show help
|
||||
files = list_safetensor_files(model_id, console=console)
|
||||
if not files:
|
||||
console.print("[red]No safetensor files found in model[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if len(files) == 1:
|
||||
filename = files[0]
|
||||
console.print(f"[dim]Downloading only safetensor file: {filename}[/dim]")
|
||||
else:
|
||||
console.print("[yellow]Multiple safetensor files found. Specify one with -f or use --all:[/yellow]")
|
||||
for i, f in enumerate(files, 1):
|
||||
console.print(f" {i}. {f}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
result = download_hf_safetensor(model_id, filename, output_dir, console=console)
|
||||
if not result:
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Main entry point."""
|
||||
# Handle legacy invocation: tsr <file.safetensors> -> tsr info <file>
|
||||
@@ -732,6 +859,7 @@ def main() -> int:
|
||||
"config",
|
||||
"serve",
|
||||
"db",
|
||||
"hf",
|
||||
)
|
||||
if len(sys.argv) > 1 and not sys.argv[1].startswith("-"):
|
||||
arg = sys.argv[1]
|
||||
|
||||
@@ -329,3 +329,135 @@ def display_search_results(results: dict[str, Any], console: Console) -> None:
|
||||
total = metadata.get("totalItems", len(items))
|
||||
console.print(f"\n[dim]Showing {len(items)} of {total:,} results[/dim]")
|
||||
console.print("[dim]Use 'tsr get <id>' to view details or 'tsr dl -m <id>' to download[/dim]")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Hugging Face Display Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _format_bytes(size_bytes: int) -> str:
|
||||
"""Format size in bytes to human-readable string."""
|
||||
if size_bytes < KB:
|
||||
return f"{size_bytes} B"
|
||||
if size_bytes < KB * KB:
|
||||
return f"{size_bytes / KB:.1f} KB"
|
||||
if size_bytes < KB * KB * KB:
|
||||
return f"{size_bytes / KB / KB:.1f} MB"
|
||||
return f"{size_bytes / KB / KB / KB:.2f} GB"
|
||||
|
||||
|
||||
def _build_hf_search_table(console: Console) -> Table:
|
||||
"""Build Hugging Face search results table."""
|
||||
id_width = 40
|
||||
dls_width = 8
|
||||
likes_width = 6
|
||||
files_width = 5
|
||||
|
||||
terminal_width = console.size.width
|
||||
fixed_width = id_width + dls_width + likes_width + files_width
|
||||
overhead = 17
|
||||
author_width = max(15, (terminal_width - fixed_width - overhead) // 2)
|
||||
|
||||
table = Table(show_header=True, header_style="bold magenta")
|
||||
table.add_column("Model ID", style="cyan", width=id_width, no_wrap=True, overflow="ellipsis")
|
||||
table.add_column("Author", style="yellow", width=author_width, no_wrap=True, overflow="ellipsis")
|
||||
table.add_column("DLs", justify="right", width=dls_width, no_wrap=True)
|
||||
table.add_column("Likes", justify="right", width=likes_width, no_wrap=True)
|
||||
table.add_column("Files", justify="right", width=files_width, no_wrap=True)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def display_hf_search_results(models: list[dict[str, Any]], console: Console) -> None:
|
||||
"""Display Hugging Face search results in a table."""
|
||||
if not models:
|
||||
console.print("[yellow]No results found.[/yellow]")
|
||||
return
|
||||
|
||||
table = _build_hf_search_table(console)
|
||||
|
||||
for model in models:
|
||||
model_id = model.get("id", "N/A")
|
||||
author = model.get("author", model_id.split("/")[0] if "/" in model_id else "N/A")
|
||||
downloads = _format_count(model.get("downloads", 0))
|
||||
likes = _format_count(model.get("likes", 0))
|
||||
safetensor_files = model.get("_safetensor_files", [])
|
||||
files_count = str(len(safetensor_files))
|
||||
|
||||
table.add_row(model_id, author, downloads, likes, files_count)
|
||||
|
||||
console.print()
|
||||
console.print(table)
|
||||
console.print(f"\n[dim]Showing {len(models)} models with safetensor files[/dim]")
|
||||
console.print("[dim]Use 'tsr hf get <model_id>' to view details or 'tsr hf dl <model_id>' to download[/dim]")
|
||||
|
||||
|
||||
def _build_hf_model_table(console: Console) -> Table:
|
||||
"""Build Hugging Face model info table."""
|
||||
prop_width = 12
|
||||
terminal_width = console.size.width
|
||||
overhead = 7
|
||||
value_width = max(40, terminal_width - prop_width - overhead)
|
||||
|
||||
table = Table(title="Hugging Face Model", show_header=True, header_style="bold magenta")
|
||||
table.add_column("Property", style="cyan", width=prop_width, no_wrap=True)
|
||||
table.add_column("Value", style="green", width=value_width, no_wrap=True, overflow="ellipsis")
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def display_hf_model_info(model: dict[str, Any], console: Console) -> None:
|
||||
"""Display Hugging Face model information."""
|
||||
if not model:
|
||||
console.print("[yellow]Model not found.[/yellow]")
|
||||
return
|
||||
|
||||
table = _build_hf_model_table(console)
|
||||
|
||||
model_id = model.get("id", "N/A")
|
||||
table.add_row("Model ID", model_id)
|
||||
table.add_row("Author", model.get("author", "N/A"))
|
||||
table.add_row("Downloads", f"{model.get('downloads', 0):,}")
|
||||
table.add_row("Likes", f"{model.get('likes', 0):,}")
|
||||
|
||||
# Handle datetime or string values
|
||||
created = model.get("created_at") or model.get("createdAt")
|
||||
if created:
|
||||
created_str = created.strftime("%Y-%m-%d") if hasattr(created, "strftime") else str(created)[:10]
|
||||
table.add_row("Created", created_str)
|
||||
|
||||
updated = model.get("last_modified") or model.get("lastModified")
|
||||
if updated:
|
||||
updated_str = updated.strftime("%Y-%m-%d") if hasattr(updated, "strftime") else str(updated)[:10]
|
||||
table.add_row("Updated", updated_str)
|
||||
|
||||
tags = model.get("tags", [])
|
||||
if tags:
|
||||
table.add_row("Tags", ", ".join(tags[:MAX_TAGS_DISPLAY]) + ("..." if len(tags) > MAX_TAGS_DISPLAY else ""))
|
||||
|
||||
pipeline = model.get("pipeline_tag")
|
||||
if pipeline:
|
||||
table.add_row("Pipeline", pipeline)
|
||||
|
||||
console.print()
|
||||
console.print(table)
|
||||
|
||||
# Display safetensor files
|
||||
safetensor_files = model.get("_safetensor_files", [])
|
||||
if safetensor_files:
|
||||
files_table = Table(title="Safetensor Files", show_header=True, header_style="bold magenta")
|
||||
files_table.add_column("#", style="dim", width=3, justify="right")
|
||||
files_table.add_column("Filename", style="cyan", no_wrap=True, overflow="ellipsis")
|
||||
files_table.add_column("Size", style="green", justify="right", width=10)
|
||||
|
||||
for i, f in enumerate(safetensor_files, 1):
|
||||
filename = f.get("rfilename", "N/A")
|
||||
size = _format_bytes(f.get("size", 0)) if f.get("size") else "N/A"
|
||||
files_table.add_row(str(i), filename, size)
|
||||
|
||||
console.print()
|
||||
console.print(files_table)
|
||||
|
||||
console.print()
|
||||
console.print(f"[bold blue]View on HuggingFace:[/bold blue] https://huggingface.co/{model_id}")
|
||||
|
||||
+246
@@ -0,0 +1,246 @@
|
||||
"""Hugging Face Hub integration for safetensor files."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from huggingface_hub import HfApi, hf_hub_download, list_repo_files
|
||||
from huggingface_hub.errors import RepositoryNotFoundError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from rich.console import Console
|
||||
|
||||
# Shared API instance
|
||||
_api: HfApi | None = None
|
||||
|
||||
|
||||
def _get_api(token: str | None = None) -> HfApi:
|
||||
"""Get or create HfApi instance."""
|
||||
global _api # noqa: PLW0603
|
||||
if _api is None:
|
||||
_api = HfApi(token=token)
|
||||
return _api
|
||||
|
||||
|
||||
def search_hf_models(
|
||||
query: str | None = None,
|
||||
*,
|
||||
author: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
pipeline_tag: str | None = None,
|
||||
sort: str | None = None,
|
||||
limit: int = 25,
|
||||
token: str | None = None,
|
||||
console: Console | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search Hugging Face models with safetensor files.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
author: Filter by author/organization
|
||||
tags: Additional tags to filter by
|
||||
pipeline_tag: Pipeline type (text-generation, text-to-image, etc.)
|
||||
sort: Sort field (downloads, likes, created_at, trending_score)
|
||||
limit: Maximum results
|
||||
token: HuggingFace API token
|
||||
console: Rich console for output
|
||||
|
||||
Returns:
|
||||
List of model info dictionaries with safetensor files
|
||||
"""
|
||||
api = _get_api(token)
|
||||
|
||||
# Build filter list - always include safetensors
|
||||
filters = ["safetensors"]
|
||||
if tags:
|
||||
filters.extend(tags)
|
||||
|
||||
try:
|
||||
models = api.list_models(
|
||||
search=query,
|
||||
author=author,
|
||||
filter=filters,
|
||||
pipeline_tag=pipeline_tag,
|
||||
sort=sort or "downloads",
|
||||
limit=limit,
|
||||
expand=["siblings", "downloads", "likes", "author", "lastModified", "createdAt", "tags"],
|
||||
)
|
||||
|
||||
results = []
|
||||
for model in models:
|
||||
model_dict = model.__dict__.copy()
|
||||
|
||||
# Get safetensor files from siblings
|
||||
siblings = getattr(model, "siblings", None) or []
|
||||
safetensor_files = [
|
||||
{"rfilename": s.rfilename, "size": getattr(s, "size", None)}
|
||||
for s in siblings
|
||||
if s.rfilename.endswith(".safetensors")
|
||||
]
|
||||
|
||||
if safetensor_files:
|
||||
model_dict["_safetensor_files"] = safetensor_files
|
||||
results.append(model_dict)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
if console:
|
||||
console.print(f"[red]Error searching models: {e}[/red]")
|
||||
return []
|
||||
|
||||
|
||||
def get_hf_model(
|
||||
model_id: str,
|
||||
token: str | None = None,
|
||||
console: Console | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Get detailed model information from Hugging Face.
|
||||
|
||||
Args:
|
||||
model_id: Model ID (e.g., "stabilityai/stable-diffusion-xl-base-1.0")
|
||||
token: HuggingFace API token
|
||||
console: Rich console for output
|
||||
|
||||
Returns:
|
||||
Model info dictionary or None if not found
|
||||
"""
|
||||
api = _get_api(token)
|
||||
|
||||
try:
|
||||
model = api.model_info(model_id, files_metadata=True)
|
||||
model_dict = model.__dict__.copy()
|
||||
|
||||
# Get safetensor files
|
||||
siblings = getattr(model, "siblings", None) or []
|
||||
safetensor_files = [
|
||||
{"rfilename": s.rfilename, "size": getattr(s, "size", None)}
|
||||
for s in siblings
|
||||
if s.rfilename.endswith(".safetensors")
|
||||
]
|
||||
model_dict["_safetensor_files"] = safetensor_files
|
||||
|
||||
return model_dict
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
if console:
|
||||
console.print(f"[red]Model not found: {model_id}[/red]")
|
||||
return None
|
||||
except Exception as e:
|
||||
if console:
|
||||
console.print(f"[red]Error fetching model: {e}[/red]")
|
||||
return None
|
||||
|
||||
|
||||
def list_safetensor_files(
|
||||
model_id: str,
|
||||
token: str | None = None,
|
||||
console: Console | None = None,
|
||||
) -> list[str]:
|
||||
"""List all safetensor files in a Hugging Face model.
|
||||
|
||||
Args:
|
||||
model_id: Model ID
|
||||
token: HuggingFace API token
|
||||
console: Rich console for output
|
||||
|
||||
Returns:
|
||||
List of safetensor filenames
|
||||
"""
|
||||
try:
|
||||
files = list_repo_files(model_id, token=token)
|
||||
return [f for f in files if f.endswith(".safetensors")]
|
||||
except RepositoryNotFoundError:
|
||||
if console:
|
||||
console.print(f"[red]Model not found: {model_id}[/red]")
|
||||
return []
|
||||
except Exception as e:
|
||||
if console:
|
||||
console.print(f"[red]Error listing files: {e}[/red]")
|
||||
return []
|
||||
|
||||
|
||||
def download_hf_safetensor(
|
||||
model_id: str,
|
||||
filename: str,
|
||||
output_dir: Path,
|
||||
token: str | None = None,
|
||||
console: Console | None = None,
|
||||
*,
|
||||
resume: bool = True,
|
||||
) -> Path | None:
|
||||
"""Download a safetensor file from Hugging Face.
|
||||
|
||||
Args:
|
||||
model_id: Model ID (e.g., "stabilityai/stable-diffusion-xl-base-1.0")
|
||||
filename: File name within the model repo
|
||||
output_dir: Directory to save the file
|
||||
token: HuggingFace API token
|
||||
console: Rich console for progress output
|
||||
resume: Whether to resume partial downloads
|
||||
|
||||
Returns:
|
||||
Path to downloaded file, or None on failure
|
||||
"""
|
||||
if not filename.endswith(".safetensors"):
|
||||
if console:
|
||||
console.print("[red]Only .safetensors files are supported[/red]")
|
||||
return None
|
||||
|
||||
try:
|
||||
# hf_hub_download handles caching and resume automatically
|
||||
downloaded_path = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=filename,
|
||||
local_dir=output_dir,
|
||||
token=token,
|
||||
force_download=not resume,
|
||||
)
|
||||
|
||||
if console:
|
||||
console.print(f"[green]Downloaded: {downloaded_path}[/green]")
|
||||
|
||||
return Path(downloaded_path)
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
if console:
|
||||
console.print(f"[red]Model not found: {model_id}[/red]")
|
||||
return None
|
||||
except Exception as e:
|
||||
if console:
|
||||
console.print(f"[red]Download failed: {e}[/red]")
|
||||
return None
|
||||
|
||||
|
||||
def download_all_safetensors(
|
||||
model_id: str,
|
||||
output_dir: Path,
|
||||
token: str | None = None,
|
||||
console: Console | None = None,
|
||||
) -> list[Path]:
|
||||
"""Download all safetensor files from a model.
|
||||
|
||||
Args:
|
||||
model_id: Model ID
|
||||
output_dir: Directory to save files
|
||||
token: HuggingFace API token
|
||||
console: Rich console for output
|
||||
|
||||
Returns:
|
||||
List of downloaded file paths
|
||||
"""
|
||||
files = list_safetensor_files(model_id, token, console)
|
||||
if not files:
|
||||
return []
|
||||
|
||||
downloaded = []
|
||||
for filename in files:
|
||||
if console:
|
||||
console.print(f"[dim]Downloading {filename}...[/dim]")
|
||||
|
||||
path = download_hf_safetensor(model_id, filename, output_dir, token, console)
|
||||
if path:
|
||||
downloaded.append(path)
|
||||
|
||||
return downloaded
|
||||
Reference in New Issue
Block a user