From 4a2fdce115019bc63f0048bc06508cf1c9b0ca0a Mon Sep 17 00:00:00 2001 From: aladac Date: Mon, 6 Apr 2026 01:26:55 +0200 Subject: [PATCH] Add top-level generate and models commands with --remote support Add `tsr generate` and `tsr models` as top-level CLI commands that call ComfyUI library functions directly or HTTP to a remote tensors server. Add `--remote` flag to existing `tsr search` and `tsr dl` commands. New file `tensors/remote.py` provides HTTP client functions for all four operations against the remote tensors API (generate, models, search, download with progress polling). Co-Authored-By: Claude Opus 4.6 (1M context) --- tensors/cli.py | 353 +++++++++++++++++++++++++++++++++++++++++++++- tensors/remote.py | 300 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 652 insertions(+), 1 deletion(-) create mode 100644 tensors/remote.py diff --git a/tensors/cli.py b/tensors/cli.py index 6cea831..e64c2aa 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -54,6 +54,14 @@ from tensors.hf import ( list_safetensor_files, search_hf_models, ) +from tensors.remote import ( + remote_download, + remote_download_status, + remote_generate, + remote_get_image, + remote_models, + remote_search, +) from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_metadata # Key masking threshold @@ -232,6 +240,7 @@ def search( pipeline: Annotated[str | None, typer.Option("--pipeline", help="Pipeline tag (HuggingFace)")] = None, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, + remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, ) -> None: """Search models on CivitAI and/or Hugging Face. @@ -242,7 +251,32 @@ def search( tsr search -t lora -b pony # CivitAI LoRAs for Pony tsr search -a stabilityai -P hf # HF by author tsr search --sfw -P civitai # CivitAI SFW only + tsr search "pony" --remote junkpile # Search via remote server """ + # Remote mode: delegate to remote tensors server + if remote: + civitai_results = remote_search( + remote, + query=query, + model_type=model_type.to_api() if model_type else None, + base_model=base.to_api() if base else None, + sort=sort.value, + limit=limit, + page=page, + nsfw=nsfw.value if nsfw else None, + sfw=sfw, + console=console, + ) + if not civitai_results: + console.print("[red]Remote search failed.[/red]") + raise typer.Exit(1) + + if json_output: + console.print_json(data={"civitai": civitai_results}) + else: + display_search_results(civitai_results, console) + return + key = api_key or load_api_key() civitai_results: dict[str, Any] | None = None hf_results: list[dict[str, Any]] | None = None @@ -429,6 +463,78 @@ def _prepare_download_dir(output: Path | None, model_type_str: str | None) -> Pa return output_dir +def _poll_remote_download(remote_name: str, download_id: str) -> None: + """Poll a remote download for completion with a progress bar.""" + import time # noqa: PLC0415 + + from rich.progress import BarColumn, DownloadColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn # noqa: PLC0415 + + status: dict[str, Any] | None = None + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + DownloadColumn(), + console=console, + ) as progress: + task = progress.add_task("[cyan]Downloading...", total=100) + + while True: + status = remote_download_status(remote_name, download_id) + if not status: + break + + dl_status = status.get("status", "") + pct = status.get("progress", 0) + progress.update(task, completed=pct, description=f"[cyan]{dl_status.title()}...") + + if dl_status in ("completed", "failed"): + break + + time.sleep(1) + + if status and status.get("status") == "completed": + console.print(f"[green]Download complete:[/green] {status.get('path', 'N/A')}") + elif status and status.get("status") == "failed": + console.print(f"[red]Download failed:[/red] {status.get('error', 'Unknown error')}") + + +def _download_remote( + remote_name: str, + version_id: int | None, + model_id: int | None, + hash_val: str | None, + output: Path | None, +) -> None: + """Handle remote download flow.""" + if not version_id and not model_id: + if hash_val: + console.print("[yellow]Remote download does not support --hash. Use --version-id or --model-id.[/yellow]") + else: + console.print("[red]Error: Must specify --version-id or --model-id for remote download[/red]") + raise typer.Exit(1) + + console.print("[dim]Starting download on remote server...[/dim]") + result = remote_download( + remote_name, + version_id=version_id, + model_id=model_id, + output_dir=str(output) if output else None, + console=console, + ) + if not result: + raise typer.Exit(1) + + console.print(f"[green]Download started:[/green] {result.get('model_name', 'N/A')}") + console.print(f"[dim]Version: {result.get('version_name', 'N/A')}[/dim]") + console.print(f"[dim]Destination: {result.get('destination', 'N/A')}[/dim]") + + download_id = result.get("download_id") + if download_id: + _poll_remote_download(remote_name, download_id) + + @app.command("dl") def download( version_id: Annotated[int | None, typer.Option("-v", "--version-id", help="Model version ID")] = None, @@ -437,8 +543,21 @@ def download( output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None, no_resume: Annotated[bool, typer.Option("--no-resume", help="Don't resume partial downloads")] = False, api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, + remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, ) -> None: - """Download a model from CivitAI.""" + """Download a model from CivitAI. + + When --remote is specified, the download happens on the remote server. + + Examples: + tsr dl -v 12345 # Download by version ID + tsr dl -m 67890 # Download latest version of model + tsr dl -v 12345 --remote junkpile # Download on remote server + """ + if remote: + _download_remote(remote, version_id, model_id, hash_val, output) + return + key = api_key or load_api_key() resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key) @@ -645,6 +764,236 @@ def serve( uvicorn.run(create_app(), host=host, port=port, log_level=log_level) +# ============================================================================= +# Top-Level Generate Command +# ============================================================================= + + +@app.command() +def generate( # noqa: PLR0915 + prompt: Annotated[str, typer.Argument(help="Positive prompt text")], + model: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model name")] = None, + width: Annotated[int, typer.Option("-W", "--width", help="Image width")] = 1024, + height: Annotated[int, typer.Option("-H", "--height", help="Image height")] = 1024, + steps: Annotated[int, typer.Option("--steps", help="Sampling steps")] = 20, + cfg: Annotated[float, typer.Option("--cfg", help="CFG scale")] = 7.0, + seed: Annotated[int, typer.Option("--seed", "-s", help="Random seed (-1 for random)")] = -1, + sampler: Annotated[str, typer.Option("--sampler", help="Sampler name")] = "euler", + scheduler: Annotated[str, typer.Option("--scheduler", help="Scheduler name")] = "normal", + vae: Annotated[str | None, typer.Option("--vae", help="VAE model name")] = None, + lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None, + lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8, + negative: Annotated[str, typer.Option("-n", "--negative-prompt", help="Negative prompt")] = "", + output: Annotated[Path | None, typer.Option("-o", "--output", help="Save path (default: current dir)")] = None, + remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """Generate an image using text-to-image. + + Calls ComfyUI directly when local, or the remote tensors API when --remote is given. + + Examples: + tsr generate "a cat on a windowsill" + tsr generate "portrait photo" -m "flux1-dev-fp8.safetensors" --steps 30 + tsr generate "cyberpunk city" -o output.png + tsr generate "landscape" --remote junkpile + """ + import random as rng # noqa: PLC0415 + + from tensors.config import resolve_remote as do_resolve_remote # noqa: PLC0415 + + # Resolve remote (explicit flag, or default from config) + remote_url = do_resolve_remote(remote) if remote else do_resolve_remote(None) + + if remote_url: + # ---- Remote mode: HTTP call to tensors server ---- + if not json_output: + console.print(f"[dim]Remote: {remote_url}[/dim]") + + result = remote_generate( + remote or remote_url, + prompt, + negative_prompt=negative, + model=model, + width=width, + height=height, + steps=steps, + cfg=cfg, + seed=seed, + sampler=sampler, + scheduler=scheduler, + vae=vae, + lora_name=lora, + lora_strength=lora_strength, + console=console, + ) + + if not result: + if not json_output: + console.print("[red]Generation failed[/red]") + raise typer.Exit(1) + + if json_output: + console.print_json(data=result) + return + + if not result.get("success"): + console.print("[red]Generation failed[/red]") + errors = result.get("errors", {}) + for node_id, err in errors.items(): + console.print(f" [yellow]Node {node_id}:[/yellow] {err}") + raise typer.Exit(1) + + images = result.get("images", []) + console.print(f"[green]Generated {len(images)} image(s)[/green]") + console.print(f"[dim]Prompt ID: {result.get('prompt_id', 'N/A')}[/dim]") + + # Download and save images if --output specified + if output and images: + for i, img_name in enumerate(images): + img_data = remote_get_image(remote or remote_url, img_name) + if img_data: + save_path = output if len(images) == 1 else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}" + save_path.write_bytes(img_data) + console.print(f"[green]Saved:[/green] {save_path}") + else: + console.print(f"[yellow]Could not download image: {img_name}[/yellow]") + elif images: + for img_name in images: + console.print(f" [dim]{img_name}[/dim]") + + else: + # ---- Local mode: direct library call ---- + from tensors.comfyui import generate_image, get_image # noqa: PLC0415 + + actual_seed = seed if seed >= 0 else rng.randint(0, 2**32 - 1) + + result_local = generate_image( + prompt=prompt, + negative_prompt=negative, + model=model, + width=width, + height=height, + steps=steps, + cfg=cfg, + seed=actual_seed, + sampler=sampler, + scheduler=scheduler, + console=console if not json_output else None, + lora_name=lora, + lora_strength=lora_strength, + vae=vae, + ) + + if not result_local: + if json_output: + console.print_json(data={"success": False, "errors": {"generation": "Failed to generate"}}) + else: + console.print("[red]Generation failed[/red]") + raise typer.Exit(1) + + if not result_local.success: + if json_output: + console.print_json(data={"success": False, "errors": result_local.node_errors}) + else: + console.print("[red]Generation failed[/red]") + for node_id, errors in result_local.node_errors.items(): + console.print(f" [yellow]Node {node_id}:[/yellow] {errors}") + raise typer.Exit(1) + + # Save images + saved_paths: list[Path] = [] + for i, img_path in enumerate(result_local.images): + if output: + img_data = get_image(str(img_path)) + if img_data: + save_path = ( + output if len(result_local.images) == 1 else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}" + ) + save_path.write_bytes(img_data) + saved_paths.append(save_path) + if not json_output: + console.print(f"[green]Saved:[/green] {save_path}") + + if json_output: + console.print_json( + data={ + "success": True, + "prompt_id": result_local.prompt_id, + "images": [str(p) for p in result_local.images], + "saved": [str(p) for p in saved_paths], + } + ) + return + + console.print("[bold green]Generation complete![/bold green]") + console.print(f"[dim]Prompt ID: {result_local.prompt_id}[/dim]") + + +# ============================================================================= +# Top-Level Models Command +# ============================================================================= + + +@app.command() +def models( + model_type: Annotated[str | None, typer.Option("-t", "--type", help="Filter by type (checkpoints, loras, vae)")] = None, + remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """List available models from ComfyUI. + + Shows checkpoints, LoRAs, VAEs, and other model types loaded in ComfyUI. + Uses --remote to query a remote tensors server instead of local ComfyUI. + + Examples: + tsr models + tsr models -t checkpoints + tsr models --remote junkpile + tsr models --json + """ + from tensors.config import resolve_remote as do_resolve_remote # noqa: PLC0415 + + remote_url = do_resolve_remote(remote) if remote else do_resolve_remote(None) + + if remote_url: + if not json_output: + console.print(f"[dim]Remote: {remote_url}[/dim]") + result = remote_models(remote or remote_url, console=console) + else: + from tensors.comfyui import get_loaded_models # noqa: PLC0415 + + result = get_loaded_models(console=console if not json_output else None) + + if not result: + console.print("[red]Error: Could not fetch models[/red]") + raise typer.Exit(1) + + # Filter by type if requested + if model_type: + key = model_type.lower() + filtered = {k: v for k, v in result.items() if k.lower() == key} + if not filtered: + console.print(f"[yellow]No models found for type '{model_type}'[/yellow]") + console.print(f"[dim]Available types: {', '.join(sorted(result.keys()))}[/dim]") + raise typer.Exit(1) + result = filtered + + if json_output: + console.print_json(data=result) + return + + console.print("[bold cyan]Available Models[/bold cyan]") + + for mtype, model_list in sorted(result.items()): + console.print() + console.print(f"[bold]{mtype}:[/bold] ({len(model_list)})") + for name in model_list[:MAX_MODEL_LIST_DISPLAY]: + console.print(f" {name}") + if len(model_list) > MAX_MODEL_LIST_DISPLAY: + console.print(f" ... and {len(model_list) - MAX_MODEL_LIST_DISPLAY} more") + + # ============================================================================= # Database Commands # ============================================================================= @@ -1455,6 +1804,8 @@ def main() -> int: "get", "dl", "download", + "generate", + "models", "config", "serve", "db", diff --git a/tensors/remote.py b/tensors/remote.py new file mode 100644 index 0000000..bafabbb --- /dev/null +++ b/tensors/remote.py @@ -0,0 +1,300 @@ +"""HTTP client for calling tensors server API remotely.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import httpx + +if TYPE_CHECKING: + from rich.console import Console + +from tensors.config import get_server_api_key, resolve_remote + + +def _build_client(base_url: str, timeout: float = 300.0) -> httpx.Client: + """Build an httpx client with API key auth.""" + api_key = get_server_api_key() + headers: dict[str, str] = {} + if api_key: + headers["X-API-Key"] = api_key + return httpx.Client(base_url=base_url, headers=headers, timeout=timeout) + + +def remote_generate( + remote: str, + prompt: str, + *, + negative_prompt: str = "", + model: str | None = None, + width: int = 1024, + height: int = 1024, + steps: int = 20, + cfg: float = 7.0, + seed: int = -1, + sampler: str = "euler", + scheduler: str = "normal", + vae: str | None = None, + lora_name: str | None = None, + lora_strength: float = 0.8, + console: Console | None = None, +) -> dict[str, Any] | None: + """Generate an image via remote tensors server. + + Args: + remote: Remote name or URL (resolved via config) + prompt: Positive prompt text + console: Rich console for error output + + Returns: + Response dict with success, prompt_id, images, errors — or None on connection error + """ + base_url = resolve_remote(remote) + if not base_url: + if console: + console.print("[red]Error: Could not resolve remote server[/red]") + return None + + payload: dict[str, Any] = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "width": width, + "height": height, + "steps": steps, + "cfg": cfg, + "seed": seed, + "sampler": sampler, + "scheduler": scheduler, + "lora_strength": lora_strength, + } + if model: + payload["model"] = model + if vae: + payload["vae"] = vae + if lora_name: + payload["lora_name"] = lora_name + + try: + with _build_client(base_url) as client: + response = client.post("/api/comfyui/generate", json=payload) + response.raise_for_status() + result: dict[str, Any] = response.json() + return result + except httpx.HTTPStatusError as e: + if console: + console.print(f"[red]Remote API error: {e.response.status_code}[/red]") + try: + detail = e.response.json().get("detail", "") + if detail: + console.print(f" [yellow]{detail}[/yellow]") + except Exception: + pass + return None + except httpx.RequestError as e: + if console: + console.print(f"[red]Remote connection error: {e}[/red]") + return None + + +def remote_get_image(remote: str, filename: str) -> bytes | None: + """Download a generated image from remote tensors server. + + Args: + remote: Remote name or URL + filename: Image filename from generation result + + Returns: + Image bytes or None on error + """ + base_url = resolve_remote(remote) + if not base_url: + return None + + try: + with _build_client(base_url) as client: + response = client.get("/api/comfyui/image/" + filename) + response.raise_for_status() + return response.content + except (httpx.HTTPStatusError, httpx.RequestError): + return None + + +def remote_models( + remote: str, + console: Console | None = None, +) -> dict[str, list[str]] | None: + """List available models from remote tensors server. + + Args: + remote: Remote name or URL + console: Rich console for error output + + Returns: + Dict mapping model type to list of model names, or None on error + """ + base_url = resolve_remote(remote) + if not base_url: + if console: + console.print("[red]Error: Could not resolve remote server[/red]") + return None + + try: + with _build_client(base_url) as client: + response = client.get("/api/comfyui/models") + response.raise_for_status() + result: dict[str, list[str]] = response.json() + return result + except httpx.HTTPStatusError as e: + if console: + console.print(f"[red]Remote API error: {e.response.status_code}[/red]") + return None + except httpx.RequestError as e: + if console: + console.print(f"[red]Remote connection error: {e}[/red]") + return None + + +def remote_search( + remote: str, + *, + query: str | None = None, + model_type: str | None = None, + base_model: str | None = None, + sort: str = "downloads", + limit: int = 20, + page: int | None = None, + nsfw: str | None = None, + sfw: bool = False, + console: Console | None = None, +) -> dict[str, Any] | None: + """Search CivitAI models via remote tensors server. + + Args: + remote: Remote name or URL + console: Rich console for error output + + Returns: + Search results dict or None on error + """ + base_url = resolve_remote(remote) + if not base_url: + if console: + console.print("[red]Error: Could not resolve remote server[/red]") + return None + + params: dict[str, Any] = { + "provider": "civitai", + "sort": sort, + "limit": limit, + } + if query: + params["query"] = query + if model_type: + params["types"] = model_type + if base_model: + params["baseModels"] = base_model + if page: + params["page"] = page + if sfw: + params["sfw"] = True + elif nsfw: + params["nsfw"] = nsfw + + try: + with _build_client(base_url) as client: + response = client.get("/api/search", params=params) + response.raise_for_status() + result: dict[str, Any] = response.json() + # The remote API wraps CivitAI results under "civitai" key + return result.get("civitai", result) + except httpx.HTTPStatusError as e: + if console: + console.print(f"[red]Remote API error: {e.response.status_code}[/red]") + return None + except httpx.RequestError as e: + if console: + console.print(f"[red]Remote connection error: {e}[/red]") + return None + + +def remote_download( + remote: str, + *, + version_id: int | None = None, + model_id: int | None = None, + output_dir: str | None = None, + console: Console | None = None, +) -> dict[str, Any] | None: + """Start a model download on remote tensors server. + + Args: + remote: Remote name or URL + version_id: CivitAI version ID + model_id: CivitAI model ID (downloads latest version) + output_dir: Override output directory on the remote + console: Rich console for error output + + Returns: + Download status dict with download_id, or None on error + """ + base_url = resolve_remote(remote) + if not base_url: + if console: + console.print("[red]Error: Could not resolve remote server[/red]") + return None + + payload: dict[str, Any] = {} + if version_id: + payload["version_id"] = version_id + if model_id: + payload["model_id"] = model_id + if output_dir: + payload["output_dir"] = output_dir + + try: + with _build_client(base_url) as client: + response = client.post("/api/download", json=payload) + response.raise_for_status() + result: dict[str, Any] = response.json() + return result + except httpx.HTTPStatusError as e: + if console: + console.print(f"[red]Remote API error: {e.response.status_code}[/red]") + try: + detail = e.response.json().get("detail", "") + if detail: + console.print(f" [yellow]{detail}[/yellow]") + except Exception: + pass + return None + except httpx.RequestError as e: + if console: + console.print(f"[red]Remote connection error: {e}[/red]") + return None + + +def remote_download_status( + remote: str, + download_id: str, +) -> dict[str, Any] | None: + """Check download status on remote tensors server. + + Args: + remote: Remote name or URL + download_id: Download ID from start_download response + + Returns: + Download status dict or None on error + """ + base_url = resolve_remote(remote) + if not base_url: + return None + + try: + with _build_client(base_url) as client: + response = client.get(f"/api/download/status/{download_id}") + response.raise_for_status() + result: dict[str, Any] = response.json() + return result + except (httpx.HTTPStatusError, httpx.RequestError): + return None