Add top-level generate and models commands with --remote support
Add `tsr generate` and `tsr models` as top-level CLI commands that call ComfyUI library functions directly or HTTP to a remote tensors server. Add `--remote` flag to existing `tsr search` and `tsr dl` commands. New file `tensors/remote.py` provides HTTP client functions for all four operations against the remote tensors API (generate, models, search, download with progress polling). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
+352
-1
@@ -54,6 +54,14 @@ from tensors.hf import (
|
|||||||
list_safetensor_files,
|
list_safetensor_files,
|
||||||
search_hf_models,
|
search_hf_models,
|
||||||
)
|
)
|
||||||
|
from tensors.remote import (
|
||||||
|
remote_download,
|
||||||
|
remote_download_status,
|
||||||
|
remote_generate,
|
||||||
|
remote_get_image,
|
||||||
|
remote_models,
|
||||||
|
remote_search,
|
||||||
|
)
|
||||||
from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_metadata
|
from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_metadata
|
||||||
|
|
||||||
# Key masking threshold
|
# Key masking threshold
|
||||||
@@ -232,6 +240,7 @@ def search(
|
|||||||
pipeline: Annotated[str | None, typer.Option("--pipeline", help="Pipeline tag (HuggingFace)")] = None,
|
pipeline: Annotated[str | None, typer.Option("--pipeline", help="Pipeline tag (HuggingFace)")] = None,
|
||||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
||||||
|
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Search models on CivitAI and/or Hugging Face.
|
"""Search models on CivitAI and/or Hugging Face.
|
||||||
|
|
||||||
@@ -242,7 +251,32 @@ def search(
|
|||||||
tsr search -t lora -b pony # CivitAI LoRAs for Pony
|
tsr search -t lora -b pony # CivitAI LoRAs for Pony
|
||||||
tsr search -a stabilityai -P hf # HF by author
|
tsr search -a stabilityai -P hf # HF by author
|
||||||
tsr search --sfw -P civitai # CivitAI SFW only
|
tsr search --sfw -P civitai # CivitAI SFW only
|
||||||
|
tsr search "pony" --remote junkpile # Search via remote server
|
||||||
"""
|
"""
|
||||||
|
# Remote mode: delegate to remote tensors server
|
||||||
|
if remote:
|
||||||
|
civitai_results = remote_search(
|
||||||
|
remote,
|
||||||
|
query=query,
|
||||||
|
model_type=model_type.to_api() if model_type else None,
|
||||||
|
base_model=base.to_api() if base else None,
|
||||||
|
sort=sort.value,
|
||||||
|
limit=limit,
|
||||||
|
page=page,
|
||||||
|
nsfw=nsfw.value if nsfw else None,
|
||||||
|
sfw=sfw,
|
||||||
|
console=console,
|
||||||
|
)
|
||||||
|
if not civitai_results:
|
||||||
|
console.print("[red]Remote search failed.[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
console.print_json(data={"civitai": civitai_results})
|
||||||
|
else:
|
||||||
|
display_search_results(civitai_results, console)
|
||||||
|
return
|
||||||
|
|
||||||
key = api_key or load_api_key()
|
key = api_key or load_api_key()
|
||||||
civitai_results: dict[str, Any] | None = None
|
civitai_results: dict[str, Any] | None = None
|
||||||
hf_results: list[dict[str, Any]] | None = None
|
hf_results: list[dict[str, Any]] | None = None
|
||||||
@@ -429,6 +463,78 @@ def _prepare_download_dir(output: Path | None, model_type_str: str | None) -> Pa
|
|||||||
return output_dir
|
return output_dir
|
||||||
|
|
||||||
|
|
||||||
|
def _poll_remote_download(remote_name: str, download_id: str) -> None:
|
||||||
|
"""Poll a remote download for completion with a progress bar."""
|
||||||
|
import time # noqa: PLC0415
|
||||||
|
|
||||||
|
from rich.progress import BarColumn, DownloadColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn # noqa: PLC0415
|
||||||
|
|
||||||
|
status: dict[str, Any] | None = None
|
||||||
|
with Progress(
|
||||||
|
SpinnerColumn(),
|
||||||
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
TaskProgressColumn(),
|
||||||
|
DownloadColumn(),
|
||||||
|
console=console,
|
||||||
|
) as progress:
|
||||||
|
task = progress.add_task("[cyan]Downloading...", total=100)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
status = remote_download_status(remote_name, download_id)
|
||||||
|
if not status:
|
||||||
|
break
|
||||||
|
|
||||||
|
dl_status = status.get("status", "")
|
||||||
|
pct = status.get("progress", 0)
|
||||||
|
progress.update(task, completed=pct, description=f"[cyan]{dl_status.title()}...")
|
||||||
|
|
||||||
|
if dl_status in ("completed", "failed"):
|
||||||
|
break
|
||||||
|
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
if status and status.get("status") == "completed":
|
||||||
|
console.print(f"[green]Download complete:[/green] {status.get('path', 'N/A')}")
|
||||||
|
elif status and status.get("status") == "failed":
|
||||||
|
console.print(f"[red]Download failed:[/red] {status.get('error', 'Unknown error')}")
|
||||||
|
|
||||||
|
|
||||||
|
def _download_remote(
|
||||||
|
remote_name: str,
|
||||||
|
version_id: int | None,
|
||||||
|
model_id: int | None,
|
||||||
|
hash_val: str | None,
|
||||||
|
output: Path | None,
|
||||||
|
) -> None:
|
||||||
|
"""Handle remote download flow."""
|
||||||
|
if not version_id and not model_id:
|
||||||
|
if hash_val:
|
||||||
|
console.print("[yellow]Remote download does not support --hash. Use --version-id or --model-id.[/yellow]")
|
||||||
|
else:
|
||||||
|
console.print("[red]Error: Must specify --version-id or --model-id for remote download[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
console.print("[dim]Starting download on remote server...[/dim]")
|
||||||
|
result = remote_download(
|
||||||
|
remote_name,
|
||||||
|
version_id=version_id,
|
||||||
|
model_id=model_id,
|
||||||
|
output_dir=str(output) if output else None,
|
||||||
|
console=console,
|
||||||
|
)
|
||||||
|
if not result:
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
console.print(f"[green]Download started:[/green] {result.get('model_name', 'N/A')}")
|
||||||
|
console.print(f"[dim]Version: {result.get('version_name', 'N/A')}[/dim]")
|
||||||
|
console.print(f"[dim]Destination: {result.get('destination', 'N/A')}[/dim]")
|
||||||
|
|
||||||
|
download_id = result.get("download_id")
|
||||||
|
if download_id:
|
||||||
|
_poll_remote_download(remote_name, download_id)
|
||||||
|
|
||||||
|
|
||||||
@app.command("dl")
|
@app.command("dl")
|
||||||
def download(
|
def download(
|
||||||
version_id: Annotated[int | None, typer.Option("-v", "--version-id", help="Model version ID")] = None,
|
version_id: Annotated[int | None, typer.Option("-v", "--version-id", help="Model version ID")] = None,
|
||||||
@@ -437,8 +543,21 @@ def download(
|
|||||||
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None,
|
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None,
|
||||||
no_resume: Annotated[bool, typer.Option("--no-resume", help="Don't resume partial downloads")] = False,
|
no_resume: Annotated[bool, typer.Option("--no-resume", help="Don't resume partial downloads")] = False,
|
||||||
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
||||||
|
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Download a model from CivitAI."""
|
"""Download a model from CivitAI.
|
||||||
|
|
||||||
|
When --remote is specified, the download happens on the remote server.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
tsr dl -v 12345 # Download by version ID
|
||||||
|
tsr dl -m 67890 # Download latest version of model
|
||||||
|
tsr dl -v 12345 --remote junkpile # Download on remote server
|
||||||
|
"""
|
||||||
|
if remote:
|
||||||
|
_download_remote(remote, version_id, model_id, hash_val, output)
|
||||||
|
return
|
||||||
|
|
||||||
key = api_key or load_api_key()
|
key = api_key or load_api_key()
|
||||||
|
|
||||||
resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key)
|
resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key)
|
||||||
@@ -645,6 +764,236 @@ def serve(
|
|||||||
uvicorn.run(create_app(), host=host, port=port, log_level=log_level)
|
uvicorn.run(create_app(), host=host, port=port, log_level=log_level)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Top-Level Generate Command
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def generate( # noqa: PLR0915
|
||||||
|
prompt: Annotated[str, typer.Argument(help="Positive prompt text")],
|
||||||
|
model: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model name")] = None,
|
||||||
|
width: Annotated[int, typer.Option("-W", "--width", help="Image width")] = 1024,
|
||||||
|
height: Annotated[int, typer.Option("-H", "--height", help="Image height")] = 1024,
|
||||||
|
steps: Annotated[int, typer.Option("--steps", help="Sampling steps")] = 20,
|
||||||
|
cfg: Annotated[float, typer.Option("--cfg", help="CFG scale")] = 7.0,
|
||||||
|
seed: Annotated[int, typer.Option("--seed", "-s", help="Random seed (-1 for random)")] = -1,
|
||||||
|
sampler: Annotated[str, typer.Option("--sampler", help="Sampler name")] = "euler",
|
||||||
|
scheduler: Annotated[str, typer.Option("--scheduler", help="Scheduler name")] = "normal",
|
||||||
|
vae: Annotated[str | None, typer.Option("--vae", help="VAE model name")] = None,
|
||||||
|
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None,
|
||||||
|
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8,
|
||||||
|
negative: Annotated[str, typer.Option("-n", "--negative-prompt", help="Negative prompt")] = "",
|
||||||
|
output: Annotated[Path | None, typer.Option("-o", "--output", help="Save path (default: current dir)")] = None,
|
||||||
|
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||||
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
|
) -> None:
|
||||||
|
"""Generate an image using text-to-image.
|
||||||
|
|
||||||
|
Calls ComfyUI directly when local, or the remote tensors API when --remote is given.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
tsr generate "a cat on a windowsill"
|
||||||
|
tsr generate "portrait photo" -m "flux1-dev-fp8.safetensors" --steps 30
|
||||||
|
tsr generate "cyberpunk city" -o output.png
|
||||||
|
tsr generate "landscape" --remote junkpile
|
||||||
|
"""
|
||||||
|
import random as rng # noqa: PLC0415
|
||||||
|
|
||||||
|
from tensors.config import resolve_remote as do_resolve_remote # noqa: PLC0415
|
||||||
|
|
||||||
|
# Resolve remote (explicit flag, or default from config)
|
||||||
|
remote_url = do_resolve_remote(remote) if remote else do_resolve_remote(None)
|
||||||
|
|
||||||
|
if remote_url:
|
||||||
|
# ---- Remote mode: HTTP call to tensors server ----
|
||||||
|
if not json_output:
|
||||||
|
console.print(f"[dim]Remote: {remote_url}[/dim]")
|
||||||
|
|
||||||
|
result = remote_generate(
|
||||||
|
remote or remote_url,
|
||||||
|
prompt,
|
||||||
|
negative_prompt=negative,
|
||||||
|
model=model,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
steps=steps,
|
||||||
|
cfg=cfg,
|
||||||
|
seed=seed,
|
||||||
|
sampler=sampler,
|
||||||
|
scheduler=scheduler,
|
||||||
|
vae=vae,
|
||||||
|
lora_name=lora,
|
||||||
|
lora_strength=lora_strength,
|
||||||
|
console=console,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
if not json_output:
|
||||||
|
console.print("[red]Generation failed[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
console.print_json(data=result)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not result.get("success"):
|
||||||
|
console.print("[red]Generation failed[/red]")
|
||||||
|
errors = result.get("errors", {})
|
||||||
|
for node_id, err in errors.items():
|
||||||
|
console.print(f" [yellow]Node {node_id}:[/yellow] {err}")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
images = result.get("images", [])
|
||||||
|
console.print(f"[green]Generated {len(images)} image(s)[/green]")
|
||||||
|
console.print(f"[dim]Prompt ID: {result.get('prompt_id', 'N/A')}[/dim]")
|
||||||
|
|
||||||
|
# Download and save images if --output specified
|
||||||
|
if output and images:
|
||||||
|
for i, img_name in enumerate(images):
|
||||||
|
img_data = remote_get_image(remote or remote_url, img_name)
|
||||||
|
if img_data:
|
||||||
|
save_path = output if len(images) == 1 else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}"
|
||||||
|
save_path.write_bytes(img_data)
|
||||||
|
console.print(f"[green]Saved:[/green] {save_path}")
|
||||||
|
else:
|
||||||
|
console.print(f"[yellow]Could not download image: {img_name}[/yellow]")
|
||||||
|
elif images:
|
||||||
|
for img_name in images:
|
||||||
|
console.print(f" [dim]{img_name}[/dim]")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# ---- Local mode: direct library call ----
|
||||||
|
from tensors.comfyui import generate_image, get_image # noqa: PLC0415
|
||||||
|
|
||||||
|
actual_seed = seed if seed >= 0 else rng.randint(0, 2**32 - 1)
|
||||||
|
|
||||||
|
result_local = generate_image(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative,
|
||||||
|
model=model,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
steps=steps,
|
||||||
|
cfg=cfg,
|
||||||
|
seed=actual_seed,
|
||||||
|
sampler=sampler,
|
||||||
|
scheduler=scheduler,
|
||||||
|
console=console if not json_output else None,
|
||||||
|
lora_name=lora,
|
||||||
|
lora_strength=lora_strength,
|
||||||
|
vae=vae,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result_local:
|
||||||
|
if json_output:
|
||||||
|
console.print_json(data={"success": False, "errors": {"generation": "Failed to generate"}})
|
||||||
|
else:
|
||||||
|
console.print("[red]Generation failed[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
if not result_local.success:
|
||||||
|
if json_output:
|
||||||
|
console.print_json(data={"success": False, "errors": result_local.node_errors})
|
||||||
|
else:
|
||||||
|
console.print("[red]Generation failed[/red]")
|
||||||
|
for node_id, errors in result_local.node_errors.items():
|
||||||
|
console.print(f" [yellow]Node {node_id}:[/yellow] {errors}")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
# Save images
|
||||||
|
saved_paths: list[Path] = []
|
||||||
|
for i, img_path in enumerate(result_local.images):
|
||||||
|
if output:
|
||||||
|
img_data = get_image(str(img_path))
|
||||||
|
if img_data:
|
||||||
|
save_path = (
|
||||||
|
output if len(result_local.images) == 1 else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}"
|
||||||
|
)
|
||||||
|
save_path.write_bytes(img_data)
|
||||||
|
saved_paths.append(save_path)
|
||||||
|
if not json_output:
|
||||||
|
console.print(f"[green]Saved:[/green] {save_path}")
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
console.print_json(
|
||||||
|
data={
|
||||||
|
"success": True,
|
||||||
|
"prompt_id": result_local.prompt_id,
|
||||||
|
"images": [str(p) for p in result_local.images],
|
||||||
|
"saved": [str(p) for p in saved_paths],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
console.print("[bold green]Generation complete![/bold green]")
|
||||||
|
console.print(f"[dim]Prompt ID: {result_local.prompt_id}[/dim]")
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Top-Level Models Command
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def models(
|
||||||
|
model_type: Annotated[str | None, typer.Option("-t", "--type", help="Filter by type (checkpoints, loras, vae)")] = None,
|
||||||
|
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||||
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
|
) -> None:
|
||||||
|
"""List available models from ComfyUI.
|
||||||
|
|
||||||
|
Shows checkpoints, LoRAs, VAEs, and other model types loaded in ComfyUI.
|
||||||
|
Uses --remote to query a remote tensors server instead of local ComfyUI.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
tsr models
|
||||||
|
tsr models -t checkpoints
|
||||||
|
tsr models --remote junkpile
|
||||||
|
tsr models --json
|
||||||
|
"""
|
||||||
|
from tensors.config import resolve_remote as do_resolve_remote # noqa: PLC0415
|
||||||
|
|
||||||
|
remote_url = do_resolve_remote(remote) if remote else do_resolve_remote(None)
|
||||||
|
|
||||||
|
if remote_url:
|
||||||
|
if not json_output:
|
||||||
|
console.print(f"[dim]Remote: {remote_url}[/dim]")
|
||||||
|
result = remote_models(remote or remote_url, console=console)
|
||||||
|
else:
|
||||||
|
from tensors.comfyui import get_loaded_models # noqa: PLC0415
|
||||||
|
|
||||||
|
result = get_loaded_models(console=console if not json_output else None)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
console.print("[red]Error: Could not fetch models[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
# Filter by type if requested
|
||||||
|
if model_type:
|
||||||
|
key = model_type.lower()
|
||||||
|
filtered = {k: v for k, v in result.items() if k.lower() == key}
|
||||||
|
if not filtered:
|
||||||
|
console.print(f"[yellow]No models found for type '{model_type}'[/yellow]")
|
||||||
|
console.print(f"[dim]Available types: {', '.join(sorted(result.keys()))}[/dim]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
result = filtered
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
console.print_json(data=result)
|
||||||
|
return
|
||||||
|
|
||||||
|
console.print("[bold cyan]Available Models[/bold cyan]")
|
||||||
|
|
||||||
|
for mtype, model_list in sorted(result.items()):
|
||||||
|
console.print()
|
||||||
|
console.print(f"[bold]{mtype}:[/bold] ({len(model_list)})")
|
||||||
|
for name in model_list[:MAX_MODEL_LIST_DISPLAY]:
|
||||||
|
console.print(f" {name}")
|
||||||
|
if len(model_list) > MAX_MODEL_LIST_DISPLAY:
|
||||||
|
console.print(f" ... and {len(model_list) - MAX_MODEL_LIST_DISPLAY} more")
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Database Commands
|
# Database Commands
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -1455,6 +1804,8 @@ def main() -> int:
|
|||||||
"get",
|
"get",
|
||||||
"dl",
|
"dl",
|
||||||
"download",
|
"download",
|
||||||
|
"generate",
|
||||||
|
"models",
|
||||||
"config",
|
"config",
|
||||||
"serve",
|
"serve",
|
||||||
"db",
|
"db",
|
||||||
|
|||||||
@@ -0,0 +1,300 @@
|
|||||||
|
"""HTTP client for calling tensors server API remotely."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from rich.console import Console
|
||||||
|
|
||||||
|
from tensors.config import get_server_api_key, resolve_remote
|
||||||
|
|
||||||
|
|
||||||
|
def _build_client(base_url: str, timeout: float = 300.0) -> httpx.Client:
|
||||||
|
"""Build an httpx client with API key auth."""
|
||||||
|
api_key = get_server_api_key()
|
||||||
|
headers: dict[str, str] = {}
|
||||||
|
if api_key:
|
||||||
|
headers["X-API-Key"] = api_key
|
||||||
|
return httpx.Client(base_url=base_url, headers=headers, timeout=timeout)
|
||||||
|
|
||||||
|
|
||||||
|
def remote_generate(
|
||||||
|
remote: str,
|
||||||
|
prompt: str,
|
||||||
|
*,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
model: str | None = None,
|
||||||
|
width: int = 1024,
|
||||||
|
height: int = 1024,
|
||||||
|
steps: int = 20,
|
||||||
|
cfg: float = 7.0,
|
||||||
|
seed: int = -1,
|
||||||
|
sampler: str = "euler",
|
||||||
|
scheduler: str = "normal",
|
||||||
|
vae: str | None = None,
|
||||||
|
lora_name: str | None = None,
|
||||||
|
lora_strength: float = 0.8,
|
||||||
|
console: Console | None = None,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Generate an image via remote tensors server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote: Remote name or URL (resolved via config)
|
||||||
|
prompt: Positive prompt text
|
||||||
|
console: Rich console for error output
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response dict with success, prompt_id, images, errors — or None on connection error
|
||||||
|
"""
|
||||||
|
base_url = resolve_remote(remote)
|
||||||
|
if not base_url:
|
||||||
|
if console:
|
||||||
|
console.print("[red]Error: Could not resolve remote server[/red]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"negative_prompt": negative_prompt,
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"steps": steps,
|
||||||
|
"cfg": cfg,
|
||||||
|
"seed": seed,
|
||||||
|
"sampler": sampler,
|
||||||
|
"scheduler": scheduler,
|
||||||
|
"lora_strength": lora_strength,
|
||||||
|
}
|
||||||
|
if model:
|
||||||
|
payload["model"] = model
|
||||||
|
if vae:
|
||||||
|
payload["vae"] = vae
|
||||||
|
if lora_name:
|
||||||
|
payload["lora_name"] = lora_name
|
||||||
|
|
||||||
|
try:
|
||||||
|
with _build_client(base_url) as client:
|
||||||
|
response = client.post("/api/comfyui/generate", json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
result: dict[str, Any] = response.json()
|
||||||
|
return result
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
if console:
|
||||||
|
console.print(f"[red]Remote API error: {e.response.status_code}[/red]")
|
||||||
|
try:
|
||||||
|
detail = e.response.json().get("detail", "")
|
||||||
|
if detail:
|
||||||
|
console.print(f" [yellow]{detail}[/yellow]")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
if console:
|
||||||
|
console.print(f"[red]Remote connection error: {e}[/red]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def remote_get_image(remote: str, filename: str) -> bytes | None:
|
||||||
|
"""Download a generated image from remote tensors server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote: Remote name or URL
|
||||||
|
filename: Image filename from generation result
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image bytes or None on error
|
||||||
|
"""
|
||||||
|
base_url = resolve_remote(remote)
|
||||||
|
if not base_url:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with _build_client(base_url) as client:
|
||||||
|
response = client.get("/api/comfyui/image/" + filename)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.content
|
||||||
|
except (httpx.HTTPStatusError, httpx.RequestError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def remote_models(
|
||||||
|
remote: str,
|
||||||
|
console: Console | None = None,
|
||||||
|
) -> dict[str, list[str]] | None:
|
||||||
|
"""List available models from remote tensors server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote: Remote name or URL
|
||||||
|
console: Rich console for error output
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping model type to list of model names, or None on error
|
||||||
|
"""
|
||||||
|
base_url = resolve_remote(remote)
|
||||||
|
if not base_url:
|
||||||
|
if console:
|
||||||
|
console.print("[red]Error: Could not resolve remote server[/red]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with _build_client(base_url) as client:
|
||||||
|
response = client.get("/api/comfyui/models")
|
||||||
|
response.raise_for_status()
|
||||||
|
result: dict[str, list[str]] = response.json()
|
||||||
|
return result
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
if console:
|
||||||
|
console.print(f"[red]Remote API error: {e.response.status_code}[/red]")
|
||||||
|
return None
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
if console:
|
||||||
|
console.print(f"[red]Remote connection error: {e}[/red]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def remote_search(
|
||||||
|
remote: str,
|
||||||
|
*,
|
||||||
|
query: str | None = None,
|
||||||
|
model_type: str | None = None,
|
||||||
|
base_model: str | None = None,
|
||||||
|
sort: str = "downloads",
|
||||||
|
limit: int = 20,
|
||||||
|
page: int | None = None,
|
||||||
|
nsfw: str | None = None,
|
||||||
|
sfw: bool = False,
|
||||||
|
console: Console | None = None,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Search CivitAI models via remote tensors server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote: Remote name or URL
|
||||||
|
console: Rich console for error output
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Search results dict or None on error
|
||||||
|
"""
|
||||||
|
base_url = resolve_remote(remote)
|
||||||
|
if not base_url:
|
||||||
|
if console:
|
||||||
|
console.print("[red]Error: Could not resolve remote server[/red]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"provider": "civitai",
|
||||||
|
"sort": sort,
|
||||||
|
"limit": limit,
|
||||||
|
}
|
||||||
|
if query:
|
||||||
|
params["query"] = query
|
||||||
|
if model_type:
|
||||||
|
params["types"] = model_type
|
||||||
|
if base_model:
|
||||||
|
params["baseModels"] = base_model
|
||||||
|
if page:
|
||||||
|
params["page"] = page
|
||||||
|
if sfw:
|
||||||
|
params["sfw"] = True
|
||||||
|
elif nsfw:
|
||||||
|
params["nsfw"] = nsfw
|
||||||
|
|
||||||
|
try:
|
||||||
|
with _build_client(base_url) as client:
|
||||||
|
response = client.get("/api/search", params=params)
|
||||||
|
response.raise_for_status()
|
||||||
|
result: dict[str, Any] = response.json()
|
||||||
|
# The remote API wraps CivitAI results under "civitai" key
|
||||||
|
return result.get("civitai", result)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
if console:
|
||||||
|
console.print(f"[red]Remote API error: {e.response.status_code}[/red]")
|
||||||
|
return None
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
if console:
|
||||||
|
console.print(f"[red]Remote connection error: {e}[/red]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def remote_download(
|
||||||
|
remote: str,
|
||||||
|
*,
|
||||||
|
version_id: int | None = None,
|
||||||
|
model_id: int | None = None,
|
||||||
|
output_dir: str | None = None,
|
||||||
|
console: Console | None = None,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Start a model download on remote tensors server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote: Remote name or URL
|
||||||
|
version_id: CivitAI version ID
|
||||||
|
model_id: CivitAI model ID (downloads latest version)
|
||||||
|
output_dir: Override output directory on the remote
|
||||||
|
console: Rich console for error output
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Download status dict with download_id, or None on error
|
||||||
|
"""
|
||||||
|
base_url = resolve_remote(remote)
|
||||||
|
if not base_url:
|
||||||
|
if console:
|
||||||
|
console.print("[red]Error: Could not resolve remote server[/red]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
payload: dict[str, Any] = {}
|
||||||
|
if version_id:
|
||||||
|
payload["version_id"] = version_id
|
||||||
|
if model_id:
|
||||||
|
payload["model_id"] = model_id
|
||||||
|
if output_dir:
|
||||||
|
payload["output_dir"] = output_dir
|
||||||
|
|
||||||
|
try:
|
||||||
|
with _build_client(base_url) as client:
|
||||||
|
response = client.post("/api/download", json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
result: dict[str, Any] = response.json()
|
||||||
|
return result
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
if console:
|
||||||
|
console.print(f"[red]Remote API error: {e.response.status_code}[/red]")
|
||||||
|
try:
|
||||||
|
detail = e.response.json().get("detail", "")
|
||||||
|
if detail:
|
||||||
|
console.print(f" [yellow]{detail}[/yellow]")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
if console:
|
||||||
|
console.print(f"[red]Remote connection error: {e}[/red]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def remote_download_status(
|
||||||
|
remote: str,
|
||||||
|
download_id: str,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Check download status on remote tensors server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote: Remote name or URL
|
||||||
|
download_id: Download ID from start_download response
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Download status dict or None on error
|
||||||
|
"""
|
||||||
|
base_url = resolve_remote(remote)
|
||||||
|
if not base_url:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with _build_client(base_url) as client:
|
||||||
|
response = client.get(f"/api/download/status/{download_id}")
|
||||||
|
response.raise_for_status()
|
||||||
|
result: dict[str, Any] = response.json()
|
||||||
|
return result
|
||||||
|
except (httpx.HTTPStatusError, httpx.RequestError):
|
||||||
|
return None
|
||||||
Reference in New Issue
Block a user