Files
tensors/tensors/display.py
T

332 lines
12 KiB
Python

"""Rich table display functions for tsr CLI."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from pathlib import Path
from rich.table import Table
if TYPE_CHECKING:
from rich.console import Console
# Size formatting constants
KB = 1024
MB_IN_KB = KB * KB
THOUSAND = 1000
MILLION = 1_000_000
MAX_TAGS_DISPLAY = 10
def _format_size(size_kb: float) -> str:
"""Format size in KB to human-readable string."""
if size_kb < KB:
return f"{size_kb:.0f} KB"
if size_kb < MB_IN_KB:
return f"{size_kb / KB:.1f} MB"
return f"{size_kb / KB / KB:.2f} GB"
def _format_count(count: int) -> str:
"""Format large numbers with K/M suffix."""
if count < THOUSAND:
return str(count)
if count < MILLION:
return f"{count / THOUSAND:.1f}K"
return f"{count / MILLION:.1f}M"
def display_file_info(file_path: Path, local_metadata: dict[str, Any], sha256_hash: str, console: Console) -> None:
"""Display file information table."""
prop_width = 12
file_table = Table(title="File Information", show_header=True, header_style="bold magenta", expand=True)
file_table.add_column("Property", style="cyan", width=prop_width, no_wrap=True)
file_table.add_column("Value", style="green", no_wrap=True, overflow="ellipsis")
file_table.add_row("File", str(file_path.name))
file_table.add_row("Path", str(file_path.parent))
file_table.add_row("Size", f"{file_path.stat().st_size / (1024**3):.2f} GB")
file_table.add_row("SHA256", sha256_hash)
file_table.add_row("Header Size", f"{local_metadata['header_size']:,} bytes")
file_table.add_row("Tensor Count", str(local_metadata["tensor_count"]))
console.print()
console.print(file_table)
def display_local_metadata(local_metadata: dict[str, Any], console: Console, keys_filter: list[str] | None = None) -> None:
"""Display local safetensor metadata table."""
if not local_metadata["metadata"]:
console.print()
console.print("[yellow]No embedded metadata found in safetensor file.[/yellow]")
return
metadata = local_metadata["metadata"]
# If specific keys requested, show them in full
if keys_filter:
for key in keys_filter:
if key in metadata:
console.print(f"[cyan]{key}[/cyan]: {metadata[key]}")
else:
console.print(f"[yellow]{key}: not found[/yellow]")
return
# Find the longest key to set column width
all_keys = list(metadata.keys())
key_width = max(len(k) for k in all_keys) if all_keys else 20
# Value width: terminal minus key column and table borders (7 chars)
terminal_width = console.size.width
value_width = terminal_width - key_width - 7
meta_table = Table(
title="Safetensor Metadata",
show_header=True,
header_style="bold magenta",
)
meta_table.add_column("Key", style="cyan", width=key_width, no_wrap=True)
meta_table.add_column("Value", style="green", width=value_width, no_wrap=True, overflow="ellipsis")
for key, value in sorted(metadata.items()):
meta_table.add_row(key, str(value))
console.print()
console.print(meta_table)
def _build_civitai_table(console: Console) -> tuple[Table, int]:
"""Build CivitAI info table with proper column widths."""
prop_width = 14
terminal_width = console.size.width
overhead = 7
value_width = max(40, terminal_width - prop_width - overhead)
table = Table(title="CivitAI Model Information", show_header=True, header_style="bold magenta")
table.add_column("Property", style="cyan", width=prop_width, no_wrap=True)
table.add_column("Value", style="green", width=value_width, no_wrap=True, overflow="ellipsis")
return table, value_width
def display_civitai_data(civitai_data: dict[str, Any] | None, console: Console) -> None:
"""Display CivitAI model information table."""
if not civitai_data:
console.print()
console.print("[yellow]Model not found on CivitAI.[/yellow]")
return
civit_table, _ = _build_civitai_table(console)
civit_table.add_row("Model ID", str(civitai_data.get("modelId", "N/A")))
civit_table.add_row("Version ID", str(civitai_data.get("id", "N/A")))
civit_table.add_row("Version Name", str(civitai_data.get("name", "N/A")))
civit_table.add_row("Base Model", str(civitai_data.get("baseModel", "N/A")))
civit_table.add_row("Created At", str(civitai_data.get("createdAt", "N/A")))
trained_words: list[str] = civitai_data.get("trainedWords", [])
if trained_words:
civit_table.add_row("Trigger Words", ", ".join(trained_words))
download_url = str(civitai_data.get("downloadUrl", "N/A"))
civit_table.add_row("Download URL", download_url)
files: list[dict[str, Any]] = civitai_data.get("files", [])
for f in files:
if f.get("primary"):
civit_table.add_row("Primary File", str(f.get("name", "N/A")))
civit_table.add_row("File Size", _format_size(f.get("sizeKB", 0)))
meta: dict[str, Any] = f.get("metadata", {})
if meta:
civit_table.add_row("Format", str(meta.get("format", "N/A")))
civit_table.add_row("Precision", str(meta.get("fp", "N/A")))
civit_table.add_row("Size Type", str(meta.get("size", "N/A")))
console.print()
console.print(civit_table)
model_id = civitai_data.get("modelId")
if model_id:
console.print()
console.print(f"[bold blue]View on CivitAI:[/bold blue] https://civitai.com/models/{model_id}")
def _build_model_table(console: Console) -> Table:
"""Build model info table with proper column widths."""
prop_width = 10
terminal_width = console.size.width
overhead = 7
value_width = max(40, terminal_width - prop_width - overhead)
table = Table(title="Model Information", show_header=True, header_style="bold magenta")
table.add_column("Property", style="cyan", width=prop_width, no_wrap=True)
table.add_column("Value", style="green", width=value_width, no_wrap=True, overflow="ellipsis")
return table
def _add_model_basic_info(table: Table, model_data: dict[str, Any]) -> None:
"""Add basic model info rows to table."""
table.add_row("ID", str(model_data.get("id", "N/A")))
table.add_row("Name", str(model_data.get("name", "N/A")))
table.add_row("Type", str(model_data.get("type", "N/A")))
table.add_row("NSFW", str(model_data.get("nsfw", False)))
creator = model_data.get("creator", {})
if creator:
table.add_row("Creator", str(creator.get("username", "N/A")))
tags: list[str] = model_data.get("tags", [])
if tags:
table.add_row("Tags", ", ".join(tags[:MAX_TAGS_DISPLAY]) + ("..." if len(tags) > MAX_TAGS_DISPLAY else ""))
stats: dict[str, Any] = model_data.get("stats", {})
if stats:
table.add_row("Downloads", f"{stats.get('downloadCount', 0):,}")
table.add_row("Likes", f"{stats.get('thumbsUpCount', 0):,}")
mode = model_data.get("mode")
if mode:
table.add_row("Status", str(mode))
def _build_versions_table(console: Console) -> Table:
"""Build model versions table with proper column widths."""
id_width = 7
base_width = 20
created_width = 10
size_width = 8
terminal_width = console.size.width
fixed_width = id_width + base_width + created_width + size_width
overhead = 20
remaining = max(40, terminal_width - fixed_width - overhead)
name_width = remaining // 3
file_width = remaining - name_width
table = Table(title="Model Versions", show_header=True, header_style="bold magenta")
table.add_column("ID", style="cyan", width=id_width, no_wrap=True)
table.add_column("Name", style="green", width=name_width, no_wrap=True, overflow="ellipsis")
table.add_column("Base Model", style="yellow", width=base_width, no_wrap=True, overflow="ellipsis")
table.add_column("Created", style="blue", width=created_width, no_wrap=True)
table.add_column("Filename", style="white", width=file_width, no_wrap=True, overflow="ellipsis")
table.add_column("Size", justify="right", width=size_width, no_wrap=True)
return table
def _add_version_rows(table: Table, versions: list[dict[str, Any]]) -> None:
"""Add version rows to versions table."""
for ver in versions:
files: list[dict[str, Any]] = ver.get("files", [])
primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
filename = "N/A"
size = "N/A"
if primary_file:
filename = primary_file.get("name", "N/A")
size = _format_size(primary_file.get("sizeKB", 0))
created = str(ver.get("createdAt", "N/A"))[:10]
table.add_row(
str(ver.get("id", "N/A")),
str(ver.get("name", "N/A")),
str(ver.get("baseModel", "N/A")),
created,
filename,
size,
)
def display_model_info(model_data: dict[str, Any], console: Console) -> None:
"""Display full CivitAI model information."""
model_table = _build_model_table(console)
_add_model_basic_info(model_table, model_data)
console.print()
console.print(model_table)
versions: list[dict[str, Any]] = model_data.get("modelVersions", [])
if versions:
ver_table = _build_versions_table(console)
_add_version_rows(ver_table, versions)
console.print()
console.print(ver_table)
model_id = model_data.get("id")
if model_id:
console.print()
console.print(f"[bold blue]View on CivitAI:[/bold blue] https://civitai.com/models/{model_id}")
def _build_search_table(console: Console) -> Table:
"""Build search results table with proper column widths."""
id_width = 7
type_width = 16
base_width = 20
size_width = 8
dls_width = 6
likes_width = 6
terminal_width = console.size.width
fixed_width = id_width + type_width + base_width + size_width + dls_width + likes_width
overhead = 23
name_width = max(20, terminal_width - fixed_width - overhead)
table = Table(show_header=True, header_style="bold magenta")
table.add_column("ID", style="cyan", justify="right", width=id_width, no_wrap=True)
table.add_column("Name", style="green", width=name_width, no_wrap=True, overflow="ellipsis")
table.add_column("Type", style="yellow", width=type_width, no_wrap=True)
table.add_column("Base", style="blue", width=base_width, no_wrap=True, overflow="ellipsis")
table.add_column("Size", justify="right", width=size_width, no_wrap=True)
table.add_column("DLs", justify="right", width=dls_width, no_wrap=True)
table.add_column("Likes", justify="right", width=likes_width, no_wrap=True)
return table
def _add_search_rows(table: Table, items: list[dict[str, Any]]) -> None:
"""Add search result rows to table."""
for model in items:
model_id = str(model.get("id", ""))
name = model.get("name", "N/A")
model_type = model.get("type", "N/A")
versions = model.get("modelVersions", [])
base_model = "N/A"
size = "N/A"
if versions:
latest = versions[0]
base_model = latest.get("baseModel", "N/A")
files = latest.get("files", [])
primary = next((f for f in files if f.get("primary")), files[0] if files else None)
if primary:
size = _format_size(primary.get("sizeKB", 0))
stats = model.get("stats", {})
downloads = _format_count(stats.get("downloadCount", 0))
likes = _format_count(stats.get("thumbsUpCount", 0))
table.add_row(model_id, name, model_type, base_model, size, downloads, likes)
def display_search_results(results: dict[str, Any], console: Console) -> None:
"""Display search results in a table."""
items = results.get("items", [])
if not items:
console.print("[yellow]No results found.[/yellow]")
return
table = _build_search_table(console)
_add_search_rows(table, items)
console.print()
console.print(table)
metadata = results.get("metadata", {})
total = metadata.get("totalItems", len(items))
console.print(f"\n[dim]Showing {len(items)} of {total:,} results[/dim]")
console.print("[dim]Use 'tsr get <id>' to view details or 'tsr dl -m <id>' to download[/dim]")