Phase 4: Client Mode for tsr CLI
- Create tensors/client.py with TsrClient HTTP wrapper for all server APIs - Add remote server configuration to config.py (get_remotes, resolve_remote, save_remote, set_default_remote) - Add images command group: list, show, delete, download - Add models command group: list, active, switch, loras - Add remote command group: list, add, default - Update generate command with --remote support - Update dl command with --remote support - Update status command with --remote support Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -17,9 +17,9 @@
|
||||
- [x] Step 3.4: Enhance `/api/generate` (gallery integration, full params)
|
||||
|
||||
## Phase 4: Client Mode for tsr CLI
|
||||
- [ ] Step 4.1: Create `tensors/client.py` (TsrClient HTTP wrapper)
|
||||
- [ ] Step 4.2: Add `[remotes]` config section + `--remote` flag support
|
||||
- [ ] Step 4.3: Update CLI commands with `--remote` support (generate, images, models, dl, db)
|
||||
- [x] Step 4.1: Create `tensors/client.py` (TsrClient HTTP wrapper)
|
||||
- [x] Step 4.2: Add `[remotes]` config section + `--remote` flag support
|
||||
- [x] Step 4.3: Update CLI commands with `--remote` support (generate, images, models, dl, db)
|
||||
|
||||
## Phase 5: Docker Deployment Automation (SKIPPED)
|
||||
- [x] Step 5.1: ~~Create `rocm-docker/docker-compose.yml`~~ (skipped)
|
||||
|
||||
+443
-10
@@ -19,15 +19,20 @@ from tensors.api import (
|
||||
fetch_civitai_model_version,
|
||||
search_civitai,
|
||||
)
|
||||
from tensors.client import TsrClient, TsrClientError
|
||||
from tensors.config import (
|
||||
CONFIG_FILE,
|
||||
BaseModel,
|
||||
ModelType,
|
||||
SortOrder,
|
||||
get_default_output_path,
|
||||
get_remotes,
|
||||
load_api_key,
|
||||
load_config,
|
||||
resolve_remote,
|
||||
save_config,
|
||||
save_remote,
|
||||
set_default_remote,
|
||||
)
|
||||
from tensors.db import DB_PATH, Database
|
||||
from tensors.display import (
|
||||
@@ -311,8 +316,41 @@ def download(
|
||||
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None,
|
||||
no_resume: Annotated[bool, typer.Option("--no-resume", help="Don't resume partial downloads")] = False,
|
||||
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON (remote mode)")] = False,
|
||||
) -> None:
|
||||
"""Download a model from CivitAI."""
|
||||
"""Download a model from CivitAI (locally or to remote server)."""
|
||||
# Check if remote is specified or configured
|
||||
remote_url = resolve_remote(remote)
|
||||
|
||||
if remote_url:
|
||||
# Remote mode: use TsrClient API
|
||||
if not version_id and not model_id and not hash_val:
|
||||
console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
try:
|
||||
with TsrClient(remote_url) as client:
|
||||
console.print(f"[cyan]Starting download on {remote_url}...[/cyan]")
|
||||
result = client.start_download(
|
||||
version_id=version_id,
|
||||
model_id=model_id,
|
||||
hash_val=hash_val,
|
||||
output_dir=str(output) if output else None,
|
||||
)
|
||||
except TsrClientError as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
if json_output:
|
||||
console.print_json(data=result)
|
||||
return
|
||||
|
||||
download_id = result.get("download_id")
|
||||
console.print(f"[green]Download started:[/green] {download_id}")
|
||||
console.print(f"[dim]Check status with: tsr images download-status {download_id} --remote {remote or 'default'}[/dim]")
|
||||
else:
|
||||
# Local mode: direct download
|
||||
key = api_key or load_api_key()
|
||||
|
||||
resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key)
|
||||
@@ -401,9 +439,10 @@ def config(
|
||||
@app.command()
|
||||
def generate(
|
||||
prompt: Annotated[str, typer.Argument(help="Text prompt for image generation.")],
|
||||
host: Annotated[str, typer.Option(help="sd-server address.")] = "127.0.0.1",
|
||||
port: Annotated[int, typer.Option(help="sd-server port.")] = 8080,
|
||||
output: Annotated[str, typer.Option("-o", help="Output directory.")] = ".",
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
host: Annotated[str, typer.Option(help="sd-server address (local mode).")] = "127.0.0.1",
|
||||
port: Annotated[int, typer.Option(help="sd-server port (local mode).")] = 8080,
|
||||
output: Annotated[str, typer.Option("-o", help="Output directory (local mode).")] = ".",
|
||||
negative_prompt: Annotated[str, typer.Option("-n", help="Negative prompt.")] = "",
|
||||
width: Annotated[int, typer.Option("-W", help="Image width.")] = 512,
|
||||
height: Annotated[int, typer.Option("-H", help="Image height.")] = 512,
|
||||
@@ -413,8 +452,42 @@ def generate(
|
||||
sampler: Annotated[str, typer.Option(help="Sampler name.")] = "",
|
||||
scheduler: Annotated[str, typer.Option(help="Scheduler name.")] = "",
|
||||
batch_size: Annotated[int, typer.Option("-b", help="Number of images.")] = 1,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON (remote mode)")] = False,
|
||||
) -> None:
|
||||
"""Generate images using a running sd-server."""
|
||||
"""Generate images using sd-server (local or remote)."""
|
||||
# Check if remote is specified or configured
|
||||
remote_url = resolve_remote(remote)
|
||||
|
||||
if remote_url:
|
||||
# Remote mode: use TsrClient API
|
||||
try:
|
||||
with TsrClient(remote_url) as client:
|
||||
console.print(f"[cyan]Generating {batch_size} image(s) on {remote_url}...[/cyan]")
|
||||
result = client.generate(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
steps=steps,
|
||||
cfg_scale=cfg_scale,
|
||||
seed=seed,
|
||||
sampler_name=sampler,
|
||||
scheduler=scheduler,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
except TsrClientError as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
if json_output:
|
||||
console.print_json(data=result)
|
||||
return
|
||||
|
||||
images = result.get("images", [])
|
||||
for img in images:
|
||||
console.print(f"[green]Generated:[/green] {img.get('id', 'unknown')}")
|
||||
else:
|
||||
# Local mode: direct sd-server connection
|
||||
from tensors.generate import SDClient, Txt2ImgParams, save_images # noqa: PLC0415
|
||||
|
||||
params = Txt2ImgParams(
|
||||
@@ -440,11 +513,25 @@ def generate(
|
||||
|
||||
@app.command()
|
||||
def status(
|
||||
host: Annotated[str, typer.Option(help="Wrapper API host.")] = "127.0.0.1",
|
||||
port: Annotated[int, typer.Option(help="Wrapper API port.")] = 8080,
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
host: Annotated[str, typer.Option(help="Wrapper API host (local mode).")] = "127.0.0.1",
|
||||
port: Annotated[int, typer.Option(help="Wrapper API port (local mode).")] = 8080,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
) -> None:
|
||||
"""Show sd-server wrapper status."""
|
||||
# Check if remote is specified or configured
|
||||
remote_url = resolve_remote(remote)
|
||||
|
||||
if remote_url:
|
||||
# Remote mode: use TsrClient API
|
||||
try:
|
||||
with TsrClient(remote_url) as client:
|
||||
data = client.status()
|
||||
except TsrClientError as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
else:
|
||||
# Local mode: direct HTTP call
|
||||
import httpx # noqa: PLC0415
|
||||
|
||||
url = f"http://{host}:{port}/status"
|
||||
@@ -748,14 +835,360 @@ def db_stats(
|
||||
console.print(table)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Images Commands (Remote)
|
||||
# =============================================================================
|
||||
|
||||
images_app = typer.Typer(
|
||||
name="images",
|
||||
help="Manage images in remote gallery.",
|
||||
no_args_is_help=True,
|
||||
)
|
||||
app.add_typer(images_app, name="images")
|
||||
|
||||
|
||||
def _get_client(remote: str | None) -> TsrClient:
|
||||
"""Get TsrClient for remote or raise error."""
|
||||
url = resolve_remote(remote)
|
||||
if not url:
|
||||
console.print("[red]Error: No remote specified. Use --remote or set default_remote in config.[/red]")
|
||||
raise typer.Exit(1)
|
||||
return TsrClient(url)
|
||||
|
||||
|
||||
@images_app.command("list")
|
||||
def images_list(
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 50,
|
||||
offset: Annotated[int, typer.Option("--offset", help="Offset for pagination")] = 0,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
) -> None:
|
||||
"""List images in remote gallery."""
|
||||
try:
|
||||
with _get_client(remote) as client:
|
||||
result = client.list_images(limit=limit, offset=offset)
|
||||
except TsrClientError as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
images = result.get("images", [])
|
||||
total = result.get("total", len(images))
|
||||
|
||||
if json_output:
|
||||
console.print_json(data=result)
|
||||
return
|
||||
|
||||
if not images:
|
||||
console.print("[yellow]No images in gallery.[/yellow]")
|
||||
return
|
||||
|
||||
table = Table(title=f"Gallery Images ({len(images)}/{total})", show_header=True, header_style="bold magenta")
|
||||
table.add_column("ID", style="cyan")
|
||||
table.add_column("Filename", style="green")
|
||||
table.add_column("Size", style="white")
|
||||
table.add_column("Created", style="dim")
|
||||
|
||||
for img in images:
|
||||
size = f"{img.get('width', '?')}x{img.get('height', '?')}"
|
||||
created = img.get("created_at", "")
|
||||
if isinstance(created, (int, float)):
|
||||
from datetime import datetime # noqa: PLC0415
|
||||
|
||||
created = datetime.fromtimestamp(created).strftime("%Y-%m-%d %H:%M")
|
||||
table.add_row(img.get("id", ""), img.get("filename", ""), size, str(created))
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
@images_app.command("show")
|
||||
def images_show(
|
||||
image_id: Annotated[str, typer.Argument(help="Image ID to show")],
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
) -> None:
|
||||
"""Show image metadata."""
|
||||
try:
|
||||
with _get_client(remote) as client:
|
||||
meta = client.get_image_meta(image_id)
|
||||
except TsrClientError as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
if json_output:
|
||||
console.print_json(data=meta)
|
||||
return
|
||||
|
||||
table = Table(title=f"Image: {image_id}", show_header=True, header_style="bold magenta")
|
||||
table.add_column("Property", style="cyan")
|
||||
table.add_column("Value", style="green")
|
||||
|
||||
for key, value in meta.items():
|
||||
display_value = json.dumps(value, indent=2) if isinstance(value, dict) else str(value)
|
||||
table.add_row(key, display_value)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
@images_app.command("delete")
|
||||
def images_delete(
|
||||
image_id: Annotated[str, typer.Argument(help="Image ID to delete")],
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
force: Annotated[bool, typer.Option("-f", "--force", help="Skip confirmation")] = False,
|
||||
) -> None:
|
||||
"""Delete an image from the gallery."""
|
||||
if not force:
|
||||
confirm = typer.confirm(f"Delete image {image_id}?")
|
||||
if not confirm:
|
||||
console.print("[yellow]Cancelled.[/yellow]")
|
||||
raise typer.Exit(0)
|
||||
|
||||
try:
|
||||
with _get_client(remote) as client:
|
||||
client.delete_image(image_id)
|
||||
except TsrClientError as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
console.print(f"[green]Deleted image: {image_id}[/green]")
|
||||
|
||||
|
||||
@images_app.command("download")
|
||||
def images_download(
|
||||
image_id: Annotated[str, typer.Argument(help="Image ID to download")],
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file or directory")] = None,
|
||||
) -> None:
|
||||
"""Download an image from the remote gallery."""
|
||||
try:
|
||||
with _get_client(remote) as client:
|
||||
content = client.download_image(image_id)
|
||||
except TsrClientError as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
# Determine output path
|
||||
if output is None:
|
||||
dest = Path(f"{image_id}.png")
|
||||
elif output.is_dir():
|
||||
dest = output / f"{image_id}.png"
|
||||
else:
|
||||
dest = output
|
||||
|
||||
dest.write_bytes(content)
|
||||
console.print(f"[green]Saved:[/green] {dest}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Models Commands (Remote)
|
||||
# =============================================================================
|
||||
|
||||
models_app = typer.Typer(
|
||||
name="models",
|
||||
help="Manage models on remote server.",
|
||||
no_args_is_help=True,
|
||||
)
|
||||
app.add_typer(models_app, name="models")
|
||||
|
||||
|
||||
@models_app.command("list")
|
||||
def models_list(
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
) -> None:
|
||||
"""List available models on remote server."""
|
||||
try:
|
||||
with _get_client(remote) as client:
|
||||
result = client.list_models()
|
||||
except TsrClientError as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
if json_output:
|
||||
console.print_json(data=result)
|
||||
return
|
||||
|
||||
models = result.get("models", [])
|
||||
active = result.get("active", "")
|
||||
|
||||
if not models:
|
||||
console.print("[yellow]No models found.[/yellow]")
|
||||
return
|
||||
|
||||
table = Table(title="Available Models", show_header=True, header_style="bold magenta")
|
||||
table.add_column("Status", style="dim", width=3)
|
||||
table.add_column("Name", style="cyan")
|
||||
table.add_column("Path", style="dim")
|
||||
|
||||
for model in models:
|
||||
path = model.get("path", "")
|
||||
name = model.get("name", Path(path).stem if path else "")
|
||||
is_active = active in {path, name}
|
||||
status = "[green]✓[/green]" if is_active else ""
|
||||
table.add_row(status, name, path)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
@models_app.command("active")
|
||||
def models_active(
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
) -> None:
|
||||
"""Show currently active model."""
|
||||
try:
|
||||
with _get_client(remote) as client:
|
||||
result = client.get_active_model()
|
||||
except TsrClientError as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
if json_output:
|
||||
console.print_json(data=result)
|
||||
return
|
||||
|
||||
model = result.get("model", "None")
|
||||
console.print(f"[bold]Active model:[/bold] {model}")
|
||||
|
||||
|
||||
@models_app.command("switch")
|
||||
def models_switch(
|
||||
model: Annotated[str, typer.Argument(help="Model path or name to switch to")],
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
) -> None:
|
||||
"""Switch to a different model on the remote server."""
|
||||
console.print(f"[cyan]Switching to model: {model}[/cyan]")
|
||||
try:
|
||||
with _get_client(remote) as client:
|
||||
result = client.switch_model(model)
|
||||
except TsrClientError as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
console.print(f"[green]{result.get('status', 'OK')}[/green]")
|
||||
|
||||
|
||||
@models_app.command("loras")
|
||||
def models_loras(
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
) -> None:
|
||||
"""List available LoRAs on remote server."""
|
||||
try:
|
||||
with _get_client(remote) as client:
|
||||
result = client.list_loras()
|
||||
except TsrClientError as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
if json_output:
|
||||
console.print_json(data=result)
|
||||
return
|
||||
|
||||
loras = result.get("loras", [])
|
||||
if not loras:
|
||||
console.print("[yellow]No LoRAs found.[/yellow]")
|
||||
return
|
||||
|
||||
table = Table(title="Available LoRAs", show_header=True, header_style="bold magenta")
|
||||
table.add_column("Name", style="cyan")
|
||||
table.add_column("Path", style="dim")
|
||||
|
||||
for lora in loras:
|
||||
path = lora.get("path", "")
|
||||
name = lora.get("name", Path(path).stem if path else "")
|
||||
table.add_row(name, path)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Remote Configuration Commands
|
||||
# =============================================================================
|
||||
|
||||
remote_app = typer.Typer(
|
||||
name="remote",
|
||||
help="Manage remote server configuration.",
|
||||
no_args_is_help=True,
|
||||
)
|
||||
app.add_typer(remote_app, name="remote")
|
||||
|
||||
|
||||
@remote_app.command("list")
|
||||
def remote_list(
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
) -> None:
|
||||
"""List configured remotes."""
|
||||
from tensors.config import get_default_remote # noqa: PLC0415
|
||||
|
||||
remotes = get_remotes()
|
||||
default = get_default_remote()
|
||||
|
||||
if json_output:
|
||||
console.print_json(data={"remotes": remotes, "default": default})
|
||||
return
|
||||
|
||||
if not remotes:
|
||||
console.print("[yellow]No remotes configured.[/yellow]")
|
||||
console.print("[dim]Add one with: tsr remote add NAME URL[/dim]")
|
||||
return
|
||||
|
||||
table = Table(title="Configured Remotes", show_header=True, header_style="bold magenta")
|
||||
table.add_column("Default", style="dim", width=3)
|
||||
table.add_column("Name", style="cyan")
|
||||
table.add_column("URL", style="green")
|
||||
|
||||
for name, url in remotes.items():
|
||||
is_default = name == default
|
||||
status = "[green]✓[/green]" if is_default else ""
|
||||
table.add_row(status, name, url)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
@remote_app.command("add")
|
||||
def remote_add(
|
||||
name: Annotated[str, typer.Argument(help="Remote name")],
|
||||
url: Annotated[str, typer.Argument(help="Remote URL (e.g., http://host:8080)")],
|
||||
) -> None:
|
||||
"""Add a remote server."""
|
||||
save_remote(name, url)
|
||||
console.print(f"[green]Added remote:[/green] {name} → {url}")
|
||||
|
||||
|
||||
@remote_app.command("default")
|
||||
def remote_default(
|
||||
name: Annotated[str | None, typer.Argument(help="Remote name to set as default (omit to clear)")] = None,
|
||||
) -> None:
|
||||
"""Set or clear the default remote."""
|
||||
set_default_remote(name)
|
||||
if name:
|
||||
console.print(f"[green]Default remote set to:[/green] {name}")
|
||||
else:
|
||||
console.print("[green]Default remote cleared.[/green]")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Main entry point."""
|
||||
# Handle legacy invocation: tsr <file.safetensors> -> tsr info <file>
|
||||
known_commands = (
|
||||
"info",
|
||||
"search",
|
||||
"get",
|
||||
"dl",
|
||||
"download",
|
||||
"config",
|
||||
"generate",
|
||||
"serve",
|
||||
"status",
|
||||
"reload",
|
||||
"db",
|
||||
"images",
|
||||
"models",
|
||||
"remote",
|
||||
)
|
||||
if len(sys.argv) > 1 and not sys.argv[1].startswith("-"):
|
||||
arg = sys.argv[1]
|
||||
if arg not in ("info", "search", "get", "dl", "download", "config", "generate", "serve", "status", "reload", "db") and (
|
||||
arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists()
|
||||
):
|
||||
if arg not in known_commands and (arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists()):
|
||||
sys.argv = [sys.argv[0], "info", *sys.argv[1:]]
|
||||
|
||||
app()
|
||||
|
||||
@@ -0,0 +1,292 @@
|
||||
"""HTTP client for remote tsr server API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
|
||||
class TsrClientError(Exception):
|
||||
"""Error from TsrClient operations."""
|
||||
|
||||
|
||||
class TsrClient:
|
||||
"""HTTP client wrapper for tsr server API.
|
||||
|
||||
Usage:
|
||||
with TsrClient("http://junkpile:8080") as client:
|
||||
images = client.list_images()
|
||||
result = client.generate("a cat")
|
||||
"""
|
||||
|
||||
def __init__(self, base_url: str, timeout: float = 300.0) -> None:
|
||||
"""Initialize client with server URL."""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout = timeout
|
||||
self._client: httpx.Client | None = None
|
||||
|
||||
def __enter__(self) -> TsrClient:
|
||||
self._client = httpx.Client(base_url=self.base_url, timeout=self.timeout)
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc: object) -> None:
|
||||
if self._client:
|
||||
self._client.close()
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self) -> httpx.Client:
|
||||
"""Get the HTTP client, creating if needed."""
|
||||
if self._client is None:
|
||||
self._client = httpx.Client(base_url=self.base_url, timeout=self.timeout)
|
||||
return self._client
|
||||
|
||||
def _get(self, path: str, params: dict[str, Any] | None = None) -> Any:
|
||||
"""Make GET request."""
|
||||
try:
|
||||
resp = self.client.get(path, params=params)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
|
||||
except httpx.RequestError as e:
|
||||
raise TsrClientError(f"Request failed: {e}") from e
|
||||
|
||||
def _post(self, path: str, json: dict[str, Any] | None = None) -> Any:
|
||||
"""Make POST request."""
|
||||
try:
|
||||
resp = self.client.post(path, json=json)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
|
||||
except httpx.RequestError as e:
|
||||
raise TsrClientError(f"Request failed: {e}") from e
|
||||
|
||||
def _delete(self, path: str) -> Any:
|
||||
"""Make DELETE request."""
|
||||
try:
|
||||
resp = self.client.delete(path)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
|
||||
except httpx.RequestError as e:
|
||||
raise TsrClientError(f"Request failed: {e}") from e
|
||||
|
||||
# =========================================================================
|
||||
# Server Status
|
||||
# =========================================================================
|
||||
|
||||
def status(self) -> dict[str, Any]:
|
||||
"""Get server status."""
|
||||
return dict(self._get("/status"))
|
||||
|
||||
# =========================================================================
|
||||
# Gallery / Images
|
||||
# =========================================================================
|
||||
|
||||
def list_images(self, limit: int = 50, offset: int = 0) -> dict[str, Any]:
|
||||
"""List images in gallery."""
|
||||
return dict(self._get("/api/images", params={"limit": limit, "offset": offset}))
|
||||
|
||||
def get_image_meta(self, image_id: str) -> dict[str, Any]:
|
||||
"""Get metadata for an image."""
|
||||
return dict(self._get(f"/api/images/{image_id}/meta"))
|
||||
|
||||
def delete_image(self, image_id: str) -> dict[str, Any]:
|
||||
"""Delete an image."""
|
||||
return dict(self._delete(f"/api/images/{image_id}"))
|
||||
|
||||
def edit_image(self, image_id: str, updates: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Update image metadata."""
|
||||
return dict(self._post(f"/api/images/{image_id}/edit", json=updates))
|
||||
|
||||
def download_image(self, image_id: str) -> bytes:
|
||||
"""Download image file bytes."""
|
||||
try:
|
||||
resp = self.client.get(f"/api/images/{image_id}")
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
|
||||
except httpx.RequestError as e:
|
||||
raise TsrClientError(f"Request failed: {e}") from e
|
||||
|
||||
# =========================================================================
|
||||
# Models
|
||||
# =========================================================================
|
||||
|
||||
def list_models(self) -> dict[str, Any]:
|
||||
"""List available models."""
|
||||
return dict(self._get("/api/models"))
|
||||
|
||||
def get_active_model(self) -> dict[str, Any]:
|
||||
"""Get currently active model."""
|
||||
return dict(self._get("/api/models/active"))
|
||||
|
||||
def switch_model(self, model_path: str) -> dict[str, Any]:
|
||||
"""Switch to a different model."""
|
||||
return dict(self._post("/api/models/switch", json={"model": model_path}))
|
||||
|
||||
def list_loras(self) -> dict[str, Any]:
|
||||
"""List available LoRAs."""
|
||||
return dict(self._get("/api/models/loras"))
|
||||
|
||||
def scan_models(self) -> dict[str, Any]:
|
||||
"""Scan model directories."""
|
||||
return dict(self._get("/api/models/scan"))
|
||||
|
||||
# =========================================================================
|
||||
# Generation
|
||||
# =========================================================================
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
width: int = 512,
|
||||
height: int = 512,
|
||||
steps: int = 20,
|
||||
cfg_scale: float = 7.0,
|
||||
seed: int = -1,
|
||||
sampler_name: str = "",
|
||||
scheduler: str = "",
|
||||
batch_size: int = 1,
|
||||
save_to_gallery: bool = True,
|
||||
return_base64: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Generate images."""
|
||||
body = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"steps": steps,
|
||||
"cfg_scale": cfg_scale,
|
||||
"seed": seed,
|
||||
"sampler_name": sampler_name,
|
||||
"scheduler": scheduler,
|
||||
"batch_size": batch_size,
|
||||
"save_to_gallery": save_to_gallery,
|
||||
"return_base64": return_base64,
|
||||
}
|
||||
return dict(self._post("/api/generate", json=body))
|
||||
|
||||
def list_samplers(self) -> dict[str, Any]:
|
||||
"""List available samplers."""
|
||||
return dict(self._get("/api/samplers"))
|
||||
|
||||
def list_schedulers(self) -> dict[str, Any]:
|
||||
"""List available schedulers."""
|
||||
return dict(self._get("/api/schedulers"))
|
||||
|
||||
# =========================================================================
|
||||
# Download
|
||||
# =========================================================================
|
||||
|
||||
def start_download(
|
||||
self,
|
||||
version_id: int | None = None,
|
||||
model_id: int | None = None,
|
||||
hash_val: str | None = None,
|
||||
output_dir: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Start a model download from CivitAI."""
|
||||
body: dict[str, Any] = {}
|
||||
if version_id:
|
||||
body["version_id"] = version_id
|
||||
if model_id:
|
||||
body["model_id"] = model_id
|
||||
if hash_val:
|
||||
body["hash"] = hash_val
|
||||
if output_dir:
|
||||
body["output_dir"] = output_dir
|
||||
return dict(self._post("/api/download", json=body))
|
||||
|
||||
def get_download_status(self, download_id: str) -> dict[str, Any]:
|
||||
"""Get download status."""
|
||||
return dict(self._get(f"/api/download/status/{download_id}"))
|
||||
|
||||
def list_downloads(self) -> dict[str, Any]:
|
||||
"""List active downloads."""
|
||||
return dict(self._get("/api/download/active"))
|
||||
|
||||
# =========================================================================
|
||||
# Database
|
||||
# =========================================================================
|
||||
|
||||
def db_list_files(self) -> list[dict[str, Any]]:
|
||||
"""List local files in database."""
|
||||
return list(self._get("/api/db/files"))
|
||||
|
||||
def db_search_models(
|
||||
self,
|
||||
query: str | None = None,
|
||||
model_type: str | None = None,
|
||||
base_model: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search cached models."""
|
||||
params: dict[str, Any] = {"limit": limit}
|
||||
if query:
|
||||
params["query"] = query
|
||||
if model_type:
|
||||
params["type"] = model_type
|
||||
if base_model:
|
||||
params["base"] = base_model
|
||||
return list(self._get("/api/db/models", params=params))
|
||||
|
||||
def db_get_model(self, civitai_id: int) -> dict[str, Any]:
|
||||
"""Get cached model by CivitAI ID."""
|
||||
return dict(self._get(f"/api/db/models/{civitai_id}"))
|
||||
|
||||
def db_get_triggers(self, file_path: str | None = None, version_id: int | None = None) -> list[str]:
|
||||
"""Get trigger words."""
|
||||
if version_id:
|
||||
return list(self._get(f"/api/db/triggers/{version_id}"))
|
||||
if file_path:
|
||||
return list(self._get("/api/db/triggers", params={"file_path": file_path}))
|
||||
return []
|
||||
|
||||
def db_stats(self) -> dict[str, Any]:
|
||||
"""Get database statistics."""
|
||||
return dict(self._get("/api/db/stats"))
|
||||
|
||||
def db_scan(self, directory: str) -> dict[str, Any]:
|
||||
"""Scan directory for safetensor files."""
|
||||
return dict(self._post("/api/db/scan", json={"directory": directory}))
|
||||
|
||||
def db_link(self) -> dict[str, Any]:
|
||||
"""Link unlinked files to CivitAI."""
|
||||
return dict(self._post("/api/db/link"))
|
||||
|
||||
def db_cache(self, model_id: int) -> dict[str, Any]:
|
||||
"""Cache CivitAI model data."""
|
||||
return dict(self._post("/api/db/cache", json={"model_id": model_id}))
|
||||
|
||||
# =========================================================================
|
||||
# Streaming Downloads
|
||||
# =========================================================================
|
||||
|
||||
def stream_image(self, image_id: str) -> Iterator[bytes]:
|
||||
"""Stream image download in chunks."""
|
||||
try:
|
||||
with self.client.stream("GET", f"/api/images/{image_id}") as resp:
|
||||
resp.raise_for_status()
|
||||
yield from resp.iter_bytes(chunk_size=1024 * 64)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise TsrClientError(f"HTTP {e.response.status_code}") from e
|
||||
except httpx.RequestError as e:
|
||||
raise TsrClientError(f"Request failed: {e}") from e
|
||||
|
||||
def save_image_to(self, image_id: str, dest: Path) -> Path:
|
||||
"""Download and save image to file."""
|
||||
content = self.download_image(image_id)
|
||||
dest.write_bytes(content)
|
||||
return dest
|
||||
@@ -165,3 +165,73 @@ def get_default_output_path(model_type: str | None) -> Path | None:
|
||||
if model_type and model_type in DEFAULT_PATHS:
|
||||
return DEFAULT_PATHS[model_type]
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Remote Server Configuration
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def get_remotes() -> dict[str, str]:
|
||||
"""Get configured remote servers.
|
||||
|
||||
Returns a dict mapping names to URLs, e.g., {"junkpile": "http://junkpile:8080"}
|
||||
"""
|
||||
config = load_config()
|
||||
remotes = config.get("remotes", {})
|
||||
return dict(remotes) if isinstance(remotes, dict) else {}
|
||||
|
||||
|
||||
def get_default_remote() -> str | None:
|
||||
"""Get the default remote name or URL."""
|
||||
config = load_config()
|
||||
return config.get("default_remote")
|
||||
|
||||
|
||||
def resolve_remote(remote: str | None) -> str | None:
|
||||
"""Resolve a remote name or URL to a full URL.
|
||||
|
||||
Args:
|
||||
remote: Remote name (from config), URL, or None
|
||||
|
||||
Returns:
|
||||
Full URL or None if no remote specified and no default
|
||||
"""
|
||||
if remote is None:
|
||||
# Check for default remote
|
||||
default = get_default_remote()
|
||||
if default:
|
||||
remote = default
|
||||
else:
|
||||
return None
|
||||
|
||||
# Check if it's a URL (starts with http:// or https://)
|
||||
if remote.startswith(("http://", "https://")):
|
||||
return remote
|
||||
|
||||
# Look up in configured remotes
|
||||
remotes = get_remotes()
|
||||
if remote in remotes:
|
||||
return remotes[remote]
|
||||
|
||||
# Treat as hostname with default port
|
||||
return f"http://{remote}:8080"
|
||||
|
||||
|
||||
def save_remote(name: str, url: str) -> None:
|
||||
"""Save a remote server configuration."""
|
||||
config = load_config()
|
||||
if "remotes" not in config:
|
||||
config["remotes"] = {}
|
||||
config["remotes"][name] = url
|
||||
save_config(config)
|
||||
|
||||
|
||||
def set_default_remote(name: str | None) -> None:
|
||||
"""Set the default remote."""
|
||||
config = load_config()
|
||||
if name is None:
|
||||
config.pop("default_remote", None)
|
||||
else:
|
||||
config["default_remote"] = name
|
||||
save_config(config)
|
||||
|
||||
Reference in New Issue
Block a user