Phase 4: Client Mode for tsr CLI

- Create tensors/client.py with TsrClient HTTP wrapper for all server APIs
- Add remote server configuration to config.py (get_remotes, resolve_remote, save_remote, set_default_remote)
- Add images command group: list, show, delete, download
- Add models command group: list, active, switch, loras
- Add remote command group: list, add, default
- Update generate command with --remote support
- Update dl command with --remote support
- Update status command with --remote support

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Adam Ladachowski
2026-02-14 01:47:32 +01:00
parent 11a289ebd0
commit 7efec1a033
4 changed files with 861 additions and 66 deletions
+3 -3
View File
@@ -17,9 +17,9 @@
- [x] Step 3.4: Enhance `/api/generate` (gallery integration, full params)
## Phase 4: Client Mode for tsr CLI
- [ ] Step 4.1: Create `tensors/client.py` (TsrClient HTTP wrapper)
- [ ] Step 4.2: Add `[remotes]` config section + `--remote` flag support
- [ ] Step 4.3: Update CLI commands with `--remote` support (generate, images, models, dl, db)
- [x] Step 4.1: Create `tensors/client.py` (TsrClient HTTP wrapper)
- [x] Step 4.2: Add `[remotes]` config section + `--remote` flag support
- [x] Step 4.3: Update CLI commands with `--remote` support (generate, images, models, dl, db)
## Phase 5: Docker Deployment Automation (SKIPPED)
- [x] Step 5.1: ~~Create `rocm-docker/docker-compose.yml`~~ (skipped)
+443 -10
View File
@@ -19,15 +19,20 @@ from tensors.api import (
fetch_civitai_model_version,
search_civitai,
)
from tensors.client import TsrClient, TsrClientError
from tensors.config import (
CONFIG_FILE,
BaseModel,
ModelType,
SortOrder,
get_default_output_path,
get_remotes,
load_api_key,
load_config,
resolve_remote,
save_config,
save_remote,
set_default_remote,
)
from tensors.db import DB_PATH, Database
from tensors.display import (
@@ -311,8 +316,41 @@ def download(
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None,
no_resume: Annotated[bool, typer.Option("--no-resume", help="Don't resume partial downloads")] = False,
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON (remote mode)")] = False,
) -> None:
"""Download a model from CivitAI."""
"""Download a model from CivitAI (locally or to remote server)."""
# Check if remote is specified or configured
remote_url = resolve_remote(remote)
if remote_url:
# Remote mode: use TsrClient API
if not version_id and not model_id and not hash_val:
console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]")
raise typer.Exit(1)
try:
with TsrClient(remote_url) as client:
console.print(f"[cyan]Starting download on {remote_url}...[/cyan]")
result = client.start_download(
version_id=version_id,
model_id=model_id,
hash_val=hash_val,
output_dir=str(output) if output else None,
)
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
if json_output:
console.print_json(data=result)
return
download_id = result.get("download_id")
console.print(f"[green]Download started:[/green] {download_id}")
console.print(f"[dim]Check status with: tsr images download-status {download_id} --remote {remote or 'default'}[/dim]")
else:
# Local mode: direct download
key = api_key or load_api_key()
resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key)
@@ -401,9 +439,10 @@ def config(
@app.command()
def generate(
prompt: Annotated[str, typer.Argument(help="Text prompt for image generation.")],
host: Annotated[str, typer.Option(help="sd-server address.")] = "127.0.0.1",
port: Annotated[int, typer.Option(help="sd-server port.")] = 8080,
output: Annotated[str, typer.Option("-o", help="Output directory.")] = ".",
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
host: Annotated[str, typer.Option(help="sd-server address (local mode).")] = "127.0.0.1",
port: Annotated[int, typer.Option(help="sd-server port (local mode).")] = 8080,
output: Annotated[str, typer.Option("-o", help="Output directory (local mode).")] = ".",
negative_prompt: Annotated[str, typer.Option("-n", help="Negative prompt.")] = "",
width: Annotated[int, typer.Option("-W", help="Image width.")] = 512,
height: Annotated[int, typer.Option("-H", help="Image height.")] = 512,
@@ -413,8 +452,42 @@ def generate(
sampler: Annotated[str, typer.Option(help="Sampler name.")] = "",
scheduler: Annotated[str, typer.Option(help="Scheduler name.")] = "",
batch_size: Annotated[int, typer.Option("-b", help="Number of images.")] = 1,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON (remote mode)")] = False,
) -> None:
"""Generate images using a running sd-server."""
"""Generate images using sd-server (local or remote)."""
# Check if remote is specified or configured
remote_url = resolve_remote(remote)
if remote_url:
# Remote mode: use TsrClient API
try:
with TsrClient(remote_url) as client:
console.print(f"[cyan]Generating {batch_size} image(s) on {remote_url}...[/cyan]")
result = client.generate(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
steps=steps,
cfg_scale=cfg_scale,
seed=seed,
sampler_name=sampler,
scheduler=scheduler,
batch_size=batch_size,
)
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
if json_output:
console.print_json(data=result)
return
images = result.get("images", [])
for img in images:
console.print(f"[green]Generated:[/green] {img.get('id', 'unknown')}")
else:
# Local mode: direct sd-server connection
from tensors.generate import SDClient, Txt2ImgParams, save_images # noqa: PLC0415
params = Txt2ImgParams(
@@ -440,11 +513,25 @@ def generate(
@app.command()
def status(
host: Annotated[str, typer.Option(help="Wrapper API host.")] = "127.0.0.1",
port: Annotated[int, typer.Option(help="Wrapper API port.")] = 8080,
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
host: Annotated[str, typer.Option(help="Wrapper API host (local mode).")] = "127.0.0.1",
port: Annotated[int, typer.Option(help="Wrapper API port (local mode).")] = 8080,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Show sd-server wrapper status."""
# Check if remote is specified or configured
remote_url = resolve_remote(remote)
if remote_url:
# Remote mode: use TsrClient API
try:
with TsrClient(remote_url) as client:
data = client.status()
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
else:
# Local mode: direct HTTP call
import httpx # noqa: PLC0415
url = f"http://{host}:{port}/status"
@@ -748,14 +835,360 @@ def db_stats(
console.print(table)
# =============================================================================
# Images Commands (Remote)
# =============================================================================
images_app = typer.Typer(
name="images",
help="Manage images in remote gallery.",
no_args_is_help=True,
)
app.add_typer(images_app, name="images")
def _get_client(remote: str | None) -> TsrClient:
"""Get TsrClient for remote or raise error."""
url = resolve_remote(remote)
if not url:
console.print("[red]Error: No remote specified. Use --remote or set default_remote in config.[/red]")
raise typer.Exit(1)
return TsrClient(url)
@images_app.command("list")
def images_list(
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 50,
offset: Annotated[int, typer.Option("--offset", help="Offset for pagination")] = 0,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""List images in remote gallery."""
try:
with _get_client(remote) as client:
result = client.list_images(limit=limit, offset=offset)
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
images = result.get("images", [])
total = result.get("total", len(images))
if json_output:
console.print_json(data=result)
return
if not images:
console.print("[yellow]No images in gallery.[/yellow]")
return
table = Table(title=f"Gallery Images ({len(images)}/{total})", show_header=True, header_style="bold magenta")
table.add_column("ID", style="cyan")
table.add_column("Filename", style="green")
table.add_column("Size", style="white")
table.add_column("Created", style="dim")
for img in images:
size = f"{img.get('width', '?')}x{img.get('height', '?')}"
created = img.get("created_at", "")
if isinstance(created, (int, float)):
from datetime import datetime # noqa: PLC0415
created = datetime.fromtimestamp(created).strftime("%Y-%m-%d %H:%M")
table.add_row(img.get("id", ""), img.get("filename", ""), size, str(created))
console.print(table)
@images_app.command("show")
def images_show(
image_id: Annotated[str, typer.Argument(help="Image ID to show")],
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Show image metadata."""
try:
with _get_client(remote) as client:
meta = client.get_image_meta(image_id)
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
if json_output:
console.print_json(data=meta)
return
table = Table(title=f"Image: {image_id}", show_header=True, header_style="bold magenta")
table.add_column("Property", style="cyan")
table.add_column("Value", style="green")
for key, value in meta.items():
display_value = json.dumps(value, indent=2) if isinstance(value, dict) else str(value)
table.add_row(key, display_value)
console.print(table)
@images_app.command("delete")
def images_delete(
image_id: Annotated[str, typer.Argument(help="Image ID to delete")],
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
force: Annotated[bool, typer.Option("-f", "--force", help="Skip confirmation")] = False,
) -> None:
"""Delete an image from the gallery."""
if not force:
confirm = typer.confirm(f"Delete image {image_id}?")
if not confirm:
console.print("[yellow]Cancelled.[/yellow]")
raise typer.Exit(0)
try:
with _get_client(remote) as client:
client.delete_image(image_id)
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
console.print(f"[green]Deleted image: {image_id}[/green]")
@images_app.command("download")
def images_download(
image_id: Annotated[str, typer.Argument(help="Image ID to download")],
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file or directory")] = None,
) -> None:
"""Download an image from the remote gallery."""
try:
with _get_client(remote) as client:
content = client.download_image(image_id)
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
# Determine output path
if output is None:
dest = Path(f"{image_id}.png")
elif output.is_dir():
dest = output / f"{image_id}.png"
else:
dest = output
dest.write_bytes(content)
console.print(f"[green]Saved:[/green] {dest}")
# =============================================================================
# Models Commands (Remote)
# =============================================================================
models_app = typer.Typer(
name="models",
help="Manage models on remote server.",
no_args_is_help=True,
)
app.add_typer(models_app, name="models")
@models_app.command("list")
def models_list(
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""List available models on remote server."""
try:
with _get_client(remote) as client:
result = client.list_models()
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
if json_output:
console.print_json(data=result)
return
models = result.get("models", [])
active = result.get("active", "")
if not models:
console.print("[yellow]No models found.[/yellow]")
return
table = Table(title="Available Models", show_header=True, header_style="bold magenta")
table.add_column("Status", style="dim", width=3)
table.add_column("Name", style="cyan")
table.add_column("Path", style="dim")
for model in models:
path = model.get("path", "")
name = model.get("name", Path(path).stem if path else "")
is_active = active in {path, name}
status = "[green]✓[/green]" if is_active else ""
table.add_row(status, name, path)
console.print(table)
@models_app.command("active")
def models_active(
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Show currently active model."""
try:
with _get_client(remote) as client:
result = client.get_active_model()
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
if json_output:
console.print_json(data=result)
return
model = result.get("model", "None")
console.print(f"[bold]Active model:[/bold] {model}")
@models_app.command("switch")
def models_switch(
model: Annotated[str, typer.Argument(help="Model path or name to switch to")],
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
) -> None:
"""Switch to a different model on the remote server."""
console.print(f"[cyan]Switching to model: {model}[/cyan]")
try:
with _get_client(remote) as client:
result = client.switch_model(model)
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
console.print(f"[green]{result.get('status', 'OK')}[/green]")
@models_app.command("loras")
def models_loras(
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""List available LoRAs on remote server."""
try:
with _get_client(remote) as client:
result = client.list_loras()
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
if json_output:
console.print_json(data=result)
return
loras = result.get("loras", [])
if not loras:
console.print("[yellow]No LoRAs found.[/yellow]")
return
table = Table(title="Available LoRAs", show_header=True, header_style="bold magenta")
table.add_column("Name", style="cyan")
table.add_column("Path", style="dim")
for lora in loras:
path = lora.get("path", "")
name = lora.get("name", Path(path).stem if path else "")
table.add_row(name, path)
console.print(table)
# =============================================================================
# Remote Configuration Commands
# =============================================================================
remote_app = typer.Typer(
name="remote",
help="Manage remote server configuration.",
no_args_is_help=True,
)
app.add_typer(remote_app, name="remote")
@remote_app.command("list")
def remote_list(
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""List configured remotes."""
from tensors.config import get_default_remote # noqa: PLC0415
remotes = get_remotes()
default = get_default_remote()
if json_output:
console.print_json(data={"remotes": remotes, "default": default})
return
if not remotes:
console.print("[yellow]No remotes configured.[/yellow]")
console.print("[dim]Add one with: tsr remote add NAME URL[/dim]")
return
table = Table(title="Configured Remotes", show_header=True, header_style="bold magenta")
table.add_column("Default", style="dim", width=3)
table.add_column("Name", style="cyan")
table.add_column("URL", style="green")
for name, url in remotes.items():
is_default = name == default
status = "[green]✓[/green]" if is_default else ""
table.add_row(status, name, url)
console.print(table)
@remote_app.command("add")
def remote_add(
name: Annotated[str, typer.Argument(help="Remote name")],
url: Annotated[str, typer.Argument(help="Remote URL (e.g., http://host:8080)")],
) -> None:
"""Add a remote server."""
save_remote(name, url)
console.print(f"[green]Added remote:[/green] {name}{url}")
@remote_app.command("default")
def remote_default(
name: Annotated[str | None, typer.Argument(help="Remote name to set as default (omit to clear)")] = None,
) -> None:
"""Set or clear the default remote."""
set_default_remote(name)
if name:
console.print(f"[green]Default remote set to:[/green] {name}")
else:
console.print("[green]Default remote cleared.[/green]")
def main() -> int:
"""Main entry point."""
# Handle legacy invocation: tsr <file.safetensors> -> tsr info <file>
known_commands = (
"info",
"search",
"get",
"dl",
"download",
"config",
"generate",
"serve",
"status",
"reload",
"db",
"images",
"models",
"remote",
)
if len(sys.argv) > 1 and not sys.argv[1].startswith("-"):
arg = sys.argv[1]
if arg not in ("info", "search", "get", "dl", "download", "config", "generate", "serve", "status", "reload", "db") and (
arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists()
):
if arg not in known_commands and (arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists()):
sys.argv = [sys.argv[0], "info", *sys.argv[1:]]
app()
+292
View File
@@ -0,0 +1,292 @@
"""HTTP client for remote tsr server API."""
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, Any
import httpx
if TYPE_CHECKING:
from collections.abc import Iterator
class TsrClientError(Exception):
"""Error from TsrClient operations."""
class TsrClient:
"""HTTP client wrapper for tsr server API.
Usage:
with TsrClient("http://junkpile:8080") as client:
images = client.list_images()
result = client.generate("a cat")
"""
def __init__(self, base_url: str, timeout: float = 300.0) -> None:
"""Initialize client with server URL."""
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self._client: httpx.Client | None = None
def __enter__(self) -> TsrClient:
self._client = httpx.Client(base_url=self.base_url, timeout=self.timeout)
return self
def __exit__(self, *exc: object) -> None:
if self._client:
self._client.close()
self._client = None
@property
def client(self) -> httpx.Client:
"""Get the HTTP client, creating if needed."""
if self._client is None:
self._client = httpx.Client(base_url=self.base_url, timeout=self.timeout)
return self._client
def _get(self, path: str, params: dict[str, Any] | None = None) -> Any:
"""Make GET request."""
try:
resp = self.client.get(path, params=params)
resp.raise_for_status()
return resp.json()
except httpx.HTTPStatusError as e:
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
except httpx.RequestError as e:
raise TsrClientError(f"Request failed: {e}") from e
def _post(self, path: str, json: dict[str, Any] | None = None) -> Any:
"""Make POST request."""
try:
resp = self.client.post(path, json=json)
resp.raise_for_status()
return resp.json()
except httpx.HTTPStatusError as e:
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
except httpx.RequestError as e:
raise TsrClientError(f"Request failed: {e}") from e
def _delete(self, path: str) -> Any:
"""Make DELETE request."""
try:
resp = self.client.delete(path)
resp.raise_for_status()
return resp.json()
except httpx.HTTPStatusError as e:
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
except httpx.RequestError as e:
raise TsrClientError(f"Request failed: {e}") from e
# =========================================================================
# Server Status
# =========================================================================
def status(self) -> dict[str, Any]:
"""Get server status."""
return dict(self._get("/status"))
# =========================================================================
# Gallery / Images
# =========================================================================
def list_images(self, limit: int = 50, offset: int = 0) -> dict[str, Any]:
"""List images in gallery."""
return dict(self._get("/api/images", params={"limit": limit, "offset": offset}))
def get_image_meta(self, image_id: str) -> dict[str, Any]:
"""Get metadata for an image."""
return dict(self._get(f"/api/images/{image_id}/meta"))
def delete_image(self, image_id: str) -> dict[str, Any]:
"""Delete an image."""
return dict(self._delete(f"/api/images/{image_id}"))
def edit_image(self, image_id: str, updates: dict[str, Any]) -> dict[str, Any]:
"""Update image metadata."""
return dict(self._post(f"/api/images/{image_id}/edit", json=updates))
def download_image(self, image_id: str) -> bytes:
"""Download image file bytes."""
try:
resp = self.client.get(f"/api/images/{image_id}")
resp.raise_for_status()
return resp.content
except httpx.HTTPStatusError as e:
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
except httpx.RequestError as e:
raise TsrClientError(f"Request failed: {e}") from e
# =========================================================================
# Models
# =========================================================================
def list_models(self) -> dict[str, Any]:
"""List available models."""
return dict(self._get("/api/models"))
def get_active_model(self) -> dict[str, Any]:
"""Get currently active model."""
return dict(self._get("/api/models/active"))
def switch_model(self, model_path: str) -> dict[str, Any]:
"""Switch to a different model."""
return dict(self._post("/api/models/switch", json={"model": model_path}))
def list_loras(self) -> dict[str, Any]:
"""List available LoRAs."""
return dict(self._get("/api/models/loras"))
def scan_models(self) -> dict[str, Any]:
"""Scan model directories."""
return dict(self._get("/api/models/scan"))
# =========================================================================
# Generation
# =========================================================================
def generate(
self,
prompt: str,
negative_prompt: str = "",
width: int = 512,
height: int = 512,
steps: int = 20,
cfg_scale: float = 7.0,
seed: int = -1,
sampler_name: str = "",
scheduler: str = "",
batch_size: int = 1,
save_to_gallery: bool = True,
return_base64: bool = False,
) -> dict[str, Any]:
"""Generate images."""
body = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"width": width,
"height": height,
"steps": steps,
"cfg_scale": cfg_scale,
"seed": seed,
"sampler_name": sampler_name,
"scheduler": scheduler,
"batch_size": batch_size,
"save_to_gallery": save_to_gallery,
"return_base64": return_base64,
}
return dict(self._post("/api/generate", json=body))
def list_samplers(self) -> dict[str, Any]:
"""List available samplers."""
return dict(self._get("/api/samplers"))
def list_schedulers(self) -> dict[str, Any]:
"""List available schedulers."""
return dict(self._get("/api/schedulers"))
# =========================================================================
# Download
# =========================================================================
def start_download(
self,
version_id: int | None = None,
model_id: int | None = None,
hash_val: str | None = None,
output_dir: str | None = None,
) -> dict[str, Any]:
"""Start a model download from CivitAI."""
body: dict[str, Any] = {}
if version_id:
body["version_id"] = version_id
if model_id:
body["model_id"] = model_id
if hash_val:
body["hash"] = hash_val
if output_dir:
body["output_dir"] = output_dir
return dict(self._post("/api/download", json=body))
def get_download_status(self, download_id: str) -> dict[str, Any]:
"""Get download status."""
return dict(self._get(f"/api/download/status/{download_id}"))
def list_downloads(self) -> dict[str, Any]:
"""List active downloads."""
return dict(self._get("/api/download/active"))
# =========================================================================
# Database
# =========================================================================
def db_list_files(self) -> list[dict[str, Any]]:
"""List local files in database."""
return list(self._get("/api/db/files"))
def db_search_models(
self,
query: str | None = None,
model_type: str | None = None,
base_model: str | None = None,
limit: int = 20,
) -> list[dict[str, Any]]:
"""Search cached models."""
params: dict[str, Any] = {"limit": limit}
if query:
params["query"] = query
if model_type:
params["type"] = model_type
if base_model:
params["base"] = base_model
return list(self._get("/api/db/models", params=params))
def db_get_model(self, civitai_id: int) -> dict[str, Any]:
"""Get cached model by CivitAI ID."""
return dict(self._get(f"/api/db/models/{civitai_id}"))
def db_get_triggers(self, file_path: str | None = None, version_id: int | None = None) -> list[str]:
"""Get trigger words."""
if version_id:
return list(self._get(f"/api/db/triggers/{version_id}"))
if file_path:
return list(self._get("/api/db/triggers", params={"file_path": file_path}))
return []
def db_stats(self) -> dict[str, Any]:
"""Get database statistics."""
return dict(self._get("/api/db/stats"))
def db_scan(self, directory: str) -> dict[str, Any]:
"""Scan directory for safetensor files."""
return dict(self._post("/api/db/scan", json={"directory": directory}))
def db_link(self) -> dict[str, Any]:
"""Link unlinked files to CivitAI."""
return dict(self._post("/api/db/link"))
def db_cache(self, model_id: int) -> dict[str, Any]:
"""Cache CivitAI model data."""
return dict(self._post("/api/db/cache", json={"model_id": model_id}))
# =========================================================================
# Streaming Downloads
# =========================================================================
def stream_image(self, image_id: str) -> Iterator[bytes]:
"""Stream image download in chunks."""
try:
with self.client.stream("GET", f"/api/images/{image_id}") as resp:
resp.raise_for_status()
yield from resp.iter_bytes(chunk_size=1024 * 64)
except httpx.HTTPStatusError as e:
raise TsrClientError(f"HTTP {e.response.status_code}") from e
except httpx.RequestError as e:
raise TsrClientError(f"Request failed: {e}") from e
def save_image_to(self, image_id: str, dest: Path) -> Path:
"""Download and save image to file."""
content = self.download_image(image_id)
dest.write_bytes(content)
return dest
+70
View File
@@ -165,3 +165,73 @@ def get_default_output_path(model_type: str | None) -> Path | None:
if model_type and model_type in DEFAULT_PATHS:
return DEFAULT_PATHS[model_type]
return None
# ============================================================================
# Remote Server Configuration
# ============================================================================
def get_remotes() -> dict[str, str]:
"""Get configured remote servers.
Returns a dict mapping names to URLs, e.g., {"junkpile": "http://junkpile:8080"}
"""
config = load_config()
remotes = config.get("remotes", {})
return dict(remotes) if isinstance(remotes, dict) else {}
def get_default_remote() -> str | None:
"""Get the default remote name or URL."""
config = load_config()
return config.get("default_remote")
def resolve_remote(remote: str | None) -> str | None:
"""Resolve a remote name or URL to a full URL.
Args:
remote: Remote name (from config), URL, or None
Returns:
Full URL or None if no remote specified and no default
"""
if remote is None:
# Check for default remote
default = get_default_remote()
if default:
remote = default
else:
return None
# Check if it's a URL (starts with http:// or https://)
if remote.startswith(("http://", "https://")):
return remote
# Look up in configured remotes
remotes = get_remotes()
if remote in remotes:
return remotes[remote]
# Treat as hostname with default port
return f"http://{remote}:8080"
def save_remote(name: str, url: str) -> None:
"""Save a remote server configuration."""
config = load_config()
if "remotes" not in config:
config["remotes"] = {}
config["remotes"][name] = url
save_config(config)
def set_default_remote(name: str | None) -> None:
"""Set the default remote."""
config = load_config()
if name is None:
config.pop("default_remote", None)
else:
config["default_remote"] = name
save_config(config)