💬 Commit message: Update 2026-02-15 06:21:35, 7 files, 1559 lines
📁 Files changed: 7 📝 Lines changed: 1559 • .coverage • cli.py • __init__.py • conftest.py • test_client.py • test_generate.py • test_server.py
This commit is contained in:
+30
-590
@@ -20,20 +20,15 @@ from tensors.api import (
|
||||
fetch_civitai_model_version,
|
||||
search_civitai,
|
||||
)
|
||||
from tensors.client import TsrClient, TsrClientError
|
||||
from tensors.config import (
|
||||
CONFIG_FILE,
|
||||
BaseModel,
|
||||
ModelType,
|
||||
SortOrder,
|
||||
get_default_output_path,
|
||||
get_remotes,
|
||||
load_api_key,
|
||||
load_config,
|
||||
resolve_remote,
|
||||
save_config,
|
||||
save_remote,
|
||||
set_default_remote,
|
||||
)
|
||||
from tensors.db import DB_PATH, Database
|
||||
from tensors.display import (
|
||||
@@ -49,18 +44,6 @@ from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_me
|
||||
# Key masking threshold
|
||||
MIN_KEY_LENGTH_FOR_MASKING = 8
|
||||
|
||||
# Size threshold for GB display
|
||||
_MB_PER_GB = 1024
|
||||
|
||||
|
||||
def _format_size_mb(size_mb: float | None) -> str:
|
||||
"""Format size in MB to human-readable string."""
|
||||
if not size_mb:
|
||||
return ""
|
||||
if size_mb >= _MB_PER_GB:
|
||||
return f"{size_mb / _MB_PER_GB:.1f} GB"
|
||||
return f"{size_mb:.0f} MB"
|
||||
|
||||
|
||||
def _version_callback(value: bool) -> None:
|
||||
if value:
|
||||
@@ -329,74 +312,41 @@ def download(
|
||||
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None,
|
||||
no_resume: Annotated[bool, typer.Option("--no-resume", help="Don't resume partial downloads")] = False,
|
||||
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON (remote mode)")] = False,
|
||||
) -> None:
|
||||
"""Download a model from CivitAI (locally or to remote server)."""
|
||||
# Check if remote is specified or configured
|
||||
remote_url = resolve_remote(remote)
|
||||
"""Download a model from CivitAI."""
|
||||
key = api_key or load_api_key()
|
||||
|
||||
if remote_url:
|
||||
# Remote mode: use TsrClient API
|
||||
if not version_id and not model_id and not hash_val:
|
||||
resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key)
|
||||
if not resolved_version_id:
|
||||
if not version_id and not hash_val and not model_id:
|
||||
console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]")
|
||||
raise typer.Exit(1)
|
||||
raise typer.Exit(1)
|
||||
|
||||
try:
|
||||
with TsrClient(remote_url) as client:
|
||||
console.print(f"[cyan]Starting download on {remote_url}...[/cyan]")
|
||||
result = client.start_download(
|
||||
version_id=version_id,
|
||||
model_id=model_id,
|
||||
hash_val=hash_val,
|
||||
output_dir=str(output) if output else None,
|
||||
)
|
||||
except TsrClientError as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
console.print(f"[cyan]Fetching version info for {resolved_version_id}...[/cyan]")
|
||||
version_info = fetch_civitai_model_version(resolved_version_id, key, console)
|
||||
if not version_info:
|
||||
console.print("[red]Error: Could not fetch model version info.[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if json_output:
|
||||
console.print_json(data=result)
|
||||
return
|
||||
model_type_str: str | None = version_info.get("model", {}).get("type")
|
||||
output_dir = _prepare_download_dir(output, model_type_str)
|
||||
if not output_dir:
|
||||
raise typer.Exit(1)
|
||||
|
||||
download_id = result.get("download_id")
|
||||
console.print(f"[green]Download started:[/green] {download_id}")
|
||||
console.print(f"[dim]Check status with: tsr images download-status {download_id} --remote {remote or 'default'}[/dim]")
|
||||
else:
|
||||
# Local mode: direct download
|
||||
key = api_key or load_api_key()
|
||||
files: list[dict[str, Any]] = version_info.get("files", [])
|
||||
primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
|
||||
if not primary_file:
|
||||
console.print("[red]Error: No files found for this version.[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key)
|
||||
if not resolved_version_id:
|
||||
if not version_id and not hash_val and not model_id:
|
||||
console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]")
|
||||
raise typer.Exit(1)
|
||||
filename = primary_file.get("name", f"model-{resolved_version_id}.safetensors")
|
||||
dest_path = output_dir / filename
|
||||
|
||||
console.print(f"[cyan]Fetching version info for {resolved_version_id}...[/cyan]")
|
||||
version_info = fetch_civitai_model_version(resolved_version_id, key, console)
|
||||
if not version_info:
|
||||
console.print("[red]Error: Could not fetch model version info.[/red]")
|
||||
raise typer.Exit(1)
|
||||
_display_download_info(version_info, filename, primary_file, dest_path)
|
||||
|
||||
model_type_str: str | None = version_info.get("model", {}).get("type")
|
||||
output_dir = _prepare_download_dir(output, model_type_str)
|
||||
if not output_dir:
|
||||
raise typer.Exit(1)
|
||||
|
||||
files: list[dict[str, Any]] = version_info.get("files", [])
|
||||
primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
|
||||
if not primary_file:
|
||||
console.print("[red]Error: No files found for this version.[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
filename = primary_file.get("name", f"model-{resolved_version_id}.safetensors")
|
||||
dest_path = output_dir / filename
|
||||
|
||||
_display_download_info(version_info, filename, primary_file, dest_path)
|
||||
|
||||
success = download_model(resolved_version_id, dest_path, key, console, resume=not no_resume)
|
||||
if not success:
|
||||
raise typer.Exit(1)
|
||||
success = download_model(resolved_version_id, dest_path, key, console, resume=not no_resume)
|
||||
if not success:
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def _display_download_info(
|
||||
@@ -449,167 +399,13 @@ def config(
|
||||
console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]")
|
||||
|
||||
|
||||
@app.command()
|
||||
def generate(
|
||||
prompt: Annotated[str, typer.Argument(help="Text prompt for image generation.")],
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
model: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model (remote mode only).")] = None,
|
||||
host: Annotated[str, typer.Option(help="sd-server address (local mode).")] = "127.0.0.1",
|
||||
port: Annotated[int, typer.Option(help="sd-server port (local mode).")] = 8080,
|
||||
output: Annotated[str, typer.Option("-o", help="Output directory (local mode).")] = ".",
|
||||
negative_prompt: Annotated[str, typer.Option("-n", help="Negative prompt.")] = "",
|
||||
width: Annotated[int, typer.Option("-W", help="Image width.")] = 512,
|
||||
height: Annotated[int, typer.Option("-H", help="Image height.")] = 512,
|
||||
steps: Annotated[int, typer.Option(help="Sampling steps.")] = 20,
|
||||
cfg_scale: Annotated[float, typer.Option(help="CFG scale.")] = 7.0,
|
||||
seed: Annotated[int, typer.Option("-s", help="RNG seed (-1 for random).")] = -1,
|
||||
sampler: Annotated[str, typer.Option(help="Sampler name.")] = "",
|
||||
scheduler: Annotated[str, typer.Option(help="Scheduler name.")] = "",
|
||||
batch_size: Annotated[int, typer.Option("-b", help="Number of images.")] = 1,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON (remote mode)")] = False,
|
||||
) -> None:
|
||||
"""Generate images using sd-server (local or remote)."""
|
||||
# Check if remote is specified or configured
|
||||
remote_url = resolve_remote(remote)
|
||||
|
||||
if remote_url:
|
||||
# Remote mode: use TsrClient API
|
||||
try:
|
||||
with TsrClient(remote_url) as client:
|
||||
# Switch model if specified
|
||||
if model:
|
||||
console.print(f"[cyan]Switching to model: {model}[/cyan]")
|
||||
client.switch_model(model)
|
||||
|
||||
console.print(f"[cyan]Generating {batch_size} image(s) on {remote_url}...[/cyan]")
|
||||
result = client.generate(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
steps=steps,
|
||||
cfg_scale=cfg_scale,
|
||||
seed=seed,
|
||||
sampler_name=sampler,
|
||||
scheduler=scheduler,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
except TsrClientError as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
if json_output:
|
||||
console.print_json(data=result)
|
||||
return
|
||||
|
||||
images = result.get("images", [])
|
||||
for img in images:
|
||||
console.print(f"[green]Generated:[/green] {img.get('id', 'unknown')}")
|
||||
else:
|
||||
# Local mode: direct sd-server connection
|
||||
if model:
|
||||
console.print("[yellow]Warning: --model ignored in local mode (sd-server loads model at startup)[/yellow]")
|
||||
|
||||
from tensors.generate import SDClient, Txt2ImgParams, save_images # noqa: PLC0415
|
||||
|
||||
params = Txt2ImgParams(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
steps=steps,
|
||||
cfg_scale=cfg_scale,
|
||||
seed=seed,
|
||||
batch_size=batch_size,
|
||||
sampler_name=sampler,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
with SDClient(host=host, port=port) as client:
|
||||
console.print(f"[cyan]Generating {batch_size} image(s)...[/cyan]")
|
||||
images = client.generate.txt2img(params)
|
||||
paths = save_images(images, output)
|
||||
for p in paths:
|
||||
console.print(f"[green]Saved:[/green] {p}")
|
||||
|
||||
|
||||
@app.command()
|
||||
def status(
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
host: Annotated[str, typer.Option(help="Wrapper API host (local mode).")] = "127.0.0.1",
|
||||
port: Annotated[int, typer.Option(help="Wrapper API port (local mode).")] = 8080,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
) -> None:
|
||||
"""Show sd-server wrapper status."""
|
||||
# Check if remote is specified or configured
|
||||
remote_url = resolve_remote(remote)
|
||||
|
||||
if remote_url:
|
||||
# Remote mode: use TsrClient API
|
||||
try:
|
||||
with TsrClient(remote_url) as client:
|
||||
data = client.status()
|
||||
except TsrClientError as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
else:
|
||||
# Local mode: direct HTTP call
|
||||
import httpx # noqa: PLC0415
|
||||
|
||||
url = f"http://{host}:{port}/status"
|
||||
try:
|
||||
resp = httpx.get(url, timeout=10)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
except httpx.HTTPError as e:
|
||||
console.print(f"[red]Error: Could not reach wrapper at {url}: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
if json_output:
|
||||
console.print_json(data=data)
|
||||
return
|
||||
|
||||
table = Table(title="Server Status", show_header=True, header_style="bold magenta")
|
||||
table.add_column("Property", style="cyan")
|
||||
table.add_column("Value", style="green")
|
||||
for key, value in data.items():
|
||||
table.add_row(key, str(value))
|
||||
console.print(table)
|
||||
|
||||
|
||||
@app.command()
|
||||
def reload(
|
||||
model: Annotated[str, typer.Option(help="Path to model file for sd-server.")],
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
host: Annotated[str, typer.Option(help="Wrapper API host (local mode).")] = "127.0.0.1",
|
||||
port: Annotated[int, typer.Option(help="Wrapper API port (local mode).")] = 8080,
|
||||
) -> None:
|
||||
"""Reload sd-server with a new model."""
|
||||
import httpx # noqa: PLC0415
|
||||
|
||||
remote_url = resolve_remote(remote)
|
||||
url = f"{remote_url.rstrip('/')}/reload" if remote_url else f"http://{host}:{port}/reload"
|
||||
|
||||
console.print(f"[cyan]Reloading model: {model}[/cyan]")
|
||||
try:
|
||||
resp = httpx.post(url, json={"model": model}, timeout=300)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
except httpx.HTTPError as e:
|
||||
console.print(f"[red]Error: Reload failed at {url}: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
console.print(f"[green]{data.get('status', 'OK')}[/green]")
|
||||
|
||||
|
||||
@app.command()
|
||||
def serve(
|
||||
host: Annotated[str, typer.Option(help="Wrapper API listen address.")] = "127.0.0.1",
|
||||
port: Annotated[int, typer.Option(help="Wrapper API listen port.")] = 8080,
|
||||
sd_server: Annotated[str | None, typer.Option(help="sd-server URL to proxy to.")] = None,
|
||||
host: Annotated[str, typer.Option(help="Listen address.")] = "127.0.0.1",
|
||||
port: Annotated[int, typer.Option(help="Listen port.")] = 8080,
|
||||
log_level: Annotated[str, typer.Option(help="Log level.")] = "info",
|
||||
) -> None:
|
||||
"""Start the sd-server wrapper API (proxies to external sd-server)."""
|
||||
"""Start the tensors server (gallery and CivitAI management)."""
|
||||
try:
|
||||
import uvicorn # noqa: PLC0415
|
||||
|
||||
@@ -619,7 +415,7 @@ def serve(
|
||||
console.print(" pip install tensors[server]")
|
||||
raise typer.Exit(1) from None
|
||||
|
||||
uvicorn.run(create_app(sd_server_url=sd_server), host=host, port=port, log_level=log_level)
|
||||
uvicorn.run(create_app(), host=host, port=port, log_level=log_level)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -858,356 +654,6 @@ def db_stats(
|
||||
console.print(table)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Images Commands (Remote)
|
||||
# =============================================================================
|
||||
|
||||
images_app = typer.Typer(
|
||||
name="images",
|
||||
help="Manage images in remote gallery.",
|
||||
no_args_is_help=True,
|
||||
)
|
||||
app.add_typer(images_app, name="images")
|
||||
|
||||
|
||||
def _get_client(remote: str | None) -> TsrClient:
|
||||
"""Get TsrClient for remote or raise error."""
|
||||
url = resolve_remote(remote)
|
||||
if not url:
|
||||
console.print("[red]Error: No remote specified. Use --remote or set default_remote in config.[/red]")
|
||||
raise typer.Exit(1)
|
||||
return TsrClient(url)
|
||||
|
||||
|
||||
@images_app.command("list")
|
||||
def images_list(
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 50,
|
||||
offset: Annotated[int, typer.Option("--offset", help="Offset for pagination")] = 0,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
) -> None:
|
||||
"""List images in remote gallery."""
|
||||
try:
|
||||
with _get_client(remote) as client:
|
||||
result = client.list_images(limit=limit, offset=offset)
|
||||
except TsrClientError as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
images = result.get("images", [])
|
||||
total = result.get("total", len(images))
|
||||
|
||||
if json_output:
|
||||
console.print_json(data=result)
|
||||
return
|
||||
|
||||
if not images:
|
||||
console.print("[yellow]No images in gallery.[/yellow]")
|
||||
return
|
||||
|
||||
table = Table(title=f"Gallery Images ({len(images)}/{total})", show_header=True, header_style="bold magenta")
|
||||
table.add_column("ID", style="cyan")
|
||||
table.add_column("Filename", style="green")
|
||||
table.add_column("Size", style="white")
|
||||
table.add_column("Created", style="dim")
|
||||
|
||||
for img in images:
|
||||
size = f"{img.get('width', '?')}x{img.get('height', '?')}"
|
||||
created = img.get("created_at", "")
|
||||
if isinstance(created, (int, float)):
|
||||
from datetime import datetime # noqa: PLC0415
|
||||
|
||||
created = datetime.fromtimestamp(created).strftime("%Y-%m-%d %H:%M")
|
||||
table.add_row(img.get("id", ""), img.get("filename", ""), size, str(created))
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
@images_app.command("show")
|
||||
def images_show(
|
||||
image_id: Annotated[str, typer.Argument(help="Image ID to show")],
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
) -> None:
|
||||
"""Show image metadata."""
|
||||
try:
|
||||
with _get_client(remote) as client:
|
||||
meta = client.get_image_meta(image_id)
|
||||
except TsrClientError as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
if json_output:
|
||||
console.print_json(data=meta)
|
||||
return
|
||||
|
||||
table = Table(title=f"Image: {image_id}", show_header=True, header_style="bold magenta")
|
||||
table.add_column("Property", style="cyan")
|
||||
table.add_column("Value", style="green")
|
||||
|
||||
for key, value in meta.items():
|
||||
display_value = json.dumps(value, indent=2) if isinstance(value, dict) else str(value)
|
||||
table.add_row(key, display_value)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
@images_app.command("delete")
|
||||
def images_delete(
|
||||
image_id: Annotated[str, typer.Argument(help="Image ID to delete")],
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
force: Annotated[bool, typer.Option("-f", "--force", help="Skip confirmation")] = False,
|
||||
) -> None:
|
||||
"""Delete an image from the gallery."""
|
||||
if not force:
|
||||
confirm = typer.confirm(f"Delete image {image_id}?")
|
||||
if not confirm:
|
||||
console.print("[yellow]Cancelled.[/yellow]")
|
||||
raise typer.Exit(0)
|
||||
|
||||
try:
|
||||
with _get_client(remote) as client:
|
||||
client.delete_image(image_id)
|
||||
except TsrClientError as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
console.print(f"[green]Deleted image: {image_id}[/green]")
|
||||
|
||||
|
||||
@images_app.command("download")
|
||||
def images_download(
|
||||
image_id: Annotated[str, typer.Argument(help="Image ID to download")],
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file or directory")] = None,
|
||||
) -> None:
|
||||
"""Download an image from the remote gallery."""
|
||||
try:
|
||||
with _get_client(remote) as client:
|
||||
content = client.download_image(image_id)
|
||||
except TsrClientError as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
# Determine output path
|
||||
if output is None:
|
||||
dest = Path(f"{image_id}.png")
|
||||
elif output.is_dir():
|
||||
dest = output / f"{image_id}.png"
|
||||
else:
|
||||
dest = output
|
||||
|
||||
dest.write_bytes(content)
|
||||
console.print(f"[green]Saved:[/green] {dest}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Models Commands (Remote)
|
||||
# =============================================================================
|
||||
|
||||
models_app = typer.Typer(
|
||||
name="models",
|
||||
help="Manage models on remote server.",
|
||||
no_args_is_help=True,
|
||||
)
|
||||
app.add_typer(models_app, name="models")
|
||||
|
||||
|
||||
@models_app.command("list")
|
||||
def models_list(
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
) -> None:
|
||||
"""List available models on remote server."""
|
||||
try:
|
||||
with _get_client(remote) as client:
|
||||
result = client.list_models()
|
||||
except TsrClientError as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
if json_output:
|
||||
console.print_json(data=result)
|
||||
return
|
||||
|
||||
models = result.get("models", [])
|
||||
active = result.get("active", "")
|
||||
|
||||
if not models:
|
||||
console.print("[yellow]No models found.[/yellow]")
|
||||
return
|
||||
|
||||
table = Table(title="Available Models", show_header=True, header_style="bold magenta")
|
||||
table.add_column("ID", style="dim", width=8)
|
||||
table.add_column("Name", style="cyan")
|
||||
table.add_column("File", style="white")
|
||||
table.add_column("Size", style="green", justify="right")
|
||||
|
||||
for model in models:
|
||||
path = model.get("path", "")
|
||||
name = model.get("name", Path(path).stem if path else "")
|
||||
is_active = active in {path, name}
|
||||
|
||||
civitai_id = model.get("civitai_model_id")
|
||||
id_str = str(civitai_id) if civitai_id else ""
|
||||
display_name = model.get("display_name", name)
|
||||
if is_active:
|
||||
display_name = f"[green]✓[/green] {display_name}"
|
||||
filename = model.get("filename", Path(path).name if path else "")
|
||||
size_str = _format_size_mb(model.get("size_mb"))
|
||||
|
||||
table.add_row(id_str, display_name, filename, size_str)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
@models_app.command("active")
|
||||
def models_active(
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
) -> None:
|
||||
"""Show currently active model."""
|
||||
try:
|
||||
with _get_client(remote) as client:
|
||||
result = client.get_active_model()
|
||||
except TsrClientError as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
if json_output:
|
||||
console.print_json(data=result)
|
||||
return
|
||||
|
||||
model = result.get("model", "None")
|
||||
console.print(f"[bold]Active model:[/bold] {model}")
|
||||
|
||||
|
||||
@models_app.command("switch")
|
||||
def models_switch(
|
||||
model: Annotated[str, typer.Argument(help="Model path or name to switch to")],
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
) -> None:
|
||||
"""Switch to a different model on the remote server."""
|
||||
console.print(f"[cyan]Switching to model: {model}[/cyan]")
|
||||
try:
|
||||
with _get_client(remote) as client:
|
||||
result = client.switch_model(model)
|
||||
except TsrClientError as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
console.print(f"[green]{result.get('status', 'OK')}[/green]")
|
||||
|
||||
|
||||
@models_app.command("loras")
|
||||
def models_loras(
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
) -> None:
|
||||
"""List available LoRAs on remote server."""
|
||||
try:
|
||||
with _get_client(remote) as client:
|
||||
result = client.list_loras()
|
||||
except TsrClientError as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
if json_output:
|
||||
console.print_json(data=result)
|
||||
return
|
||||
|
||||
loras = result.get("loras", [])
|
||||
if not loras:
|
||||
console.print("[yellow]No LoRAs found.[/yellow]")
|
||||
return
|
||||
|
||||
table = Table(title="Available LoRAs", show_header=True, header_style="bold magenta")
|
||||
table.add_column("ID", style="dim", width=8)
|
||||
table.add_column("Name", style="cyan")
|
||||
table.add_column("File", style="white")
|
||||
table.add_column("Size", style="green", justify="right")
|
||||
|
||||
for lora in loras:
|
||||
path = lora.get("path", "")
|
||||
name = lora.get("name", Path(path).stem if path else "")
|
||||
|
||||
civitai_id = lora.get("civitai_model_id")
|
||||
id_str = str(civitai_id) if civitai_id else ""
|
||||
display_name = lora.get("display_name", name)
|
||||
filename = lora.get("filename", Path(path).name if path else "")
|
||||
size_str = _format_size_mb(lora.get("size_mb"))
|
||||
|
||||
table.add_row(id_str, display_name, filename, size_str)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Remote Configuration Commands
|
||||
# =============================================================================
|
||||
|
||||
remote_app = typer.Typer(
|
||||
name="remote",
|
||||
help="Manage remote server configuration.",
|
||||
no_args_is_help=True,
|
||||
)
|
||||
app.add_typer(remote_app, name="remote")
|
||||
|
||||
|
||||
@remote_app.command("list")
|
||||
def remote_list(
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
) -> None:
|
||||
"""List configured remotes."""
|
||||
from tensors.config import get_default_remote # noqa: PLC0415
|
||||
|
||||
remotes = get_remotes()
|
||||
default = get_default_remote()
|
||||
|
||||
if json_output:
|
||||
console.print_json(data={"remotes": remotes, "default": default})
|
||||
return
|
||||
|
||||
if not remotes:
|
||||
console.print("[yellow]No remotes configured.[/yellow]")
|
||||
console.print("[dim]Add one with: tsr remote add NAME URL[/dim]")
|
||||
return
|
||||
|
||||
table = Table(title="Configured Remotes", show_header=True, header_style="bold magenta")
|
||||
table.add_column("Default", style="dim", width=3)
|
||||
table.add_column("Name", style="cyan")
|
||||
table.add_column("URL", style="green")
|
||||
|
||||
for name, url in remotes.items():
|
||||
is_default = name == default
|
||||
status = "[green]✓[/green]" if is_default else ""
|
||||
table.add_row(status, name, url)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
@remote_app.command("add")
|
||||
def remote_add(
|
||||
name: Annotated[str, typer.Argument(help="Remote name")],
|
||||
url: Annotated[str, typer.Argument(help="Remote URL (e.g., http://host:8080)")],
|
||||
) -> None:
|
||||
"""Add a remote server."""
|
||||
save_remote(name, url)
|
||||
console.print(f"[green]Added remote:[/green] {name} → {url}")
|
||||
|
||||
|
||||
@remote_app.command("default")
|
||||
def remote_default(
|
||||
name: Annotated[str | None, typer.Argument(help="Remote name to set as default (omit to clear)")] = None,
|
||||
) -> None:
|
||||
"""Set or clear the default remote."""
|
||||
set_default_remote(name)
|
||||
if name:
|
||||
console.print(f"[green]Default remote set to:[/green] {name}")
|
||||
else:
|
||||
console.print("[green]Default remote cleared.[/green]")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ComfyUI Commands
|
||||
# =============================================================================
|
||||
@@ -1424,14 +870,8 @@ def main() -> int:
|
||||
"dl",
|
||||
"download",
|
||||
"config",
|
||||
"generate",
|
||||
"serve",
|
||||
"status",
|
||||
"reload",
|
||||
"db",
|
||||
"images",
|
||||
"models",
|
||||
"remote",
|
||||
"comfy",
|
||||
)
|
||||
if len(sys.argv) > 1 and not sys.argv[1].startswith("-"):
|
||||
|
||||
+14
-31
@@ -1,4 +1,4 @@
|
||||
"""sd-server wrapper — FastAPI app for proxying to an external sd-server."""
|
||||
"""Tensors server — FastAPI app for gallery and CivitAI management."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -7,19 +7,14 @@ from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from tensors.config import get_sd_server_api_key, get_sd_server_url
|
||||
from tensors.server.civitai_routes import create_civitai_router
|
||||
from tensors.server.db_routes import create_db_router
|
||||
from tensors.server.download_routes import create_download_router
|
||||
from tensors.server.gallery_routes import create_gallery_router
|
||||
from tensors.server.generate_routes import create_generate_router
|
||||
from tensors.server.models_routes import create_models_router
|
||||
from tensors.server.routes import create_router
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator
|
||||
@@ -29,28 +24,15 @@ __all__ = ["app", "create_app"]
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_app(sd_server_url: str | None = None) -> FastAPI:
|
||||
"""Build the FastAPI application that proxies to an external sd-server.
|
||||
|
||||
Args:
|
||||
sd_server_url: URL of the sd-server to proxy to. If None, uses
|
||||
get_sd_server_url() to resolve from env/config.
|
||||
"""
|
||||
backend_url = sd_server_url or get_sd_server_url()
|
||||
api_key = get_sd_server_api_key()
|
||||
def create_app() -> FastAPI:
|
||||
"""Build the FastAPI application for gallery and model management."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
||||
_app.state.sd_server_url = backend_url
|
||||
_app.state.sd_server_api_key = api_key
|
||||
logger.info(f"Proxying to sd-server at: {backend_url}")
|
||||
if api_key:
|
||||
logger.info("Using API key authentication for sd-server")
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
_app.state.client = client
|
||||
yield
|
||||
logger.info("Tensors server starting")
|
||||
yield
|
||||
|
||||
app = FastAPI(title="sd-server wrapper", lifespan=lifespan)
|
||||
app = FastAPI(title="tensors", lifespan=lifespan)
|
||||
|
||||
# Serve Vue UI static files
|
||||
static_dir = Path(__file__).parent / "static"
|
||||
@@ -66,13 +48,14 @@ def create_app(sd_server_url: str | None = None) -> FastAPI:
|
||||
async def vite_icon() -> FileResponse:
|
||||
return FileResponse(static_dir / "vite.svg")
|
||||
|
||||
app.include_router(create_civitai_router()) # Must be before catch-all proxy
|
||||
app.include_router(create_db_router()) # Must be before catch-all proxy
|
||||
app.include_router(create_gallery_router()) # Must be before catch-all proxy
|
||||
app.include_router(create_models_router()) # Must be before catch-all proxy
|
||||
app.include_router(create_download_router()) # Must be before catch-all proxy
|
||||
app.include_router(create_generate_router()) # Must be before catch-all proxy
|
||||
app.include_router(create_router())
|
||||
@app.get("/status")
|
||||
async def status() -> dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
app.include_router(create_civitai_router())
|
||||
app.include_router(create_db_router())
|
||||
app.include_router(create_gallery_router())
|
||||
app.include_router(create_download_router())
|
||||
return app
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user