Files
tensors/tensors/cli.py
T

1380 lines
52 KiB
Python

"""CLI application and commands for tsr."""
from __future__ import annotations
import json
import sys
from importlib.metadata import version
from pathlib import Path
from typing import Annotated, Any
import typer
from rich.console import Console
from rich.table import Table
from tensors.api import (
download_model,
fetch_civitai_by_hash,
fetch_civitai_model,
fetch_civitai_model_version,
search_civitai,
)
from tensors.config import (
CONFIG_FILE,
MODEL_FAMILY_DEFAULTS,
BaseModel,
CommercialUse,
ModelType,
NsfwLevel,
Period,
Provider,
SortOrder,
detect_model_family,
get_default_output_path,
get_model_paths,
load_api_key,
load_config,
save_config,
)
from tensors.db import DB_PATH, Database
from tensors.display import (
_format_size,
display_civitai_data,
display_file_info,
display_hf_model_info,
display_hf_search_results,
display_local_metadata,
display_model_info,
display_search_results,
)
from tensors.hf import (
download_all_safetensors,
download_hf_safetensor,
get_hf_model,
list_safetensor_files,
search_hf_models,
)
from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_metadata
# Key masking threshold
MIN_KEY_LENGTH_FOR_MASKING = 8
# Display truncation limits
MAX_QUEUE_DISPLAY = 10
MAX_MODEL_LIST_DISPLAY = 20
MAX_PROMPT_ID_DISPLAY = 36
def _cache_model_quietly(model_data: dict[str, Any]) -> None:
"""Cache model data to database without output."""
try:
with Database() as db:
db.init_schema()
db.cache_model(model_data)
except Exception:
pass # Silently ignore cache failures
def _cache_models_quietly(models: list[dict[str, Any]]) -> None:
"""Cache multiple models to database without output."""
if not models:
return
try:
with Database() as db:
db.init_schema()
for model_data in models:
db.cache_model(model_data)
except Exception:
pass # Silently ignore cache failures
def _version_callback(value: bool) -> None:
if value:
print(f"tsr {version('tensors')}")
raise typer.Exit
app = typer.Typer(
name="tsr",
help="Read safetensor metadata, search and download CivitAI models.",
no_args_is_help=True,
)
@app.callback()
def _main(
_version: Annotated[
bool,
typer.Option("--version", "-V", callback=_version_callback, is_eager=True, help="Show version and exit."),
] = False,
) -> None:
"""Read safetensor metadata, search and download CivitAI models."""
console = Console()
@app.command()
def info(
file: Annotated[Path, typer.Argument(help="Path to the safetensor file")],
meta: Annotated[list[str] | None, typer.Option("--meta", "-m", help="Show specific metadata key(s) in full")] = None,
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
skip_civitai: Annotated[bool, typer.Option("--skip-civitai", help="Skip CivitAI API lookup")] = False,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
save_to: Annotated[Path | None, typer.Option("--save-to", help="Save metadata to directory")] = None,
) -> None:
"""Read safetensor metadata and fetch CivitAI info."""
file_path = file.resolve()
if not file_path.exists():
console.print(f"[red]Error: File not found: {file_path}[/red]")
raise typer.Exit(1)
if file_path.suffix.lower() not in (".safetensors", ".sft"):
console.print("[yellow]Warning: File does not have .safetensors extension[/yellow]")
try:
local_metadata = read_safetensor_metadata(file_path)
if meta:
display_local_metadata(local_metadata, console, keys_filter=meta)
return
console.print(f"[bold]Reading safetensor file:[/bold] {file_path.name}")
sha256_hash = compute_sha256(file_path, console)
civitai_data = None
if not skip_civitai:
key = api_key or load_api_key()
civitai_data = fetch_civitai_by_hash(sha256_hash, key, console)
if json_output:
_output_info_json(file_path, sha256_hash, local_metadata, civitai_data)
else:
display_file_info(file_path, local_metadata, sha256_hash, console)
display_local_metadata(local_metadata, console)
display_civitai_data(civitai_data, console)
if save_to:
_save_metadata(save_to, file_path, sha256_hash, local_metadata, civitai_data)
except ValueError as e:
console.print(f"[red]Error reading safetensor: {e}[/red]")
raise typer.Exit(1) from e
def _output_info_json(
file_path: Path,
sha256_hash: str,
local_metadata: dict[str, Any],
civitai_data: dict[str, Any] | None,
) -> None:
"""Output info command result as JSON."""
output = {
"file": str(file_path),
"sha256": sha256_hash,
"header_size": local_metadata["header_size"],
"tensor_count": local_metadata["tensor_count"],
"metadata": local_metadata["metadata"],
"civitai": civitai_data,
}
console.print_json(data=output)
def _save_metadata(
save_to: Path,
file_path: Path,
sha256_hash: str,
local_metadata: dict[str, Any],
civitai_data: dict[str, Any] | None,
) -> None:
"""Save metadata to directory."""
output_dir = save_to.resolve()
if not output_dir.exists() or not output_dir.is_dir():
console.print(f"[red]Error: Invalid directory: {output_dir}[/red]")
raise typer.Exit(1)
base_name = get_base_name(file_path)
json_path = output_dir / f"{base_name}.json"
sha_path = output_dir / f"{base_name}.sha256"
output = {
"file": str(file_path),
"sha256": sha256_hash,
"header_size": local_metadata["header_size"],
"tensor_count": local_metadata["tensor_count"],
"metadata": local_metadata["metadata"],
"civitai": civitai_data,
}
json_path.write_text(json.dumps(output, indent=2))
sha_path.write_text(f"{sha256_hash} {file_path.name}\n")
console.print()
console.print(f"[green]Saved:[/green] {json_path}")
console.print(f"[green]Saved:[/green] {sha_path}")
@app.command()
def search(
query: Annotated[str | None, typer.Argument(help="Search query (optional)")] = None,
provider: Annotated[Provider, typer.Option("--provider", "-P", help="Search provider")] = Provider.all,
model_type: Annotated[ModelType | None, typer.Option("-t", "--type", help="Model type filter (CivitAI)")] = None,
base: Annotated[BaseModel | None, typer.Option("-b", "--base", help="Base model filter (CivitAI)")] = None,
sort: Annotated[SortOrder, typer.Option("-s", "--sort", help="Sort order")] = SortOrder.downloads,
limit: Annotated[int, typer.Option("-n", "--limit", help="Max results per provider")] = 20,
period: Annotated[Period | None, typer.Option("-p", "--period", help="Time period (CivitAI)")] = None,
tag: Annotated[str | None, typer.Option("--tag", help="Filter by tag")] = None,
username: Annotated[str | None, typer.Option("-u", "--user", "-a", "--author", help="Filter by creator/author")] = None,
page: Annotated[int | None, typer.Option("--page", help="Page number (CivitAI)")] = None,
nsfw: Annotated[NsfwLevel | None, typer.Option("--nsfw", help="NSFW filter level (CivitAI)")] = None,
sfw: Annotated[bool, typer.Option("--sfw", help="Exclude NSFW content (CivitAI)")] = False,
commercial: Annotated[CommercialUse | None, typer.Option("--commercial", help="Commercial use filter (CivitAI)")] = None,
pipeline: Annotated[str | None, typer.Option("--pipeline", help="Pipeline tag (HuggingFace)")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
) -> None:
"""Search models on CivitAI and/or Hugging Face.
Examples:
tsr search "flux" # Search both providers
tsr search "anime" -P civitai # CivitAI only
tsr search "llama" -P hf # Hugging Face only
tsr search -t lora -b pony # CivitAI LoRAs for Pony
tsr search -a stabilityai -P hf # HF by author
tsr search --sfw -P civitai # CivitAI SFW only
"""
key = api_key or load_api_key()
civitai_results: dict[str, Any] | None = None
hf_results: list[dict[str, Any]] | None = None
# Search CivitAI
if provider in (Provider.civitai, Provider.all):
nsfw_filter: NsfwLevel | bool | None = NsfwLevel.none if sfw else nsfw
civitai_results = search_civitai(
query=query,
model_type=model_type,
base_model=base,
sort=sort,
limit=limit,
api_key=key,
console=console if provider == Provider.civitai else None,
period=period,
nsfw=nsfw_filter,
tag=tag,
username=username,
page=page,
commercial_use=commercial,
)
if civitai_results:
_cache_models_quietly(civitai_results.get("items", []))
# Search Hugging Face
if provider in (Provider.hf, Provider.all):
tags = [tag] if tag else None
hf_results = search_hf_models(
query=query,
author=username,
tags=tags,
pipeline_tag=pipeline,
sort="downloads" if sort == SortOrder.downloads else "likes" if sort == SortOrder.rating else "created_at",
limit=limit,
console=console if provider == Provider.hf else None,
)
# Output results
if json_output:
output: dict[str, Any] = {}
if civitai_results:
output["civitai"] = civitai_results
if hf_results:
output["huggingface"] = hf_results
console.print_json(data=output)
return
# Display based on provider
if provider == Provider.civitai:
if not civitai_results:
console.print("[red]CivitAI search failed.[/red]")
raise typer.Exit(1)
display_search_results(civitai_results, console)
elif provider == Provider.hf:
if hf_results is None:
console.print("[red]Hugging Face search failed.[/red]")
raise typer.Exit(1)
display_hf_search_results(hf_results, console)
else:
# Both providers
if civitai_results and civitai_results.get("items"):
console.print("\n[bold cyan]═══ CivitAI Results ═══[/bold cyan]")
display_search_results(civitai_results, console)
if hf_results:
console.print("\n[bold cyan]═══ Hugging Face Results ═══[/bold cyan]")
display_hf_search_results(hf_results, console)
if not (civitai_results and civitai_results.get("items")) and not hf_results:
console.print("[yellow]No results found on either provider.[/yellow]")
@app.command()
def get(
id_value: Annotated[int, typer.Argument(help="CivitAI model ID or version ID")],
version: Annotated[bool, typer.Option("-v", "--version", help="Treat ID as version ID instead of model ID")] = False,
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
no_cache: Annotated[bool, typer.Option("--no-cache", help="Don't cache to local database")] = False,
) -> None:
"""Fetch model information from CivitAI by model ID or version ID."""
key = api_key or load_api_key()
if version:
version_data = fetch_civitai_model_version(id_value, key, console)
if not version_data:
console.print(f"[red]Error: Version {id_value} not found on CivitAI.[/red]")
raise typer.Exit(1)
# Auto-cache version data (need to fetch full model for complete cache)
if not no_cache:
model_id = version_data.get("modelId")
if model_id:
model_data = fetch_civitai_model(model_id, key)
if model_data:
_cache_model_quietly(model_data)
if json_output:
console.print_json(data=version_data)
else:
display_civitai_data(version_data, console)
else:
model_data = fetch_civitai_model(id_value, key, console)
if not model_data:
console.print(f"[red]Error: Model {id_value} not found on CivitAI.[/red]")
raise typer.Exit(1)
# Auto-cache model data
if not no_cache:
_cache_model_quietly(model_data)
if json_output:
console.print_json(data=model_data)
else:
display_model_info(model_data, console)
def _resolve_by_hash(hash_val: str, api_key: str | None) -> int | None:
"""Resolve version ID from SHA256 hash."""
console.print(f"[cyan]Looking up model by hash: {hash_val[:16]}...[/cyan]")
civitai_data = fetch_civitai_by_hash(hash_val.upper(), api_key, console)
if not civitai_data:
console.print("[red]Error: Model not found on CivitAI for this hash.[/red]")
return None
vid: int | None = civitai_data.get("id")
if vid:
console.print(f"[green]Found:[/green] {civitai_data.get('name', 'N/A')}")
return vid
def _resolve_by_model_id(model_id: int, api_key: str | None) -> int | None:
"""Resolve latest version ID from model ID."""
console.print(f"[cyan]Looking up model {model_id}...[/cyan]")
model_data = fetch_civitai_model(model_id, api_key, console)
if not model_data:
console.print(f"[red]Error: Model {model_id} not found.[/red]")
return None
versions = model_data.get("modelVersions", [])
if not versions:
console.print("[red]Error: Model has no versions.[/red]")
return None
latest = versions[0]
latest_vid: int | None = latest.get("id")
if latest_vid:
console.print(f"[green]Found latest:[/green] {latest.get('name', 'N/A')} (ID: {latest_vid})")
return latest_vid
def _resolve_version_id(
version_id: int | None,
hash_val: str | None,
model_id: int | None,
api_key: str | None,
) -> int | None:
"""Resolve version ID from direct ID, hash, or model ID."""
if version_id:
return version_id
if hash_val:
return _resolve_by_hash(hash_val, api_key)
if model_id:
return _resolve_by_model_id(model_id, api_key)
return None
def _prepare_download_dir(output: Path | None, model_type_str: str | None) -> Path | None:
"""Prepare output directory for download."""
if output is None:
output_dir = get_default_output_path(model_type_str)
if output_dir is None:
console.print(f"[red]Error: No default path for type '{model_type_str}'. Use --output to specify.[/red]")
return None
console.print(f"[dim]Using default path for {model_type_str}: {output_dir}[/dim]")
else:
output_dir = output.resolve()
if not output_dir.exists():
console.print(f"[cyan]Creating directory: {output_dir}[/cyan]")
output_dir.mkdir(parents=True, exist_ok=True)
elif not output_dir.is_dir():
console.print(f"[red]Error: Not a directory: {output_dir}[/red]")
return None
return output_dir
@app.command("dl")
def download(
version_id: Annotated[int | None, typer.Option("-v", "--version-id", help="Model version ID")] = None,
model_id: Annotated[int | None, typer.Option("-m", "--model-id", help="Model ID (downloads latest)")] = None,
hash_val: Annotated[str | None, typer.Option("-H", "--hash", help="SHA256 hash to look up")] = None,
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None,
no_resume: Annotated[bool, typer.Option("--no-resume", help="Don't resume partial downloads")] = False,
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
) -> None:
"""Download a model from CivitAI."""
key = api_key or load_api_key()
resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key)
if not resolved_version_id:
if not version_id and not hash_val and not model_id:
console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]")
raise typer.Exit(1)
console.print(f"[cyan]Fetching version info for {resolved_version_id}...[/cyan]")
version_info = fetch_civitai_model_version(resolved_version_id, key, console)
if not version_info:
console.print("[red]Error: Could not fetch model version info.[/red]")
raise typer.Exit(1)
model_type_str: str | None = version_info.get("model", {}).get("type")
output_dir = _prepare_download_dir(output, model_type_str)
if not output_dir:
raise typer.Exit(1)
files: list[dict[str, Any]] = version_info.get("files", [])
primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
if not primary_file:
console.print("[red]Error: No files found for this version.[/red]")
raise typer.Exit(1)
filename = primary_file.get("name", f"model-{resolved_version_id}.safetensors")
dest_path = output_dir / filename
_display_download_info(version_info, filename, primary_file, dest_path)
success = download_model(resolved_version_id, dest_path, key, console, resume=not no_resume)
if not success:
raise typer.Exit(1)
def _display_download_info(
version_info: dict[str, Any],
filename: str,
primary_file: dict[str, Any],
dest_path: Path,
) -> None:
"""Display download info table."""
table = Table(title="Model Download", show_header=True, header_style="bold magenta")
table.add_column("Property", style="cyan")
table.add_column("Value", style="green")
table.add_row("Version", version_info.get("name", "N/A"))
table.add_row("Base Model", version_info.get("baseModel", "N/A"))
table.add_row("File", filename)
table.add_row("Size", _format_size(primary_file.get("sizeKB", 0)))
table.add_row("Destination", str(dest_path))
console.print()
console.print(table)
console.print()
@app.command()
def config(
show: Annotated[bool, typer.Option("--show", help="Show current config")] = False,
set_key: Annotated[str | None, typer.Option("--set-key", help="Set CivitAI API key")] = None,
set_path: Annotated[str | None, typer.Option("--set-path", help="Set model path (TYPE=PATH)")] = None,
) -> None:
"""Manage configuration."""
if set_key:
cfg = load_config()
if "api" not in cfg:
cfg["api"] = {}
cfg["api"]["civitai_key"] = set_key
save_config(cfg)
console.print(f"[green]API key saved to {CONFIG_FILE}[/green]")
return
if set_path:
# Parse TYPE=PATH format
if "=" not in set_path:
console.print("[red]Error: Use format TYPE=PATH (e.g., checkpoints=/opt/models/checkpoints)[/red]")
raise typer.Exit(1)
path_type, path_value = set_path.split("=", 1)
path_type = path_type.lower().strip()
valid_types = ["checkpoints", "loras", "embeddings", "vae", "controlnet", "upscalers", "other"]
if path_type not in valid_types:
console.print(f"[red]Error: Invalid type '{path_type}'. Valid: {', '.join(valid_types)}[/red]")
raise typer.Exit(1)
cfg = load_config()
if "paths" not in cfg:
cfg["paths"] = {}
cfg["paths"][path_type] = path_value.strip()
save_config(cfg)
console.print(f"[green]Path for {path_type} set to: {path_value}[/green]")
return
if show or (not set_key and not set_path):
console.print(f"[bold]Config file:[/bold] {CONFIG_FILE}")
console.print(f"[bold]Config exists:[/bold] {CONFIG_FILE.exists()}")
key = load_api_key()
if key:
masked = key[:4] + "..." + key[-4:] if len(key) > MIN_KEY_LENGTH_FOR_MASKING else "***"
console.print(f"[bold]API key:[/bold] {masked}")
else:
console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]")
console.print()
console.print("[bold]Model paths:[/bold]")
paths = get_model_paths()
# Group by unique paths to show cleanly
shown_paths: dict[str, list[str]] = {}
for model_type, path in paths.items():
path_str = str(path)
if path_str not in shown_paths:
shown_paths[path_str] = []
shown_paths[path_str].append(model_type)
cfg = load_config()
configured_paths = cfg.get("paths", {})
for path_str, types in sorted(shown_paths.items(), key=lambda x: x[0]):
is_custom = any(
path_str == configured_paths.get(k)
for k in ["checkpoints", "loras", "embeddings", "vae", "controlnet", "upscalers", "other"]
)
marker = " [green](custom)[/green]" if is_custom else " [dim](default)[/dim]"
console.print(f" {', '.join(sorted(types))}: {path_str}{marker}")
console.print()
console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]")
console.print("[dim]Set paths with: tsr config --set-path checkpoints=/path/to/models[/dim]")
@app.command()
def serve(
host: Annotated[str, typer.Option(help="Listen address.")] = "127.0.0.1",
port: Annotated[int, typer.Option(help="Listen port.")] = 51200,
log_level: Annotated[str, typer.Option(help="Log level.")] = "info",
) -> None:
"""Start the tensors server (gallery and CivitAI management)."""
try:
import uvicorn # noqa: PLC0415
from tensors.server import create_app # noqa: PLC0415
except ImportError:
console.print("[red]Missing server dependencies. Install with:[/red]")
console.print(" pip install tensors[server]")
raise typer.Exit(1) from None
uvicorn.run(create_app(), host=host, port=port, log_level=log_level)
# =============================================================================
# Database Commands
# =============================================================================
db_app = typer.Typer(
name="db",
help="Manage local models database and CivitAI cache.",
no_args_is_help=True,
)
app.add_typer(db_app, name="db")
@db_app.command("scan")
def db_scan(
directory: Annotated[Path, typer.Argument(help="Directory to scan for safetensor files")],
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Scan directory for safetensor files and add to database."""
path = directory.resolve()
if not path.exists() or not path.is_dir():
console.print(f"[red]Error: Directory not found: {path}[/red]")
raise typer.Exit(1)
with Database() as db:
db.init_schema()
console.print(f"[cyan]Scanning {path}...[/cyan]")
results = db.scan_directory(path, console if not json_output else None)
if json_output:
console.print_json(data=results)
else:
console.print(f"[green]Scanned {len(results)} file(s)[/green]")
for f in results:
console.print(f"{f['file_path']}")
@db_app.command("link")
def db_link(
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Link unlinked local files to CivitAI by hash lookup."""
key = api_key or load_api_key()
with Database() as db:
db.init_schema()
unlinked = db.get_unlinked_files()
if not unlinked:
console.print("[green]All files already linked.[/green]")
return
console.print(f"[cyan]Found {len(unlinked)} unlinked file(s)[/cyan]")
linked: list[dict[str, Any]] = []
for file_info in unlinked:
sha256 = file_info["sha256"]
console.print(f"[dim]Looking up {sha256[:16]}...[/dim]")
civitai_data = fetch_civitai_by_hash(sha256, key, console if not json_output else None)
if civitai_data:
version_id: int = civitai_data.get("id", 0)
model_id: int = civitai_data.get("modelId", 0)
if version_id and model_id:
db.link_file_to_civitai(file_info["id"], model_id, version_id)
linked.append(
{
"file": file_info["file_path"],
"model_id": model_id,
"version_id": version_id,
"name": civitai_data.get("name", ""),
}
)
if not json_output:
console.print(f" [green]✓[/green] {civitai_data.get('name', 'N/A')}")
if json_output:
console.print_json(data=linked)
else:
console.print(f"[green]Linked {len(linked)} file(s)[/green]")
@db_app.command("cache")
def db_cache(
model_id: Annotated[int, typer.Argument(help="CivitAI model ID to cache")],
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Fetch and cache full CivitAI model data."""
key = api_key or load_api_key()
model_data = fetch_civitai_model(model_id, key, console if not json_output else None)
if not model_data:
console.print(f"[red]Error: Model {model_id} not found on CivitAI.[/red]")
raise typer.Exit(1)
with Database() as db:
db.init_schema()
internal_id = db.cache_model(model_data)
if json_output:
console.print_json(data={"model_id": model_id, "internal_id": internal_id, "name": model_data.get("name")})
else:
console.print(f"[green]Cached:[/green] {model_data.get('name')} (internal ID: {internal_id})")
@db_app.command("list")
def db_list(
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""List local files with CivitAI info."""
with Database() as db:
db.init_schema()
files = db.list_local_files()
if json_output:
console.print_json(data=files)
return
if not files:
console.print("[yellow]No files in database. Run 'tsr db scan' first.[/yellow]")
return
table = Table(title="Local Files", show_header=True, header_style="bold magenta")
table.add_column("Path", style="cyan", max_width=50)
table.add_column("Model", style="green")
table.add_column("Version", style="white")
table.add_column("Type", style="yellow")
table.add_column("Base", style="dim")
for f in files:
path = Path(f["file_path"]).name
model = f.get("model_name") or "[dim]unlinked[/dim]"
version = f.get("version_name") or ""
model_type = f.get("model_type") or ""
base = f.get("base_model") or ""
table.add_row(path, model, version, model_type, base)
console.print(table)
@db_app.command("search")
def db_search(
query: Annotated[str | None, typer.Argument(help="Search query")] = None,
model_type: Annotated[str | None, typer.Option("-t", "--type", help="Model type filter")] = None,
base_model: Annotated[str | None, typer.Option("-b", "--base", help="Base model filter")] = None,
limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 20,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Search cached models offline."""
with Database() as db:
db.init_schema()
results = db.search_models(query=query, model_type=model_type, base_model=base_model, limit=limit)
if json_output:
console.print_json(data=results)
return
if not results:
console.print("[yellow]No models found.[/yellow]")
return
table = Table(title="Cached Models", show_header=True, header_style="bold magenta")
table.add_column("ID", style="dim")
table.add_column("Name", style="cyan")
table.add_column("Type", style="yellow")
table.add_column("Base", style="green")
table.add_column("Creator", style="dim")
table.add_column("Downloads", justify="right")
for m in results:
table.add_row(
str(m.get("civitai_id", "")),
m.get("name", ""),
m.get("type", ""),
m.get("base_model", ""),
m.get("creator", ""),
str(m.get("download_count", 0)),
)
console.print(table)
@db_app.command("triggers")
def db_triggers(
file: Annotated[Path, typer.Argument(help="Path to safetensor file")],
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Show trigger words for a LoRA file."""
file_path = file.resolve()
if not file_path.exists():
console.print(f"[red]Error: File not found: {file_path}[/red]")
raise typer.Exit(1)
with Database() as db:
db.init_schema()
triggers = db.get_triggers(str(file_path))
if json_output:
console.print_json(data=triggers)
return
if not triggers:
console.print("[yellow]No trigger words found. File may not be linked to CivitAI.[/yellow]")
console.print("[dim]Run 'tsr db link' to link files to CivitAI.[/dim]")
return
console.print(f"[bold]Trigger words for {file_path.name}:[/bold]")
for word in triggers:
console.print(f"{word}")
@db_app.command("stats")
def db_stats(
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Show database statistics."""
with Database() as db:
db.init_schema()
stats = db.get_stats()
if json_output:
console.print_json(data={"db_path": str(DB_PATH), "stats": stats})
return
table = Table(title="Database Statistics", show_header=True, header_style="bold magenta")
table.add_column("Table", style="cyan")
table.add_column("Count", style="green", justify="right")
for table_name, count in stats.items():
table.add_row(table_name, str(count))
console.print(f"[dim]Database: {DB_PATH}[/dim]")
console.print(table)
# =============================================================================
# Hugging Face Commands
# =============================================================================
hf_app = typer.Typer(name="hf", help="Hugging Face Hub commands for safetensor files.")
app.add_typer(hf_app)
@hf_app.command("get")
def hf_get(
model_id: Annotated[str, typer.Argument(help="Model ID (e.g., stabilityai/stable-diffusion-xl-base-1.0)")],
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Get Hugging Face model info and list safetensor files."""
model = get_hf_model(model_id, console=console)
if not model:
raise typer.Exit(1)
if json_output:
console.print_json(data=model)
return
display_hf_model_info(model, console)
@hf_app.command("files")
def hf_files(
model_id: Annotated[str, typer.Argument(help="Model ID")],
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""List safetensor files in a Hugging Face model."""
files = list_safetensor_files(model_id, console=console)
if json_output:
console.print_json(data=files)
return
if not files:
console.print("[yellow]No safetensor files found.[/yellow]")
return
console.print(f"[bold]Safetensor files in {model_id}:[/bold]")
for i, f in enumerate(files, 1):
console.print(f" {i}. {f}")
@hf_app.command("dl")
def hf_download(
model_id: Annotated[str, typer.Argument(help="Model ID (e.g., stabilityai/stable-diffusion-xl-base-1.0)")],
filename: Annotated[str | None, typer.Option("-f", "--file", help="Specific file to download")] = None,
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None,
all_files: Annotated[bool, typer.Option("--all", "-a", help="Download all safetensor files")] = False,
) -> None:
"""Download safetensor files from Hugging Face.
Examples:
tsr hf dl stabilityai/stable-diffusion-xl-base-1.0 -f sd_xl_base_1.0.safetensors
tsr hf dl author/model --all
"""
output_dir = output or Path.cwd()
if all_files:
downloaded = download_all_safetensors(model_id, output_dir, console=console)
if downloaded:
console.print(f"[green]Downloaded {len(downloaded)} files[/green]")
else:
console.print("[red]No files downloaded[/red]")
raise typer.Exit(1)
return
if not filename:
# List files and prompt or show help
files = list_safetensor_files(model_id, console=console)
if not files:
console.print("[red]No safetensor files found in model[/red]")
raise typer.Exit(1)
if len(files) == 1:
filename = files[0]
console.print(f"[dim]Downloading only safetensor file: {filename}[/dim]")
else:
console.print("[yellow]Multiple safetensor files found. Specify one with -f or use --all:[/yellow]")
for i, f in enumerate(files, 1):
console.print(f" {i}. {f}")
raise typer.Exit(1)
result = download_hf_safetensor(model_id, filename, output_dir, console=console)
if not result:
raise typer.Exit(1)
# =============================================================================
# ComfyUI Commands
# =============================================================================
comfy_app = typer.Typer(name="comfy", help="ComfyUI integration for image generation.")
app.add_typer(comfy_app)
@comfy_app.command("status")
def comfy_status(
url: Annotated[str | None, typer.Option("--url", "-u", help="ComfyUI server URL")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Show ComfyUI system status (GPU, RAM, queue)."""
from tensors.comfyui import get_queue_status, get_system_stats # noqa: PLC0415
stats = get_system_stats(url=url, console=console if not json_output else None)
if not stats:
console.print("[red]Error: Could not connect to ComfyUI[/red]")
raise typer.Exit(1)
queue = get_queue_status(url=url)
if json_output:
output = {"system": stats, "queue": queue}
console.print_json(data=output)
return
# Display system stats
console.print("[bold cyan]ComfyUI System Status[/bold cyan]")
console.print()
# System info
system_info = stats.get("system", {})
console.print(f"[bold]OS:[/bold] {system_info.get('os', 'N/A')}")
console.print(f"[bold]Python:[/bold] {system_info.get('python_version', 'N/A')}")
console.print(f"[bold]PyTorch:[/bold] {system_info.get('pytorch_version', 'N/A')}")
# GPU info
devices = stats.get("devices", [])
if devices:
console.print()
console.print("[bold]GPU Devices:[/bold]")
for i, device in enumerate(devices):
name = device.get("name", "Unknown")
vram_total = device.get("vram_total", 0)
vram_free = device.get("vram_free", 0)
vram_used = vram_total - vram_free
vram_pct = (vram_used / vram_total * 100) if vram_total > 0 else 0
console.print(f" [{i}] {name}")
console.print(f" VRAM: {vram_used / 1024**3:.1f} / {vram_total / 1024**3:.1f} GB ({vram_pct:.0f}%)")
# Queue info
if queue:
running = len(queue.get("queue_running", []))
pending = len(queue.get("queue_pending", []))
console.print()
console.print(f"[bold]Queue:[/bold] {running} running, {pending} pending")
@comfy_app.command("queue")
def comfy_queue(
url: Annotated[str | None, typer.Option("--url", "-u", help="ComfyUI server URL")] = None,
clear: Annotated[bool, typer.Option("--clear", "-c", help="Clear the queue")] = False,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Show or clear the ComfyUI queue."""
from tensors.comfyui import clear_queue as do_clear_queue # noqa: PLC0415
from tensors.comfyui import get_queue_status # noqa: PLC0415
if clear:
success = do_clear_queue(url=url, console=console)
if not success:
raise typer.Exit(1)
return
queue = get_queue_status(url=url, console=console if not json_output else None)
if not queue:
console.print("[red]Error: Could not connect to ComfyUI[/red]")
raise typer.Exit(1)
if json_output:
console.print_json(data=queue)
return
running = queue.get("queue_running", [])
pending = queue.get("queue_pending", [])
console.print("[bold cyan]ComfyUI Queue[/bold cyan]")
console.print()
console.print(f"[bold]Running:[/bold] {len(running)}")
console.print(f"[bold]Pending:[/bold] {len(pending)}")
if running:
console.print()
console.print("[bold]Running Jobs:[/bold]")
for job in running:
prompt_id = job[1] if len(job) > 1 else "unknown"
console.print(f"{prompt_id}")
if pending:
console.print()
console.print("[bold]Pending Jobs:[/bold]")
for job in pending[:MAX_QUEUE_DISPLAY]:
prompt_id = job[1] if len(job) > 1 else "unknown"
console.print(f"{prompt_id}")
if len(pending) > MAX_QUEUE_DISPLAY:
console.print(f" ... and {len(pending) - MAX_QUEUE_DISPLAY} more")
@comfy_app.command("models")
def comfy_models(
url: Annotated[str | None, typer.Option("--url", "-u", help="ComfyUI server URL")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""List available models in ComfyUI."""
from tensors.comfyui import get_loaded_models # noqa: PLC0415
models = get_loaded_models(url=url, console=console if not json_output else None)
if not models:
console.print("[red]Error: Could not fetch models from ComfyUI[/red]")
raise typer.Exit(1)
if json_output:
console.print_json(data=models)
return
console.print("[bold cyan]ComfyUI Available Models[/bold cyan]")
for model_type, model_list in sorted(models.items()):
console.print()
console.print(f"[bold]{model_type}:[/bold] ({len(model_list)})")
for name in model_list[:MAX_MODEL_LIST_DISPLAY]:
console.print(f"{name}")
if len(model_list) > MAX_MODEL_LIST_DISPLAY:
console.print(f" ... and {len(model_list) - MAX_MODEL_LIST_DISPLAY} more")
@comfy_app.command("history")
def comfy_history(
prompt_id: Annotated[str | None, typer.Argument(help="Specific prompt ID to view")] = None,
url: Annotated[str | None, typer.Option("--url", "-u", help="ComfyUI server URL")] = None,
limit: Annotated[int, typer.Option("-n", "--limit", help="Max history items")] = 20,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""View ComfyUI generation history."""
from tensors.comfyui import get_history # noqa: PLC0415
history = get_history(url=url, prompt_id=prompt_id, max_items=limit, console=console if not json_output else None)
if history is None:
console.print("[red]Error: Could not fetch history from ComfyUI[/red]")
raise typer.Exit(1)
if json_output:
console.print_json(data=history)
return
if not history:
console.print("[yellow]No history found.[/yellow]")
return
if prompt_id:
# Show single entry details
if prompt_id not in history:
console.print(f"[yellow]Prompt {prompt_id} not found in history.[/yellow]")
return
entry = history[prompt_id]
console.print(f"[bold cyan]Prompt: {prompt_id}[/bold cyan]")
console.print()
status = entry.get("status", {})
console.print(f"[bold]Status:[/bold] {status.get('status_str', 'unknown')}")
outputs = entry.get("outputs", {})
if outputs:
console.print()
console.print("[bold]Outputs:[/bold]")
for node_id, output in outputs.items():
if "images" in output:
for img in output["images"]:
console.print(f" [{node_id}] {img.get('filename', 'unknown')}")
else:
# Show list of history entries
console.print("[bold cyan]ComfyUI History[/bold cyan]")
console.print()
table = Table(show_header=True, header_style="bold magenta")
table.add_column("Prompt ID", style="cyan", max_width=40)
table.add_column("Status", style="green")
table.add_column("Images", justify="right")
for pid, entry in list(history.items())[:limit]:
status = entry.get("status", {}).get("status_str", "unknown")
outputs = entry.get("outputs", {})
image_count = sum(len(o.get("images", [])) for o in outputs.values())
display_pid = pid[:MAX_PROMPT_ID_DISPLAY] + "..." if len(pid) > MAX_PROMPT_ID_DISPLAY else pid
table.add_row(display_pid, status, str(image_count))
console.print(table)
@comfy_app.command("generate")
def comfy_generate( # noqa: PLR0915
prompt: Annotated[str, typer.Argument(help="Positive prompt text")],
url: Annotated[str | None, typer.Option("--url", "-u", help="ComfyUI server URL")] = None,
negative: Annotated[str, typer.Option("-n", "--negative", help="Negative prompt")] = "",
model: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model name")] = None,
width: Annotated[int, typer.Option("-W", "--width", help="Image width")] = 1024,
height: Annotated[int, typer.Option("-H", "--height", help="Image height")] = 1024,
steps: Annotated[int, typer.Option("--steps", help="Sampling steps")] = 20,
cfg: Annotated[float, typer.Option("--cfg", help="CFG scale")] = 7.0,
seed: Annotated[int, typer.Option("--seed", "-s", help="Random seed (-1 for random)")] = -1,
sampler: Annotated[str, typer.Option("--sampler", help="Sampler name")] = "euler",
scheduler: Annotated[str, typer.Option("--scheduler", help="Scheduler name")] = "normal",
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file path")] = None,
count: Annotated[int, typer.Option("-c", "--count", help="Number of images to generate")] = 1,
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None,
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 1.0,
no_quality: Annotated[bool, typer.Option("--no-quality", help="Disable auto quality tags")] = False,
no_negative: Annotated[bool, typer.Option("--no-negative", help="Disable auto negative prompt")] = False,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Generate an image with a simple text-to-image workflow.
Examples:
tsr comfy generate "a cat sitting on a windowsill"
tsr comfy generate "portrait photo" -n "blurry, bad quality" --steps 30
tsr comfy generate "landscape" -m "flux1-dev-fp8.safetensors" -W 1024 -H 768
tsr comfy generate "cyberpunk city" --count 4 -o batch.png
tsr comfy generate "girl" --lora spumcostyle.safetensors --lora-strength 0.8
tsr comfy generate "raw prompt" --no-quality --no-negative
"""
import random # noqa: PLC0415
from tensors.comfyui import generate_image, get_image # noqa: PLC0415
all_results: list[dict[str, Any]] = []
all_saved: list[Path] = []
# Determine base seed for batch
base_seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
# Detect model family and apply defaults
family_defaults: dict[str, Any] = {}
model_family: str | None = None
if model:
# Try to get base_model from database
base_model_str: str | None = None
try:
with Database() as db:
db.init_schema()
base_model_str = db.get_base_model_by_filename(model)
except Exception:
pass
model_family = detect_model_family(model, base_model_str)
if model_family:
family_defaults = MODEL_FAMILY_DEFAULTS.get(model_family, {})
if not json_output:
console.print(f"[dim]Detected model family: {model_family}[/dim]")
# Build enhanced prompt with quality prefix and LoRA trigger words
enhanced_prompt = prompt
prompt_parts: list[str] = []
# Add LoRA trigger words if using LoRA
if lora:
try:
with Database() as db:
db.init_schema()
trigger_words = db.get_trigger_words_by_filename(lora)
if trigger_words:
prompt_parts.extend(trigger_words)
if not json_output:
console.print(f"[dim]LoRA trigger words: {', '.join(trigger_words)}[/dim]")
except Exception:
pass
# Add quality prefix based on model family
if not no_quality and family_defaults.get("quality_prefix"):
prompt_parts.append(family_defaults["quality_prefix"])
# Add user prompt
prompt_parts.append(prompt)
enhanced_prompt = ", ".join(prompt_parts) if len(prompt_parts) > 1 else prompt
# Build enhanced negative prompt
enhanced_negative = negative
if not no_negative and family_defaults.get("negative_prompt"):
family_negative = family_defaults["negative_prompt"]
enhanced_negative = f"{negative}, {family_negative}" if negative else family_negative
if not json_output and (enhanced_prompt != prompt or enhanced_negative != negative):
if enhanced_prompt != prompt:
truncated = enhanced_prompt[:100] + "..." if len(enhanced_prompt) > 100 else enhanced_prompt # noqa: PLR2004
console.print(f"[dim]Enhanced prompt: {truncated}[/dim]")
if enhanced_negative != negative:
truncated = enhanced_negative[:80] + "..." if len(enhanced_negative) > 80 else enhanced_negative # noqa: PLR2004
console.print(f"[dim]Enhanced negative: {truncated}[/dim]")
for i in range(count):
current_seed = base_seed + i if seed >= 0 else -1 # Increment seed or use random each time
if count > 1 and not json_output:
console.print(f"\n[cyan]Generating image {i + 1}/{count}...[/cyan]")
result = generate_image(
prompt=enhanced_prompt,
url=url,
negative_prompt=enhanced_negative,
model=model,
width=width,
height=height,
steps=steps,
cfg=cfg,
seed=current_seed,
sampler=sampler,
scheduler=scheduler,
console=console if not json_output else None,
lora_name=lora,
lora_strength=lora_strength,
)
if not result:
if json_output:
all_results.append({"success": False, "index": i, "errors": {"generation": "Failed to generate"}})
else:
console.print(f"[red]Generation {i + 1} failed[/red]")
continue
if not result.success:
if json_output:
all_results.append({"success": False, "index": i, "errors": result.node_errors})
else:
console.print(f"[red]Generation {i + 1} failed[/red]")
for node_id, errors in result.node_errors.items():
console.print(f" [yellow]Node {node_id}:[/yellow] {errors}")
continue
# Save output if requested
saved_path: Path | None = None
if output and result.images:
img_path = result.images[0]
img_data = get_image(str(img_path), url=url)
if img_data:
save_path = output if count == 1 else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}"
save_path.write_bytes(img_data)
saved_path = save_path
all_saved.append(save_path)
if not json_output:
console.print(f"[green]Saved:[/green] {save_path}")
elif not json_output:
console.print(f"[yellow]Could not download image: {img_path}[/yellow]")
all_results.append(
{
"success": True,
"index": i,
"prompt_id": result.prompt_id,
"images": [str(img) for img in result.images],
"saved": str(saved_path) if saved_path else None,
}
)
if json_output:
console.print_json(
data={
"success": all(r.get("success", False) for r in all_results),
"count": len(all_results),
"results": all_results,
}
)
return
console.print("\n[bold green]Generation complete![/bold green]")
if count > 1:
successful = sum(1 for r in all_results if r.get("success", False))
console.print(f"[dim]Generated {successful}/{count} images[/dim]")
if all_saved:
console.print(f"[dim]Saved to: {all_saved[0].parent}/[/dim]")
elif all_results and all_results[0].get("prompt_id"):
console.print(f"[dim]Prompt ID: {all_results[0]['prompt_id']}[/dim]")
@comfy_app.command("run")
def comfy_run(
workflow_file: Annotated[Path, typer.Argument(help="Path to workflow JSON file")],
url: Annotated[str | None, typer.Option("--url", "-u", help="ComfyUI server URL")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Run an arbitrary ComfyUI workflow from a JSON file.
The workflow should be in ComfyUI API format (exported via "Save (API Format)").
"""
from tensors.comfyui import run_workflow # noqa: PLC0415
if not workflow_file.exists():
console.print(f"[red]Error: Workflow file not found: {workflow_file}[/red]")
raise typer.Exit(1)
result = run_workflow(
workflow=workflow_file,
url=url,
console=console if not json_output else None,
)
if not result:
console.print("[red]Failed to queue workflow[/red]")
raise typer.Exit(1)
if not result.success:
if json_output:
console.print_json(data={"success": False, "prompt_id": result.prompt_id, "errors": result.node_errors})
else:
console.print("[red]Workflow execution failed[/red]")
for node_id, errors in result.node_errors.items():
console.print(f" [yellow]Node {node_id}:[/yellow] {errors}")
raise typer.Exit(1)
if json_output:
console.print_json(data={"success": True, "prompt_id": result.prompt_id, "outputs": result.outputs})
return
console.print("[bold green]Workflow complete![/bold green]")
console.print(f"[dim]Prompt ID: {result.prompt_id}[/dim]")
# Show output images
for _node_id, output in result.outputs.items():
if "images" in output:
for img in output["images"]:
console.print(f" [green]Image:[/green] {img.get('filename', 'unknown')}")
def main() -> int:
"""Main entry point."""
# Handle legacy invocation: tsr <file.safetensors> -> tsr info <file>
known_commands = (
"info",
"search",
"get",
"dl",
"download",
"config",
"serve",
"db",
"hf",
"comfy",
)
if len(sys.argv) > 1 and not sys.argv[1].startswith("-"):
arg = sys.argv[1]
if arg not in known_commands and (arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists()):
sys.argv = [sys.argv[0], "info", *sys.argv[1:]]
app()
return 0
if __name__ == "__main__":
sys.exit(main())