Add top-level generate and models commands with --remote support

Add `tsr generate` and `tsr models` as top-level CLI commands that call
ComfyUI library functions directly or HTTP to a remote tensors server.
Add `--remote` flag to existing `tsr search` and `tsr dl` commands.

New file `tensors/remote.py` provides HTTP client functions for all four
operations against the remote tensors API (generate, models, search,
download with progress polling).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-06 01:26:55 +02:00
parent 56d5233962
commit 4a2fdce115
2 changed files with 652 additions and 1 deletions
+352 -1
View File
@@ -54,6 +54,14 @@ from tensors.hf import (
list_safetensor_files, list_safetensor_files,
search_hf_models, search_hf_models,
) )
from tensors.remote import (
remote_download,
remote_download_status,
remote_generate,
remote_get_image,
remote_models,
remote_search,
)
from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_metadata from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_metadata
# Key masking threshold # Key masking threshold
@@ -232,6 +240,7 @@ def search(
pipeline: Annotated[str | None, typer.Option("--pipeline", help="Pipeline tag (HuggingFace)")] = None, pipeline: Annotated[str | None, typer.Option("--pipeline", help="Pipeline tag (HuggingFace)")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
) -> None: ) -> None:
"""Search models on CivitAI and/or Hugging Face. """Search models on CivitAI and/or Hugging Face.
@@ -242,7 +251,32 @@ def search(
tsr search -t lora -b pony # CivitAI LoRAs for Pony tsr search -t lora -b pony # CivitAI LoRAs for Pony
tsr search -a stabilityai -P hf # HF by author tsr search -a stabilityai -P hf # HF by author
tsr search --sfw -P civitai # CivitAI SFW only tsr search --sfw -P civitai # CivitAI SFW only
tsr search "pony" --remote junkpile # Search via remote server
""" """
# Remote mode: delegate to remote tensors server
if remote:
civitai_results = remote_search(
remote,
query=query,
model_type=model_type.to_api() if model_type else None,
base_model=base.to_api() if base else None,
sort=sort.value,
limit=limit,
page=page,
nsfw=nsfw.value if nsfw else None,
sfw=sfw,
console=console,
)
if not civitai_results:
console.print("[red]Remote search failed.[/red]")
raise typer.Exit(1)
if json_output:
console.print_json(data={"civitai": civitai_results})
else:
display_search_results(civitai_results, console)
return
key = api_key or load_api_key() key = api_key or load_api_key()
civitai_results: dict[str, Any] | None = None civitai_results: dict[str, Any] | None = None
hf_results: list[dict[str, Any]] | None = None hf_results: list[dict[str, Any]] | None = None
@@ -429,6 +463,78 @@ def _prepare_download_dir(output: Path | None, model_type_str: str | None) -> Pa
return output_dir return output_dir
def _poll_remote_download(remote_name: str, download_id: str) -> None:
"""Poll a remote download for completion with a progress bar."""
import time # noqa: PLC0415
from rich.progress import BarColumn, DownloadColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn # noqa: PLC0415
status: dict[str, Any] | None = None
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
DownloadColumn(),
console=console,
) as progress:
task = progress.add_task("[cyan]Downloading...", total=100)
while True:
status = remote_download_status(remote_name, download_id)
if not status:
break
dl_status = status.get("status", "")
pct = status.get("progress", 0)
progress.update(task, completed=pct, description=f"[cyan]{dl_status.title()}...")
if dl_status in ("completed", "failed"):
break
time.sleep(1)
if status and status.get("status") == "completed":
console.print(f"[green]Download complete:[/green] {status.get('path', 'N/A')}")
elif status and status.get("status") == "failed":
console.print(f"[red]Download failed:[/red] {status.get('error', 'Unknown error')}")
def _download_remote(
remote_name: str,
version_id: int | None,
model_id: int | None,
hash_val: str | None,
output: Path | None,
) -> None:
"""Handle remote download flow."""
if not version_id and not model_id:
if hash_val:
console.print("[yellow]Remote download does not support --hash. Use --version-id or --model-id.[/yellow]")
else:
console.print("[red]Error: Must specify --version-id or --model-id for remote download[/red]")
raise typer.Exit(1)
console.print("[dim]Starting download on remote server...[/dim]")
result = remote_download(
remote_name,
version_id=version_id,
model_id=model_id,
output_dir=str(output) if output else None,
console=console,
)
if not result:
raise typer.Exit(1)
console.print(f"[green]Download started:[/green] {result.get('model_name', 'N/A')}")
console.print(f"[dim]Version: {result.get('version_name', 'N/A')}[/dim]")
console.print(f"[dim]Destination: {result.get('destination', 'N/A')}[/dim]")
download_id = result.get("download_id")
if download_id:
_poll_remote_download(remote_name, download_id)
@app.command("dl") @app.command("dl")
def download( def download(
version_id: Annotated[int | None, typer.Option("-v", "--version-id", help="Model version ID")] = None, version_id: Annotated[int | None, typer.Option("-v", "--version-id", help="Model version ID")] = None,
@@ -437,8 +543,21 @@ def download(
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None, output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None,
no_resume: Annotated[bool, typer.Option("--no-resume", help="Don't resume partial downloads")] = False, no_resume: Annotated[bool, typer.Option("--no-resume", help="Don't resume partial downloads")] = False,
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
) -> None: ) -> None:
"""Download a model from CivitAI.""" """Download a model from CivitAI.
When --remote is specified, the download happens on the remote server.
Examples:
tsr dl -v 12345 # Download by version ID
tsr dl -m 67890 # Download latest version of model
tsr dl -v 12345 --remote junkpile # Download on remote server
"""
if remote:
_download_remote(remote, version_id, model_id, hash_val, output)
return
key = api_key or load_api_key() key = api_key or load_api_key()
resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key) resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key)
@@ -645,6 +764,236 @@ def serve(
uvicorn.run(create_app(), host=host, port=port, log_level=log_level) uvicorn.run(create_app(), host=host, port=port, log_level=log_level)
# =============================================================================
# Top-Level Generate Command
# =============================================================================
@app.command()
def generate( # noqa: PLR0915
prompt: Annotated[str, typer.Argument(help="Positive prompt text")],
model: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model name")] = None,
width: Annotated[int, typer.Option("-W", "--width", help="Image width")] = 1024,
height: Annotated[int, typer.Option("-H", "--height", help="Image height")] = 1024,
steps: Annotated[int, typer.Option("--steps", help="Sampling steps")] = 20,
cfg: Annotated[float, typer.Option("--cfg", help="CFG scale")] = 7.0,
seed: Annotated[int, typer.Option("--seed", "-s", help="Random seed (-1 for random)")] = -1,
sampler: Annotated[str, typer.Option("--sampler", help="Sampler name")] = "euler",
scheduler: Annotated[str, typer.Option("--scheduler", help="Scheduler name")] = "normal",
vae: Annotated[str | None, typer.Option("--vae", help="VAE model name")] = None,
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None,
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8,
negative: Annotated[str, typer.Option("-n", "--negative-prompt", help="Negative prompt")] = "",
output: Annotated[Path | None, typer.Option("-o", "--output", help="Save path (default: current dir)")] = None,
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Generate an image using text-to-image.
Calls ComfyUI directly when local, or the remote tensors API when --remote is given.
Examples:
tsr generate "a cat on a windowsill"
tsr generate "portrait photo" -m "flux1-dev-fp8.safetensors" --steps 30
tsr generate "cyberpunk city" -o output.png
tsr generate "landscape" --remote junkpile
"""
import random as rng # noqa: PLC0415
from tensors.config import resolve_remote as do_resolve_remote # noqa: PLC0415
# Resolve remote (explicit flag, or default from config)
remote_url = do_resolve_remote(remote) if remote else do_resolve_remote(None)
if remote_url:
# ---- Remote mode: HTTP call to tensors server ----
if not json_output:
console.print(f"[dim]Remote: {remote_url}[/dim]")
result = remote_generate(
remote or remote_url,
prompt,
negative_prompt=negative,
model=model,
width=width,
height=height,
steps=steps,
cfg=cfg,
seed=seed,
sampler=sampler,
scheduler=scheduler,
vae=vae,
lora_name=lora,
lora_strength=lora_strength,
console=console,
)
if not result:
if not json_output:
console.print("[red]Generation failed[/red]")
raise typer.Exit(1)
if json_output:
console.print_json(data=result)
return
if not result.get("success"):
console.print("[red]Generation failed[/red]")
errors = result.get("errors", {})
for node_id, err in errors.items():
console.print(f" [yellow]Node {node_id}:[/yellow] {err}")
raise typer.Exit(1)
images = result.get("images", [])
console.print(f"[green]Generated {len(images)} image(s)[/green]")
console.print(f"[dim]Prompt ID: {result.get('prompt_id', 'N/A')}[/dim]")
# Download and save images if --output specified
if output and images:
for i, img_name in enumerate(images):
img_data = remote_get_image(remote or remote_url, img_name)
if img_data:
save_path = output if len(images) == 1 else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}"
save_path.write_bytes(img_data)
console.print(f"[green]Saved:[/green] {save_path}")
else:
console.print(f"[yellow]Could not download image: {img_name}[/yellow]")
elif images:
for img_name in images:
console.print(f" [dim]{img_name}[/dim]")
else:
# ---- Local mode: direct library call ----
from tensors.comfyui import generate_image, get_image # noqa: PLC0415
actual_seed = seed if seed >= 0 else rng.randint(0, 2**32 - 1)
result_local = generate_image(
prompt=prompt,
negative_prompt=negative,
model=model,
width=width,
height=height,
steps=steps,
cfg=cfg,
seed=actual_seed,
sampler=sampler,
scheduler=scheduler,
console=console if not json_output else None,
lora_name=lora,
lora_strength=lora_strength,
vae=vae,
)
if not result_local:
if json_output:
console.print_json(data={"success": False, "errors": {"generation": "Failed to generate"}})
else:
console.print("[red]Generation failed[/red]")
raise typer.Exit(1)
if not result_local.success:
if json_output:
console.print_json(data={"success": False, "errors": result_local.node_errors})
else:
console.print("[red]Generation failed[/red]")
for node_id, errors in result_local.node_errors.items():
console.print(f" [yellow]Node {node_id}:[/yellow] {errors}")
raise typer.Exit(1)
# Save images
saved_paths: list[Path] = []
for i, img_path in enumerate(result_local.images):
if output:
img_data = get_image(str(img_path))
if img_data:
save_path = (
output if len(result_local.images) == 1 else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}"
)
save_path.write_bytes(img_data)
saved_paths.append(save_path)
if not json_output:
console.print(f"[green]Saved:[/green] {save_path}")
if json_output:
console.print_json(
data={
"success": True,
"prompt_id": result_local.prompt_id,
"images": [str(p) for p in result_local.images],
"saved": [str(p) for p in saved_paths],
}
)
return
console.print("[bold green]Generation complete![/bold green]")
console.print(f"[dim]Prompt ID: {result_local.prompt_id}[/dim]")
# =============================================================================
# Top-Level Models Command
# =============================================================================
@app.command()
def models(
model_type: Annotated[str | None, typer.Option("-t", "--type", help="Filter by type (checkpoints, loras, vae)")] = None,
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""List available models from ComfyUI.
Shows checkpoints, LoRAs, VAEs, and other model types loaded in ComfyUI.
Uses --remote to query a remote tensors server instead of local ComfyUI.
Examples:
tsr models
tsr models -t checkpoints
tsr models --remote junkpile
tsr models --json
"""
from tensors.config import resolve_remote as do_resolve_remote # noqa: PLC0415
remote_url = do_resolve_remote(remote) if remote else do_resolve_remote(None)
if remote_url:
if not json_output:
console.print(f"[dim]Remote: {remote_url}[/dim]")
result = remote_models(remote or remote_url, console=console)
else:
from tensors.comfyui import get_loaded_models # noqa: PLC0415
result = get_loaded_models(console=console if not json_output else None)
if not result:
console.print("[red]Error: Could not fetch models[/red]")
raise typer.Exit(1)
# Filter by type if requested
if model_type:
key = model_type.lower()
filtered = {k: v for k, v in result.items() if k.lower() == key}
if not filtered:
console.print(f"[yellow]No models found for type '{model_type}'[/yellow]")
console.print(f"[dim]Available types: {', '.join(sorted(result.keys()))}[/dim]")
raise typer.Exit(1)
result = filtered
if json_output:
console.print_json(data=result)
return
console.print("[bold cyan]Available Models[/bold cyan]")
for mtype, model_list in sorted(result.items()):
console.print()
console.print(f"[bold]{mtype}:[/bold] ({len(model_list)})")
for name in model_list[:MAX_MODEL_LIST_DISPLAY]:
console.print(f" {name}")
if len(model_list) > MAX_MODEL_LIST_DISPLAY:
console.print(f" ... and {len(model_list) - MAX_MODEL_LIST_DISPLAY} more")
# ============================================================================= # =============================================================================
# Database Commands # Database Commands
# ============================================================================= # =============================================================================
@@ -1455,6 +1804,8 @@ def main() -> int:
"get", "get",
"dl", "dl",
"download", "download",
"generate",
"models",
"config", "config",
"serve", "serve",
"db", "db",
+300
View File
@@ -0,0 +1,300 @@
"""HTTP client for calling tensors server API remotely."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import httpx
if TYPE_CHECKING:
from rich.console import Console
from tensors.config import get_server_api_key, resolve_remote
def _build_client(base_url: str, timeout: float = 300.0) -> httpx.Client:
"""Build an httpx client with API key auth."""
api_key = get_server_api_key()
headers: dict[str, str] = {}
if api_key:
headers["X-API-Key"] = api_key
return httpx.Client(base_url=base_url, headers=headers, timeout=timeout)
def remote_generate(
remote: str,
prompt: str,
*,
negative_prompt: str = "",
model: str | None = None,
width: int = 1024,
height: int = 1024,
steps: int = 20,
cfg: float = 7.0,
seed: int = -1,
sampler: str = "euler",
scheduler: str = "normal",
vae: str | None = None,
lora_name: str | None = None,
lora_strength: float = 0.8,
console: Console | None = None,
) -> dict[str, Any] | None:
"""Generate an image via remote tensors server.
Args:
remote: Remote name or URL (resolved via config)
prompt: Positive prompt text
console: Rich console for error output
Returns:
Response dict with success, prompt_id, images, errors — or None on connection error
"""
base_url = resolve_remote(remote)
if not base_url:
if console:
console.print("[red]Error: Could not resolve remote server[/red]")
return None
payload: dict[str, Any] = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"width": width,
"height": height,
"steps": steps,
"cfg": cfg,
"seed": seed,
"sampler": sampler,
"scheduler": scheduler,
"lora_strength": lora_strength,
}
if model:
payload["model"] = model
if vae:
payload["vae"] = vae
if lora_name:
payload["lora_name"] = lora_name
try:
with _build_client(base_url) as client:
response = client.post("/api/comfyui/generate", json=payload)
response.raise_for_status()
result: dict[str, Any] = response.json()
return result
except httpx.HTTPStatusError as e:
if console:
console.print(f"[red]Remote API error: {e.response.status_code}[/red]")
try:
detail = e.response.json().get("detail", "")
if detail:
console.print(f" [yellow]{detail}[/yellow]")
except Exception:
pass
return None
except httpx.RequestError as e:
if console:
console.print(f"[red]Remote connection error: {e}[/red]")
return None
def remote_get_image(remote: str, filename: str) -> bytes | None:
"""Download a generated image from remote tensors server.
Args:
remote: Remote name or URL
filename: Image filename from generation result
Returns:
Image bytes or None on error
"""
base_url = resolve_remote(remote)
if not base_url:
return None
try:
with _build_client(base_url) as client:
response = client.get("/api/comfyui/image/" + filename)
response.raise_for_status()
return response.content
except (httpx.HTTPStatusError, httpx.RequestError):
return None
def remote_models(
remote: str,
console: Console | None = None,
) -> dict[str, list[str]] | None:
"""List available models from remote tensors server.
Args:
remote: Remote name or URL
console: Rich console for error output
Returns:
Dict mapping model type to list of model names, or None on error
"""
base_url = resolve_remote(remote)
if not base_url:
if console:
console.print("[red]Error: Could not resolve remote server[/red]")
return None
try:
with _build_client(base_url) as client:
response = client.get("/api/comfyui/models")
response.raise_for_status()
result: dict[str, list[str]] = response.json()
return result
except httpx.HTTPStatusError as e:
if console:
console.print(f"[red]Remote API error: {e.response.status_code}[/red]")
return None
except httpx.RequestError as e:
if console:
console.print(f"[red]Remote connection error: {e}[/red]")
return None
def remote_search(
remote: str,
*,
query: str | None = None,
model_type: str | None = None,
base_model: str | None = None,
sort: str = "downloads",
limit: int = 20,
page: int | None = None,
nsfw: str | None = None,
sfw: bool = False,
console: Console | None = None,
) -> dict[str, Any] | None:
"""Search CivitAI models via remote tensors server.
Args:
remote: Remote name or URL
console: Rich console for error output
Returns:
Search results dict or None on error
"""
base_url = resolve_remote(remote)
if not base_url:
if console:
console.print("[red]Error: Could not resolve remote server[/red]")
return None
params: dict[str, Any] = {
"provider": "civitai",
"sort": sort,
"limit": limit,
}
if query:
params["query"] = query
if model_type:
params["types"] = model_type
if base_model:
params["baseModels"] = base_model
if page:
params["page"] = page
if sfw:
params["sfw"] = True
elif nsfw:
params["nsfw"] = nsfw
try:
with _build_client(base_url) as client:
response = client.get("/api/search", params=params)
response.raise_for_status()
result: dict[str, Any] = response.json()
# The remote API wraps CivitAI results under "civitai" key
return result.get("civitai", result)
except httpx.HTTPStatusError as e:
if console:
console.print(f"[red]Remote API error: {e.response.status_code}[/red]")
return None
except httpx.RequestError as e:
if console:
console.print(f"[red]Remote connection error: {e}[/red]")
return None
def remote_download(
remote: str,
*,
version_id: int | None = None,
model_id: int | None = None,
output_dir: str | None = None,
console: Console | None = None,
) -> dict[str, Any] | None:
"""Start a model download on remote tensors server.
Args:
remote: Remote name or URL
version_id: CivitAI version ID
model_id: CivitAI model ID (downloads latest version)
output_dir: Override output directory on the remote
console: Rich console for error output
Returns:
Download status dict with download_id, or None on error
"""
base_url = resolve_remote(remote)
if not base_url:
if console:
console.print("[red]Error: Could not resolve remote server[/red]")
return None
payload: dict[str, Any] = {}
if version_id:
payload["version_id"] = version_id
if model_id:
payload["model_id"] = model_id
if output_dir:
payload["output_dir"] = output_dir
try:
with _build_client(base_url) as client:
response = client.post("/api/download", json=payload)
response.raise_for_status()
result: dict[str, Any] = response.json()
return result
except httpx.HTTPStatusError as e:
if console:
console.print(f"[red]Remote API error: {e.response.status_code}[/red]")
try:
detail = e.response.json().get("detail", "")
if detail:
console.print(f" [yellow]{detail}[/yellow]")
except Exception:
pass
return None
except httpx.RequestError as e:
if console:
console.print(f"[red]Remote connection error: {e}[/red]")
return None
def remote_download_status(
remote: str,
download_id: str,
) -> dict[str, Any] | None:
"""Check download status on remote tensors server.
Args:
remote: Remote name or URL
download_id: Download ID from start_download response
Returns:
Download status dict or None on error
"""
base_url = resolve_remote(remote)
if not base_url:
return None
try:
with _build_client(base_url) as client:
response = client.get(f"/api/download/status/{download_id}")
response.raise_for_status()
result: dict[str, Any] = response.json()
return result
except (httpx.HTTPStatusError, httpx.RequestError):
return None