Add top-level generate and models commands with --remote support
Add `tsr generate` and `tsr models` as top-level CLI commands that call ComfyUI library functions directly or HTTP to a remote tensors server. Add `--remote` flag to existing `tsr search` and `tsr dl` commands. New file `tensors/remote.py` provides HTTP client functions for all four operations against the remote tensors API (generate, models, search, download with progress polling). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
+352
-1
@@ -54,6 +54,14 @@ from tensors.hf import (
|
||||
list_safetensor_files,
|
||||
search_hf_models,
|
||||
)
|
||||
from tensors.remote import (
|
||||
remote_download,
|
||||
remote_download_status,
|
||||
remote_generate,
|
||||
remote_get_image,
|
||||
remote_models,
|
||||
remote_search,
|
||||
)
|
||||
from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_metadata
|
||||
|
||||
# Key masking threshold
|
||||
@@ -232,6 +240,7 @@ def search(
|
||||
pipeline: Annotated[str | None, typer.Option("--pipeline", help="Pipeline tag (HuggingFace)")] = None,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
) -> None:
|
||||
"""Search models on CivitAI and/or Hugging Face.
|
||||
|
||||
@@ -242,7 +251,32 @@ def search(
|
||||
tsr search -t lora -b pony # CivitAI LoRAs for Pony
|
||||
tsr search -a stabilityai -P hf # HF by author
|
||||
tsr search --sfw -P civitai # CivitAI SFW only
|
||||
tsr search "pony" --remote junkpile # Search via remote server
|
||||
"""
|
||||
# Remote mode: delegate to remote tensors server
|
||||
if remote:
|
||||
civitai_results = remote_search(
|
||||
remote,
|
||||
query=query,
|
||||
model_type=model_type.to_api() if model_type else None,
|
||||
base_model=base.to_api() if base else None,
|
||||
sort=sort.value,
|
||||
limit=limit,
|
||||
page=page,
|
||||
nsfw=nsfw.value if nsfw else None,
|
||||
sfw=sfw,
|
||||
console=console,
|
||||
)
|
||||
if not civitai_results:
|
||||
console.print("[red]Remote search failed.[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if json_output:
|
||||
console.print_json(data={"civitai": civitai_results})
|
||||
else:
|
||||
display_search_results(civitai_results, console)
|
||||
return
|
||||
|
||||
key = api_key or load_api_key()
|
||||
civitai_results: dict[str, Any] | None = None
|
||||
hf_results: list[dict[str, Any]] | None = None
|
||||
@@ -429,6 +463,78 @@ def _prepare_download_dir(output: Path | None, model_type_str: str | None) -> Pa
|
||||
return output_dir
|
||||
|
||||
|
||||
def _poll_remote_download(remote_name: str, download_id: str) -> None:
|
||||
"""Poll a remote download for completion with a progress bar."""
|
||||
import time # noqa: PLC0415
|
||||
|
||||
from rich.progress import BarColumn, DownloadColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn # noqa: PLC0415
|
||||
|
||||
status: dict[str, Any] | None = None
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
DownloadColumn(),
|
||||
console=console,
|
||||
) as progress:
|
||||
task = progress.add_task("[cyan]Downloading...", total=100)
|
||||
|
||||
while True:
|
||||
status = remote_download_status(remote_name, download_id)
|
||||
if not status:
|
||||
break
|
||||
|
||||
dl_status = status.get("status", "")
|
||||
pct = status.get("progress", 0)
|
||||
progress.update(task, completed=pct, description=f"[cyan]{dl_status.title()}...")
|
||||
|
||||
if dl_status in ("completed", "failed"):
|
||||
break
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
if status and status.get("status") == "completed":
|
||||
console.print(f"[green]Download complete:[/green] {status.get('path', 'N/A')}")
|
||||
elif status and status.get("status") == "failed":
|
||||
console.print(f"[red]Download failed:[/red] {status.get('error', 'Unknown error')}")
|
||||
|
||||
|
||||
def _download_remote(
|
||||
remote_name: str,
|
||||
version_id: int | None,
|
||||
model_id: int | None,
|
||||
hash_val: str | None,
|
||||
output: Path | None,
|
||||
) -> None:
|
||||
"""Handle remote download flow."""
|
||||
if not version_id and not model_id:
|
||||
if hash_val:
|
||||
console.print("[yellow]Remote download does not support --hash. Use --version-id or --model-id.[/yellow]")
|
||||
else:
|
||||
console.print("[red]Error: Must specify --version-id or --model-id for remote download[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
console.print("[dim]Starting download on remote server...[/dim]")
|
||||
result = remote_download(
|
||||
remote_name,
|
||||
version_id=version_id,
|
||||
model_id=model_id,
|
||||
output_dir=str(output) if output else None,
|
||||
console=console,
|
||||
)
|
||||
if not result:
|
||||
raise typer.Exit(1)
|
||||
|
||||
console.print(f"[green]Download started:[/green] {result.get('model_name', 'N/A')}")
|
||||
console.print(f"[dim]Version: {result.get('version_name', 'N/A')}[/dim]")
|
||||
console.print(f"[dim]Destination: {result.get('destination', 'N/A')}[/dim]")
|
||||
|
||||
download_id = result.get("download_id")
|
||||
if download_id:
|
||||
_poll_remote_download(remote_name, download_id)
|
||||
|
||||
|
||||
@app.command("dl")
|
||||
def download(
|
||||
version_id: Annotated[int | None, typer.Option("-v", "--version-id", help="Model version ID")] = None,
|
||||
@@ -437,8 +543,21 @@ def download(
|
||||
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None,
|
||||
no_resume: Annotated[bool, typer.Option("--no-resume", help="Don't resume partial downloads")] = False,
|
||||
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
) -> None:
|
||||
"""Download a model from CivitAI."""
|
||||
"""Download a model from CivitAI.
|
||||
|
||||
When --remote is specified, the download happens on the remote server.
|
||||
|
||||
Examples:
|
||||
tsr dl -v 12345 # Download by version ID
|
||||
tsr dl -m 67890 # Download latest version of model
|
||||
tsr dl -v 12345 --remote junkpile # Download on remote server
|
||||
"""
|
||||
if remote:
|
||||
_download_remote(remote, version_id, model_id, hash_val, output)
|
||||
return
|
||||
|
||||
key = api_key or load_api_key()
|
||||
|
||||
resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key)
|
||||
@@ -645,6 +764,236 @@ def serve(
|
||||
uvicorn.run(create_app(), host=host, port=port, log_level=log_level)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Top-Level Generate Command
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@app.command()
|
||||
def generate( # noqa: PLR0915
|
||||
prompt: Annotated[str, typer.Argument(help="Positive prompt text")],
|
||||
model: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model name")] = None,
|
||||
width: Annotated[int, typer.Option("-W", "--width", help="Image width")] = 1024,
|
||||
height: Annotated[int, typer.Option("-H", "--height", help="Image height")] = 1024,
|
||||
steps: Annotated[int, typer.Option("--steps", help="Sampling steps")] = 20,
|
||||
cfg: Annotated[float, typer.Option("--cfg", help="CFG scale")] = 7.0,
|
||||
seed: Annotated[int, typer.Option("--seed", "-s", help="Random seed (-1 for random)")] = -1,
|
||||
sampler: Annotated[str, typer.Option("--sampler", help="Sampler name")] = "euler",
|
||||
scheduler: Annotated[str, typer.Option("--scheduler", help="Scheduler name")] = "normal",
|
||||
vae: Annotated[str | None, typer.Option("--vae", help="VAE model name")] = None,
|
||||
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None,
|
||||
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8,
|
||||
negative: Annotated[str, typer.Option("-n", "--negative-prompt", help="Negative prompt")] = "",
|
||||
output: Annotated[Path | None, typer.Option("-o", "--output", help="Save path (default: current dir)")] = None,
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
) -> None:
|
||||
"""Generate an image using text-to-image.
|
||||
|
||||
Calls ComfyUI directly when local, or the remote tensors API when --remote is given.
|
||||
|
||||
Examples:
|
||||
tsr generate "a cat on a windowsill"
|
||||
tsr generate "portrait photo" -m "flux1-dev-fp8.safetensors" --steps 30
|
||||
tsr generate "cyberpunk city" -o output.png
|
||||
tsr generate "landscape" --remote junkpile
|
||||
"""
|
||||
import random as rng # noqa: PLC0415
|
||||
|
||||
from tensors.config import resolve_remote as do_resolve_remote # noqa: PLC0415
|
||||
|
||||
# Resolve remote (explicit flag, or default from config)
|
||||
remote_url = do_resolve_remote(remote) if remote else do_resolve_remote(None)
|
||||
|
||||
if remote_url:
|
||||
# ---- Remote mode: HTTP call to tensors server ----
|
||||
if not json_output:
|
||||
console.print(f"[dim]Remote: {remote_url}[/dim]")
|
||||
|
||||
result = remote_generate(
|
||||
remote or remote_url,
|
||||
prompt,
|
||||
negative_prompt=negative,
|
||||
model=model,
|
||||
width=width,
|
||||
height=height,
|
||||
steps=steps,
|
||||
cfg=cfg,
|
||||
seed=seed,
|
||||
sampler=sampler,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
lora_name=lora,
|
||||
lora_strength=lora_strength,
|
||||
console=console,
|
||||
)
|
||||
|
||||
if not result:
|
||||
if not json_output:
|
||||
console.print("[red]Generation failed[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if json_output:
|
||||
console.print_json(data=result)
|
||||
return
|
||||
|
||||
if not result.get("success"):
|
||||
console.print("[red]Generation failed[/red]")
|
||||
errors = result.get("errors", {})
|
||||
for node_id, err in errors.items():
|
||||
console.print(f" [yellow]Node {node_id}:[/yellow] {err}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
images = result.get("images", [])
|
||||
console.print(f"[green]Generated {len(images)} image(s)[/green]")
|
||||
console.print(f"[dim]Prompt ID: {result.get('prompt_id', 'N/A')}[/dim]")
|
||||
|
||||
# Download and save images if --output specified
|
||||
if output and images:
|
||||
for i, img_name in enumerate(images):
|
||||
img_data = remote_get_image(remote or remote_url, img_name)
|
||||
if img_data:
|
||||
save_path = output if len(images) == 1 else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}"
|
||||
save_path.write_bytes(img_data)
|
||||
console.print(f"[green]Saved:[/green] {save_path}")
|
||||
else:
|
||||
console.print(f"[yellow]Could not download image: {img_name}[/yellow]")
|
||||
elif images:
|
||||
for img_name in images:
|
||||
console.print(f" [dim]{img_name}[/dim]")
|
||||
|
||||
else:
|
||||
# ---- Local mode: direct library call ----
|
||||
from tensors.comfyui import generate_image, get_image # noqa: PLC0415
|
||||
|
||||
actual_seed = seed if seed >= 0 else rng.randint(0, 2**32 - 1)
|
||||
|
||||
result_local = generate_image(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative,
|
||||
model=model,
|
||||
width=width,
|
||||
height=height,
|
||||
steps=steps,
|
||||
cfg=cfg,
|
||||
seed=actual_seed,
|
||||
sampler=sampler,
|
||||
scheduler=scheduler,
|
||||
console=console if not json_output else None,
|
||||
lora_name=lora,
|
||||
lora_strength=lora_strength,
|
||||
vae=vae,
|
||||
)
|
||||
|
||||
if not result_local:
|
||||
if json_output:
|
||||
console.print_json(data={"success": False, "errors": {"generation": "Failed to generate"}})
|
||||
else:
|
||||
console.print("[red]Generation failed[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if not result_local.success:
|
||||
if json_output:
|
||||
console.print_json(data={"success": False, "errors": result_local.node_errors})
|
||||
else:
|
||||
console.print("[red]Generation failed[/red]")
|
||||
for node_id, errors in result_local.node_errors.items():
|
||||
console.print(f" [yellow]Node {node_id}:[/yellow] {errors}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Save images
|
||||
saved_paths: list[Path] = []
|
||||
for i, img_path in enumerate(result_local.images):
|
||||
if output:
|
||||
img_data = get_image(str(img_path))
|
||||
if img_data:
|
||||
save_path = (
|
||||
output if len(result_local.images) == 1 else output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}"
|
||||
)
|
||||
save_path.write_bytes(img_data)
|
||||
saved_paths.append(save_path)
|
||||
if not json_output:
|
||||
console.print(f"[green]Saved:[/green] {save_path}")
|
||||
|
||||
if json_output:
|
||||
console.print_json(
|
||||
data={
|
||||
"success": True,
|
||||
"prompt_id": result_local.prompt_id,
|
||||
"images": [str(p) for p in result_local.images],
|
||||
"saved": [str(p) for p in saved_paths],
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
console.print("[bold green]Generation complete![/bold green]")
|
||||
console.print(f"[dim]Prompt ID: {result_local.prompt_id}[/dim]")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Top-Level Models Command
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@app.command()
|
||||
def models(
|
||||
model_type: Annotated[str | None, typer.Option("-t", "--type", help="Filter by type (checkpoints, loras, vae)")] = None,
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
) -> None:
|
||||
"""List available models from ComfyUI.
|
||||
|
||||
Shows checkpoints, LoRAs, VAEs, and other model types loaded in ComfyUI.
|
||||
Uses --remote to query a remote tensors server instead of local ComfyUI.
|
||||
|
||||
Examples:
|
||||
tsr models
|
||||
tsr models -t checkpoints
|
||||
tsr models --remote junkpile
|
||||
tsr models --json
|
||||
"""
|
||||
from tensors.config import resolve_remote as do_resolve_remote # noqa: PLC0415
|
||||
|
||||
remote_url = do_resolve_remote(remote) if remote else do_resolve_remote(None)
|
||||
|
||||
if remote_url:
|
||||
if not json_output:
|
||||
console.print(f"[dim]Remote: {remote_url}[/dim]")
|
||||
result = remote_models(remote or remote_url, console=console)
|
||||
else:
|
||||
from tensors.comfyui import get_loaded_models # noqa: PLC0415
|
||||
|
||||
result = get_loaded_models(console=console if not json_output else None)
|
||||
|
||||
if not result:
|
||||
console.print("[red]Error: Could not fetch models[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Filter by type if requested
|
||||
if model_type:
|
||||
key = model_type.lower()
|
||||
filtered = {k: v for k, v in result.items() if k.lower() == key}
|
||||
if not filtered:
|
||||
console.print(f"[yellow]No models found for type '{model_type}'[/yellow]")
|
||||
console.print(f"[dim]Available types: {', '.join(sorted(result.keys()))}[/dim]")
|
||||
raise typer.Exit(1)
|
||||
result = filtered
|
||||
|
||||
if json_output:
|
||||
console.print_json(data=result)
|
||||
return
|
||||
|
||||
console.print("[bold cyan]Available Models[/bold cyan]")
|
||||
|
||||
for mtype, model_list in sorted(result.items()):
|
||||
console.print()
|
||||
console.print(f"[bold]{mtype}:[/bold] ({len(model_list)})")
|
||||
for name in model_list[:MAX_MODEL_LIST_DISPLAY]:
|
||||
console.print(f" {name}")
|
||||
if len(model_list) > MAX_MODEL_LIST_DISPLAY:
|
||||
console.print(f" ... and {len(model_list) - MAX_MODEL_LIST_DISPLAY} more")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Database Commands
|
||||
# =============================================================================
|
||||
@@ -1455,6 +1804,8 @@ def main() -> int:
|
||||
"get",
|
||||
"dl",
|
||||
"download",
|
||||
"generate",
|
||||
"models",
|
||||
"config",
|
||||
"serve",
|
||||
"db",
|
||||
|
||||
@@ -0,0 +1,300 @@
|
||||
"""HTTP client for calling tensors server API remotely."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from rich.console import Console
|
||||
|
||||
from tensors.config import get_server_api_key, resolve_remote
|
||||
|
||||
|
||||
def _build_client(base_url: str, timeout: float = 300.0) -> httpx.Client:
|
||||
"""Build an httpx client with API key auth."""
|
||||
api_key = get_server_api_key()
|
||||
headers: dict[str, str] = {}
|
||||
if api_key:
|
||||
headers["X-API-Key"] = api_key
|
||||
return httpx.Client(base_url=base_url, headers=headers, timeout=timeout)
|
||||
|
||||
|
||||
def remote_generate(
|
||||
remote: str,
|
||||
prompt: str,
|
||||
*,
|
||||
negative_prompt: str = "",
|
||||
model: str | None = None,
|
||||
width: int = 1024,
|
||||
height: int = 1024,
|
||||
steps: int = 20,
|
||||
cfg: float = 7.0,
|
||||
seed: int = -1,
|
||||
sampler: str = "euler",
|
||||
scheduler: str = "normal",
|
||||
vae: str | None = None,
|
||||
lora_name: str | None = None,
|
||||
lora_strength: float = 0.8,
|
||||
console: Console | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Generate an image via remote tensors server.
|
||||
|
||||
Args:
|
||||
remote: Remote name or URL (resolved via config)
|
||||
prompt: Positive prompt text
|
||||
console: Rich console for error output
|
||||
|
||||
Returns:
|
||||
Response dict with success, prompt_id, images, errors — or None on connection error
|
||||
"""
|
||||
base_url = resolve_remote(remote)
|
||||
if not base_url:
|
||||
if console:
|
||||
console.print("[red]Error: Could not resolve remote server[/red]")
|
||||
return None
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"steps": steps,
|
||||
"cfg": cfg,
|
||||
"seed": seed,
|
||||
"sampler": sampler,
|
||||
"scheduler": scheduler,
|
||||
"lora_strength": lora_strength,
|
||||
}
|
||||
if model:
|
||||
payload["model"] = model
|
||||
if vae:
|
||||
payload["vae"] = vae
|
||||
if lora_name:
|
||||
payload["lora_name"] = lora_name
|
||||
|
||||
try:
|
||||
with _build_client(base_url) as client:
|
||||
response = client.post("/api/comfyui/generate", json=payload)
|
||||
response.raise_for_status()
|
||||
result: dict[str, Any] = response.json()
|
||||
return result
|
||||
except httpx.HTTPStatusError as e:
|
||||
if console:
|
||||
console.print(f"[red]Remote API error: {e.response.status_code}[/red]")
|
||||
try:
|
||||
detail = e.response.json().get("detail", "")
|
||||
if detail:
|
||||
console.print(f" [yellow]{detail}[/yellow]")
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
except httpx.RequestError as e:
|
||||
if console:
|
||||
console.print(f"[red]Remote connection error: {e}[/red]")
|
||||
return None
|
||||
|
||||
|
||||
def remote_get_image(remote: str, filename: str) -> bytes | None:
|
||||
"""Download a generated image from remote tensors server.
|
||||
|
||||
Args:
|
||||
remote: Remote name or URL
|
||||
filename: Image filename from generation result
|
||||
|
||||
Returns:
|
||||
Image bytes or None on error
|
||||
"""
|
||||
base_url = resolve_remote(remote)
|
||||
if not base_url:
|
||||
return None
|
||||
|
||||
try:
|
||||
with _build_client(base_url) as client:
|
||||
response = client.get("/api/comfyui/image/" + filename)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
except (httpx.HTTPStatusError, httpx.RequestError):
|
||||
return None
|
||||
|
||||
|
||||
def remote_models(
|
||||
remote: str,
|
||||
console: Console | None = None,
|
||||
) -> dict[str, list[str]] | None:
|
||||
"""List available models from remote tensors server.
|
||||
|
||||
Args:
|
||||
remote: Remote name or URL
|
||||
console: Rich console for error output
|
||||
|
||||
Returns:
|
||||
Dict mapping model type to list of model names, or None on error
|
||||
"""
|
||||
base_url = resolve_remote(remote)
|
||||
if not base_url:
|
||||
if console:
|
||||
console.print("[red]Error: Could not resolve remote server[/red]")
|
||||
return None
|
||||
|
||||
try:
|
||||
with _build_client(base_url) as client:
|
||||
response = client.get("/api/comfyui/models")
|
||||
response.raise_for_status()
|
||||
result: dict[str, list[str]] = response.json()
|
||||
return result
|
||||
except httpx.HTTPStatusError as e:
|
||||
if console:
|
||||
console.print(f"[red]Remote API error: {e.response.status_code}[/red]")
|
||||
return None
|
||||
except httpx.RequestError as e:
|
||||
if console:
|
||||
console.print(f"[red]Remote connection error: {e}[/red]")
|
||||
return None
|
||||
|
||||
|
||||
def remote_search(
|
||||
remote: str,
|
||||
*,
|
||||
query: str | None = None,
|
||||
model_type: str | None = None,
|
||||
base_model: str | None = None,
|
||||
sort: str = "downloads",
|
||||
limit: int = 20,
|
||||
page: int | None = None,
|
||||
nsfw: str | None = None,
|
||||
sfw: bool = False,
|
||||
console: Console | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Search CivitAI models via remote tensors server.
|
||||
|
||||
Args:
|
||||
remote: Remote name or URL
|
||||
console: Rich console for error output
|
||||
|
||||
Returns:
|
||||
Search results dict or None on error
|
||||
"""
|
||||
base_url = resolve_remote(remote)
|
||||
if not base_url:
|
||||
if console:
|
||||
console.print("[red]Error: Could not resolve remote server[/red]")
|
||||
return None
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"provider": "civitai",
|
||||
"sort": sort,
|
||||
"limit": limit,
|
||||
}
|
||||
if query:
|
||||
params["query"] = query
|
||||
if model_type:
|
||||
params["types"] = model_type
|
||||
if base_model:
|
||||
params["baseModels"] = base_model
|
||||
if page:
|
||||
params["page"] = page
|
||||
if sfw:
|
||||
params["sfw"] = True
|
||||
elif nsfw:
|
||||
params["nsfw"] = nsfw
|
||||
|
||||
try:
|
||||
with _build_client(base_url) as client:
|
||||
response = client.get("/api/search", params=params)
|
||||
response.raise_for_status()
|
||||
result: dict[str, Any] = response.json()
|
||||
# The remote API wraps CivitAI results under "civitai" key
|
||||
return result.get("civitai", result)
|
||||
except httpx.HTTPStatusError as e:
|
||||
if console:
|
||||
console.print(f"[red]Remote API error: {e.response.status_code}[/red]")
|
||||
return None
|
||||
except httpx.RequestError as e:
|
||||
if console:
|
||||
console.print(f"[red]Remote connection error: {e}[/red]")
|
||||
return None
|
||||
|
||||
|
||||
def remote_download(
|
||||
remote: str,
|
||||
*,
|
||||
version_id: int | None = None,
|
||||
model_id: int | None = None,
|
||||
output_dir: str | None = None,
|
||||
console: Console | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Start a model download on remote tensors server.
|
||||
|
||||
Args:
|
||||
remote: Remote name or URL
|
||||
version_id: CivitAI version ID
|
||||
model_id: CivitAI model ID (downloads latest version)
|
||||
output_dir: Override output directory on the remote
|
||||
console: Rich console for error output
|
||||
|
||||
Returns:
|
||||
Download status dict with download_id, or None on error
|
||||
"""
|
||||
base_url = resolve_remote(remote)
|
||||
if not base_url:
|
||||
if console:
|
||||
console.print("[red]Error: Could not resolve remote server[/red]")
|
||||
return None
|
||||
|
||||
payload: dict[str, Any] = {}
|
||||
if version_id:
|
||||
payload["version_id"] = version_id
|
||||
if model_id:
|
||||
payload["model_id"] = model_id
|
||||
if output_dir:
|
||||
payload["output_dir"] = output_dir
|
||||
|
||||
try:
|
||||
with _build_client(base_url) as client:
|
||||
response = client.post("/api/download", json=payload)
|
||||
response.raise_for_status()
|
||||
result: dict[str, Any] = response.json()
|
||||
return result
|
||||
except httpx.HTTPStatusError as e:
|
||||
if console:
|
||||
console.print(f"[red]Remote API error: {e.response.status_code}[/red]")
|
||||
try:
|
||||
detail = e.response.json().get("detail", "")
|
||||
if detail:
|
||||
console.print(f" [yellow]{detail}[/yellow]")
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
except httpx.RequestError as e:
|
||||
if console:
|
||||
console.print(f"[red]Remote connection error: {e}[/red]")
|
||||
return None
|
||||
|
||||
|
||||
def remote_download_status(
|
||||
remote: str,
|
||||
download_id: str,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Check download status on remote tensors server.
|
||||
|
||||
Args:
|
||||
remote: Remote name or URL
|
||||
download_id: Download ID from start_download response
|
||||
|
||||
Returns:
|
||||
Download status dict or None on error
|
||||
"""
|
||||
base_url = resolve_remote(remote)
|
||||
if not base_url:
|
||||
return None
|
||||
|
||||
try:
|
||||
with _build_client(base_url) as client:
|
||||
response = client.get(f"/api/download/status/{download_id}")
|
||||
response.raise_for_status()
|
||||
result: dict[str, Any] = response.json()
|
||||
return result
|
||||
except (httpx.HTTPStatusError, httpx.RequestError):
|
||||
return None
|
||||
Reference in New Issue
Block a user