From a92c9fb83ac44e42f6f18c737bd5af561b5f5dc2 Mon Sep 17 00:00:00 2001 From: Adam Ladachowski Date: Sat, 14 Feb 2026 01:33:56 +0100 Subject: [PATCH] Phase 2.2: Add tsr db CLI commands Add database management commands to CLI: - tsr db scan - Scan safetensors, compute hashes, store metadata - tsr db link - Match unlinked files to CivitAI by hash lookup - tsr db cache - Fetch and cache full CivitAI model data - tsr db list - List local files with CivitAI info - tsr db search - Search cached models offline - tsr db triggers - Show trigger words for a LoRA - tsr db stats - Show database statistics Update API functions to accept optional console for quiet/batch operations. Co-Authored-By: Claude Opus 4.5 --- TODO.md | 2 +- tensors/api.py | 66 ++++++++------ tensors/cli.py | 239 ++++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 280 insertions(+), 27 deletions(-) diff --git a/TODO.md b/TODO.md index d7f4e34..28e5112 100644 --- a/TODO.md +++ b/TODO.md @@ -7,7 +7,7 @@ ## Phase 2: Models Database in tensors - [x] Step 2.1: Create `tensors/db.py` + `tensors/schema.sql` (SQLite wrapper, schema, CRUD) -- [ ] Step 2.2: Add `tsr db` CLI commands (scan, link, cache, list, search, triggers) +- [x] Step 2.2: Add `tsr db` CLI commands (scan, link, cache, list, search, triggers, stats) - [ ] Step 2.3: Add `/api/db/*` endpoints (files, models, triggers, scan, link) ## Phase 3: Enhanced Server API diff --git a/tensors/api.py b/tensors/api.py index 3d49e56..0cb2a60 100644 --- a/tensors/api.py +++ b/tensors/api.py @@ -35,7 +35,7 @@ def _get_headers(api_key: str | None) -> dict[str, str]: return headers -def fetch_civitai_model_version(version_id: int, api_key: str | None, console: Console) -> dict[str, Any] | None: +def fetch_civitai_model_version(version_id: int, api_key: str | None, console: Console | None = None) -> dict[str, Any] | None: """Fetch model version information from CivitAI by version ID.""" url = f"{CIVITAI_API_BASE}/model-versions/{version_id}" @@ -47,25 +47,20 @@ def fetch_civitai_model_version(version_id: int, api_key: str | None, console: C result: dict[str, Any] = response.json() return result except httpx.HTTPStatusError as e: - console.print(f"[red]API error: {e.response.status_code}[/red]") + if console: + console.print(f"[red]API error: {e.response.status_code}[/red]") return None except httpx.RequestError as e: - console.print(f"[red]Request error: {e}[/red]") + if console: + console.print(f"[red]Request error: {e}[/red]") return None -def fetch_civitai_model(model_id: int, api_key: str | None, console: Console) -> dict[str, Any] | None: +def fetch_civitai_model(model_id: int, api_key: str | None, console: Console | None = None) -> dict[str, Any] | None: """Fetch model information from CivitAI by model ID.""" url = f"{CIVITAI_API_BASE}/models/{model_id}" - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - transient=True, - ) as progress: - progress.add_task("[cyan]Fetching model from CivitAI...", total=None) - + def _do_fetch() -> dict[str, Any] | None: try: response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) if response.status_code == HTTPStatus.NOT_FOUND: @@ -74,25 +69,32 @@ def fetch_civitai_model(model_id: int, api_key: str | None, console: Console) -> result: dict[str, Any] = response.json() return result except httpx.HTTPStatusError as e: - console.print(f"[red]API error: {e.response.status_code}[/red]") + if console: + console.print(f"[red]API error: {e.response.status_code}[/red]") return None except httpx.RequestError as e: - console.print(f"[red]Request error: {e}[/red]") + if console: + console.print(f"[red]Request error: {e}[/red]") return None + if console: + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + transient=True, + ) as progress: + progress.add_task("[cyan]Fetching model from CivitAI...", total=None) + return _do_fetch() + else: + return _do_fetch() -def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None, console: Console) -> dict[str, Any] | None: + +def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None, console: Console | None = None) -> dict[str, Any] | None: """Fetch model information from CivitAI by SHA256 hash.""" url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}" - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - transient=True, - ) as progress: - progress.add_task("[cyan]Fetching from CivitAI...", total=None) - + def _do_fetch() -> dict[str, Any] | None: try: response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) if response.status_code == HTTPStatus.NOT_FOUND: @@ -101,12 +103,26 @@ def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None, console: Consol result: dict[str, Any] = response.json() return result except httpx.HTTPStatusError as e: - console.print(f"[red]API error: {e.response.status_code}[/red]") + if console: + console.print(f"[red]API error: {e.response.status_code}[/red]") return None except httpx.RequestError as e: - console.print(f"[red]Request error: {e}[/red]") + if console: + console.print(f"[red]Request error: {e}[/red]") return None + if console: + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + transient=True, + ) as progress: + progress.add_task("[cyan]Fetching from CivitAI...", total=None) + return _do_fetch() + else: + return _do_fetch() + def _build_search_params( query: str | None, diff --git a/tensors/cli.py b/tensors/cli.py index e0cb436..6489510 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -29,6 +29,7 @@ from tensors.config import ( load_config, save_config, ) +from tensors.db import DB_PATH, Database from tensors.display import ( _format_size, display_civitai_data, @@ -511,12 +512,248 @@ def serve( uvicorn.run(create_app(config), host=host, port=port, log_level=log_level) +# ============================================================================= +# Database Commands +# ============================================================================= + +db_app = typer.Typer( + name="db", + help="Manage local models database and CivitAI cache.", + no_args_is_help=True, +) +app.add_typer(db_app, name="db") + + +@db_app.command("scan") +def db_scan( + directory: Annotated[Path, typer.Argument(help="Directory to scan for safetensor files")], + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """Scan directory for safetensor files and add to database.""" + path = directory.resolve() + if not path.exists() or not path.is_dir(): + console.print(f"[red]Error: Directory not found: {path}[/red]") + raise typer.Exit(1) + + with Database() as db: + db.init_schema() + console.print(f"[cyan]Scanning {path}...[/cyan]") + results = db.scan_directory(path, console if not json_output else None) + + if json_output: + console.print_json(data=results) + else: + console.print(f"[green]Scanned {len(results)} file(s)[/green]") + for f in results: + console.print(f" • {f['file_path']}") + + +@db_app.command("link") +def db_link( + api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """Link unlinked local files to CivitAI by hash lookup.""" + key = api_key or load_api_key() + + with Database() as db: + db.init_schema() + unlinked = db.get_unlinked_files() + + if not unlinked: + console.print("[green]All files already linked.[/green]") + return + + console.print(f"[cyan]Found {len(unlinked)} unlinked file(s)[/cyan]") + linked: list[dict[str, Any]] = [] + + for file_info in unlinked: + sha256 = file_info["sha256"] + console.print(f"[dim]Looking up {sha256[:16]}...[/dim]") + + civitai_data = fetch_civitai_by_hash(sha256, key, console if not json_output else None) + if civitai_data: + version_id: int = civitai_data.get("id", 0) + model_id: int = civitai_data.get("modelId", 0) + if version_id and model_id: + db.link_file_to_civitai(file_info["id"], model_id, version_id) + linked.append( + { + "file": file_info["file_path"], + "model_id": model_id, + "version_id": version_id, + "name": civitai_data.get("name", ""), + } + ) + if not json_output: + console.print(f" [green]✓[/green] {civitai_data.get('name', 'N/A')}") + + if json_output: + console.print_json(data=linked) + else: + console.print(f"[green]Linked {len(linked)} file(s)[/green]") + + +@db_app.command("cache") +def db_cache( + model_id: Annotated[int, typer.Argument(help="CivitAI model ID to cache")], + api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """Fetch and cache full CivitAI model data.""" + key = api_key or load_api_key() + + model_data = fetch_civitai_model(model_id, key, console if not json_output else None) + if not model_data: + console.print(f"[red]Error: Model {model_id} not found on CivitAI.[/red]") + raise typer.Exit(1) + + with Database() as db: + db.init_schema() + internal_id = db.cache_model(model_data) + + if json_output: + console.print_json(data={"model_id": model_id, "internal_id": internal_id, "name": model_data.get("name")}) + else: + console.print(f"[green]Cached:[/green] {model_data.get('name')} (internal ID: {internal_id})") + + +@db_app.command("list") +def db_list( + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """List local files with CivitAI info.""" + with Database() as db: + db.init_schema() + files = db.list_local_files() + + if json_output: + console.print_json(data=files) + return + + if not files: + console.print("[yellow]No files in database. Run 'tsr db scan' first.[/yellow]") + return + + table = Table(title="Local Files", show_header=True, header_style="bold magenta") + table.add_column("Path", style="cyan", max_width=50) + table.add_column("Model", style="green") + table.add_column("Version", style="white") + table.add_column("Type", style="yellow") + table.add_column("Base", style="dim") + + for f in files: + path = Path(f["file_path"]).name + model = f.get("model_name") or "[dim]unlinked[/dim]" + version = f.get("version_name") or "" + model_type = f.get("model_type") or "" + base = f.get("base_model") or "" + table.add_row(path, model, version, model_type, base) + + console.print(table) + + +@db_app.command("search") +def db_search( + query: Annotated[str | None, typer.Argument(help="Search query")] = None, + model_type: Annotated[str | None, typer.Option("-t", "--type", help="Model type filter")] = None, + base_model: Annotated[str | None, typer.Option("-b", "--base", help="Base model filter")] = None, + limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 20, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """Search cached models offline.""" + with Database() as db: + db.init_schema() + results = db.search_models(query=query, model_type=model_type, base_model=base_model, limit=limit) + + if json_output: + console.print_json(data=results) + return + + if not results: + console.print("[yellow]No models found.[/yellow]") + return + + table = Table(title="Cached Models", show_header=True, header_style="bold magenta") + table.add_column("ID", style="dim") + table.add_column("Name", style="cyan") + table.add_column("Type", style="yellow") + table.add_column("Base", style="green") + table.add_column("Creator", style="dim") + table.add_column("Downloads", justify="right") + + for m in results: + table.add_row( + str(m.get("civitai_id", "")), + m.get("name", ""), + m.get("type", ""), + m.get("base_model", ""), + m.get("creator", ""), + str(m.get("download_count", 0)), + ) + + console.print(table) + + +@db_app.command("triggers") +def db_triggers( + file: Annotated[Path, typer.Argument(help="Path to safetensor file")], + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """Show trigger words for a LoRA file.""" + file_path = file.resolve() + if not file_path.exists(): + console.print(f"[red]Error: File not found: {file_path}[/red]") + raise typer.Exit(1) + + with Database() as db: + db.init_schema() + triggers = db.get_triggers(str(file_path)) + + if json_output: + console.print_json(data=triggers) + return + + if not triggers: + console.print("[yellow]No trigger words found. File may not be linked to CivitAI.[/yellow]") + console.print("[dim]Run 'tsr db link' to link files to CivitAI.[/dim]") + return + + console.print(f"[bold]Trigger words for {file_path.name}:[/bold]") + for word in triggers: + console.print(f" • {word}") + + +@db_app.command("stats") +def db_stats( + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """Show database statistics.""" + with Database() as db: + db.init_schema() + stats = db.get_stats() + + if json_output: + console.print_json(data={"db_path": str(DB_PATH), "stats": stats}) + return + + table = Table(title="Database Statistics", show_header=True, header_style="bold magenta") + table.add_column("Table", style="cyan") + table.add_column("Count", style="green", justify="right") + + for table_name, count in stats.items(): + table.add_row(table_name, str(count)) + + console.print(f"[dim]Database: {DB_PATH}[/dim]") + console.print(table) + + def main() -> int: """Main entry point.""" # Handle legacy invocation: tsr -> tsr info if len(sys.argv) > 1 and not sys.argv[1].startswith("-"): arg = sys.argv[1] - if arg not in ("info", "search", "get", "dl", "download", "config", "generate", "serve", "status", "reload") and ( + if arg not in ("info", "search", "get", "dl", "download", "config", "generate", "serve", "status", "reload", "db") and ( arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists() ): sys.argv = [sys.argv[0], "info", *sys.argv[1:]]