Files
tensors/tensors/api.py
T
Adam Ladachowski a92c9fb83a Phase 2.2: Add tsr db CLI commands
Add database management commands to CLI:
- tsr db scan <directory> - Scan safetensors, compute hashes, store metadata
- tsr db link - Match unlinked files to CivitAI by hash lookup
- tsr db cache <model_id> - Fetch and cache full CivitAI model data
- tsr db list - List local files with CivitAI info
- tsr db search - Search cached models offline
- tsr db triggers <file> - Show trigger words for a LoRA
- tsr db stats - Show database statistics

Update API functions to accept optional console for quiet/batch operations.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-14 01:33:56 +01:00

305 lines
9.9 KiB
Python

"""CivitAI API functions."""
from __future__ import annotations
import re
from http import HTTPStatus
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from pathlib import Path
import httpx
from rich.progress import (
BarColumn,
DownloadColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeRemainingColumn,
TransferSpeedColumn,
)
from tensors.config import CIVITAI_API_BASE, CIVITAI_DOWNLOAD_BASE, BaseModel, ModelType, SortOrder
if TYPE_CHECKING:
from rich.console import Console
def _get_headers(api_key: str | None) -> dict[str, str]:
"""Get headers for CivitAI API requests."""
headers: dict[str, str] = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def fetch_civitai_model_version(version_id: int, api_key: str | None, console: Console | None = None) -> dict[str, Any] | None:
"""Fetch model version information from CivitAI by version ID."""
url = f"{CIVITAI_API_BASE}/model-versions/{version_id}"
try:
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
if response.status_code == HTTPStatus.NOT_FOUND:
return None
response.raise_for_status()
result: dict[str, Any] = response.json()
return result
except httpx.HTTPStatusError as e:
if console:
console.print(f"[red]API error: {e.response.status_code}[/red]")
return None
except httpx.RequestError as e:
if console:
console.print(f"[red]Request error: {e}[/red]")
return None
def fetch_civitai_model(model_id: int, api_key: str | None, console: Console | None = None) -> dict[str, Any] | None:
"""Fetch model information from CivitAI by model ID."""
url = f"{CIVITAI_API_BASE}/models/{model_id}"
def _do_fetch() -> dict[str, Any] | None:
try:
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
if response.status_code == HTTPStatus.NOT_FOUND:
return None
response.raise_for_status()
result: dict[str, Any] = response.json()
return result
except httpx.HTTPStatusError as e:
if console:
console.print(f"[red]API error: {e.response.status_code}[/red]")
return None
except httpx.RequestError as e:
if console:
console.print(f"[red]Request error: {e}[/red]")
return None
if console:
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console,
transient=True,
) as progress:
progress.add_task("[cyan]Fetching model from CivitAI...", total=None)
return _do_fetch()
else:
return _do_fetch()
def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None, console: Console | None = None) -> dict[str, Any] | None:
"""Fetch model information from CivitAI by SHA256 hash."""
url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}"
def _do_fetch() -> dict[str, Any] | None:
try:
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
if response.status_code == HTTPStatus.NOT_FOUND:
return None
response.raise_for_status()
result: dict[str, Any] = response.json()
return result
except httpx.HTTPStatusError as e:
if console:
console.print(f"[red]API error: {e.response.status_code}[/red]")
return None
except httpx.RequestError as e:
if console:
console.print(f"[red]Request error: {e}[/red]")
return None
if console:
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console,
transient=True,
) as progress:
progress.add_task("[cyan]Fetching from CivitAI...", total=None)
return _do_fetch()
else:
return _do_fetch()
def _build_search_params(
query: str | None,
model_type: ModelType | None,
base_model: BaseModel | None,
sort: SortOrder,
limit: int,
) -> tuple[dict[str, Any], bool]:
"""Build search parameters and return (params, has_filters)."""
params: dict[str, Any] = {
"limit": min(limit, 100),
"nsfw": "true",
}
# API quirk: query + filters don't work reliably together
has_filters = model_type is not None or base_model is not None
if query and not has_filters:
params["query"] = query
if model_type:
params["types"] = model_type.to_api()
if base_model:
params["baseModels"] = base_model.to_api()
params["sort"] = sort.to_api()
# Request more if we need client-side filtering
if query and has_filters:
params["limit"] = 100
return params, has_filters
def _filter_results(result: dict[str, Any], query: str | None, has_filters: bool, limit: int) -> dict[str, Any]:
"""Apply client-side filtering when query + filters combined."""
if query and has_filters:
q_lower = query.lower()
result["items"] = [m for m in result.get("items", []) if q_lower in m.get("name", "").lower()][:limit]
return result
def search_civitai(
query: str | None,
model_type: ModelType | None,
base_model: BaseModel | None,
sort: SortOrder,
limit: int,
api_key: str | None,
console: Console,
) -> dict[str, Any] | None:
"""Search CivitAI models."""
params, has_filters = _build_search_params(query, model_type, base_model, sort, limit)
url = f"{CIVITAI_API_BASE}/models"
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console,
transient=True,
) as progress:
progress.add_task("[cyan]Searching CivitAI...", total=None)
try:
response = httpx.get(url, params=params, headers=_get_headers(api_key), timeout=30.0)
response.raise_for_status()
result: dict[str, Any] = response.json()
return _filter_results(result, query, has_filters, limit)
except httpx.HTTPStatusError as e:
console.print(f"[red]API error: {e.response.status_code}[/red]")
return None
except httpx.RequestError as e:
console.print(f"[red]Request error: {e}[/red]")
return None
def _setup_resume(dest_path: Path, resume: bool, console: Console) -> tuple[dict[str, str], str, int]:
"""Set up resume headers and mode for download."""
headers: dict[str, str] = {}
mode = "wb"
initial_size = 0
if resume and dest_path.exists():
initial_size = dest_path.stat().st_size
headers["Range"] = f"bytes={initial_size}-"
mode = "ab"
console.print(f"[cyan]Resuming download from {initial_size / (1024**2):.1f} MB[/cyan]")
return headers, mode, initial_size
def _get_dest_from_response(response: httpx.Response, dest_path: Path) -> Path:
"""Extract destination path from response headers if dest is directory."""
content_disp = response.headers.get("content-disposition", "")
if "filename=" in content_disp:
match = re.search(r'filename="?([^";\n]+)"?', content_disp)
if match and dest_path.is_dir():
return dest_path / match.group(1)
return dest_path
def _stream_download(
response: httpx.Response,
dest_path: Path,
mode: str,
initial_size: int,
console: Console,
) -> bool:
"""Stream download content to file with progress."""
content_length = response.headers.get("content-length")
total_size = int(content_length) + initial_size if content_length else 0
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
DownloadColumn(),
TransferSpeedColumn(),
TimeRemainingColumn(),
console=console,
) as progress:
task = progress.add_task(
f"[cyan]Downloading {dest_path.name}...",
total=total_size if total_size > 0 else None,
completed=initial_size,
)
with dest_path.open(mode) as f:
for chunk in response.iter_bytes(1024 * 1024):
f.write(chunk)
progress.update(task, advance=len(chunk))
console.print()
console.print(f'[magenta]Downloaded:[/magenta] [green]"{dest_path}"[/green]')
return True
def download_model(
version_id: int,
dest_path: Path,
api_key: str | None,
console: Console,
resume: bool = True,
) -> bool:
"""Download a model from CivitAI by version ID with resume support."""
url = f"{CIVITAI_DOWNLOAD_BASE}/{version_id}"
params: dict[str, str] = {}
if api_key:
params["token"] = api_key
headers, mode, initial_size = _setup_resume(dest_path, resume, console)
try:
with httpx.stream(
"GET",
url,
params=params,
headers=headers,
follow_redirects=True,
timeout=httpx.Timeout(30.0, read=None),
) as response:
if response.status_code == HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE:
console.print("[green]File already fully downloaded.[/green]")
return True
response.raise_for_status()
dest_path = _get_dest_from_response(response, dest_path)
return _stream_download(response, dest_path, mode, initial_size, console)
except httpx.HTTPStatusError as e:
console.print(f"[red]Download error: HTTP {e.response.status_code}[/red]")
if e.response.status_code == HTTPStatus.UNAUTHORIZED:
console.print("[yellow]Hint: This model may require an API key.[/yellow]")
return False
except httpx.RequestError as e:
console.print(f"[red]Download error: {e}[/red]")
return False