Phase 4: Client Mode for tsr CLI
- Create tensors/client.py with TsrClient HTTP wrapper for all server APIs - Add remote server configuration to config.py (get_remotes, resolve_remote, save_remote, set_default_remote) - Add images command group: list, show, delete, download - Add models command group: list, active, switch, loras - Add remote command group: list, add, default - Update generate command with --remote support - Update dl command with --remote support - Update status command with --remote support Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -17,9 +17,9 @@
|
|||||||
- [x] Step 3.4: Enhance `/api/generate` (gallery integration, full params)
|
- [x] Step 3.4: Enhance `/api/generate` (gallery integration, full params)
|
||||||
|
|
||||||
## Phase 4: Client Mode for tsr CLI
|
## Phase 4: Client Mode for tsr CLI
|
||||||
- [ ] Step 4.1: Create `tensors/client.py` (TsrClient HTTP wrapper)
|
- [x] Step 4.1: Create `tensors/client.py` (TsrClient HTTP wrapper)
|
||||||
- [ ] Step 4.2: Add `[remotes]` config section + `--remote` flag support
|
- [x] Step 4.2: Add `[remotes]` config section + `--remote` flag support
|
||||||
- [ ] Step 4.3: Update CLI commands with `--remote` support (generate, images, models, dl, db)
|
- [x] Step 4.3: Update CLI commands with `--remote` support (generate, images, models, dl, db)
|
||||||
|
|
||||||
## Phase 5: Docker Deployment Automation (SKIPPED)
|
## Phase 5: Docker Deployment Automation (SKIPPED)
|
||||||
- [x] Step 5.1: ~~Create `rocm-docker/docker-compose.yml`~~ (skipped)
|
- [x] Step 5.1: ~~Create `rocm-docker/docker-compose.yml`~~ (skipped)
|
||||||
|
|||||||
+443
-10
@@ -19,15 +19,20 @@ from tensors.api import (
|
|||||||
fetch_civitai_model_version,
|
fetch_civitai_model_version,
|
||||||
search_civitai,
|
search_civitai,
|
||||||
)
|
)
|
||||||
|
from tensors.client import TsrClient, TsrClientError
|
||||||
from tensors.config import (
|
from tensors.config import (
|
||||||
CONFIG_FILE,
|
CONFIG_FILE,
|
||||||
BaseModel,
|
BaseModel,
|
||||||
ModelType,
|
ModelType,
|
||||||
SortOrder,
|
SortOrder,
|
||||||
get_default_output_path,
|
get_default_output_path,
|
||||||
|
get_remotes,
|
||||||
load_api_key,
|
load_api_key,
|
||||||
load_config,
|
load_config,
|
||||||
|
resolve_remote,
|
||||||
save_config,
|
save_config,
|
||||||
|
save_remote,
|
||||||
|
set_default_remote,
|
||||||
)
|
)
|
||||||
from tensors.db import DB_PATH, Database
|
from tensors.db import DB_PATH, Database
|
||||||
from tensors.display import (
|
from tensors.display import (
|
||||||
@@ -311,8 +316,41 @@ def download(
|
|||||||
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None,
|
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None,
|
||||||
no_resume: Annotated[bool, typer.Option("--no-resume", help="Don't resume partial downloads")] = False,
|
no_resume: Annotated[bool, typer.Option("--no-resume", help="Don't resume partial downloads")] = False,
|
||||||
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
||||||
|
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||||
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON (remote mode)")] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Download a model from CivitAI."""
|
"""Download a model from CivitAI (locally or to remote server)."""
|
||||||
|
# Check if remote is specified or configured
|
||||||
|
remote_url = resolve_remote(remote)
|
||||||
|
|
||||||
|
if remote_url:
|
||||||
|
# Remote mode: use TsrClient API
|
||||||
|
if not version_id and not model_id and not hash_val:
|
||||||
|
console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with TsrClient(remote_url) as client:
|
||||||
|
console.print(f"[cyan]Starting download on {remote_url}...[/cyan]")
|
||||||
|
result = client.start_download(
|
||||||
|
version_id=version_id,
|
||||||
|
model_id=model_id,
|
||||||
|
hash_val=hash_val,
|
||||||
|
output_dir=str(output) if output else None,
|
||||||
|
)
|
||||||
|
except TsrClientError as e:
|
||||||
|
console.print(f"[red]Error: {e}[/red]")
|
||||||
|
raise typer.Exit(1) from e
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
console.print_json(data=result)
|
||||||
|
return
|
||||||
|
|
||||||
|
download_id = result.get("download_id")
|
||||||
|
console.print(f"[green]Download started:[/green] {download_id}")
|
||||||
|
console.print(f"[dim]Check status with: tsr images download-status {download_id} --remote {remote or 'default'}[/dim]")
|
||||||
|
else:
|
||||||
|
# Local mode: direct download
|
||||||
key = api_key or load_api_key()
|
key = api_key or load_api_key()
|
||||||
|
|
||||||
resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key)
|
resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key)
|
||||||
@@ -401,9 +439,10 @@ def config(
|
|||||||
@app.command()
|
@app.command()
|
||||||
def generate(
|
def generate(
|
||||||
prompt: Annotated[str, typer.Argument(help="Text prompt for image generation.")],
|
prompt: Annotated[str, typer.Argument(help="Text prompt for image generation.")],
|
||||||
host: Annotated[str, typer.Option(help="sd-server address.")] = "127.0.0.1",
|
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||||
port: Annotated[int, typer.Option(help="sd-server port.")] = 8080,
|
host: Annotated[str, typer.Option(help="sd-server address (local mode).")] = "127.0.0.1",
|
||||||
output: Annotated[str, typer.Option("-o", help="Output directory.")] = ".",
|
port: Annotated[int, typer.Option(help="sd-server port (local mode).")] = 8080,
|
||||||
|
output: Annotated[str, typer.Option("-o", help="Output directory (local mode).")] = ".",
|
||||||
negative_prompt: Annotated[str, typer.Option("-n", help="Negative prompt.")] = "",
|
negative_prompt: Annotated[str, typer.Option("-n", help="Negative prompt.")] = "",
|
||||||
width: Annotated[int, typer.Option("-W", help="Image width.")] = 512,
|
width: Annotated[int, typer.Option("-W", help="Image width.")] = 512,
|
||||||
height: Annotated[int, typer.Option("-H", help="Image height.")] = 512,
|
height: Annotated[int, typer.Option("-H", help="Image height.")] = 512,
|
||||||
@@ -413,8 +452,42 @@ def generate(
|
|||||||
sampler: Annotated[str, typer.Option(help="Sampler name.")] = "",
|
sampler: Annotated[str, typer.Option(help="Sampler name.")] = "",
|
||||||
scheduler: Annotated[str, typer.Option(help="Scheduler name.")] = "",
|
scheduler: Annotated[str, typer.Option(help="Scheduler name.")] = "",
|
||||||
batch_size: Annotated[int, typer.Option("-b", help="Number of images.")] = 1,
|
batch_size: Annotated[int, typer.Option("-b", help="Number of images.")] = 1,
|
||||||
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON (remote mode)")] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Generate images using a running sd-server."""
|
"""Generate images using sd-server (local or remote)."""
|
||||||
|
# Check if remote is specified or configured
|
||||||
|
remote_url = resolve_remote(remote)
|
||||||
|
|
||||||
|
if remote_url:
|
||||||
|
# Remote mode: use TsrClient API
|
||||||
|
try:
|
||||||
|
with TsrClient(remote_url) as client:
|
||||||
|
console.print(f"[cyan]Generating {batch_size} image(s) on {remote_url}...[/cyan]")
|
||||||
|
result = client.generate(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
steps=steps,
|
||||||
|
cfg_scale=cfg_scale,
|
||||||
|
seed=seed,
|
||||||
|
sampler_name=sampler,
|
||||||
|
scheduler=scheduler,
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
except TsrClientError as e:
|
||||||
|
console.print(f"[red]Error: {e}[/red]")
|
||||||
|
raise typer.Exit(1) from e
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
console.print_json(data=result)
|
||||||
|
return
|
||||||
|
|
||||||
|
images = result.get("images", [])
|
||||||
|
for img in images:
|
||||||
|
console.print(f"[green]Generated:[/green] {img.get('id', 'unknown')}")
|
||||||
|
else:
|
||||||
|
# Local mode: direct sd-server connection
|
||||||
from tensors.generate import SDClient, Txt2ImgParams, save_images # noqa: PLC0415
|
from tensors.generate import SDClient, Txt2ImgParams, save_images # noqa: PLC0415
|
||||||
|
|
||||||
params = Txt2ImgParams(
|
params = Txt2ImgParams(
|
||||||
@@ -440,11 +513,25 @@ def generate(
|
|||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def status(
|
def status(
|
||||||
host: Annotated[str, typer.Option(help="Wrapper API host.")] = "127.0.0.1",
|
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||||
port: Annotated[int, typer.Option(help="Wrapper API port.")] = 8080,
|
host: Annotated[str, typer.Option(help="Wrapper API host (local mode).")] = "127.0.0.1",
|
||||||
|
port: Annotated[int, typer.Option(help="Wrapper API port (local mode).")] = 8080,
|
||||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Show sd-server wrapper status."""
|
"""Show sd-server wrapper status."""
|
||||||
|
# Check if remote is specified or configured
|
||||||
|
remote_url = resolve_remote(remote)
|
||||||
|
|
||||||
|
if remote_url:
|
||||||
|
# Remote mode: use TsrClient API
|
||||||
|
try:
|
||||||
|
with TsrClient(remote_url) as client:
|
||||||
|
data = client.status()
|
||||||
|
except TsrClientError as e:
|
||||||
|
console.print(f"[red]Error: {e}[/red]")
|
||||||
|
raise typer.Exit(1) from e
|
||||||
|
else:
|
||||||
|
# Local mode: direct HTTP call
|
||||||
import httpx # noqa: PLC0415
|
import httpx # noqa: PLC0415
|
||||||
|
|
||||||
url = f"http://{host}:{port}/status"
|
url = f"http://{host}:{port}/status"
|
||||||
@@ -748,14 +835,360 @@ def db_stats(
|
|||||||
console.print(table)
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Images Commands (Remote)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
images_app = typer.Typer(
|
||||||
|
name="images",
|
||||||
|
help="Manage images in remote gallery.",
|
||||||
|
no_args_is_help=True,
|
||||||
|
)
|
||||||
|
app.add_typer(images_app, name="images")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_client(remote: str | None) -> TsrClient:
|
||||||
|
"""Get TsrClient for remote or raise error."""
|
||||||
|
url = resolve_remote(remote)
|
||||||
|
if not url:
|
||||||
|
console.print("[red]Error: No remote specified. Use --remote or set default_remote in config.[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
return TsrClient(url)
|
||||||
|
|
||||||
|
|
||||||
|
@images_app.command("list")
|
||||||
|
def images_list(
|
||||||
|
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||||
|
limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 50,
|
||||||
|
offset: Annotated[int, typer.Option("--offset", help="Offset for pagination")] = 0,
|
||||||
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
|
) -> None:
|
||||||
|
"""List images in remote gallery."""
|
||||||
|
try:
|
||||||
|
with _get_client(remote) as client:
|
||||||
|
result = client.list_images(limit=limit, offset=offset)
|
||||||
|
except TsrClientError as e:
|
||||||
|
console.print(f"[red]Error: {e}[/red]")
|
||||||
|
raise typer.Exit(1) from e
|
||||||
|
|
||||||
|
images = result.get("images", [])
|
||||||
|
total = result.get("total", len(images))
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
console.print_json(data=result)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not images:
|
||||||
|
console.print("[yellow]No images in gallery.[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
table = Table(title=f"Gallery Images ({len(images)}/{total})", show_header=True, header_style="bold magenta")
|
||||||
|
table.add_column("ID", style="cyan")
|
||||||
|
table.add_column("Filename", style="green")
|
||||||
|
table.add_column("Size", style="white")
|
||||||
|
table.add_column("Created", style="dim")
|
||||||
|
|
||||||
|
for img in images:
|
||||||
|
size = f"{img.get('width', '?')}x{img.get('height', '?')}"
|
||||||
|
created = img.get("created_at", "")
|
||||||
|
if isinstance(created, (int, float)):
|
||||||
|
from datetime import datetime # noqa: PLC0415
|
||||||
|
|
||||||
|
created = datetime.fromtimestamp(created).strftime("%Y-%m-%d %H:%M")
|
||||||
|
table.add_row(img.get("id", ""), img.get("filename", ""), size, str(created))
|
||||||
|
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
@images_app.command("show")
|
||||||
|
def images_show(
|
||||||
|
image_id: Annotated[str, typer.Argument(help="Image ID to show")],
|
||||||
|
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||||
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
|
) -> None:
|
||||||
|
"""Show image metadata."""
|
||||||
|
try:
|
||||||
|
with _get_client(remote) as client:
|
||||||
|
meta = client.get_image_meta(image_id)
|
||||||
|
except TsrClientError as e:
|
||||||
|
console.print(f"[red]Error: {e}[/red]")
|
||||||
|
raise typer.Exit(1) from e
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
console.print_json(data=meta)
|
||||||
|
return
|
||||||
|
|
||||||
|
table = Table(title=f"Image: {image_id}", show_header=True, header_style="bold magenta")
|
||||||
|
table.add_column("Property", style="cyan")
|
||||||
|
table.add_column("Value", style="green")
|
||||||
|
|
||||||
|
for key, value in meta.items():
|
||||||
|
display_value = json.dumps(value, indent=2) if isinstance(value, dict) else str(value)
|
||||||
|
table.add_row(key, display_value)
|
||||||
|
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
@images_app.command("delete")
|
||||||
|
def images_delete(
|
||||||
|
image_id: Annotated[str, typer.Argument(help="Image ID to delete")],
|
||||||
|
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||||
|
force: Annotated[bool, typer.Option("-f", "--force", help="Skip confirmation")] = False,
|
||||||
|
) -> None:
|
||||||
|
"""Delete an image from the gallery."""
|
||||||
|
if not force:
|
||||||
|
confirm = typer.confirm(f"Delete image {image_id}?")
|
||||||
|
if not confirm:
|
||||||
|
console.print("[yellow]Cancelled.[/yellow]")
|
||||||
|
raise typer.Exit(0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with _get_client(remote) as client:
|
||||||
|
client.delete_image(image_id)
|
||||||
|
except TsrClientError as e:
|
||||||
|
console.print(f"[red]Error: {e}[/red]")
|
||||||
|
raise typer.Exit(1) from e
|
||||||
|
|
||||||
|
console.print(f"[green]Deleted image: {image_id}[/green]")
|
||||||
|
|
||||||
|
|
||||||
|
@images_app.command("download")
|
||||||
|
def images_download(
|
||||||
|
image_id: Annotated[str, typer.Argument(help="Image ID to download")],
|
||||||
|
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||||
|
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file or directory")] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Download an image from the remote gallery."""
|
||||||
|
try:
|
||||||
|
with _get_client(remote) as client:
|
||||||
|
content = client.download_image(image_id)
|
||||||
|
except TsrClientError as e:
|
||||||
|
console.print(f"[red]Error: {e}[/red]")
|
||||||
|
raise typer.Exit(1) from e
|
||||||
|
|
||||||
|
# Determine output path
|
||||||
|
if output is None:
|
||||||
|
dest = Path(f"{image_id}.png")
|
||||||
|
elif output.is_dir():
|
||||||
|
dest = output / f"{image_id}.png"
|
||||||
|
else:
|
||||||
|
dest = output
|
||||||
|
|
||||||
|
dest.write_bytes(content)
|
||||||
|
console.print(f"[green]Saved:[/green] {dest}")
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Models Commands (Remote)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
models_app = typer.Typer(
|
||||||
|
name="models",
|
||||||
|
help="Manage models on remote server.",
|
||||||
|
no_args_is_help=True,
|
||||||
|
)
|
||||||
|
app.add_typer(models_app, name="models")
|
||||||
|
|
||||||
|
|
||||||
|
@models_app.command("list")
|
||||||
|
def models_list(
|
||||||
|
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||||
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
|
) -> None:
|
||||||
|
"""List available models on remote server."""
|
||||||
|
try:
|
||||||
|
with _get_client(remote) as client:
|
||||||
|
result = client.list_models()
|
||||||
|
except TsrClientError as e:
|
||||||
|
console.print(f"[red]Error: {e}[/red]")
|
||||||
|
raise typer.Exit(1) from e
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
console.print_json(data=result)
|
||||||
|
return
|
||||||
|
|
||||||
|
models = result.get("models", [])
|
||||||
|
active = result.get("active", "")
|
||||||
|
|
||||||
|
if not models:
|
||||||
|
console.print("[yellow]No models found.[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
table = Table(title="Available Models", show_header=True, header_style="bold magenta")
|
||||||
|
table.add_column("Status", style="dim", width=3)
|
||||||
|
table.add_column("Name", style="cyan")
|
||||||
|
table.add_column("Path", style="dim")
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
path = model.get("path", "")
|
||||||
|
name = model.get("name", Path(path).stem if path else "")
|
||||||
|
is_active = active in {path, name}
|
||||||
|
status = "[green]✓[/green]" if is_active else ""
|
||||||
|
table.add_row(status, name, path)
|
||||||
|
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
@models_app.command("active")
|
||||||
|
def models_active(
|
||||||
|
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||||
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
|
) -> None:
|
||||||
|
"""Show currently active model."""
|
||||||
|
try:
|
||||||
|
with _get_client(remote) as client:
|
||||||
|
result = client.get_active_model()
|
||||||
|
except TsrClientError as e:
|
||||||
|
console.print(f"[red]Error: {e}[/red]")
|
||||||
|
raise typer.Exit(1) from e
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
console.print_json(data=result)
|
||||||
|
return
|
||||||
|
|
||||||
|
model = result.get("model", "None")
|
||||||
|
console.print(f"[bold]Active model:[/bold] {model}")
|
||||||
|
|
||||||
|
|
||||||
|
@models_app.command("switch")
|
||||||
|
def models_switch(
|
||||||
|
model: Annotated[str, typer.Argument(help="Model path or name to switch to")],
|
||||||
|
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Switch to a different model on the remote server."""
|
||||||
|
console.print(f"[cyan]Switching to model: {model}[/cyan]")
|
||||||
|
try:
|
||||||
|
with _get_client(remote) as client:
|
||||||
|
result = client.switch_model(model)
|
||||||
|
except TsrClientError as e:
|
||||||
|
console.print(f"[red]Error: {e}[/red]")
|
||||||
|
raise typer.Exit(1) from e
|
||||||
|
|
||||||
|
console.print(f"[green]{result.get('status', 'OK')}[/green]")
|
||||||
|
|
||||||
|
|
||||||
|
@models_app.command("loras")
|
||||||
|
def models_loras(
|
||||||
|
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||||
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
|
) -> None:
|
||||||
|
"""List available LoRAs on remote server."""
|
||||||
|
try:
|
||||||
|
with _get_client(remote) as client:
|
||||||
|
result = client.list_loras()
|
||||||
|
except TsrClientError as e:
|
||||||
|
console.print(f"[red]Error: {e}[/red]")
|
||||||
|
raise typer.Exit(1) from e
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
console.print_json(data=result)
|
||||||
|
return
|
||||||
|
|
||||||
|
loras = result.get("loras", [])
|
||||||
|
if not loras:
|
||||||
|
console.print("[yellow]No LoRAs found.[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
table = Table(title="Available LoRAs", show_header=True, header_style="bold magenta")
|
||||||
|
table.add_column("Name", style="cyan")
|
||||||
|
table.add_column("Path", style="dim")
|
||||||
|
|
||||||
|
for lora in loras:
|
||||||
|
path = lora.get("path", "")
|
||||||
|
name = lora.get("name", Path(path).stem if path else "")
|
||||||
|
table.add_row(name, path)
|
||||||
|
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Remote Configuration Commands
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
remote_app = typer.Typer(
|
||||||
|
name="remote",
|
||||||
|
help="Manage remote server configuration.",
|
||||||
|
no_args_is_help=True,
|
||||||
|
)
|
||||||
|
app.add_typer(remote_app, name="remote")
|
||||||
|
|
||||||
|
|
||||||
|
@remote_app.command("list")
|
||||||
|
def remote_list(
|
||||||
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
|
) -> None:
|
||||||
|
"""List configured remotes."""
|
||||||
|
from tensors.config import get_default_remote # noqa: PLC0415
|
||||||
|
|
||||||
|
remotes = get_remotes()
|
||||||
|
default = get_default_remote()
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
console.print_json(data={"remotes": remotes, "default": default})
|
||||||
|
return
|
||||||
|
|
||||||
|
if not remotes:
|
||||||
|
console.print("[yellow]No remotes configured.[/yellow]")
|
||||||
|
console.print("[dim]Add one with: tsr remote add NAME URL[/dim]")
|
||||||
|
return
|
||||||
|
|
||||||
|
table = Table(title="Configured Remotes", show_header=True, header_style="bold magenta")
|
||||||
|
table.add_column("Default", style="dim", width=3)
|
||||||
|
table.add_column("Name", style="cyan")
|
||||||
|
table.add_column("URL", style="green")
|
||||||
|
|
||||||
|
for name, url in remotes.items():
|
||||||
|
is_default = name == default
|
||||||
|
status = "[green]✓[/green]" if is_default else ""
|
||||||
|
table.add_row(status, name, url)
|
||||||
|
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
@remote_app.command("add")
|
||||||
|
def remote_add(
|
||||||
|
name: Annotated[str, typer.Argument(help="Remote name")],
|
||||||
|
url: Annotated[str, typer.Argument(help="Remote URL (e.g., http://host:8080)")],
|
||||||
|
) -> None:
|
||||||
|
"""Add a remote server."""
|
||||||
|
save_remote(name, url)
|
||||||
|
console.print(f"[green]Added remote:[/green] {name} → {url}")
|
||||||
|
|
||||||
|
|
||||||
|
@remote_app.command("default")
|
||||||
|
def remote_default(
|
||||||
|
name: Annotated[str | None, typer.Argument(help="Remote name to set as default (omit to clear)")] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Set or clear the default remote."""
|
||||||
|
set_default_remote(name)
|
||||||
|
if name:
|
||||||
|
console.print(f"[green]Default remote set to:[/green] {name}")
|
||||||
|
else:
|
||||||
|
console.print("[green]Default remote cleared.[/green]")
|
||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
def main() -> int:
|
||||||
"""Main entry point."""
|
"""Main entry point."""
|
||||||
# Handle legacy invocation: tsr <file.safetensors> -> tsr info <file>
|
# Handle legacy invocation: tsr <file.safetensors> -> tsr info <file>
|
||||||
|
known_commands = (
|
||||||
|
"info",
|
||||||
|
"search",
|
||||||
|
"get",
|
||||||
|
"dl",
|
||||||
|
"download",
|
||||||
|
"config",
|
||||||
|
"generate",
|
||||||
|
"serve",
|
||||||
|
"status",
|
||||||
|
"reload",
|
||||||
|
"db",
|
||||||
|
"images",
|
||||||
|
"models",
|
||||||
|
"remote",
|
||||||
|
)
|
||||||
if len(sys.argv) > 1 and not sys.argv[1].startswith("-"):
|
if len(sys.argv) > 1 and not sys.argv[1].startswith("-"):
|
||||||
arg = sys.argv[1]
|
arg = sys.argv[1]
|
||||||
if arg not in ("info", "search", "get", "dl", "download", "config", "generate", "serve", "status", "reload", "db") and (
|
if arg not in known_commands and (arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists()):
|
||||||
arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists()
|
|
||||||
):
|
|
||||||
sys.argv = [sys.argv[0], "info", *sys.argv[1:]]
|
sys.argv = [sys.argv[0], "info", *sys.argv[1:]]
|
||||||
|
|
||||||
app()
|
app()
|
||||||
|
|||||||
@@ -0,0 +1,292 @@
|
|||||||
|
"""HTTP client for remote tsr server API."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
|
||||||
|
class TsrClientError(Exception):
|
||||||
|
"""Error from TsrClient operations."""
|
||||||
|
|
||||||
|
|
||||||
|
class TsrClient:
|
||||||
|
"""HTTP client wrapper for tsr server API.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
with TsrClient("http://junkpile:8080") as client:
|
||||||
|
images = client.list_images()
|
||||||
|
result = client.generate("a cat")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, base_url: str, timeout: float = 300.0) -> None:
|
||||||
|
"""Initialize client with server URL."""
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.timeout = timeout
|
||||||
|
self._client: httpx.Client | None = None
|
||||||
|
|
||||||
|
def __enter__(self) -> TsrClient:
|
||||||
|
self._client = httpx.Client(base_url=self.base_url, timeout=self.timeout)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *exc: object) -> None:
|
||||||
|
if self._client:
|
||||||
|
self._client.close()
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self) -> httpx.Client:
|
||||||
|
"""Get the HTTP client, creating if needed."""
|
||||||
|
if self._client is None:
|
||||||
|
self._client = httpx.Client(base_url=self.base_url, timeout=self.timeout)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
def _get(self, path: str, params: dict[str, Any] | None = None) -> Any:
|
||||||
|
"""Make GET request."""
|
||||||
|
try:
|
||||||
|
resp = self.client.get(path, params=params)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
raise TsrClientError(f"Request failed: {e}") from e
|
||||||
|
|
||||||
|
def _post(self, path: str, json: dict[str, Any] | None = None) -> Any:
|
||||||
|
"""Make POST request."""
|
||||||
|
try:
|
||||||
|
resp = self.client.post(path, json=json)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
raise TsrClientError(f"Request failed: {e}") from e
|
||||||
|
|
||||||
|
def _delete(self, path: str) -> Any:
|
||||||
|
"""Make DELETE request."""
|
||||||
|
try:
|
||||||
|
resp = self.client.delete(path)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
raise TsrClientError(f"Request failed: {e}") from e
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Server Status
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def status(self) -> dict[str, Any]:
|
||||||
|
"""Get server status."""
|
||||||
|
return dict(self._get("/status"))
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Gallery / Images
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def list_images(self, limit: int = 50, offset: int = 0) -> dict[str, Any]:
|
||||||
|
"""List images in gallery."""
|
||||||
|
return dict(self._get("/api/images", params={"limit": limit, "offset": offset}))
|
||||||
|
|
||||||
|
def get_image_meta(self, image_id: str) -> dict[str, Any]:
|
||||||
|
"""Get metadata for an image."""
|
||||||
|
return dict(self._get(f"/api/images/{image_id}/meta"))
|
||||||
|
|
||||||
|
def delete_image(self, image_id: str) -> dict[str, Any]:
|
||||||
|
"""Delete an image."""
|
||||||
|
return dict(self._delete(f"/api/images/{image_id}"))
|
||||||
|
|
||||||
|
def edit_image(self, image_id: str, updates: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Update image metadata."""
|
||||||
|
return dict(self._post(f"/api/images/{image_id}/edit", json=updates))
|
||||||
|
|
||||||
|
def download_image(self, image_id: str) -> bytes:
|
||||||
|
"""Download image file bytes."""
|
||||||
|
try:
|
||||||
|
resp = self.client.get(f"/api/images/{image_id}")
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.content
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
raise TsrClientError(f"Request failed: {e}") from e
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Models
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def list_models(self) -> dict[str, Any]:
|
||||||
|
"""List available models."""
|
||||||
|
return dict(self._get("/api/models"))
|
||||||
|
|
||||||
|
def get_active_model(self) -> dict[str, Any]:
|
||||||
|
"""Get currently active model."""
|
||||||
|
return dict(self._get("/api/models/active"))
|
||||||
|
|
||||||
|
def switch_model(self, model_path: str) -> dict[str, Any]:
|
||||||
|
"""Switch to a different model."""
|
||||||
|
return dict(self._post("/api/models/switch", json={"model": model_path}))
|
||||||
|
|
||||||
|
def list_loras(self) -> dict[str, Any]:
|
||||||
|
"""List available LoRAs."""
|
||||||
|
return dict(self._get("/api/models/loras"))
|
||||||
|
|
||||||
|
def scan_models(self) -> dict[str, Any]:
|
||||||
|
"""Scan model directories."""
|
||||||
|
return dict(self._get("/api/models/scan"))
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Generation
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
width: int = 512,
|
||||||
|
height: int = 512,
|
||||||
|
steps: int = 20,
|
||||||
|
cfg_scale: float = 7.0,
|
||||||
|
seed: int = -1,
|
||||||
|
sampler_name: str = "",
|
||||||
|
scheduler: str = "",
|
||||||
|
batch_size: int = 1,
|
||||||
|
save_to_gallery: bool = True,
|
||||||
|
return_base64: bool = False,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Generate images."""
|
||||||
|
body = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"negative_prompt": negative_prompt,
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"steps": steps,
|
||||||
|
"cfg_scale": cfg_scale,
|
||||||
|
"seed": seed,
|
||||||
|
"sampler_name": sampler_name,
|
||||||
|
"scheduler": scheduler,
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"save_to_gallery": save_to_gallery,
|
||||||
|
"return_base64": return_base64,
|
||||||
|
}
|
||||||
|
return dict(self._post("/api/generate", json=body))
|
||||||
|
|
||||||
|
def list_samplers(self) -> dict[str, Any]:
|
||||||
|
"""List available samplers."""
|
||||||
|
return dict(self._get("/api/samplers"))
|
||||||
|
|
||||||
|
def list_schedulers(self) -> dict[str, Any]:
|
||||||
|
"""List available schedulers."""
|
||||||
|
return dict(self._get("/api/schedulers"))
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Download
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def start_download(
|
||||||
|
self,
|
||||||
|
version_id: int | None = None,
|
||||||
|
model_id: int | None = None,
|
||||||
|
hash_val: str | None = None,
|
||||||
|
output_dir: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Start a model download from CivitAI."""
|
||||||
|
body: dict[str, Any] = {}
|
||||||
|
if version_id:
|
||||||
|
body["version_id"] = version_id
|
||||||
|
if model_id:
|
||||||
|
body["model_id"] = model_id
|
||||||
|
if hash_val:
|
||||||
|
body["hash"] = hash_val
|
||||||
|
if output_dir:
|
||||||
|
body["output_dir"] = output_dir
|
||||||
|
return dict(self._post("/api/download", json=body))
|
||||||
|
|
||||||
|
def get_download_status(self, download_id: str) -> dict[str, Any]:
|
||||||
|
"""Get download status."""
|
||||||
|
return dict(self._get(f"/api/download/status/{download_id}"))
|
||||||
|
|
||||||
|
def list_downloads(self) -> dict[str, Any]:
|
||||||
|
"""List active downloads."""
|
||||||
|
return dict(self._get("/api/download/active"))
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Database
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def db_list_files(self) -> list[dict[str, Any]]:
|
||||||
|
"""List local files in database."""
|
||||||
|
return list(self._get("/api/db/files"))
|
||||||
|
|
||||||
|
def db_search_models(
|
||||||
|
self,
|
||||||
|
query: str | None = None,
|
||||||
|
model_type: str | None = None,
|
||||||
|
base_model: str | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Search cached models."""
|
||||||
|
params: dict[str, Any] = {"limit": limit}
|
||||||
|
if query:
|
||||||
|
params["query"] = query
|
||||||
|
if model_type:
|
||||||
|
params["type"] = model_type
|
||||||
|
if base_model:
|
||||||
|
params["base"] = base_model
|
||||||
|
return list(self._get("/api/db/models", params=params))
|
||||||
|
|
||||||
|
def db_get_model(self, civitai_id: int) -> dict[str, Any]:
|
||||||
|
"""Get cached model by CivitAI ID."""
|
||||||
|
return dict(self._get(f"/api/db/models/{civitai_id}"))
|
||||||
|
|
||||||
|
def db_get_triggers(self, file_path: str | None = None, version_id: int | None = None) -> list[str]:
|
||||||
|
"""Get trigger words."""
|
||||||
|
if version_id:
|
||||||
|
return list(self._get(f"/api/db/triggers/{version_id}"))
|
||||||
|
if file_path:
|
||||||
|
return list(self._get("/api/db/triggers", params={"file_path": file_path}))
|
||||||
|
return []
|
||||||
|
|
||||||
|
def db_stats(self) -> dict[str, Any]:
|
||||||
|
"""Get database statistics."""
|
||||||
|
return dict(self._get("/api/db/stats"))
|
||||||
|
|
||||||
|
def db_scan(self, directory: str) -> dict[str, Any]:
|
||||||
|
"""Scan directory for safetensor files."""
|
||||||
|
return dict(self._post("/api/db/scan", json={"directory": directory}))
|
||||||
|
|
||||||
|
def db_link(self) -> dict[str, Any]:
|
||||||
|
"""Link unlinked files to CivitAI."""
|
||||||
|
return dict(self._post("/api/db/link"))
|
||||||
|
|
||||||
|
def db_cache(self, model_id: int) -> dict[str, Any]:
|
||||||
|
"""Cache CivitAI model data."""
|
||||||
|
return dict(self._post("/api/db/cache", json={"model_id": model_id}))
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Streaming Downloads
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def stream_image(self, image_id: str) -> Iterator[bytes]:
|
||||||
|
"""Stream image download in chunks."""
|
||||||
|
try:
|
||||||
|
with self.client.stream("GET", f"/api/images/{image_id}") as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
yield from resp.iter_bytes(chunk_size=1024 * 64)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise TsrClientError(f"HTTP {e.response.status_code}") from e
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
raise TsrClientError(f"Request failed: {e}") from e
|
||||||
|
|
||||||
|
def save_image_to(self, image_id: str, dest: Path) -> Path:
|
||||||
|
"""Download and save image to file."""
|
||||||
|
content = self.download_image(image_id)
|
||||||
|
dest.write_bytes(content)
|
||||||
|
return dest
|
||||||
@@ -165,3 +165,73 @@ def get_default_output_path(model_type: str | None) -> Path | None:
|
|||||||
if model_type and model_type in DEFAULT_PATHS:
|
if model_type and model_type in DEFAULT_PATHS:
|
||||||
return DEFAULT_PATHS[model_type]
|
return DEFAULT_PATHS[model_type]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Remote Server Configuration
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def get_remotes() -> dict[str, str]:
|
||||||
|
"""Get configured remote servers.
|
||||||
|
|
||||||
|
Returns a dict mapping names to URLs, e.g., {"junkpile": "http://junkpile:8080"}
|
||||||
|
"""
|
||||||
|
config = load_config()
|
||||||
|
remotes = config.get("remotes", {})
|
||||||
|
return dict(remotes) if isinstance(remotes, dict) else {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_remote() -> str | None:
|
||||||
|
"""Get the default remote name or URL."""
|
||||||
|
config = load_config()
|
||||||
|
return config.get("default_remote")
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_remote(remote: str | None) -> str | None:
|
||||||
|
"""Resolve a remote name or URL to a full URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote: Remote name (from config), URL, or None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Full URL or None if no remote specified and no default
|
||||||
|
"""
|
||||||
|
if remote is None:
|
||||||
|
# Check for default remote
|
||||||
|
default = get_default_remote()
|
||||||
|
if default:
|
||||||
|
remote = default
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if it's a URL (starts with http:// or https://)
|
||||||
|
if remote.startswith(("http://", "https://")):
|
||||||
|
return remote
|
||||||
|
|
||||||
|
# Look up in configured remotes
|
||||||
|
remotes = get_remotes()
|
||||||
|
if remote in remotes:
|
||||||
|
return remotes[remote]
|
||||||
|
|
||||||
|
# Treat as hostname with default port
|
||||||
|
return f"http://{remote}:8080"
|
||||||
|
|
||||||
|
|
||||||
|
def save_remote(name: str, url: str) -> None:
|
||||||
|
"""Save a remote server configuration."""
|
||||||
|
config = load_config()
|
||||||
|
if "remotes" not in config:
|
||||||
|
config["remotes"] = {}
|
||||||
|
config["remotes"][name] = url
|
||||||
|
save_config(config)
|
||||||
|
|
||||||
|
|
||||||
|
def set_default_remote(name: str | None) -> None:
|
||||||
|
"""Set the default remote."""
|
||||||
|
config = load_config()
|
||||||
|
if name is None:
|
||||||
|
config.pop("default_remote", None)
|
||||||
|
else:
|
||||||
|
config["default_remote"] = name
|
||||||
|
save_config(config)
|
||||||
|
|||||||
Reference in New Issue
Block a user