diff --git a/.coverage b/.coverage index e98daa6..1d6cb60 100644 Binary files a/.coverage and b/.coverage differ diff --git a/pyproject.toml b/pyproject.toml index 3f94683..0b500e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ requires = ["hatchling"] build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] -packages = ["tensors.py"] +packages = ["tensors"] [dependency-groups] dev = [ @@ -33,7 +33,7 @@ dev = [ [tool.ruff] target-version = "py312" -line-length = 100 +line-length = 130 [tool.ruff.lint] select = [ @@ -51,11 +51,7 @@ select = [ "PL", # pylint "RUF", # ruff-specific ] -ignore = [ - "PLR0911", # too many return statements - "PLR0913", # too many arguments - "PLR2004", # magic value comparison -] +ignore = [] [tool.ruff.lint.isort] known-first-party = ["tensors"] diff --git a/tensors.py b/tensors.py deleted file mode 100644 index a7bf101..0000000 --- a/tensors.py +++ /dev/null @@ -1,1147 +0,0 @@ -#!/usr/bin/env python3 -""" -tsr: Read safetensor metadata, search and download CivitAI models. -""" - -from __future__ import annotations - -import hashlib -import json -import os -import re -import struct -import sys -import tomllib -from enum import Enum -from pathlib import Path -from typing import Annotated, Any - -import httpx -import typer -from rich.console import Console -from rich.progress import ( - BarColumn, - DownloadColumn, - Progress, - SpinnerColumn, - TaskProgressColumn, - TextColumn, - TimeRemainingColumn, - TransferSpeedColumn, -) -from rich.table import Table - -# ============================================================================ -# App and Console Setup -# ============================================================================ - -app = typer.Typer( - name="tsr", - help="Read safetensor metadata, search and download CivitAI models.", - no_args_is_help=True, -) -console = Console() - -# ============================================================================ -# Configuration -# ============================================================================ - -# XDG Base Directory spec -# Config: ~/.config/tensors/config.toml -# Data: ~/.local/share/tensors/models/, ~/.local/share/tensors/metadata/ -CONFIG_DIR = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config")) / "tensors" -CONFIG_FILE = CONFIG_DIR / "config.toml" - -DATA_DIR = Path(os.environ.get("XDG_DATA_HOME", Path.home() / ".local" / "share")) / "tensors" -MODELS_DIR = DATA_DIR / "models" -METADATA_DIR = DATA_DIR / "metadata" - -# Legacy config for migration -LEGACY_RC_FILE = Path.home() / ".sftrc" - -# Default download paths by model type -DEFAULT_PATHS: dict[str, Path] = { - "Checkpoint": MODELS_DIR / "checkpoints", - "LORA": MODELS_DIR / "loras", - "LoCon": MODELS_DIR / "loras", -} - -CIVITAI_API_BASE = "https://civitai.com/api/v1" -CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models" - - -# ============================================================================ -# Enums for CLI -# ============================================================================ - - -class ModelType(str, Enum): - """CivitAI model types.""" - - checkpoint = "checkpoint" - lora = "lora" - embedding = "embedding" - vae = "vae" - controlnet = "controlnet" - locon = "locon" - - def to_api(self) -> str: - """Convert to CivitAI API value.""" - mapping = { - "checkpoint": "Checkpoint", - "lora": "LORA", - "embedding": "TextualInversion", - "vae": "VAE", - "controlnet": "Controlnet", - "locon": "LoCon", - } - return mapping[self.value] - - -class BaseModel(str, Enum): - """Common base models.""" - - sd15 = "sd15" - sdxl = "sdxl" - pony = "pony" - flux = "flux" - illustrious = "illustrious" - - def to_api(self) -> str: - """Convert to CivitAI API value.""" - mapping = { - "sd15": "SD 1.5", - "sdxl": "SDXL 1.0", - "pony": "Pony", - "flux": "Flux.1 D", - "illustrious": "Illustrious", - } - return mapping[self.value] - - -class SortOrder(str, Enum): - """Sort options for search.""" - - downloads = "downloads" - rating = "rating" - newest = "newest" - - def to_api(self) -> str: - """Convert to CivitAI API value.""" - mapping = { - "downloads": "Most Downloaded", - "rating": "Highest Rated", - "newest": "Newest", - } - return mapping[self.value] - - -# ============================================================================ -# Config Functions -# ============================================================================ - - -def load_config() -> dict[str, Any]: - """Load configuration from TOML config file.""" - if CONFIG_FILE.exists(): - with CONFIG_FILE.open("rb") as f: - return tomllib.load(f) - return {} - - -def save_config(config: dict[str, Any]) -> None: - """Save configuration to TOML config file.""" - CONFIG_DIR.mkdir(parents=True, exist_ok=True) - - lines: list[str] = [] - for key, value in config.items(): - if isinstance(value, dict): - lines.append(f"[{key}]") - for k, v in value.items(): - if isinstance(v, str): - lines.append(f'{k} = "{v}"') - else: - lines.append(f"{k} = {v}") - lines.append("") - elif isinstance(value, str): - lines.append(f'{key} = "{value}"') - else: - lines.append(f"{key} = {value}") - - CONFIG_FILE.write_text("\n".join(lines) + "\n") - - -def load_api_key() -> str | None: - """Load API key from config file or CIVITAI_API_KEY env var.""" - # Check environment variable first - env_key = os.environ.get("CIVITAI_API_KEY") - if env_key: - return env_key - - # Check TOML config file - config = load_config() - api_section = config.get("api", {}) - if isinstance(api_section, dict): - key = api_section.get("civitai_key") - if key: - return str(key) - - # Fall back to legacy RC file for migration - if LEGACY_RC_FILE.exists(): - content = LEGACY_RC_FILE.read_text().strip() - if content: - return content - return None - - -def get_default_output_path(model_type: str | None) -> Path | None: - """Get default output path based on model type.""" - if model_type and model_type in DEFAULT_PATHS: - return DEFAULT_PATHS[model_type] - return None - - -# ============================================================================ -# Safetensor Functions -# ============================================================================ - - -def read_safetensor_metadata(file_path: Path) -> dict[str, Any]: - """Read metadata from a safetensor file header.""" - with file_path.open("rb") as f: - # First 8 bytes are the header size (little-endian u64) - header_size_bytes = f.read(8) - if len(header_size_bytes) < 8: - raise ValueError("Invalid safetensor file: too short") - - header_size = struct.unpack(" 100_000_000: # 100MB sanity check - raise ValueError(f"Invalid header size: {header_size}") - - header_bytes = f.read(header_size) - if len(header_bytes) < header_size: - raise ValueError("Invalid safetensor file: header truncated") - - header: dict[str, Any] = json.loads(header_bytes.decode("utf-8")) - - # Extract __metadata__ if present - metadata: dict[str, Any] = header.get("__metadata__", {}) - - # Count tensors (keys that aren't __metadata__) - tensor_count = sum(1 for k in header if k != "__metadata__") - - return { - "metadata": metadata, - "tensor_count": tensor_count, - "header_size": header_size, - } - - -def compute_sha256(file_path: Path) -> str: - """Compute SHA256 hash of a file with progress display.""" - file_size = file_path.stat().st_size - sha256 = hashlib.sha256() - chunk_size = 1024 * 1024 * 8 # 8MB chunks - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - DownloadColumn(), - TransferSpeedColumn(), - TimeRemainingColumn(), - console=console, - ) as progress: - task = progress.add_task(f"[cyan]Hashing {file_path.name}...", total=file_size) - - with file_path.open("rb") as f: - while chunk := f.read(chunk_size): - sha256.update(chunk) - progress.update(task, advance=len(chunk)) - - return sha256.hexdigest().upper() - - -def get_base_name(file_path: Path) -> str: - """Get base filename without .safetensors extension.""" - name = file_path.name - for ext in (".safetensors", ".sft"): - if name.lower().endswith(ext): - return name[: -len(ext)] - return file_path.stem - - -# ============================================================================ -# CivitAI API Functions -# ============================================================================ - - -def _get_headers(api_key: str | None) -> dict[str, str]: - """Get headers for CivitAI API requests.""" - headers: dict[str, str] = {} - if api_key: - headers["Authorization"] = f"Bearer {api_key}" - return headers - - -def fetch_civitai_model_version( - version_id: int, api_key: str | None = None -) -> dict[str, Any] | None: - """Fetch model version information from CivitAI by version ID.""" - url = f"{CIVITAI_API_BASE}/model-versions/{version_id}" - - try: - response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) - if response.status_code == 404: - return None - response.raise_for_status() - result: dict[str, Any] = response.json() - return result - except httpx.HTTPStatusError as e: - console.print(f"[red]API error: {e.response.status_code}[/red]") - return None - except httpx.RequestError as e: - console.print(f"[red]Request error: {e}[/red]") - return None - - -def fetch_civitai_model(model_id: int, api_key: str | None = None) -> dict[str, Any] | None: - """Fetch model information from CivitAI by model ID.""" - url = f"{CIVITAI_API_BASE}/models/{model_id}" - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - transient=True, - ) as progress: - progress.add_task("[cyan]Fetching model from CivitAI...", total=None) - - try: - response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) - if response.status_code == 404: - return None - response.raise_for_status() - result: dict[str, Any] = response.json() - return result - except httpx.HTTPStatusError as e: - console.print(f"[red]API error: {e.response.status_code}[/red]") - return None - except httpx.RequestError as e: - console.print(f"[red]Request error: {e}[/red]") - return None - - -def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None = None) -> dict[str, Any] | None: - """Fetch model information from CivitAI by SHA256 hash.""" - url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}" - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - transient=True, - ) as progress: - progress.add_task("[cyan]Fetching from CivitAI...", total=None) - - try: - response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) - if response.status_code == 404: - return None - response.raise_for_status() - result: dict[str, Any] = response.json() - return result - except httpx.HTTPStatusError as e: - console.print(f"[red]API error: {e.response.status_code}[/red]") - return None - except httpx.RequestError as e: - console.print(f"[red]Request error: {e}[/red]") - return None - - -def search_civitai( - query: str | None = None, - model_type: ModelType | None = None, - base_model: BaseModel | None = None, - sort: SortOrder = SortOrder.downloads, - limit: int = 20, - api_key: str | None = None, -) -> dict[str, Any] | None: - """Search CivitAI models.""" - params: dict[str, Any] = { - "limit": min(limit, 100), - "nsfw": "true", - } - - # API quirk: query + filters don't work reliably together - # If we have filters, skip query and filter client-side - has_filters = model_type is not None or base_model is not None - - if query and not has_filters: - params["query"] = query - - if model_type: - params["types"] = model_type.to_api() - - if base_model: - params["baseModels"] = base_model.to_api() - - params["sort"] = sort.to_api() - - # Request more if we need client-side filtering - if query and has_filters: - params["limit"] = 100 - - url = f"{CIVITAI_API_BASE}/models" - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - transient=True, - ) as progress: - progress.add_task("[cyan]Searching CivitAI...", total=None) - - try: - response = httpx.get(url, params=params, headers=_get_headers(api_key), timeout=30.0) - response.raise_for_status() - result: dict[str, Any] = response.json() - - # Client-side filtering when query + filters combined - if query and has_filters: - q_lower = query.lower() - result["items"] = [ - m for m in result.get("items", []) if q_lower in m.get("name", "").lower() - ][:limit] - - return result - except httpx.HTTPStatusError as e: - console.print(f"[red]API error: {e.response.status_code}[/red]") - return None - except httpx.RequestError as e: - console.print(f"[red]Request error: {e}[/red]") - return None - - -def download_model( - version_id: int, - dest_path: Path, - api_key: str | None = None, - resume: bool = True, -) -> bool: - """Download a model from CivitAI by version ID with resume support.""" - url = f"{CIVITAI_DOWNLOAD_BASE}/{version_id}" - params: dict[str, str] = {} - if api_key: - params["token"] = api_key - - headers: dict[str, str] = {} - mode = "wb" - initial_size = 0 - - # Check for existing partial download - if resume and dest_path.exists(): - initial_size = dest_path.stat().st_size - headers["Range"] = f"bytes={initial_size}-" - mode = "ab" - console.print(f"[cyan]Resuming download from {initial_size / (1024**2):.1f} MB[/cyan]") - - try: - with httpx.stream( - "GET", - url, - params=params, - headers=headers, - follow_redirects=True, - timeout=httpx.Timeout(30.0, read=None), - ) as response: - if response.status_code == 416: - console.print("[green]File already fully downloaded.[/green]") - return True - - response.raise_for_status() - - content_length = response.headers.get("content-length") - total_size = int(content_length) + initial_size if content_length else 0 - - content_disp = response.headers.get("content-disposition", "") - if "filename=" in content_disp: - match = re.search(r'filename="?([^";\n]+)"?', content_disp) - if match and dest_path.is_dir(): - dest_path = dest_path / match.group(1) - - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - DownloadColumn(), - TransferSpeedColumn(), - TimeRemainingColumn(), - console=console, - ) as progress: - task = progress.add_task( - f"[cyan]Downloading {dest_path.name}...", - total=total_size if total_size > 0 else None, - completed=initial_size, - ) - - with dest_path.open(mode) as f: - for chunk in response.iter_bytes(1024 * 1024): - f.write(chunk) - progress.update(task, advance=len(chunk)) - - console.print() - console.print(f"[magenta]Downloaded:[/magenta] [green]\"{dest_path}\"[/green]") - return True - - except httpx.HTTPStatusError as e: - console.print(f"[red]Download error: HTTP {e.response.status_code}[/red]") - if e.response.status_code == 401: - console.print("[yellow]Hint: This model may require an API key.[/yellow]") - return False - except httpx.RequestError as e: - console.print(f"[red]Download error: {e}[/red]") - return False - - -# ============================================================================ -# Display Functions -# ============================================================================ - - -def _format_size(size_kb: float) -> str: - """Format size in KB to human-readable string.""" - if size_kb < 1024: - return f"{size_kb:.0f} KB" - if size_kb < 1024 * 1024: - return f"{size_kb / 1024:.1f} MB" - return f"{size_kb / 1024 / 1024:.2f} GB" - - -def _format_count(count: int) -> str: - """Format large numbers with K/M suffix.""" - if count < 1000: - return str(count) - if count < 1_000_000: - return f"{count / 1000:.1f}K" - return f"{count / 1_000_000:.1f}M" - - -def _display_file_info(file_path: Path, local_metadata: dict[str, Any], sha256_hash: str) -> None: - """Display file information table.""" - # Property column: 12 chars, Value fills remaining width - prop_width = 12 - - file_table = Table(title="File Information", show_header=True, header_style="bold magenta", expand=True) - file_table.add_column("Property", style="cyan", width=prop_width, no_wrap=True) - file_table.add_column("Value", style="green", no_wrap=True, overflow="ellipsis") - - file_table.add_row("File", str(file_path.name)) - file_table.add_row("Path", str(file_path.parent)) - file_table.add_row("Size", f"{file_path.stat().st_size / (1024**3):.2f} GB") - file_table.add_row("SHA256", sha256_hash) - file_table.add_row("Header Size", f"{local_metadata['header_size']:,} bytes") - file_table.add_row("Tensor Count", str(local_metadata["tensor_count"])) - - console.print() - console.print(file_table) - - -def _display_local_metadata(local_metadata: dict[str, Any], keys_filter: list[str] | None = None) -> None: - """Display local safetensor metadata table.""" - if not local_metadata["metadata"]: - console.print() - console.print("[yellow]No embedded metadata found in safetensor file.[/yellow]") - return - - metadata = local_metadata["metadata"] - - # If specific keys requested, show them in full - if keys_filter: - for key in keys_filter: - if key in metadata: - console.print(f"[cyan]{key}[/cyan]: {metadata[key]}") - else: - console.print(f"[yellow]{key}: not found[/yellow]") - return - - # Find the longest key to set column width - all_keys = list(metadata.keys()) - key_width = max(len(k) for k in all_keys) if all_keys else 20 - - # Value width: terminal minus key column and table borders (7 chars) - terminal_width = console.size.width - value_width = terminal_width - key_width - 7 - - meta_table = Table( - title="Safetensor Metadata", show_header=True, header_style="bold magenta", - ) - meta_table.add_column("Key", style="cyan", width=key_width, no_wrap=True) - meta_table.add_column("Value", style="green", width=value_width, no_wrap=True, overflow="ellipsis") - - for key, value in sorted(metadata.items()): - meta_table.add_row(key, str(value)) - - console.print() - console.print(meta_table) - - -def _display_civitai_data(civitai_data: dict[str, Any] | None) -> None: - """Display CivitAI model information table.""" - if not civitai_data: - console.print() - console.print("[yellow]Model not found on CivitAI.[/yellow]") - return - - # Property column: 14 chars, Value fills remaining width - prop_width = 14 - terminal_width = console.size.width - overhead = 7 # borders and separators for 2 columns - value_width = max(40, terminal_width - prop_width - overhead) - - civit_table = Table( - title="CivitAI Model Information", show_header=True, header_style="bold magenta" - ) - civit_table.add_column("Property", style="cyan", width=prop_width, no_wrap=True) - civit_table.add_column("Value", style="green", width=value_width, no_wrap=True, overflow="ellipsis") - - civit_table.add_row("Model ID", str(civitai_data.get("modelId", "N/A"))) - civit_table.add_row("Version ID", str(civitai_data.get("id", "N/A"))) - civit_table.add_row("Version Name", str(civitai_data.get("name", "N/A"))) - civit_table.add_row("Base Model", str(civitai_data.get("baseModel", "N/A"))) - civit_table.add_row("Created At", str(civitai_data.get("createdAt", "N/A"))) - - trained_words: list[str] = civitai_data.get("trainedWords", []) - if trained_words: - civit_table.add_row("Trigger Words", ", ".join(trained_words)) - - download_url = str(civitai_data.get("downloadUrl", "N/A")) - civit_table.add_row("Download URL", download_url) - - files: list[dict[str, Any]] = civitai_data.get("files", []) - for f in files: - if f.get("primary"): - civit_table.add_row("Primary File", str(f.get("name", "N/A"))) - civit_table.add_row("File Size", _format_size(f.get("sizeKB", 0))) - meta: dict[str, Any] = f.get("metadata", {}) - if meta: - civit_table.add_row("Format", str(meta.get("format", "N/A"))) - civit_table.add_row("Precision", str(meta.get("fp", "N/A"))) - civit_table.add_row("Size Type", str(meta.get("size", "N/A"))) - - console.print() - console.print(civit_table) - - model_id = civitai_data.get("modelId") - if model_id: - console.print() - console.print( - f"[bold blue]View on CivitAI:[/bold blue] https://civitai.com/models/{model_id}" - ) - - -def _display_model_info(model_data: dict[str, Any]) -> None: - """Display full CivitAI model information.""" - # Property column: 10 chars, Value fills remaining width - prop_width = 10 - terminal_width = console.size.width - overhead = 7 # borders and separators for 2 columns - value_width = max(40, terminal_width - prop_width - overhead) - - model_table = Table(title="Model Information", show_header=True, header_style="bold magenta") - model_table.add_column("Property", style="cyan", width=prop_width, no_wrap=True) - model_table.add_column("Value", style="green", width=value_width, no_wrap=True, overflow="ellipsis") - - model_table.add_row("ID", str(model_data.get("id", "N/A"))) - model_table.add_row("Name", str(model_data.get("name", "N/A"))) - model_table.add_row("Type", str(model_data.get("type", "N/A"))) - model_table.add_row("NSFW", str(model_data.get("nsfw", False))) - - creator = model_data.get("creator", {}) - if creator: - model_table.add_row("Creator", str(creator.get("username", "N/A"))) - - tags: list[str] = model_data.get("tags", []) - if tags: - model_table.add_row("Tags", ", ".join(tags[:10]) + ("..." if len(tags) > 10 else "")) - - stats: dict[str, Any] = model_data.get("stats", {}) - if stats: - model_table.add_row("Downloads", f"{stats.get('downloadCount', 0):,}") - model_table.add_row("Likes", f"{stats.get('thumbsUpCount', 0):,}") - - mode = model_data.get("mode") - if mode: - model_table.add_row("Status", str(mode)) - - console.print() - console.print(model_table) - - versions: list[dict[str, Any]] = model_data.get("modelVersions", []) - if versions: - # Static column widths for version table - # ID: 7 chars, Base Model: 20 chars, Created: 10 chars, Size: 8 chars - id_width = 7 - base_width = 20 - created_width = 10 - size_width = 8 - - # Calculate dynamic widths for Name and Filename - terminal_width = console.size.width - fixed_width = id_width + base_width + created_width + size_width - overhead = 20 # borders and separators for 5 columns - remaining = max(40, terminal_width - fixed_width - overhead) - name_width = remaining // 3 - file_width = remaining - name_width - - ver_table = Table(title="Model Versions", show_header=True, header_style="bold magenta") - ver_table.add_column("ID", style="cyan", width=id_width, no_wrap=True) - ver_table.add_column("Name", style="green", width=name_width, no_wrap=True, overflow="ellipsis") - ver_table.add_column("Base Model", style="yellow", width=base_width, no_wrap=True, overflow="ellipsis") - ver_table.add_column("Created", style="blue", width=created_width, no_wrap=True) - ver_table.add_column("Filename", style="white", width=file_width, no_wrap=True, overflow="ellipsis") - ver_table.add_column("Size", justify="right", width=size_width, no_wrap=True) - - for ver in versions: - files: list[dict[str, Any]] = ver.get("files", []) - primary_file = next((f for f in files if f.get("primary")), files[0] if files else None) - filename = "N/A" - size = "N/A" - if primary_file: - filename = primary_file.get("name", "N/A") - size = _format_size(primary_file.get("sizeKB", 0)) - - created = str(ver.get("createdAt", "N/A"))[:10] - ver_table.add_row( - str(ver.get("id", "N/A")), - str(ver.get("name", "N/A")), - str(ver.get("baseModel", "N/A")), - created, - filename, - size, - ) - - console.print() - console.print(ver_table) - - model_id = model_data.get("id") - if model_id: - console.print() - console.print( - f"[bold blue]View on CivitAI:[/bold blue] https://civitai.com/models/{model_id}" - ) - - -def _display_search_results(results: dict[str, Any]) -> None: - """Display search results in a table.""" - items = results.get("items", []) - if not items: - console.print("[yellow]No results found.[/yellow]") - return - - # Static column widths based on expected max values - # ID: 7 chars (max ~9,999,999) - # Type: 16 chars (longest: "TextualInversion") - # Base: 20 chars (e.g., "Flux.2 Klein 9B-base") - # Size: 8 chars (e.g., "11.08 GB") - # DLs: 6 chars (e.g., "999.9K") - # Likes: 6 chars (e.g., "999.9K") - id_width = 7 - type_width = 16 - base_width = 20 - size_width = 8 - dls_width = 6 - likes_width = 6 - - # Calculate name width: terminal width minus fixed columns and separators - # Table has 7 columns with separators: "│ col │ col │ ..." = 3 chars per col (space+pipe+space) - # Plus outer borders: "┃" on each side = 2 chars - # Total overhead: 2 (outer) + 7*3 (separators) = 23 chars - terminal_width = console.size.width - fixed_width = id_width + type_width + base_width + size_width + dls_width + likes_width - overhead = 23 # borders and separators - name_width = max(20, terminal_width - fixed_width - overhead) - - table = Table(show_header=True, header_style="bold magenta") - table.add_column("ID", style="cyan", justify="right", width=id_width, no_wrap=True) - table.add_column("Name", style="green", width=name_width, no_wrap=True, overflow="ellipsis") - table.add_column("Type", style="yellow", width=type_width, no_wrap=True) - table.add_column("Base", style="blue", width=base_width, no_wrap=True, overflow="ellipsis") - table.add_column("Size", justify="right", width=size_width, no_wrap=True) - table.add_column("DLs", justify="right", width=dls_width, no_wrap=True) - table.add_column("Likes", justify="right", width=likes_width, no_wrap=True) - - for model in items: - model_id = str(model.get("id", "")) - name = model.get("name", "N/A") - model_type = model.get("type", "N/A") - - # Get latest version info - versions = model.get("modelVersions", []) - base_model = "N/A" - size = "N/A" - if versions: - latest = versions[0] - base_model = latest.get("baseModel", "N/A") - files = latest.get("files", []) - primary = next((f for f in files if f.get("primary")), files[0] if files else None) - if primary: - size = _format_size(primary.get("sizeKB", 0)) - - stats = model.get("stats", {}) - downloads = _format_count(stats.get("downloadCount", 0)) - likes = _format_count(stats.get("thumbsUpCount", 0)) - - table.add_row(model_id, name, model_type, base_model, size, downloads, likes) - - console.print() - console.print(table) - - metadata = results.get("metadata", {}) - total = metadata.get("totalItems", len(items)) - console.print(f"\n[dim]Showing {len(items)} of {total:,} results[/dim]") - console.print("[dim]Use 'tsr get ' to view details or 'tsr dl -m ' to download[/dim]") - - -# ============================================================================ -# CLI Commands -# ============================================================================ - - -@app.command() -def info( - file: Annotated[Path, typer.Argument(help="Path to the safetensor file")], - meta: Annotated[ - list[str] | None, typer.Option("--meta", "-m", help="Show specific metadata key(s) in full") - ] = None, - api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, - skip_civitai: Annotated[ - bool, typer.Option("--skip-civitai", help="Skip CivitAI API lookup") - ] = False, - json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, - save_to: Annotated[ - Path | None, typer.Option("--save-to", help="Save metadata to directory") - ] = None, -) -> None: - """Read safetensor metadata and fetch CivitAI info.""" - file_path = file.resolve() - - if not file_path.exists(): - console.print(f"[red]Error: File not found: {file_path}[/red]") - raise typer.Exit(1) - - if file_path.suffix.lower() not in (".safetensors", ".sft"): - console.print("[yellow]Warning: File does not have .safetensors extension[/yellow]") - - try: - local_metadata = read_safetensor_metadata(file_path) - - # If just fetching specific metadata keys, skip everything else - if meta: - _display_local_metadata(local_metadata, keys_filter=meta) - return - - console.print(f"[bold]Reading safetensor file:[/bold] {file_path.name}") - sha256_hash = compute_sha256(file_path) - - civitai_data = None - if not skip_civitai: - key = api_key or load_api_key() - civitai_data = fetch_civitai_by_hash(sha256_hash, key) - - if json_output: - output = { - "file": str(file_path), - "sha256": sha256_hash, - "header_size": local_metadata["header_size"], - "tensor_count": local_metadata["tensor_count"], - "metadata": local_metadata["metadata"], - "civitai": civitai_data, - } - console.print_json(data=output) - else: - _display_file_info(file_path, local_metadata, sha256_hash) - _display_local_metadata(local_metadata) - _display_civitai_data(civitai_data) - - if save_to: - output_dir = save_to.resolve() - if not output_dir.exists() or not output_dir.is_dir(): - console.print(f"[red]Error: Invalid directory: {output_dir}[/red]") - raise typer.Exit(1) - - base_name = get_base_name(file_path) - json_path = output_dir / f"{base_name}.json" - sha_path = output_dir / f"{base_name}.sha256" - - output = { - "file": str(file_path), - "sha256": sha256_hash, - "header_size": local_metadata["header_size"], - "tensor_count": local_metadata["tensor_count"], - "metadata": local_metadata["metadata"], - "civitai": civitai_data, - } - json_path.write_text(json.dumps(output, indent=2)) - sha_path.write_text(f"{sha256_hash} {file_path.name}\n") - - console.print() - console.print(f"[green]Saved:[/green] {json_path}") - console.print(f"[green]Saved:[/green] {sha_path}") - - except ValueError as e: - console.print(f"[red]Error reading safetensor: {e}[/red]") - raise typer.Exit(1) from e - - -@app.command() -def search( - query: Annotated[str | None, typer.Argument(help="Search query (optional)")] = None, - model_type: Annotated[ - ModelType | None, typer.Option("-t", "--type", help="Model type filter") - ] = None, - base: Annotated[ - BaseModel | None, typer.Option("-b", "--base", help="Base model filter") - ] = None, - sort: Annotated[ - SortOrder, typer.Option("-s", "--sort", help="Sort order") - ] = SortOrder.downloads, - limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 20, - json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, - api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, -) -> None: - """Search CivitAI models.""" - key = api_key or load_api_key() - - results = search_civitai( - query=query, - model_type=model_type, - base_model=base, - sort=sort, - limit=limit, - api_key=key, - ) - - if not results: - console.print("[red]Search failed.[/red]") - raise typer.Exit(1) - - if json_output: - console.print_json(data=results) - else: - _display_search_results(results) - - -@app.command() -def get( - id_value: Annotated[int, typer.Argument(help="CivitAI model ID or version ID")], - version: Annotated[ - bool, typer.Option("-v", "--version", help="Treat ID as version ID instead of model ID") - ] = False, - api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, - json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, -) -> None: - """Fetch model information from CivitAI by model ID or version ID.""" - key = api_key or load_api_key() - - if version: - # Fetch by version ID - version_data = fetch_civitai_model_version(id_value, key) - if not version_data: - console.print(f"[red]Error: Version {id_value} not found on CivitAI.[/red]") - raise typer.Exit(1) - - if json_output: - console.print_json(data=version_data) - else: - _display_civitai_data(version_data) - else: - # Fetch by model ID - model_data = fetch_civitai_model(id_value, key) - if not model_data: - console.print(f"[red]Error: Model {id_value} not found on CivitAI.[/red]") - raise typer.Exit(1) - - if json_output: - console.print_json(data=model_data) - else: - _display_model_info(model_data) - - -def _resolve_version_id( - version_id: int | None, - hash_val: str | None, - model_id: int | None, - api_key: str | None, -) -> int | None: - """Resolve version ID from hash or model ID.""" - if version_id: - return version_id - - if hash_val: - console.print(f"[cyan]Looking up model by hash: {hash_val[:16]}...[/cyan]") - civitai_data = fetch_civitai_by_hash(hash_val.upper(), api_key) - if not civitai_data: - console.print("[red]Error: Model not found on CivitAI for this hash.[/red]") - return None - vid: int | None = civitai_data.get("id") - if vid: - console.print(f"[green]Found:[/green] {civitai_data.get('name', 'N/A')}") - return vid - - if model_id: - console.print(f"[cyan]Looking up model {model_id}...[/cyan]") - model_data = fetch_civitai_model(model_id, api_key) - if not model_data: - console.print(f"[red]Error: Model {model_id} not found.[/red]") - return None - versions = model_data.get("modelVersions", []) - if not versions: - console.print("[red]Error: Model has no versions.[/red]") - return None - latest = versions[0] - latest_vid: int | None = latest.get("id") - if latest_vid: - name = latest.get("name", "N/A") - console.print(f"[green]Found latest:[/green] {name} (ID: {latest_vid})") - return latest_vid - - return None - - -def _prepare_download_dir(output: Path | None, model_type_str: str | None) -> Path | None: - """Prepare output directory for download.""" - if output is None: - output_dir = get_default_output_path(model_type_str) - if output_dir is None: - console.print( - f"[red]Error: No default path for type '{model_type_str}'. " - "Use --output to specify.[/red]" - ) - return None - console.print(f"[dim]Using default path for {model_type_str}: {output_dir}[/dim]") - else: - output_dir = output.resolve() - - if not output_dir.exists(): - console.print(f"[cyan]Creating directory: {output_dir}[/cyan]") - output_dir.mkdir(parents=True, exist_ok=True) - elif not output_dir.is_dir(): - console.print(f"[red]Error: Not a directory: {output_dir}[/red]") - return None - - return output_dir - - -@app.command("dl") -def download( - version_id: Annotated[ - int | None, typer.Option("-v", "--version-id", help="Model version ID") - ] = None, - model_id: Annotated[ - int | None, typer.Option("-m", "--model-id", help="Model ID (downloads latest)") - ] = None, - hash_val: Annotated[ - str | None, typer.Option("-H", "--hash", help="SHA256 hash to look up") - ] = None, - output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None, - no_resume: Annotated[ - bool, typer.Option("--no-resume", help="Don't resume partial downloads") - ] = False, - api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, -) -> None: - """Download a model from CivitAI.""" - key = api_key or load_api_key() - - resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key) - if not resolved_version_id: - if not version_id and not hash_val and not model_id: - console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]") - raise typer.Exit(1) - - console.print(f"[cyan]Fetching version info for {resolved_version_id}...[/cyan]") - version_info = fetch_civitai_model_version(resolved_version_id, key) - if not version_info: - console.print("[red]Error: Could not fetch model version info.[/red]") - raise typer.Exit(1) - - model_type_str: str | None = version_info.get("model", {}).get("type") - output_dir = _prepare_download_dir(output, model_type_str) - if not output_dir: - raise typer.Exit(1) - - files: list[dict[str, Any]] = version_info.get("files", []) - primary_file = next((f for f in files if f.get("primary")), files[0] if files else None) - if not primary_file: - console.print("[red]Error: No files found for this version.[/red]") - raise typer.Exit(1) - - filename = primary_file.get("name", f"model-{resolved_version_id}.safetensors") - dest_path = output_dir / filename - - table = Table(title="Model Download", show_header=True, header_style="bold magenta") - table.add_column("Property", style="cyan") - table.add_column("Value", style="green") - table.add_row("Version", version_info.get("name", "N/A")) - table.add_row("Base Model", version_info.get("baseModel", "N/A")) - table.add_row("File", filename) - table.add_row("Size", _format_size(primary_file.get("sizeKB", 0))) - table.add_row("Destination", str(dest_path)) - console.print() - console.print(table) - console.print() - - success = download_model(resolved_version_id, dest_path, key, resume=not no_resume) - if not success: - raise typer.Exit(1) - - -@app.command() -def config( - show: Annotated[bool, typer.Option("--show", help="Show current config")] = False, - set_key: Annotated[str | None, typer.Option("--set-key", help="Set CivitAI API key")] = None, -) -> None: - """Manage configuration.""" - if set_key: - cfg = load_config() - if "api" not in cfg: - cfg["api"] = {} - cfg["api"]["civitai_key"] = set_key - save_config(cfg) - console.print(f"[green]API key saved to {CONFIG_FILE}[/green]") - return - - if show or (not set_key): - console.print(f"[bold]Config file:[/bold] {CONFIG_FILE}") - console.print(f"[bold]Config exists:[/bold] {CONFIG_FILE.exists()}") - - key = load_api_key() - if key: - masked = key[:4] + "..." + key[-4:] if len(key) > 8 else "***" - console.print(f"[bold]API key:[/bold] {masked}") - else: - console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]") - - console.print() - console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]") - - -def main() -> int: - """Main entry point.""" - # Handle legacy invocation: tsr -> tsr info - if len(sys.argv) > 1 and not sys.argv[1].startswith("-"): - arg = sys.argv[1] - if arg not in ("info", "search", "get", "dl", "download", "config") and ( - arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists() - ): - sys.argv = [sys.argv[0], "info", *sys.argv[1:]] - - app() - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/tensors/__init__.py b/tensors/__init__.py new file mode 100644 index 0000000..4456cfe --- /dev/null +++ b/tensors/__init__.py @@ -0,0 +1,26 @@ +"""tsr: Read safetensor metadata, search and download CivitAI models.""" + +from tensors.cli import main +from tensors.config import ( + CONFIG_DIR, + CONFIG_FILE, + LEGACY_RC_FILE, + get_default_output_path, + load_api_key, + load_config, + save_config, +) +from tensors.safetensor import get_base_name, read_safetensor_metadata + +__all__ = [ + "CONFIG_DIR", + "CONFIG_FILE", + "LEGACY_RC_FILE", + "get_base_name", + "get_default_output_path", + "load_api_key", + "load_config", + "main", + "read_safetensor_metadata", + "save_config", +] diff --git a/tensors/api.py b/tensors/api.py new file mode 100644 index 0000000..06d72a0 --- /dev/null +++ b/tensors/api.py @@ -0,0 +1,287 @@ +"""CivitAI API functions.""" + +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from pathlib import Path + +import httpx +from rich.progress import ( + BarColumn, + DownloadColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, + TimeRemainingColumn, + TransferSpeedColumn, +) + +from tensors.config import CIVITAI_API_BASE, CIVITAI_DOWNLOAD_BASE, BaseModel, ModelType, SortOrder + +if TYPE_CHECKING: + from rich.console import Console + + +def _get_headers(api_key: str | None) -> dict[str, str]: + """Get headers for CivitAI API requests.""" + headers: dict[str, str] = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + return headers + + +def fetch_civitai_model_version(version_id: int, api_key: str | None, console: Console) -> dict[str, Any] | None: + """Fetch model version information from CivitAI by version ID.""" + url = f"{CIVITAI_API_BASE}/model-versions/{version_id}" + + try: + response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) + if response.status_code == 404: + return None + response.raise_for_status() + result: dict[str, Any] = response.json() + return result + except httpx.HTTPStatusError as e: + console.print(f"[red]API error: {e.response.status_code}[/red]") + return None + except httpx.RequestError as e: + console.print(f"[red]Request error: {e}[/red]") + return None + + +def fetch_civitai_model(model_id: int, api_key: str | None, console: Console) -> dict[str, Any] | None: + """Fetch model information from CivitAI by model ID.""" + url = f"{CIVITAI_API_BASE}/models/{model_id}" + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + transient=True, + ) as progress: + progress.add_task("[cyan]Fetching model from CivitAI...", total=None) + + try: + response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) + if response.status_code == 404: + return None + response.raise_for_status() + result: dict[str, Any] = response.json() + return result + except httpx.HTTPStatusError as e: + console.print(f"[red]API error: {e.response.status_code}[/red]") + return None + except httpx.RequestError as e: + console.print(f"[red]Request error: {e}[/red]") + return None + + +def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None, console: Console) -> dict[str, Any] | None: + """Fetch model information from CivitAI by SHA256 hash.""" + url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}" + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + transient=True, + ) as progress: + progress.add_task("[cyan]Fetching from CivitAI...", total=None) + + try: + response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) + if response.status_code == 404: + return None + response.raise_for_status() + result: dict[str, Any] = response.json() + return result + except httpx.HTTPStatusError as e: + console.print(f"[red]API error: {e.response.status_code}[/red]") + return None + except httpx.RequestError as e: + console.print(f"[red]Request error: {e}[/red]") + return None + + +def _build_search_params( + query: str | None, + model_type: ModelType | None, + base_model: BaseModel | None, + sort: SortOrder, + limit: int, +) -> tuple[dict[str, Any], bool]: + """Build search parameters and return (params, has_filters).""" + params: dict[str, Any] = { + "limit": min(limit, 100), + "nsfw": "true", + } + + # API quirk: query + filters don't work reliably together + has_filters = model_type is not None or base_model is not None + + if query and not has_filters: + params["query"] = query + + if model_type: + params["types"] = model_type.to_api() + + if base_model: + params["baseModels"] = base_model.to_api() + + params["sort"] = sort.to_api() + + # Request more if we need client-side filtering + if query and has_filters: + params["limit"] = 100 + + return params, has_filters + + +def _filter_results(result: dict[str, Any], query: str | None, has_filters: bool, limit: int) -> dict[str, Any]: + """Apply client-side filtering when query + filters combined.""" + if query and has_filters: + q_lower = query.lower() + result["items"] = [m for m in result.get("items", []) if q_lower in m.get("name", "").lower()][:limit] + return result + + +def search_civitai( + query: str | None, + model_type: ModelType | None, + base_model: BaseModel | None, + sort: SortOrder, + limit: int, + api_key: str | None, + console: Console, +) -> dict[str, Any] | None: + """Search CivitAI models.""" + params, has_filters = _build_search_params(query, model_type, base_model, sort, limit) + url = f"{CIVITAI_API_BASE}/models" + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + transient=True, + ) as progress: + progress.add_task("[cyan]Searching CivitAI...", total=None) + + try: + response = httpx.get(url, params=params, headers=_get_headers(api_key), timeout=30.0) + response.raise_for_status() + result: dict[str, Any] = response.json() + return _filter_results(result, query, has_filters, limit) + except httpx.HTTPStatusError as e: + console.print(f"[red]API error: {e.response.status_code}[/red]") + return None + except httpx.RequestError as e: + console.print(f"[red]Request error: {e}[/red]") + return None + + +def _setup_resume(dest_path: Path, resume: bool, console: Console) -> tuple[dict[str, str], str, int]: + """Set up resume headers and mode for download.""" + headers: dict[str, str] = {} + mode = "wb" + initial_size = 0 + + if resume and dest_path.exists(): + initial_size = dest_path.stat().st_size + headers["Range"] = f"bytes={initial_size}-" + mode = "ab" + console.print(f"[cyan]Resuming download from {initial_size / (1024**2):.1f} MB[/cyan]") + + return headers, mode, initial_size + + +def _get_dest_from_response(response: httpx.Response, dest_path: Path) -> Path: + """Extract destination path from response headers if dest is directory.""" + content_disp = response.headers.get("content-disposition", "") + if "filename=" in content_disp: + match = re.search(r'filename="?([^";\n]+)"?', content_disp) + if match and dest_path.is_dir(): + return dest_path / match.group(1) + return dest_path + + +def _stream_download( + response: httpx.Response, + dest_path: Path, + mode: str, + initial_size: int, + console: Console, +) -> bool: + """Stream download content to file with progress.""" + content_length = response.headers.get("content-length") + total_size = int(content_length) + initial_size if content_length else 0 + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + DownloadColumn(), + TransferSpeedColumn(), + TimeRemainingColumn(), + console=console, + ) as progress: + task = progress.add_task( + f"[cyan]Downloading {dest_path.name}...", + total=total_size if total_size > 0 else None, + completed=initial_size, + ) + + with dest_path.open(mode) as f: + for chunk in response.iter_bytes(1024 * 1024): + f.write(chunk) + progress.update(task, advance=len(chunk)) + + console.print() + console.print(f'[magenta]Downloaded:[/magenta] [green]"{dest_path}"[/green]') + return True + + +def download_model( + version_id: int, + dest_path: Path, + api_key: str | None, + console: Console, + resume: bool = True, +) -> bool: + """Download a model from CivitAI by version ID with resume support.""" + url = f"{CIVITAI_DOWNLOAD_BASE}/{version_id}" + params: dict[str, str] = {} + if api_key: + params["token"] = api_key + + headers, mode, initial_size = _setup_resume(dest_path, resume, console) + + try: + with httpx.stream( + "GET", + url, + params=params, + headers=headers, + follow_redirects=True, + timeout=httpx.Timeout(30.0, read=None), + ) as response: + if response.status_code == 416: + console.print("[green]File already fully downloaded.[/green]") + return True + + response.raise_for_status() + dest_path = _get_dest_from_response(response, dest_path) + return _stream_download(response, dest_path, mode, initial_size, console) + + except httpx.HTTPStatusError as e: + console.print(f"[red]Download error: HTTP {e.response.status_code}[/red]") + if e.response.status_code == 401: + console.print("[yellow]Hint: This model may require an API key.[/yellow]") + return False + except httpx.RequestError as e: + console.print(f"[red]Download error: {e}[/red]") + return False diff --git a/tensors/cli.py b/tensors/cli.py new file mode 100644 index 0000000..d0f58f9 --- /dev/null +++ b/tensors/cli.py @@ -0,0 +1,386 @@ +"""CLI application and commands for tsr.""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path +from typing import Annotated, Any + +import typer +from rich.console import Console +from rich.table import Table + +from tensors.api import ( + download_model, + fetch_civitai_by_hash, + fetch_civitai_model, + fetch_civitai_model_version, + search_civitai, +) +from tensors.config import ( + CONFIG_FILE, + BaseModel, + ModelType, + SortOrder, + get_default_output_path, + load_api_key, + load_config, + save_config, +) +from tensors.display import ( + _format_size, + display_civitai_data, + display_file_info, + display_local_metadata, + display_model_info, + display_search_results, +) +from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_metadata + +app = typer.Typer( + name="tsr", + help="Read safetensor metadata, search and download CivitAI models.", + no_args_is_help=True, +) +console = Console() + + +@app.command() +def info( + file: Annotated[Path, typer.Argument(help="Path to the safetensor file")], + meta: Annotated[list[str] | None, typer.Option("--meta", "-m", help="Show specific metadata key(s) in full")] = None, + api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, + skip_civitai: Annotated[bool, typer.Option("--skip-civitai", help="Skip CivitAI API lookup")] = False, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, + save_to: Annotated[Path | None, typer.Option("--save-to", help="Save metadata to directory")] = None, +) -> None: + """Read safetensor metadata and fetch CivitAI info.""" + file_path = file.resolve() + + if not file_path.exists(): + console.print(f"[red]Error: File not found: {file_path}[/red]") + raise typer.Exit(1) + + if file_path.suffix.lower() not in (".safetensors", ".sft"): + console.print("[yellow]Warning: File does not have .safetensors extension[/yellow]") + + try: + local_metadata = read_safetensor_metadata(file_path) + + if meta: + display_local_metadata(local_metadata, console, keys_filter=meta) + return + + console.print(f"[bold]Reading safetensor file:[/bold] {file_path.name}") + sha256_hash = compute_sha256(file_path, console) + + civitai_data = None + if not skip_civitai: + key = api_key or load_api_key() + civitai_data = fetch_civitai_by_hash(sha256_hash, key, console) + + if json_output: + _output_info_json(file_path, sha256_hash, local_metadata, civitai_data) + else: + display_file_info(file_path, local_metadata, sha256_hash, console) + display_local_metadata(local_metadata, console) + display_civitai_data(civitai_data, console) + + if save_to: + _save_metadata(save_to, file_path, sha256_hash, local_metadata, civitai_data) + + except ValueError as e: + console.print(f"[red]Error reading safetensor: {e}[/red]") + raise typer.Exit(1) from e + + +def _output_info_json( + file_path: Path, + sha256_hash: str, + local_metadata: dict[str, Any], + civitai_data: dict[str, Any] | None, +) -> None: + """Output info command result as JSON.""" + output = { + "file": str(file_path), + "sha256": sha256_hash, + "header_size": local_metadata["header_size"], + "tensor_count": local_metadata["tensor_count"], + "metadata": local_metadata["metadata"], + "civitai": civitai_data, + } + console.print_json(data=output) + + +def _save_metadata( + save_to: Path, + file_path: Path, + sha256_hash: str, + local_metadata: dict[str, Any], + civitai_data: dict[str, Any] | None, +) -> None: + """Save metadata to directory.""" + output_dir = save_to.resolve() + if not output_dir.exists() or not output_dir.is_dir(): + console.print(f"[red]Error: Invalid directory: {output_dir}[/red]") + raise typer.Exit(1) + + base_name = get_base_name(file_path) + json_path = output_dir / f"{base_name}.json" + sha_path = output_dir / f"{base_name}.sha256" + + output = { + "file": str(file_path), + "sha256": sha256_hash, + "header_size": local_metadata["header_size"], + "tensor_count": local_metadata["tensor_count"], + "metadata": local_metadata["metadata"], + "civitai": civitai_data, + } + json_path.write_text(json.dumps(output, indent=2)) + sha_path.write_text(f"{sha256_hash} {file_path.name}\n") + + console.print() + console.print(f"[green]Saved:[/green] {json_path}") + console.print(f"[green]Saved:[/green] {sha_path}") + + +@app.command() +def search( + query: Annotated[str | None, typer.Argument(help="Search query (optional)")] = None, + model_type: Annotated[ModelType | None, typer.Option("-t", "--type", help="Model type filter")] = None, + base: Annotated[BaseModel | None, typer.Option("-b", "--base", help="Base model filter")] = None, + sort: Annotated[SortOrder, typer.Option("-s", "--sort", help="Sort order")] = SortOrder.downloads, + limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 20, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, + api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, +) -> None: + """Search CivitAI models.""" + key = api_key or load_api_key() + + results = search_civitai( + query=query, + model_type=model_type, + base_model=base, + sort=sort, + limit=limit, + api_key=key, + console=console, + ) + + if not results: + console.print("[red]Search failed.[/red]") + raise typer.Exit(1) + + if json_output: + console.print_json(data=results) + else: + display_search_results(results, console) + + +@app.command() +def get( + id_value: Annotated[int, typer.Argument(help="CivitAI model ID or version ID")], + version: Annotated[bool, typer.Option("-v", "--version", help="Treat ID as version ID instead of model ID")] = False, + api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """Fetch model information from CivitAI by model ID or version ID.""" + key = api_key or load_api_key() + + if version: + version_data = fetch_civitai_model_version(id_value, key, console) + if not version_data: + console.print(f"[red]Error: Version {id_value} not found on CivitAI.[/red]") + raise typer.Exit(1) + + if json_output: + console.print_json(data=version_data) + else: + display_civitai_data(version_data, console) + else: + model_data = fetch_civitai_model(id_value, key, console) + if not model_data: + console.print(f"[red]Error: Model {id_value} not found on CivitAI.[/red]") + raise typer.Exit(1) + + if json_output: + console.print_json(data=model_data) + else: + display_model_info(model_data, console) + + +def _resolve_version_id( + version_id: int | None, + hash_val: str | None, + model_id: int | None, + api_key: str | None, +) -> int | None: + """Resolve version ID from hash or model ID.""" + if version_id: + return version_id + + if hash_val: + console.print(f"[cyan]Looking up model by hash: {hash_val[:16]}...[/cyan]") + civitai_data = fetch_civitai_by_hash(hash_val.upper(), api_key, console) + if not civitai_data: + console.print("[red]Error: Model not found on CivitAI for this hash.[/red]") + return None + vid: int | None = civitai_data.get("id") + if vid: + console.print(f"[green]Found:[/green] {civitai_data.get('name', 'N/A')}") + return vid + + if model_id: + console.print(f"[cyan]Looking up model {model_id}...[/cyan]") + model_data = fetch_civitai_model(model_id, api_key, console) + if not model_data: + console.print(f"[red]Error: Model {model_id} not found.[/red]") + return None + versions = model_data.get("modelVersions", []) + if not versions: + console.print("[red]Error: Model has no versions.[/red]") + return None + latest = versions[0] + latest_vid: int | None = latest.get("id") + if latest_vid: + name = latest.get("name", "N/A") + console.print(f"[green]Found latest:[/green] {name} (ID: {latest_vid})") + return latest_vid + + return None + + +def _prepare_download_dir(output: Path | None, model_type_str: str | None) -> Path | None: + """Prepare output directory for download.""" + if output is None: + output_dir = get_default_output_path(model_type_str) + if output_dir is None: + console.print(f"[red]Error: No default path for type '{model_type_str}'. Use --output to specify.[/red]") + return None + console.print(f"[dim]Using default path for {model_type_str}: {output_dir}[/dim]") + else: + output_dir = output.resolve() + + if not output_dir.exists(): + console.print(f"[cyan]Creating directory: {output_dir}[/cyan]") + output_dir.mkdir(parents=True, exist_ok=True) + elif not output_dir.is_dir(): + console.print(f"[red]Error: Not a directory: {output_dir}[/red]") + return None + + return output_dir + + +@app.command("dl") +def download( + version_id: Annotated[int | None, typer.Option("-v", "--version-id", help="Model version ID")] = None, + model_id: Annotated[int | None, typer.Option("-m", "--model-id", help="Model ID (downloads latest)")] = None, + hash_val: Annotated[str | None, typer.Option("-H", "--hash", help="SHA256 hash to look up")] = None, + output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None, + no_resume: Annotated[bool, typer.Option("--no-resume", help="Don't resume partial downloads")] = False, + api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, +) -> None: + """Download a model from CivitAI.""" + key = api_key or load_api_key() + + resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key) + if not resolved_version_id: + if not version_id and not hash_val and not model_id: + console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]") + raise typer.Exit(1) + + console.print(f"[cyan]Fetching version info for {resolved_version_id}...[/cyan]") + version_info = fetch_civitai_model_version(resolved_version_id, key, console) + if not version_info: + console.print("[red]Error: Could not fetch model version info.[/red]") + raise typer.Exit(1) + + model_type_str: str | None = version_info.get("model", {}).get("type") + output_dir = _prepare_download_dir(output, model_type_str) + if not output_dir: + raise typer.Exit(1) + + files: list[dict[str, Any]] = version_info.get("files", []) + primary_file = next((f for f in files if f.get("primary")), files[0] if files else None) + if not primary_file: + console.print("[red]Error: No files found for this version.[/red]") + raise typer.Exit(1) + + filename = primary_file.get("name", f"model-{resolved_version_id}.safetensors") + dest_path = output_dir / filename + + _display_download_info(version_info, filename, primary_file, dest_path) + + success = download_model(resolved_version_id, dest_path, key, console, resume=not no_resume) + if not success: + raise typer.Exit(1) + + +def _display_download_info( + version_info: dict[str, Any], + filename: str, + primary_file: dict[str, Any], + dest_path: Path, +) -> None: + """Display download info table.""" + table = Table(title="Model Download", show_header=True, header_style="bold magenta") + table.add_column("Property", style="cyan") + table.add_column("Value", style="green") + table.add_row("Version", version_info.get("name", "N/A")) + table.add_row("Base Model", version_info.get("baseModel", "N/A")) + table.add_row("File", filename) + table.add_row("Size", _format_size(primary_file.get("sizeKB", 0))) + table.add_row("Destination", str(dest_path)) + console.print() + console.print(table) + console.print() + + +@app.command() +def config( + show: Annotated[bool, typer.Option("--show", help="Show current config")] = False, + set_key: Annotated[str | None, typer.Option("--set-key", help="Set CivitAI API key")] = None, +) -> None: + """Manage configuration.""" + if set_key: + cfg = load_config() + if "api" not in cfg: + cfg["api"] = {} + cfg["api"]["civitai_key"] = set_key + save_config(cfg) + console.print(f"[green]API key saved to {CONFIG_FILE}[/green]") + return + + if show or (not set_key): + console.print(f"[bold]Config file:[/bold] {CONFIG_FILE}") + console.print(f"[bold]Config exists:[/bold] {CONFIG_FILE.exists()}") + + key = load_api_key() + if key: + masked = key[:4] + "..." + key[-4:] if len(key) > 8 else "***" + console.print(f"[bold]API key:[/bold] {masked}") + else: + console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]") + + console.print() + console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]") + + +def main() -> int: + """Main entry point.""" + # Handle legacy invocation: tsr -> tsr info + if len(sys.argv) > 1 and not sys.argv[1].startswith("-"): + arg = sys.argv[1] + if arg not in ("info", "search", "get", "dl", "download", "config") and ( + arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists() + ): + sys.argv = [sys.argv[0], "info", *sys.argv[1:]] + + app() + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tensors/config.py b/tensors/config.py new file mode 100644 index 0000000..9dd14bd --- /dev/null +++ b/tensors/config.py @@ -0,0 +1,166 @@ +"""Configuration, constants, and enums for tsr CLI.""" + +from __future__ import annotations + +import os +import tomllib +from enum import Enum +from pathlib import Path +from typing import Any + +# ============================================================================ +# XDG Base Directory Configuration +# ============================================================================ + +# Config: ~/.config/tensors/config.toml +# Data: ~/.local/share/tensors/models/, ~/.local/share/tensors/metadata/ +CONFIG_DIR = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config")) / "tensors" +CONFIG_FILE = CONFIG_DIR / "config.toml" + +DATA_DIR = Path(os.environ.get("XDG_DATA_HOME", Path.home() / ".local" / "share")) / "tensors" +MODELS_DIR = DATA_DIR / "models" +METADATA_DIR = DATA_DIR / "metadata" + +# Legacy config for migration +LEGACY_RC_FILE = Path.home() / ".sftrc" + +# Default download paths by model type +DEFAULT_PATHS: dict[str, Path] = { + "Checkpoint": MODELS_DIR / "checkpoints", + "LORA": MODELS_DIR / "loras", + "LoCon": MODELS_DIR / "loras", +} + +CIVITAI_API_BASE = "https://civitai.com/api/v1" +CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models" + + +# ============================================================================ +# Enums for CLI +# ============================================================================ + + +class ModelType(str, Enum): + """CivitAI model types.""" + + checkpoint = "checkpoint" + lora = "lora" + embedding = "embedding" + vae = "vae" + controlnet = "controlnet" + locon = "locon" + + def to_api(self) -> str: + """Convert to CivitAI API value.""" + mapping = { + "checkpoint": "Checkpoint", + "lora": "LORA", + "embedding": "TextualInversion", + "vae": "VAE", + "controlnet": "Controlnet", + "locon": "LoCon", + } + return mapping[self.value] + + +class BaseModel(str, Enum): + """Common base models.""" + + sd15 = "sd15" + sdxl = "sdxl" + pony = "pony" + flux = "flux" + illustrious = "illustrious" + + def to_api(self) -> str: + """Convert to CivitAI API value.""" + mapping = { + "sd15": "SD 1.5", + "sdxl": "SDXL 1.0", + "pony": "Pony", + "flux": "Flux.1 D", + "illustrious": "Illustrious", + } + return mapping[self.value] + + +class SortOrder(str, Enum): + """Sort options for search.""" + + downloads = "downloads" + rating = "rating" + newest = "newest" + + def to_api(self) -> str: + """Convert to CivitAI API value.""" + mapping = { + "downloads": "Most Downloaded", + "rating": "Highest Rated", + "newest": "Newest", + } + return mapping[self.value] + + +# ============================================================================ +# Config Functions +# ============================================================================ + + +def load_config() -> dict[str, Any]: + """Load configuration from TOML config file.""" + if CONFIG_FILE.exists(): + with CONFIG_FILE.open("rb") as f: + return tomllib.load(f) + return {} + + +def save_config(config: dict[str, Any]) -> None: + """Save configuration to TOML config file.""" + CONFIG_DIR.mkdir(parents=True, exist_ok=True) + + lines: list[str] = [] + for key, value in config.items(): + if isinstance(value, dict): + lines.append(f"[{key}]") + for k, v in value.items(): + if isinstance(v, str): + lines.append(f'{k} = "{v}"') + else: + lines.append(f"{k} = {v}") + lines.append("") + elif isinstance(value, str): + lines.append(f'{key} = "{value}"') + else: + lines.append(f"{key} = {value}") + + CONFIG_FILE.write_text("\n".join(lines) + "\n") + + +def load_api_key() -> str | None: + """Load API key from config file or CIVITAI_API_KEY env var.""" + # Check environment variable first + env_key = os.environ.get("CIVITAI_API_KEY") + if env_key: + return env_key + + # Check TOML config file + config = load_config() + api_section = config.get("api", {}) + if isinstance(api_section, dict): + key = api_section.get("civitai_key") + if key: + return str(key) + + # Fall back to legacy RC file for migration + if LEGACY_RC_FILE.exists(): + content = LEGACY_RC_FILE.read_text().strip() + if content: + return content + return None + + +def get_default_output_path(model_type: str | None) -> Path | None: + """Get default output path based on model type.""" + if model_type and model_type in DEFAULT_PATHS: + return DEFAULT_PATHS[model_type] + return None diff --git a/tensors/display.py b/tensors/display.py new file mode 100644 index 0000000..9f5d64d --- /dev/null +++ b/tensors/display.py @@ -0,0 +1,324 @@ +"""Rich table display functions for tsr CLI.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from pathlib import Path + +from rich.table import Table + +if TYPE_CHECKING: + from rich.console import Console + + +def _format_size(size_kb: float) -> str: + """Format size in KB to human-readable string.""" + if size_kb < 1024: + return f"{size_kb:.0f} KB" + if size_kb < 1024 * 1024: + return f"{size_kb / 1024:.1f} MB" + return f"{size_kb / 1024 / 1024:.2f} GB" + + +def _format_count(count: int) -> str: + """Format large numbers with K/M suffix.""" + if count < 1000: + return str(count) + if count < 1_000_000: + return f"{count / 1000:.1f}K" + return f"{count / 1_000_000:.1f}M" + + +def display_file_info(file_path: Path, local_metadata: dict[str, Any], sha256_hash: str, console: Console) -> None: + """Display file information table.""" + prop_width = 12 + + file_table = Table(title="File Information", show_header=True, header_style="bold magenta", expand=True) + file_table.add_column("Property", style="cyan", width=prop_width, no_wrap=True) + file_table.add_column("Value", style="green", no_wrap=True, overflow="ellipsis") + + file_table.add_row("File", str(file_path.name)) + file_table.add_row("Path", str(file_path.parent)) + file_table.add_row("Size", f"{file_path.stat().st_size / (1024**3):.2f} GB") + file_table.add_row("SHA256", sha256_hash) + file_table.add_row("Header Size", f"{local_metadata['header_size']:,} bytes") + file_table.add_row("Tensor Count", str(local_metadata["tensor_count"])) + + console.print() + console.print(file_table) + + +def display_local_metadata(local_metadata: dict[str, Any], console: Console, keys_filter: list[str] | None = None) -> None: + """Display local safetensor metadata table.""" + if not local_metadata["metadata"]: + console.print() + console.print("[yellow]No embedded metadata found in safetensor file.[/yellow]") + return + + metadata = local_metadata["metadata"] + + # If specific keys requested, show them in full + if keys_filter: + for key in keys_filter: + if key in metadata: + console.print(f"[cyan]{key}[/cyan]: {metadata[key]}") + else: + console.print(f"[yellow]{key}: not found[/yellow]") + return + + # Find the longest key to set column width + all_keys = list(metadata.keys()) + key_width = max(len(k) for k in all_keys) if all_keys else 20 + + # Value width: terminal minus key column and table borders (7 chars) + terminal_width = console.size.width + value_width = terminal_width - key_width - 7 + + meta_table = Table( + title="Safetensor Metadata", + show_header=True, + header_style="bold magenta", + ) + meta_table.add_column("Key", style="cyan", width=key_width, no_wrap=True) + meta_table.add_column("Value", style="green", width=value_width, no_wrap=True, overflow="ellipsis") + + for key, value in sorted(metadata.items()): + meta_table.add_row(key, str(value)) + + console.print() + console.print(meta_table) + + +def _build_civitai_table(console: Console) -> tuple[Table, int]: + """Build CivitAI info table with proper column widths.""" + prop_width = 14 + terminal_width = console.size.width + overhead = 7 + value_width = max(40, terminal_width - prop_width - overhead) + + table = Table(title="CivitAI Model Information", show_header=True, header_style="bold magenta") + table.add_column("Property", style="cyan", width=prop_width, no_wrap=True) + table.add_column("Value", style="green", width=value_width, no_wrap=True, overflow="ellipsis") + + return table, value_width + + +def display_civitai_data(civitai_data: dict[str, Any] | None, console: Console) -> None: + """Display CivitAI model information table.""" + if not civitai_data: + console.print() + console.print("[yellow]Model not found on CivitAI.[/yellow]") + return + + civit_table, _ = _build_civitai_table(console) + + civit_table.add_row("Model ID", str(civitai_data.get("modelId", "N/A"))) + civit_table.add_row("Version ID", str(civitai_data.get("id", "N/A"))) + civit_table.add_row("Version Name", str(civitai_data.get("name", "N/A"))) + civit_table.add_row("Base Model", str(civitai_data.get("baseModel", "N/A"))) + civit_table.add_row("Created At", str(civitai_data.get("createdAt", "N/A"))) + + trained_words: list[str] = civitai_data.get("trainedWords", []) + if trained_words: + civit_table.add_row("Trigger Words", ", ".join(trained_words)) + + download_url = str(civitai_data.get("downloadUrl", "N/A")) + civit_table.add_row("Download URL", download_url) + + files: list[dict[str, Any]] = civitai_data.get("files", []) + for f in files: + if f.get("primary"): + civit_table.add_row("Primary File", str(f.get("name", "N/A"))) + civit_table.add_row("File Size", _format_size(f.get("sizeKB", 0))) + meta: dict[str, Any] = f.get("metadata", {}) + if meta: + civit_table.add_row("Format", str(meta.get("format", "N/A"))) + civit_table.add_row("Precision", str(meta.get("fp", "N/A"))) + civit_table.add_row("Size Type", str(meta.get("size", "N/A"))) + + console.print() + console.print(civit_table) + + model_id = civitai_data.get("modelId") + if model_id: + console.print() + console.print(f"[bold blue]View on CivitAI:[/bold blue] https://civitai.com/models/{model_id}") + + +def _build_model_table(console: Console) -> Table: + """Build model info table with proper column widths.""" + prop_width = 10 + terminal_width = console.size.width + overhead = 7 + value_width = max(40, terminal_width - prop_width - overhead) + + table = Table(title="Model Information", show_header=True, header_style="bold magenta") + table.add_column("Property", style="cyan", width=prop_width, no_wrap=True) + table.add_column("Value", style="green", width=value_width, no_wrap=True, overflow="ellipsis") + + return table + + +def _add_model_basic_info(table: Table, model_data: dict[str, Any]) -> None: + """Add basic model info rows to table.""" + table.add_row("ID", str(model_data.get("id", "N/A"))) + table.add_row("Name", str(model_data.get("name", "N/A"))) + table.add_row("Type", str(model_data.get("type", "N/A"))) + table.add_row("NSFW", str(model_data.get("nsfw", False))) + + creator = model_data.get("creator", {}) + if creator: + table.add_row("Creator", str(creator.get("username", "N/A"))) + + tags: list[str] = model_data.get("tags", []) + if tags: + table.add_row("Tags", ", ".join(tags[:10]) + ("..." if len(tags) > 10 else "")) + + stats: dict[str, Any] = model_data.get("stats", {}) + if stats: + table.add_row("Downloads", f"{stats.get('downloadCount', 0):,}") + table.add_row("Likes", f"{stats.get('thumbsUpCount', 0):,}") + + mode = model_data.get("mode") + if mode: + table.add_row("Status", str(mode)) + + +def _build_versions_table(console: Console) -> Table: + """Build model versions table with proper column widths.""" + id_width = 7 + base_width = 20 + created_width = 10 + size_width = 8 + + terminal_width = console.size.width + fixed_width = id_width + base_width + created_width + size_width + overhead = 20 + remaining = max(40, terminal_width - fixed_width - overhead) + name_width = remaining // 3 + file_width = remaining - name_width + + table = Table(title="Model Versions", show_header=True, header_style="bold magenta") + table.add_column("ID", style="cyan", width=id_width, no_wrap=True) + table.add_column("Name", style="green", width=name_width, no_wrap=True, overflow="ellipsis") + table.add_column("Base Model", style="yellow", width=base_width, no_wrap=True, overflow="ellipsis") + table.add_column("Created", style="blue", width=created_width, no_wrap=True) + table.add_column("Filename", style="white", width=file_width, no_wrap=True, overflow="ellipsis") + table.add_column("Size", justify="right", width=size_width, no_wrap=True) + + return table + + +def _add_version_rows(table: Table, versions: list[dict[str, Any]]) -> None: + """Add version rows to versions table.""" + for ver in versions: + files: list[dict[str, Any]] = ver.get("files", []) + primary_file = next((f for f in files if f.get("primary")), files[0] if files else None) + filename = "N/A" + size = "N/A" + if primary_file: + filename = primary_file.get("name", "N/A") + size = _format_size(primary_file.get("sizeKB", 0)) + + created = str(ver.get("createdAt", "N/A"))[:10] + table.add_row( + str(ver.get("id", "N/A")), + str(ver.get("name", "N/A")), + str(ver.get("baseModel", "N/A")), + created, + filename, + size, + ) + + +def display_model_info(model_data: dict[str, Any], console: Console) -> None: + """Display full CivitAI model information.""" + model_table = _build_model_table(console) + _add_model_basic_info(model_table, model_data) + + console.print() + console.print(model_table) + + versions: list[dict[str, Any]] = model_data.get("modelVersions", []) + if versions: + ver_table = _build_versions_table(console) + _add_version_rows(ver_table, versions) + console.print() + console.print(ver_table) + + model_id = model_data.get("id") + if model_id: + console.print() + console.print(f"[bold blue]View on CivitAI:[/bold blue] https://civitai.com/models/{model_id}") + + +def _build_search_table(console: Console) -> Table: + """Build search results table with proper column widths.""" + id_width = 7 + type_width = 16 + base_width = 20 + size_width = 8 + dls_width = 6 + likes_width = 6 + + terminal_width = console.size.width + fixed_width = id_width + type_width + base_width + size_width + dls_width + likes_width + overhead = 23 + name_width = max(20, terminal_width - fixed_width - overhead) + + table = Table(show_header=True, header_style="bold magenta") + table.add_column("ID", style="cyan", justify="right", width=id_width, no_wrap=True) + table.add_column("Name", style="green", width=name_width, no_wrap=True, overflow="ellipsis") + table.add_column("Type", style="yellow", width=type_width, no_wrap=True) + table.add_column("Base", style="blue", width=base_width, no_wrap=True, overflow="ellipsis") + table.add_column("Size", justify="right", width=size_width, no_wrap=True) + table.add_column("DLs", justify="right", width=dls_width, no_wrap=True) + table.add_column("Likes", justify="right", width=likes_width, no_wrap=True) + + return table + + +def _add_search_rows(table: Table, items: list[dict[str, Any]]) -> None: + """Add search result rows to table.""" + for model in items: + model_id = str(model.get("id", "")) + name = model.get("name", "N/A") + model_type = model.get("type", "N/A") + + versions = model.get("modelVersions", []) + base_model = "N/A" + size = "N/A" + if versions: + latest = versions[0] + base_model = latest.get("baseModel", "N/A") + files = latest.get("files", []) + primary = next((f for f in files if f.get("primary")), files[0] if files else None) + if primary: + size = _format_size(primary.get("sizeKB", 0)) + + stats = model.get("stats", {}) + downloads = _format_count(stats.get("downloadCount", 0)) + likes = _format_count(stats.get("thumbsUpCount", 0)) + + table.add_row(model_id, name, model_type, base_model, size, downloads, likes) + + +def display_search_results(results: dict[str, Any], console: Console) -> None: + """Display search results in a table.""" + items = results.get("items", []) + if not items: + console.print("[yellow]No results found.[/yellow]") + return + + table = _build_search_table(console) + _add_search_rows(table, items) + + console.print() + console.print(table) + + metadata = results.get("metadata", {}) + total = metadata.get("totalItems", len(items)) + console.print(f"\n[dim]Showing {len(items)} of {total:,} results[/dim]") + console.print("[dim]Use 'tsr get ' to view details or 'tsr dl -m ' to download[/dim]") diff --git a/tensors/safetensor.py b/tensors/safetensor.py new file mode 100644 index 0000000..899ca38 --- /dev/null +++ b/tensors/safetensor.py @@ -0,0 +1,92 @@ +"""Safetensor file reading functions.""" + +from __future__ import annotations + +import hashlib +import json +import struct +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from pathlib import Path + +from rich.progress import ( + BarColumn, + DownloadColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, + TimeRemainingColumn, + TransferSpeedColumn, +) + +if TYPE_CHECKING: + from rich.console import Console + + +def read_safetensor_metadata(file_path: Path) -> dict[str, Any]: + """Read metadata from a safetensor file header.""" + with file_path.open("rb") as f: + # First 8 bytes are the header size (little-endian u64) + header_size_bytes = f.read(8) + if len(header_size_bytes) < 8: + raise ValueError("Invalid safetensor file: too short") + + header_size = struct.unpack(" 100_000_000: # 100MB sanity check + raise ValueError(f"Invalid header size: {header_size}") + + header_bytes = f.read(header_size) + if len(header_bytes) < header_size: + raise ValueError("Invalid safetensor file: header truncated") + + header: dict[str, Any] = json.loads(header_bytes.decode("utf-8")) + + # Extract __metadata__ if present + metadata: dict[str, Any] = header.get("__metadata__", {}) + + # Count tensors (keys that aren't __metadata__) + tensor_count = sum(1 for k in header if k != "__metadata__") + + return { + "metadata": metadata, + "tensor_count": tensor_count, + "header_size": header_size, + } + + +def compute_sha256(file_path: Path, console: Console) -> str: + """Compute SHA256 hash of a file with progress display.""" + file_size = file_path.stat().st_size + sha256 = hashlib.sha256() + chunk_size = 1024 * 1024 * 8 # 8MB chunks + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + DownloadColumn(), + TransferSpeedColumn(), + TimeRemainingColumn(), + console=console, + ) as progress: + task = progress.add_task(f"[cyan]Hashing {file_path.name}...", total=file_size) + + with file_path.open("rb") as f: + while chunk := f.read(chunk_size): + sha256.update(chunk) + progress.update(task, advance=len(chunk)) + + return sha256.hexdigest().upper() + + +def get_base_name(file_path: Path) -> str: + """Get base filename without .safetensors extension.""" + name = file_path.name + for ext in (".safetensors", ".sft"): + if name.lower().endswith(ext): + return name[: -len(ext)] + return file_path.stem diff --git a/tests/test_tensors.py b/tests/test_tensors.py index 9672d6c..134350d 100644 --- a/tests/test_tensors.py +++ b/tests/test_tensors.py @@ -7,13 +7,9 @@ from pathlib import Path import pytest -import tensors -from tensors import ( - get_base_name, - get_default_output_path, - load_api_key, - read_safetensor_metadata, -) +from tensors import config +from tensors.config import get_default_output_path, load_api_key +from tensors.safetensor import get_base_name, read_safetensor_metadata class TestReadSafetensorMetadata: @@ -111,28 +107,24 @@ class TestLoadApiKey: """Test that None is returned when no key is available.""" monkeypatch.delenv("CIVITAI_API_KEY", raising=False) # Point config and legacy files to nonexistent paths - monkeypatch.setattr(tensors, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml") - monkeypatch.setattr(tensors, "LEGACY_RC_FILE", tmp_path / "nonexistent") + monkeypatch.setattr(config, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml") + monkeypatch.setattr(config, "LEGACY_RC_FILE", tmp_path / "nonexistent") assert load_api_key() is None - def test_returns_key_from_config_file( - self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path - ) -> None: + def test_returns_key_from_config_file(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: """Test that key is loaded from TOML config file.""" monkeypatch.delenv("CIVITAI_API_KEY", raising=False) config_file = tmp_path / "config.toml" config_file.write_text('[api]\ncivitai_key = "key-from-config"\n') - monkeypatch.setattr(tensors, "CONFIG_FILE", config_file) - monkeypatch.setattr(tensors, "LEGACY_RC_FILE", tmp_path / "nonexistent") + monkeypatch.setattr(config, "CONFIG_FILE", config_file) + monkeypatch.setattr(config, "LEGACY_RC_FILE", tmp_path / "nonexistent") assert load_api_key() == "key-from-config" - def test_returns_key_from_legacy_file( - self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path - ) -> None: + def test_returns_key_from_legacy_file(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: """Test that key is loaded from legacy RC file when no config exists.""" monkeypatch.delenv("CIVITAI_API_KEY", raising=False) legacy_file = tmp_path / ".sftrc" legacy_file.write_text("legacy-key") - monkeypatch.setattr(tensors, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml") - monkeypatch.setattr(tensors, "LEGACY_RC_FILE", legacy_file) + monkeypatch.setattr(config, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml") + monkeypatch.setattr(config, "LEGACY_RC_FILE", legacy_file) assert load_api_key() == "legacy-key"