Phase 2.2: Add tsr db CLI commands
Add database management commands to CLI: - tsr db scan <directory> - Scan safetensors, compute hashes, store metadata - tsr db link - Match unlinked files to CivitAI by hash lookup - tsr db cache <model_id> - Fetch and cache full CivitAI model data - tsr db list - List local files with CivitAI info - tsr db search - Search cached models offline - tsr db triggers <file> - Show trigger words for a LoRA - tsr db stats - Show database statistics Update API functions to accept optional console for quiet/batch operations. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -7,7 +7,7 @@
|
|||||||
|
|
||||||
## Phase 2: Models Database in tensors
|
## Phase 2: Models Database in tensors
|
||||||
- [x] Step 2.1: Create `tensors/db.py` + `tensors/schema.sql` (SQLite wrapper, schema, CRUD)
|
- [x] Step 2.1: Create `tensors/db.py` + `tensors/schema.sql` (SQLite wrapper, schema, CRUD)
|
||||||
- [ ] Step 2.2: Add `tsr db` CLI commands (scan, link, cache, list, search, triggers)
|
- [x] Step 2.2: Add `tsr db` CLI commands (scan, link, cache, list, search, triggers, stats)
|
||||||
- [ ] Step 2.3: Add `/api/db/*` endpoints (files, models, triggers, scan, link)
|
- [ ] Step 2.3: Add `/api/db/*` endpoints (files, models, triggers, scan, link)
|
||||||
|
|
||||||
## Phase 3: Enhanced Server API
|
## Phase 3: Enhanced Server API
|
||||||
|
|||||||
+37
-21
@@ -35,7 +35,7 @@ def _get_headers(api_key: str | None) -> dict[str, str]:
|
|||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
|
||||||
def fetch_civitai_model_version(version_id: int, api_key: str | None, console: Console) -> dict[str, Any] | None:
|
def fetch_civitai_model_version(version_id: int, api_key: str | None, console: Console | None = None) -> dict[str, Any] | None:
|
||||||
"""Fetch model version information from CivitAI by version ID."""
|
"""Fetch model version information from CivitAI by version ID."""
|
||||||
url = f"{CIVITAI_API_BASE}/model-versions/{version_id}"
|
url = f"{CIVITAI_API_BASE}/model-versions/{version_id}"
|
||||||
|
|
||||||
@@ -47,17 +47,37 @@ def fetch_civitai_model_version(version_id: int, api_key: str | None, console: C
|
|||||||
result: dict[str, Any] = response.json()
|
result: dict[str, Any] = response.json()
|
||||||
return result
|
return result
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
|
if console:
|
||||||
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
||||||
return None
|
return None
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
|
if console:
|
||||||
console.print(f"[red]Request error: {e}[/red]")
|
console.print(f"[red]Request error: {e}[/red]")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def fetch_civitai_model(model_id: int, api_key: str | None, console: Console) -> dict[str, Any] | None:
|
def fetch_civitai_model(model_id: int, api_key: str | None, console: Console | None = None) -> dict[str, Any] | None:
|
||||||
"""Fetch model information from CivitAI by model ID."""
|
"""Fetch model information from CivitAI by model ID."""
|
||||||
url = f"{CIVITAI_API_BASE}/models/{model_id}"
|
url = f"{CIVITAI_API_BASE}/models/{model_id}"
|
||||||
|
|
||||||
|
def _do_fetch() -> dict[str, Any] | None:
|
||||||
|
try:
|
||||||
|
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
||||||
|
if response.status_code == HTTPStatus.NOT_FOUND:
|
||||||
|
return None
|
||||||
|
response.raise_for_status()
|
||||||
|
result: dict[str, Any] = response.json()
|
||||||
|
return result
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
if console:
|
||||||
|
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
||||||
|
return None
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
if console:
|
||||||
|
console.print(f"[red]Request error: {e}[/red]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if console:
|
||||||
with Progress(
|
with Progress(
|
||||||
SpinnerColumn(),
|
SpinnerColumn(),
|
||||||
TextColumn("[progress.description]{task.description}"),
|
TextColumn("[progress.description]{task.description}"),
|
||||||
@@ -65,7 +85,16 @@ def fetch_civitai_model(model_id: int, api_key: str | None, console: Console) ->
|
|||||||
transient=True,
|
transient=True,
|
||||||
) as progress:
|
) as progress:
|
||||||
progress.add_task("[cyan]Fetching model from CivitAI...", total=None)
|
progress.add_task("[cyan]Fetching model from CivitAI...", total=None)
|
||||||
|
return _do_fetch()
|
||||||
|
else:
|
||||||
|
return _do_fetch()
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None, console: Console | None = None) -> dict[str, Any] | None:
|
||||||
|
"""Fetch model information from CivitAI by SHA256 hash."""
|
||||||
|
url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}"
|
||||||
|
|
||||||
|
def _do_fetch() -> dict[str, Any] | None:
|
||||||
try:
|
try:
|
||||||
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
||||||
if response.status_code == HTTPStatus.NOT_FOUND:
|
if response.status_code == HTTPStatus.NOT_FOUND:
|
||||||
@@ -74,17 +103,15 @@ def fetch_civitai_model(model_id: int, api_key: str | None, console: Console) ->
|
|||||||
result: dict[str, Any] = response.json()
|
result: dict[str, Any] = response.json()
|
||||||
return result
|
return result
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
|
if console:
|
||||||
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
||||||
return None
|
return None
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
|
if console:
|
||||||
console.print(f"[red]Request error: {e}[/red]")
|
console.print(f"[red]Request error: {e}[/red]")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if console:
|
||||||
def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None, console: Console) -> dict[str, Any] | None:
|
|
||||||
"""Fetch model information from CivitAI by SHA256 hash."""
|
|
||||||
url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}"
|
|
||||||
|
|
||||||
with Progress(
|
with Progress(
|
||||||
SpinnerColumn(),
|
SpinnerColumn(),
|
||||||
TextColumn("[progress.description]{task.description}"),
|
TextColumn("[progress.description]{task.description}"),
|
||||||
@@ -92,20 +119,9 @@ def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None, console: Consol
|
|||||||
transient=True,
|
transient=True,
|
||||||
) as progress:
|
) as progress:
|
||||||
progress.add_task("[cyan]Fetching from CivitAI...", total=None)
|
progress.add_task("[cyan]Fetching from CivitAI...", total=None)
|
||||||
|
return _do_fetch()
|
||||||
try:
|
else:
|
||||||
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
return _do_fetch()
|
||||||
if response.status_code == HTTPStatus.NOT_FOUND:
|
|
||||||
return None
|
|
||||||
response.raise_for_status()
|
|
||||||
result: dict[str, Any] = response.json()
|
|
||||||
return result
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
|
||||||
return None
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
console.print(f"[red]Request error: {e}[/red]")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _build_search_params(
|
def _build_search_params(
|
||||||
|
|||||||
+238
-1
@@ -29,6 +29,7 @@ from tensors.config import (
|
|||||||
load_config,
|
load_config,
|
||||||
save_config,
|
save_config,
|
||||||
)
|
)
|
||||||
|
from tensors.db import DB_PATH, Database
|
||||||
from tensors.display import (
|
from tensors.display import (
|
||||||
_format_size,
|
_format_size,
|
||||||
display_civitai_data,
|
display_civitai_data,
|
||||||
@@ -511,12 +512,248 @@ def serve(
|
|||||||
uvicorn.run(create_app(config), host=host, port=port, log_level=log_level)
|
uvicorn.run(create_app(config), host=host, port=port, log_level=log_level)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Database Commands
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
db_app = typer.Typer(
|
||||||
|
name="db",
|
||||||
|
help="Manage local models database and CivitAI cache.",
|
||||||
|
no_args_is_help=True,
|
||||||
|
)
|
||||||
|
app.add_typer(db_app, name="db")
|
||||||
|
|
||||||
|
|
||||||
|
@db_app.command("scan")
|
||||||
|
def db_scan(
|
||||||
|
directory: Annotated[Path, typer.Argument(help="Directory to scan for safetensor files")],
|
||||||
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
|
) -> None:
|
||||||
|
"""Scan directory for safetensor files and add to database."""
|
||||||
|
path = directory.resolve()
|
||||||
|
if not path.exists() or not path.is_dir():
|
||||||
|
console.print(f"[red]Error: Directory not found: {path}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
with Database() as db:
|
||||||
|
db.init_schema()
|
||||||
|
console.print(f"[cyan]Scanning {path}...[/cyan]")
|
||||||
|
results = db.scan_directory(path, console if not json_output else None)
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
console.print_json(data=results)
|
||||||
|
else:
|
||||||
|
console.print(f"[green]Scanned {len(results)} file(s)[/green]")
|
||||||
|
for f in results:
|
||||||
|
console.print(f" • {f['file_path']}")
|
||||||
|
|
||||||
|
|
||||||
|
@db_app.command("link")
|
||||||
|
def db_link(
|
||||||
|
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
||||||
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
|
) -> None:
|
||||||
|
"""Link unlinked local files to CivitAI by hash lookup."""
|
||||||
|
key = api_key or load_api_key()
|
||||||
|
|
||||||
|
with Database() as db:
|
||||||
|
db.init_schema()
|
||||||
|
unlinked = db.get_unlinked_files()
|
||||||
|
|
||||||
|
if not unlinked:
|
||||||
|
console.print("[green]All files already linked.[/green]")
|
||||||
|
return
|
||||||
|
|
||||||
|
console.print(f"[cyan]Found {len(unlinked)} unlinked file(s)[/cyan]")
|
||||||
|
linked: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for file_info in unlinked:
|
||||||
|
sha256 = file_info["sha256"]
|
||||||
|
console.print(f"[dim]Looking up {sha256[:16]}...[/dim]")
|
||||||
|
|
||||||
|
civitai_data = fetch_civitai_by_hash(sha256, key, console if not json_output else None)
|
||||||
|
if civitai_data:
|
||||||
|
version_id: int = civitai_data.get("id", 0)
|
||||||
|
model_id: int = civitai_data.get("modelId", 0)
|
||||||
|
if version_id and model_id:
|
||||||
|
db.link_file_to_civitai(file_info["id"], model_id, version_id)
|
||||||
|
linked.append(
|
||||||
|
{
|
||||||
|
"file": file_info["file_path"],
|
||||||
|
"model_id": model_id,
|
||||||
|
"version_id": version_id,
|
||||||
|
"name": civitai_data.get("name", ""),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if not json_output:
|
||||||
|
console.print(f" [green]✓[/green] {civitai_data.get('name', 'N/A')}")
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
console.print_json(data=linked)
|
||||||
|
else:
|
||||||
|
console.print(f"[green]Linked {len(linked)} file(s)[/green]")
|
||||||
|
|
||||||
|
|
||||||
|
@db_app.command("cache")
|
||||||
|
def db_cache(
|
||||||
|
model_id: Annotated[int, typer.Argument(help="CivitAI model ID to cache")],
|
||||||
|
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
||||||
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
|
) -> None:
|
||||||
|
"""Fetch and cache full CivitAI model data."""
|
||||||
|
key = api_key or load_api_key()
|
||||||
|
|
||||||
|
model_data = fetch_civitai_model(model_id, key, console if not json_output else None)
|
||||||
|
if not model_data:
|
||||||
|
console.print(f"[red]Error: Model {model_id} not found on CivitAI.[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
with Database() as db:
|
||||||
|
db.init_schema()
|
||||||
|
internal_id = db.cache_model(model_data)
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
console.print_json(data={"model_id": model_id, "internal_id": internal_id, "name": model_data.get("name")})
|
||||||
|
else:
|
||||||
|
console.print(f"[green]Cached:[/green] {model_data.get('name')} (internal ID: {internal_id})")
|
||||||
|
|
||||||
|
|
||||||
|
@db_app.command("list")
|
||||||
|
def db_list(
|
||||||
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
|
) -> None:
|
||||||
|
"""List local files with CivitAI info."""
|
||||||
|
with Database() as db:
|
||||||
|
db.init_schema()
|
||||||
|
files = db.list_local_files()
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
console.print_json(data=files)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not files:
|
||||||
|
console.print("[yellow]No files in database. Run 'tsr db scan' first.[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
table = Table(title="Local Files", show_header=True, header_style="bold magenta")
|
||||||
|
table.add_column("Path", style="cyan", max_width=50)
|
||||||
|
table.add_column("Model", style="green")
|
||||||
|
table.add_column("Version", style="white")
|
||||||
|
table.add_column("Type", style="yellow")
|
||||||
|
table.add_column("Base", style="dim")
|
||||||
|
|
||||||
|
for f in files:
|
||||||
|
path = Path(f["file_path"]).name
|
||||||
|
model = f.get("model_name") or "[dim]unlinked[/dim]"
|
||||||
|
version = f.get("version_name") or ""
|
||||||
|
model_type = f.get("model_type") or ""
|
||||||
|
base = f.get("base_model") or ""
|
||||||
|
table.add_row(path, model, version, model_type, base)
|
||||||
|
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
@db_app.command("search")
|
||||||
|
def db_search(
|
||||||
|
query: Annotated[str | None, typer.Argument(help="Search query")] = None,
|
||||||
|
model_type: Annotated[str | None, typer.Option("-t", "--type", help="Model type filter")] = None,
|
||||||
|
base_model: Annotated[str | None, typer.Option("-b", "--base", help="Base model filter")] = None,
|
||||||
|
limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 20,
|
||||||
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
|
) -> None:
|
||||||
|
"""Search cached models offline."""
|
||||||
|
with Database() as db:
|
||||||
|
db.init_schema()
|
||||||
|
results = db.search_models(query=query, model_type=model_type, base_model=base_model, limit=limit)
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
console.print_json(data=results)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
console.print("[yellow]No models found.[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
table = Table(title="Cached Models", show_header=True, header_style="bold magenta")
|
||||||
|
table.add_column("ID", style="dim")
|
||||||
|
table.add_column("Name", style="cyan")
|
||||||
|
table.add_column("Type", style="yellow")
|
||||||
|
table.add_column("Base", style="green")
|
||||||
|
table.add_column("Creator", style="dim")
|
||||||
|
table.add_column("Downloads", justify="right")
|
||||||
|
|
||||||
|
for m in results:
|
||||||
|
table.add_row(
|
||||||
|
str(m.get("civitai_id", "")),
|
||||||
|
m.get("name", ""),
|
||||||
|
m.get("type", ""),
|
||||||
|
m.get("base_model", ""),
|
||||||
|
m.get("creator", ""),
|
||||||
|
str(m.get("download_count", 0)),
|
||||||
|
)
|
||||||
|
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
@db_app.command("triggers")
|
||||||
|
def db_triggers(
|
||||||
|
file: Annotated[Path, typer.Argument(help="Path to safetensor file")],
|
||||||
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
|
) -> None:
|
||||||
|
"""Show trigger words for a LoRA file."""
|
||||||
|
file_path = file.resolve()
|
||||||
|
if not file_path.exists():
|
||||||
|
console.print(f"[red]Error: File not found: {file_path}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
with Database() as db:
|
||||||
|
db.init_schema()
|
||||||
|
triggers = db.get_triggers(str(file_path))
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
console.print_json(data=triggers)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not triggers:
|
||||||
|
console.print("[yellow]No trigger words found. File may not be linked to CivitAI.[/yellow]")
|
||||||
|
console.print("[dim]Run 'tsr db link' to link files to CivitAI.[/dim]")
|
||||||
|
return
|
||||||
|
|
||||||
|
console.print(f"[bold]Trigger words for {file_path.name}:[/bold]")
|
||||||
|
for word in triggers:
|
||||||
|
console.print(f" • {word}")
|
||||||
|
|
||||||
|
|
||||||
|
@db_app.command("stats")
|
||||||
|
def db_stats(
|
||||||
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
|
) -> None:
|
||||||
|
"""Show database statistics."""
|
||||||
|
with Database() as db:
|
||||||
|
db.init_schema()
|
||||||
|
stats = db.get_stats()
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
console.print_json(data={"db_path": str(DB_PATH), "stats": stats})
|
||||||
|
return
|
||||||
|
|
||||||
|
table = Table(title="Database Statistics", show_header=True, header_style="bold magenta")
|
||||||
|
table.add_column("Table", style="cyan")
|
||||||
|
table.add_column("Count", style="green", justify="right")
|
||||||
|
|
||||||
|
for table_name, count in stats.items():
|
||||||
|
table.add_row(table_name, str(count))
|
||||||
|
|
||||||
|
console.print(f"[dim]Database: {DB_PATH}[/dim]")
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
def main() -> int:
|
||||||
"""Main entry point."""
|
"""Main entry point."""
|
||||||
# Handle legacy invocation: tsr <file.safetensors> -> tsr info <file>
|
# Handle legacy invocation: tsr <file.safetensors> -> tsr info <file>
|
||||||
if len(sys.argv) > 1 and not sys.argv[1].startswith("-"):
|
if len(sys.argv) > 1 and not sys.argv[1].startswith("-"):
|
||||||
arg = sys.argv[1]
|
arg = sys.argv[1]
|
||||||
if arg not in ("info", "search", "get", "dl", "download", "config", "generate", "serve", "status", "reload") and (
|
if arg not in ("info", "search", "get", "dl", "download", "config", "generate", "serve", "status", "reload", "db") and (
|
||||||
arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists()
|
arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists()
|
||||||
):
|
):
|
||||||
sys.argv = [sys.argv[0], "info", *sys.argv[1:]]
|
sys.argv = [sys.argv[0], "info", *sys.argv[1:]]
|
||||||
|
|||||||
Reference in New Issue
Block a user