diff --git a/.coverage b/.coverage index 1d9f032..78fc8c4 100644 Binary files a/.coverage and b/.coverage differ diff --git a/tensors.py b/tensors.py index 9c148c8..29402a6 100644 --- a/tensors.py +++ b/tensors.py @@ -1,11 +1,10 @@ #!/usr/bin/env python3 """ -sft-get: Read safetensor metadata and fetch CivitAI model information. +tsr: Read safetensor metadata, search and download CivitAI models. """ from __future__ import annotations -import argparse import hashlib import json import os @@ -13,10 +12,12 @@ import re import struct import sys import tomllib +from enum import Enum from pathlib import Path -from typing import Any +from typing import Annotated, Any import httpx +import typer from rich.console import Console from rich.progress import ( BarColumn, @@ -30,8 +31,21 @@ from rich.progress import ( ) from rich.table import Table +# ============================================================================ +# App and Console Setup +# ============================================================================ + +app = typer.Typer( + name="tsr", + help="Read safetensor metadata, search and download CivitAI models.", + no_args_is_help=True, +) console = Console() +# ============================================================================ +# Configuration +# ============================================================================ + # XDG Base Directory spec: ~/.config/tensors/config.toml CONFIG_DIR = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config")) / "tensors" CONFIG_FILE = CONFIG_DIR / "config.toml" @@ -46,6 +60,48 @@ DEFAULT_PATHS: dict[str, Path] = { "LoCon": Path.home() / ".xm" / "models" / "loras", } +CIVITAI_API_BASE = "https://civitai.com/api/v1" +CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models" + + +# ============================================================================ +# Enums for CLI +# ============================================================================ + + +class ModelType(str, Enum): + """CivitAI model types.""" + + checkpoint = "Checkpoint" + lora = "LORA" + embedding = "TextualInversion" + vae = "VAE" + controlnet = "Controlnet" + locon = "LoCon" + + +class BaseModel(str, Enum): + """Common base models.""" + + sd15 = "SD 1.5" + sdxl = "SDXL 1.0" + pony = "Pony" + flux = "Flux.1 D" + illustrious = "Illustrious" + + +class SortOrder(str, Enum): + """Sort options for search.""" + + downloads = "Most Downloaded" + rating = "Highest Rated" + newest = "Newest" + + +# ============================================================================ +# Config Functions +# ============================================================================ + def load_config() -> dict[str, Any]: """Load configuration from TOML config file.""" @@ -107,8 +163,9 @@ def get_default_output_path(model_type: str | None) -> Path | None: return None -CIVITAI_API_BASE = "https://civitai.com/api/v1" -CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models" +# ============================================================================ +# Safetensor Functions +# ============================================================================ def read_safetensor_metadata(file_path: Path) -> dict[str, Any]: @@ -169,17 +226,36 @@ def compute_sha256(file_path: Path) -> str: return sha256.hexdigest().upper() +def get_base_name(file_path: Path) -> str: + """Get base filename without .safetensors extension.""" + name = file_path.name + for ext in (".safetensors", ".sft"): + if name.lower().endswith(ext): + return name[: -len(ext)] + return file_path.stem + + +# ============================================================================ +# CivitAI API Functions +# ============================================================================ + + +def _get_headers(api_key: str | None) -> dict[str, str]: + """Get headers for CivitAI API requests.""" + headers: dict[str, str] = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + return headers + + def fetch_civitai_model_version( version_id: int, api_key: str | None = None ) -> dict[str, Any] | None: """Fetch model version information from CivitAI by version ID.""" url = f"{CIVITAI_API_BASE}/model-versions/{version_id}" - headers: dict[str, str] = {} - if api_key: - headers["Authorization"] = f"Bearer {api_key}" try: - response = httpx.get(url, headers=headers, timeout=30.0) + response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) if response.status_code == 404: return None response.raise_for_status() @@ -196,9 +272,6 @@ def fetch_civitai_model_version( def fetch_civitai_model(model_id: int, api_key: str | None = None) -> dict[str, Any] | None: """Fetch model information from CivitAI by model ID.""" url = f"{CIVITAI_API_BASE}/models/{model_id}" - headers: dict[str, str] = {} - if api_key: - headers["Authorization"] = f"Bearer {api_key}" with Progress( SpinnerColumn(), @@ -209,7 +282,7 @@ def fetch_civitai_model(model_id: int, api_key: str | None = None) -> dict[str, progress.add_task("[cyan]Fetching model from CivitAI...", total=None) try: - response = httpx.get(url, headers=headers, timeout=30.0) + response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) if response.status_code == 404: return None response.raise_for_status() @@ -226,9 +299,6 @@ def fetch_civitai_model(model_id: int, api_key: str | None = None) -> dict[str, def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None = None) -> dict[str, Any] | None: """Fetch model information from CivitAI by SHA256 hash.""" url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}" - headers: dict[str, str] = {} - if api_key: - headers["Authorization"] = f"Bearer {api_key}" with Progress( SpinnerColumn(), @@ -239,7 +309,7 @@ def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None = None) -> dict[ progress.add_task("[cyan]Fetching from CivitAI...", total=None) try: - response = httpx.get(url, headers=headers, timeout=30.0) + response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) if response.status_code == 404: return None response.raise_for_status() @@ -253,16 +323,77 @@ def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None = None) -> dict[ return None +def search_civitai( + query: str | None = None, + model_type: ModelType | None = None, + base_model: BaseModel | None = None, + sort: SortOrder = SortOrder.downloads, + limit: int = 20, + api_key: str | None = None, +) -> dict[str, Any] | None: + """Search CivitAI models.""" + params: dict[str, Any] = { + "limit": min(limit, 100), + "nsfw": "true", + } + + # API quirk: query + filters don't work reliably together + # If we have filters, skip query and filter client-side + has_filters = model_type is not None or base_model is not None + + if query and not has_filters: + params["query"] = query + + if model_type: + params["types"] = model_type.value + + if base_model: + params["baseModels"] = base_model.value + + params["sort"] = sort.value + + # Request more if we need client-side filtering + if query and has_filters: + params["limit"] = 100 + + url = f"{CIVITAI_API_BASE}/models" + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + transient=True, + ) as progress: + progress.add_task("[cyan]Searching CivitAI...", total=None) + + try: + response = httpx.get(url, params=params, headers=_get_headers(api_key), timeout=30.0) + response.raise_for_status() + result: dict[str, Any] = response.json() + + # Client-side filtering when query + filters combined + if query and has_filters: + q_lower = query.lower() + result["items"] = [ + m for m in result.get("items", []) if q_lower in m.get("name", "").lower() + ][:limit] + + return result + except httpx.HTTPStatusError as e: + console.print(f"[red]API error: {e.response.status_code}[/red]") + return None + except httpx.RequestError as e: + console.print(f"[red]Request error: {e}[/red]") + return None + + def download_model( version_id: int, dest_path: Path, api_key: str | None = None, resume: bool = True, ) -> bool: - """Download a model from CivitAI by version ID with resume support. - - Returns True on success, False on failure. - """ + """Download a model from CivitAI by version ID with resume support.""" url = f"{CIVITAI_DOWNLOAD_BASE}/{version_id}" params: dict[str, str] = {} if api_key: @@ -286,20 +417,17 @@ def download_model( params=params, headers=headers, follow_redirects=True, - timeout=httpx.Timeout(30.0, read=None), # No read timeout for large files + timeout=httpx.Timeout(30.0, read=None), ) as response: - # Handle 416 Range Not Satisfiable (file already complete) if response.status_code == 416: console.print("[green]File already fully downloaded.[/green]") return True response.raise_for_status() - # Get total size from Content-Length or Content-Range content_length = response.headers.get("content-length") total_size = int(content_length) + initial_size if content_length else 0 - # Get filename from Content-Disposition if available content_disp = response.headers.get("content-disposition", "") if "filename=" in content_disp: match = re.search(r'filename="?([^";\n]+)"?', content_disp) @@ -323,7 +451,7 @@ def download_model( ) with dest_path.open(mode) as f: - for chunk in response.iter_bytes(1024 * 1024): # 1MB chunks + for chunk in response.iter_bytes(1024 * 1024): f.write(chunk) progress.update(task, advance=len(chunk)) @@ -340,6 +468,29 @@ def download_model( return False +# ============================================================================ +# Display Functions +# ============================================================================ + + +def _format_size(size_kb: float) -> str: + """Format size in KB to human-readable string.""" + if size_kb < 1024: + return f"{size_kb:.0f} KB" + if size_kb < 1024 * 1024: + return f"{size_kb / 1024:.1f} MB" + return f"{size_kb / 1024 / 1024:.2f} GB" + + +def _format_count(count: int) -> str: + """Format large numbers with K/M suffix.""" + if count < 1000: + return str(count) + if count < 1_000_000: + return f"{count / 1000:.1f}K" + return f"{count / 1_000_000:.1f}M" + + def _display_file_info(file_path: Path, local_metadata: dict[str, Any], sha256_hash: str) -> None: """Display file information table.""" file_table = Table(title="File Information", show_header=True, header_style="bold magenta") @@ -398,24 +549,18 @@ def _display_civitai_data(civitai_data: dict[str, Any] | None) -> None: civit_table.add_row("Base Model", str(civitai_data.get("baseModel", "N/A"))) civit_table.add_row("Created At", str(civitai_data.get("createdAt", "N/A"))) - # Trained words trained_words: list[str] = civitai_data.get("trainedWords", []) if trained_words: civit_table.add_row("Trigger Words", ", ".join(trained_words)) - # Download URL download_url = str(civitai_data.get("downloadUrl", "N/A")) civit_table.add_row("Download URL", download_url) - # File info from CivitAI files: list[dict[str, Any]] = civitai_data.get("files", []) for f in files: if f.get("primary"): civit_table.add_row("Primary File", str(f.get("name", "N/A"))) - civit_table.add_row( - "File Size (CivitAI)", - f"{f.get('sizeKB', 0) / 1024:.2f} MB", - ) + civit_table.add_row("File Size (CivitAI)", _format_size(f.get("sizeKB", 0))) meta: dict[str, Any] = f.get("metadata", {}) if meta: civit_table.add_row("Format", str(meta.get("format", "N/A"))) @@ -425,7 +570,6 @@ def _display_civitai_data(civitai_data: dict[str, Any] | None) -> None: console.print() console.print(civit_table) - # Model page link model_id = civitai_data.get("modelId") if model_id: console.print() @@ -436,7 +580,6 @@ def _display_civitai_data(civitai_data: dict[str, Any] | None) -> None: def _display_model_info(model_data: dict[str, Any]) -> None: """Display full CivitAI model information.""" - # Main model info table model_table = Table(title="Model Information", show_header=True, header_style="bold magenta") model_table.add_column("Property", style="cyan") model_table.add_column("Value", style="green", max_width=80) @@ -446,17 +589,14 @@ def _display_model_info(model_data: dict[str, Any]) -> None: model_table.add_row("Type", str(model_data.get("type", "N/A"))) model_table.add_row("NSFW", str(model_data.get("nsfw", False))) - # Creator info creator = model_data.get("creator", {}) if creator: model_table.add_row("Creator", str(creator.get("username", "N/A"))) - # Tags tags: list[str] = model_data.get("tags", []) if tags: model_table.add_row("Tags", ", ".join(tags[:10]) + ("..." if len(tags) > 10 else "")) - # Stats stats: dict[str, Any] = model_data.get("stats", {}) if stats: model_table.add_row("Downloads", f"{stats.get('downloadCount', 0):,}") @@ -465,7 +605,6 @@ def _display_model_info(model_data: dict[str, Any]) -> None: "Rating", f"{stats.get('rating', 0):.1f} ({stats.get('ratingCount', 0)} ratings)" ) - # Mode (archived/taken down) mode = model_data.get("mode") if mode: model_table.add_row("Status", str(mode)) @@ -473,7 +612,6 @@ def _display_model_info(model_data: dict[str, Any]) -> None: console.print() console.print(model_table) - # Versions table versions: list[dict[str, Any]] = model_data.get("modelVersions", []) if versions: ver_table = Table(title="Model Versions", show_header=True, header_style="bold magenta") @@ -488,15 +626,12 @@ def _display_model_info(model_data: dict[str, Any]) -> None: primary_file = next((f for f in files if f.get("primary")), files[0] if files else None) file_info = "" if primary_file: - size_kb = primary_file.get("sizeKB", 0) - size_str = ( - f"{size_kb / 1024:.0f} MB" - if size_kb < 1024 * 1024 - else f"{size_kb / 1024 / 1024:.1f} GB" + file_info = ( + f"{primary_file.get('name', 'N/A')} " + f"({_format_size(primary_file.get('sizeKB', 0))})" ) - file_info = f"{primary_file.get('name', 'N/A')} ({size_str})" - created = str(ver.get("createdAt", "N/A"))[:10] # Just date portion + created = str(ver.get("createdAt", "N/A"))[:10] ver_table.add_row( str(ver.get("id", "N/A")), str(ver.get("name", "N/A")), @@ -508,7 +643,6 @@ def _display_model_info(model_data: dict[str, Any]) -> None: console.print() console.print(ver_table) - # Model page link model_id = model_data.get("id") if model_id: console.print() @@ -517,82 +651,94 @@ def _display_model_info(model_data: dict[str, Any]) -> None: ) -def display_results( - file_path: Path, - local_metadata: dict[str, Any], - sha256_hash: str, - civitai_data: dict[str, Any] | None, +def _display_search_results(results: dict[str, Any]) -> None: + """Display search results in a table.""" + items = results.get("items", []) + if not items: + console.print("[yellow]No results found.[/yellow]") + return + + table = Table(show_header=True, header_style="bold magenta") + table.add_column("ID", style="cyan", justify="right") + table.add_column("Name", style="green", max_width=40) + table.add_column("Type", style="yellow") + table.add_column("Base", style="blue") + table.add_column("Size", justify="right") + table.add_column("DLs", justify="right") + table.add_column("Rating", justify="right") + + for model in items: + model_id = str(model.get("id", "")) + name = model.get("name", "N/A") + if len(name) > 40: + name = name[:37] + "..." + model_type = model.get("type", "N/A") + + # Get latest version info + versions = model.get("modelVersions", []) + base_model = "N/A" + size = "N/A" + if versions: + latest = versions[0] + base_model = latest.get("baseModel", "N/A") + files = latest.get("files", []) + primary = next((f for f in files if f.get("primary")), files[0] if files else None) + if primary: + size = _format_size(primary.get("sizeKB", 0)) + + stats = model.get("stats", {}) + downloads = _format_count(stats.get("downloadCount", 0)) + rating = f"{stats.get('rating', 0):.1f}" + + table.add_row(model_id, name, model_type, base_model, size, downloads, rating) + + console.print() + console.print(table) + + metadata = results.get("metadata", {}) + total = metadata.get("totalItems", len(items)) + console.print(f"\n[dim]Showing {len(items)} of {total:,} results[/dim]") + console.print("[dim]Use 'tsr get ' to view details or 'tsr dl -m ' to download[/dim]") + + +# ============================================================================ +# CLI Commands +# ============================================================================ + + +@app.command() +def info( + file: Annotated[Path, typer.Argument(help="Path to the safetensor file")], + api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, + skip_civitai: Annotated[ + bool, typer.Option("--skip-civitai", help="Skip CivitAI API lookup") + ] = False, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, + save_to: Annotated[ + Path | None, typer.Option("--save-to", help="Save metadata to directory") + ] = None, ) -> None: - """Display results in rich tables.""" - _display_file_info(file_path, local_metadata, sha256_hash) - _display_local_metadata(local_metadata) - _display_civitai_data(civitai_data) - - -def get_base_name(file_path: Path) -> str: - """Get base filename without .safetensors extension.""" - name = file_path.name - for ext in (".safetensors", ".sft"): - if name.lower().endswith(ext): - return name[: -len(ext)] - return file_path.stem - - -def save_metadata( - file_path: Path, - sha256_hash: str, - local_metadata: dict[str, Any], - civitai_data: dict[str, Any] | None, - output_dir: Path, -) -> tuple[Path, Path]: - """Save metadata JSON and SHA256 hash to the specified output directory.""" - base_name = get_base_name(file_path) - - # Save JSON metadata - json_path = output_dir / f"{base_name}-xm.json" - output = { - "file": str(file_path), - "sha256": sha256_hash, - "header_size": local_metadata["header_size"], - "tensor_count": local_metadata["tensor_count"], - "metadata": local_metadata["metadata"], - "civitai": civitai_data, - } - json_path.write_text(json.dumps(output, indent=2)) - - # Save SHA256 hash - sha_path = output_dir / f"{base_name}-xm.sha256" - sha_path.write_text(f"{sha256_hash} {file_path.name}\n") - - return json_path, sha_path - - -def cmd_info(args: argparse.Namespace) -> int: - """Handle the info subcommand (default behavior).""" - file_path: Path = args.file.resolve() + """Read safetensor metadata and fetch CivitAI info.""" + file_path = file.resolve() if not file_path.exists(): console.print(f"[red]Error: File not found: {file_path}[/red]") - return 1 + raise typer.Exit(1) if file_path.suffix.lower() not in (".safetensors", ".sft"): console.print("[yellow]Warning: File does not have .safetensors extension[/yellow]") try: - # Read local metadata console.print(f"[bold]Reading safetensor file:[/bold] {file_path.name}") local_metadata = read_safetensor_metadata(file_path) - - # Compute SHA256 sha256_hash = compute_sha256(file_path) - # Fetch from CivitAI civitai_data = None - if not args.skip_civitai: - api_key = args.api_key or load_api_key() - civitai_data = fetch_civitai_by_hash(sha256_hash, api_key) + if not skip_civitai: + key = api_key or load_api_key() + civitai_data = fetch_civitai_by_hash(sha256_hash, key) - if args.json_output: + if json_output: output = { "file": str(file_path), "sha256": sha256_hash, @@ -603,296 +749,284 @@ def cmd_info(args: argparse.Namespace) -> int: } console.print_json(data=output) else: - display_results(file_path, local_metadata, sha256_hash, civitai_data) + _display_file_info(file_path, local_metadata, sha256_hash) + _display_local_metadata(local_metadata) + _display_civitai_data(civitai_data) + + if save_to: + output_dir = save_to.resolve() + if not output_dir.exists() or not output_dir.is_dir(): + console.print(f"[red]Error: Invalid directory: {output_dir}[/red]") + raise typer.Exit(1) + + base_name = get_base_name(file_path) + json_path = output_dir / f"{base_name}-xm.json" + sha_path = output_dir / f"{base_name}-xm.sha256" + + output = { + "file": str(file_path), + "sha256": sha256_hash, + "header_size": local_metadata["header_size"], + "tensor_count": local_metadata["tensor_count"], + "metadata": local_metadata["metadata"], + "civitai": civitai_data, + } + json_path.write_text(json.dumps(output, indent=2)) + sha_path.write_text(f"{sha256_hash} {file_path.name}\n") - # Save files if requested - if args.save_to: - output_dir: Path = args.save_to.resolve() - if not output_dir.exists(): - console.print(f"[red]Error: Output directory not found: {output_dir}[/red]") - return 1 - if not output_dir.is_dir(): - console.print(f"[red]Error: Not a directory: {output_dir}[/red]") - return 1 - json_path, sha_path = save_metadata( - file_path, sha256_hash, local_metadata, civitai_data, output_dir - ) console.print() console.print(f"[green]Saved:[/green] {json_path}") console.print(f"[green]Saved:[/green] {sha_path}") - return 0 - except ValueError as e: console.print(f"[red]Error reading safetensor: {e}[/red]") - return 1 - except Exception as e: - console.print(f"[red]Unexpected error: {e}[/red]") - return 1 + raise typer.Exit(1) from e + + +@app.command() +def search( + query: Annotated[str | None, typer.Argument(help="Search query (optional)")] = None, + model_type: Annotated[ + ModelType | None, typer.Option("-t", "--type", help="Model type filter") + ] = None, + base: Annotated[ + BaseModel | None, typer.Option("-b", "--base", help="Base model filter") + ] = None, + sort: Annotated[ + SortOrder, typer.Option("-s", "--sort", help="Sort order") + ] = SortOrder.downloads, + limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 20, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, + api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, +) -> None: + """Search CivitAI models.""" + key = api_key or load_api_key() + + results = search_civitai( + query=query, + model_type=model_type, + base_model=base, + sort=sort, + limit=limit, + api_key=key, + ) + + if not results: + console.print("[red]Search failed.[/red]") + raise typer.Exit(1) + + if json_output: + console.print_json(data=results) + else: + _display_search_results(results) + + +@app.command() +def get( + id_value: Annotated[int, typer.Argument(help="CivitAI model ID or version ID")], + version: Annotated[ + bool, typer.Option("-v", "--version", help="Treat ID as version ID instead of model ID") + ] = False, + api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """Fetch model information from CivitAI by model ID or version ID.""" + key = api_key or load_api_key() + + if version: + # Fetch by version ID + version_data = fetch_civitai_model_version(id_value, key) + if not version_data: + console.print(f"[red]Error: Version {id_value} not found on CivitAI.[/red]") + raise typer.Exit(1) + + if json_output: + console.print_json(data=version_data) + else: + _display_civitai_data(version_data) + else: + # Fetch by model ID + model_data = fetch_civitai_model(id_value, key) + if not model_data: + console.print(f"[red]Error: Model {id_value} not found on CivitAI.[/red]") + raise typer.Exit(1) + + if json_output: + console.print_json(data=model_data) + else: + _display_model_info(model_data) def _resolve_version_id( version_id: int | None, - sha256_hash: str | None, + hash_val: str | None, model_id: int | None, api_key: str | None, ) -> int | None: - """Resolve version ID from hash or model ID if needed.""" + """Resolve version ID from hash or model ID.""" if version_id: return version_id - if sha256_hash: - console.print(f"[cyan]Looking up model by hash: {sha256_hash[:16]}...[/cyan]") - civitai_data = fetch_civitai_by_hash(sha256_hash.upper(), api_key) + + if hash_val: + console.print(f"[cyan]Looking up model by hash: {hash_val[:16]}...[/cyan]") + civitai_data = fetch_civitai_by_hash(hash_val.upper(), api_key) if not civitai_data: console.print("[red]Error: Model not found on CivitAI for this hash.[/red]") return None - vid = civitai_data.get("id") + vid: int | None = civitai_data.get("id") if vid: - console.print(f"[green]Found model version:[/green] {civitai_data.get('name', 'N/A')}") - else: - console.print("[red]Error: Could not determine version ID from CivitAI response.[/red]") + console.print(f"[green]Found:[/green] {civitai_data.get('name', 'N/A')}") return vid + if model_id: console.print(f"[cyan]Looking up model {model_id}...[/cyan]") model_data = fetch_civitai_model(model_id, api_key) if not model_data: - console.print(f"[red]Error: Model {model_id} not found on CivitAI.[/red]") + console.print(f"[red]Error: Model {model_id} not found.[/red]") return None - versions: list[dict[str, Any]] = model_data.get("modelVersions", []) + versions = model_data.get("modelVersions", []) if not versions: console.print("[red]Error: Model has no versions.[/red]") return None - # First version is the latest latest = versions[0] - vid = latest.get("id") - if vid: - console.print( - f"[green]Found latest version:[/green] {latest.get('name', 'N/A')} (ID: {vid})" - ) - return vid + latest_vid: int | None = latest.get("id") + if latest_vid: + name = latest.get("name", "N/A") + console.print(f"[green]Found latest:[/green] {name} (ID: {latest_vid})") + return latest_vid + return None -def cmd_download(args: argparse.Namespace) -> int: - """Handle the download subcommand.""" - api_key: str | None = args.api_key or load_api_key() - - # Resolve version ID from hash or model ID if needed - version_id = _resolve_version_id( - args.version_id, args.hash, getattr(args, "model_id", None), api_key - ) - if not version_id: - if not args.version_id and not args.hash and not getattr(args, "model_id", None): - console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]") - return 1 - - # Fetch version info to get filename and model type - console.print(f"[cyan]Fetching model info for version {version_id}...[/cyan]") - version_info = fetch_civitai_model_version(version_id, api_key) - - if not version_info: - console.print("[red]Error: Could not fetch model version info.[/red]") - return 1 - - # Determine model type for default path - model_type: str | None = version_info.get("model", {}).get("type") - - # Determine output directory - if args.output is None: - # Use model type-based default - output_dir = get_default_output_path(model_type) +def _prepare_download_dir(output: Path | None, model_type_str: str | None) -> Path | None: + """Prepare output directory for download.""" + if output is None: + output_dir = get_default_output_path(model_type_str) if output_dir is None: console.print( - f"[red]Error: No default path for model type '{model_type}'. " + f"[red]Error: No default path for type '{model_type_str}'. " "Use --output to specify.[/red]" ) - return 1 - console.print(f"[dim]Using default path for {model_type}: {output_dir}[/dim]") + return None + console.print(f"[dim]Using default path for {model_type_str}: {output_dir}[/dim]") else: - output_dir = args.output.resolve() + output_dir = output.resolve() - # Create directory if it doesn't exist if not output_dir.exists(): console.print(f"[cyan]Creating directory: {output_dir}[/cyan]") output_dir.mkdir(parents=True, exist_ok=True) elif not output_dir.is_dir(): console.print(f"[red]Error: Not a directory: {output_dir}[/red]") - return 1 + return None + + return output_dir + + +@app.command("dl") +def download( + version_id: Annotated[ + int | None, typer.Option("-v", "--version-id", help="Model version ID") + ] = None, + model_id: Annotated[ + int | None, typer.Option("-m", "--model-id", help="Model ID (downloads latest)") + ] = None, + hash_val: Annotated[ + str | None, typer.Option("-H", "--hash", help="SHA256 hash to look up") + ] = None, + output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None, + no_resume: Annotated[ + bool, typer.Option("--no-resume", help="Don't resume partial downloads") + ] = False, + api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, +) -> None: + """Download a model from CivitAI.""" + key = api_key or load_api_key() + + resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key) + if not resolved_version_id: + if not version_id and not hash_val and not model_id: + console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]") + raise typer.Exit(1) + + console.print(f"[cyan]Fetching version info for {resolved_version_id}...[/cyan]") + version_info = fetch_civitai_model_version(resolved_version_id, key) + if not version_info: + console.print("[red]Error: Could not fetch model version info.[/red]") + raise typer.Exit(1) + + model_type_str: str | None = version_info.get("model", {}).get("type") + output_dir = _prepare_download_dir(output, model_type_str) + if not output_dir: + raise typer.Exit(1) - # Find primary file or first file files: list[dict[str, Any]] = version_info.get("files", []) primary_file = next((f for f in files if f.get("primary")), files[0] if files else None) - if not primary_file: - console.print("[red]Error: No files found for this model version.[/red]") - return 1 + console.print("[red]Error: No files found for this version.[/red]") + raise typer.Exit(1) - filename = primary_file.get("name", f"model-{version_id}.safetensors") + filename = primary_file.get("name", f"model-{resolved_version_id}.safetensors") dest_path = output_dir / filename - # Display model info - model_table = Table(title="Model Download", show_header=True, header_style="bold magenta") - model_table.add_column("Property", style="cyan") - model_table.add_column("Value", style="green") - model_table.add_row("Version", version_info.get("name", "N/A")) - model_table.add_row("Base Model", version_info.get("baseModel", "N/A")) - model_table.add_row("File", filename) - model_table.add_row("Size", f"{primary_file.get('sizeKB', 0) / 1024:.2f} MB") - model_table.add_row("Destination", str(dest_path)) + table = Table(title="Model Download", show_header=True, header_style="bold magenta") + table.add_column("Property", style="cyan") + table.add_column("Value", style="green") + table.add_row("Version", version_info.get("name", "N/A")) + table.add_row("Base Model", version_info.get("baseModel", "N/A")) + table.add_row("File", filename) + table.add_row("Size", _format_size(primary_file.get("sizeKB", 0))) + table.add_row("Destination", str(dest_path)) console.print() - console.print(model_table) + console.print(table) console.print() - # Download - success = download_model(version_id, dest_path, api_key, resume=not args.no_resume) - return 0 if success else 1 + success = download_model(resolved_version_id, dest_path, key, resume=not no_resume) + if not success: + raise typer.Exit(1) -def cmd_get(args: argparse.Namespace) -> int: - """Handle the get subcommand - fetch model info by ID.""" - model_id: int = args.model_id - api_key: str | None = args.api_key or load_api_key() +@app.command() +def config( + show: Annotated[bool, typer.Option("--show", help="Show current config")] = False, + set_key: Annotated[str | None, typer.Option("--set-key", help="Set CivitAI API key")] = None, +) -> None: + """Manage configuration.""" + if set_key: + cfg = load_config() + if "api" not in cfg: + cfg["api"] = {} + cfg["api"]["civitai_key"] = set_key + save_config(cfg) + console.print(f"[green]API key saved to {CONFIG_FILE}[/green]") + return - model_data = fetch_civitai_model(model_id, api_key) + if show or (not set_key): + console.print(f"[bold]Config file:[/bold] {CONFIG_FILE}") + console.print(f"[bold]Config exists:[/bold] {CONFIG_FILE.exists()}") - if not model_data: - console.print(f"[red]Error: Model {model_id} not found on CivitAI.[/red]") - return 1 + key = load_api_key() + if key: + masked = key[:4] + "..." + key[-4:] if len(key) > 8 else "***" + console.print(f"[bold]API key:[/bold] {masked}") + else: + console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]") - if args.json_output: - console.print_json(data=model_data) - else: - _display_model_info(model_data) - - return 0 + console.print() + console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]") def main() -> int: """Main entry point.""" - parser = argparse.ArgumentParser( - description="Read safetensor metadata and download CivitAI models.", - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - subparsers = parser.add_subparsers(dest="command", help="Commands") + # Handle legacy invocation: tsr -> tsr info + if len(sys.argv) > 1 and not sys.argv[1].startswith("-"): + arg = sys.argv[1] + if arg not in ("info", "search", "get", "dl", "download", "config") and ( + arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists() + ): + sys.argv = [sys.argv[0], "info", *sys.argv[1:]] - # Info command (default) - info_parser = subparsers.add_parser( - "info", - help="Read safetensor metadata and fetch CivitAI info (default)", - ) - info_parser.add_argument( - "file", - type=Path, - help="Path to the safetensor file", - ) - info_parser.add_argument( - "--api-key", - type=str, - default=None, - help="CivitAI API key for authenticated requests", - ) - info_parser.add_argument( - "--skip-civitai", - action="store_true", - help="Skip CivitAI API lookup", - ) - info_parser.add_argument( - "--json", - action="store_true", - dest="json_output", - help="Output results as JSON", - ) - info_parser.add_argument( - "--save-to", - type=Path, - metavar="DIR", - help="Save metadata JSON and SHA256 hash to the specified directory", - ) - info_parser.set_defaults(func=cmd_info) - - # Download command - dl_parser = subparsers.add_parser( - "download", - aliases=["dl"], - help="Download a model from CivitAI", - ) - dl_parser.add_argument( - "--version-id", - "-v", - type=int, - help="CivitAI model version ID to download", - ) - dl_parser.add_argument( - "--model-id", - "-m", - type=int, - help="CivitAI model ID (downloads latest version)", - ) - dl_parser.add_argument( - "--hash", - "-H", - type=str, - help="SHA256 hash to look up and download", - ) - dl_parser.add_argument( - "--api-key", - type=str, - default=None, - help="CivitAI API key for authenticated requests", - ) - dl_parser.add_argument( - "--output", - "-o", - type=Path, - default=None, - help="Output directory (default: type-based, e.g. ~/.xm/models/checkpoints for Checkpoint)", - ) - dl_parser.add_argument( - "--no-resume", - action="store_true", - help="Don't resume partial downloads, start fresh", - ) - dl_parser.set_defaults(func=cmd_download) - - # Get command - get_parser = subparsers.add_parser( - "get", - help="Fetch model information from CivitAI by model ID", - ) - get_parser.add_argument( - "model_id", - type=int, - help="CivitAI model ID", - ) - get_parser.add_argument( - "--api-key", - type=str, - default=None, - help="CivitAI API key for authenticated requests", - ) - get_parser.add_argument( - "--json", - action="store_true", - dest="json_output", - help="Output results as JSON", - ) - get_parser.set_defaults(func=cmd_get) - - # Parse and handle default command - args = parser.parse_args() - - # If no command specified and file argument given, assume 'info' command - if args.command is None: - # Check if there's a positional argument (file path) - if len(sys.argv) > 1 and not sys.argv[1].startswith("-"): - # Re-parse with 'info' prepended - args = parser.parse_args(["info", *sys.argv[1:]]) - else: - parser.print_help() - return 0 - - result: int = args.func(args) - return result + app() + return 0 if __name__ == "__main__":