💬 Commit message: Update 2026-02-15 06:21:35, 7 files, 1559 lines

📁 Files changed: 7
📝 Lines changed: 1559

  • .coverage
  • cli.py
  • __init__.py
  • conftest.py
  • test_client.py
  • test_generate.py
  • test_server.py
This commit is contained in:
Adam Ladachowski
2026-02-15 06:21:35 +01:00
parent c419e443ae
commit 356d8fd156
7 changed files with 52 additions and 1507 deletions
BIN
View File
Binary file not shown.
+30 -590
View File
@@ -20,20 +20,15 @@ from tensors.api import (
fetch_civitai_model_version, fetch_civitai_model_version,
search_civitai, search_civitai,
) )
from tensors.client import TsrClient, TsrClientError
from tensors.config import ( from tensors.config import (
CONFIG_FILE, CONFIG_FILE,
BaseModel, BaseModel,
ModelType, ModelType,
SortOrder, SortOrder,
get_default_output_path, get_default_output_path,
get_remotes,
load_api_key, load_api_key,
load_config, load_config,
resolve_remote,
save_config, save_config,
save_remote,
set_default_remote,
) )
from tensors.db import DB_PATH, Database from tensors.db import DB_PATH, Database
from tensors.display import ( from tensors.display import (
@@ -49,18 +44,6 @@ from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_me
# Key masking threshold # Key masking threshold
MIN_KEY_LENGTH_FOR_MASKING = 8 MIN_KEY_LENGTH_FOR_MASKING = 8
# Size threshold for GB display
_MB_PER_GB = 1024
def _format_size_mb(size_mb: float | None) -> str:
"""Format size in MB to human-readable string."""
if not size_mb:
return ""
if size_mb >= _MB_PER_GB:
return f"{size_mb / _MB_PER_GB:.1f} GB"
return f"{size_mb:.0f} MB"
def _version_callback(value: bool) -> None: def _version_callback(value: bool) -> None:
if value: if value:
@@ -329,74 +312,41 @@ def download(
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None, output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None,
no_resume: Annotated[bool, typer.Option("--no-resume", help="Don't resume partial downloads")] = False, no_resume: Annotated[bool, typer.Option("--no-resume", help="Don't resume partial downloads")] = False,
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON (remote mode)")] = False,
) -> None: ) -> None:
"""Download a model from CivitAI (locally or to remote server).""" """Download a model from CivitAI."""
# Check if remote is specified or configured key = api_key or load_api_key()
remote_url = resolve_remote(remote)
if remote_url: resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key)
# Remote mode: use TsrClient API if not resolved_version_id:
if not version_id and not model_id and not hash_val: if not version_id and not hash_val and not model_id:
console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]") console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]")
raise typer.Exit(1) raise typer.Exit(1)
try: console.print(f"[cyan]Fetching version info for {resolved_version_id}...[/cyan]")
with TsrClient(remote_url) as client: version_info = fetch_civitai_model_version(resolved_version_id, key, console)
console.print(f"[cyan]Starting download on {remote_url}...[/cyan]") if not version_info:
result = client.start_download( console.print("[red]Error: Could not fetch model version info.[/red]")
version_id=version_id, raise typer.Exit(1)
model_id=model_id,
hash_val=hash_val,
output_dir=str(output) if output else None,
)
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
if json_output: model_type_str: str | None = version_info.get("model", {}).get("type")
console.print_json(data=result) output_dir = _prepare_download_dir(output, model_type_str)
return if not output_dir:
raise typer.Exit(1)
download_id = result.get("download_id") files: list[dict[str, Any]] = version_info.get("files", [])
console.print(f"[green]Download started:[/green] {download_id}") primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
console.print(f"[dim]Check status with: tsr images download-status {download_id} --remote {remote or 'default'}[/dim]") if not primary_file:
else: console.print("[red]Error: No files found for this version.[/red]")
# Local mode: direct download raise typer.Exit(1)
key = api_key or load_api_key()
resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key) filename = primary_file.get("name", f"model-{resolved_version_id}.safetensors")
if not resolved_version_id: dest_path = output_dir / filename
if not version_id and not hash_val and not model_id:
console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]")
raise typer.Exit(1)
console.print(f"[cyan]Fetching version info for {resolved_version_id}...[/cyan]") _display_download_info(version_info, filename, primary_file, dest_path)
version_info = fetch_civitai_model_version(resolved_version_id, key, console)
if not version_info:
console.print("[red]Error: Could not fetch model version info.[/red]")
raise typer.Exit(1)
model_type_str: str | None = version_info.get("model", {}).get("type") success = download_model(resolved_version_id, dest_path, key, console, resume=not no_resume)
output_dir = _prepare_download_dir(output, model_type_str) if not success:
if not output_dir: raise typer.Exit(1)
raise typer.Exit(1)
files: list[dict[str, Any]] = version_info.get("files", [])
primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
if not primary_file:
console.print("[red]Error: No files found for this version.[/red]")
raise typer.Exit(1)
filename = primary_file.get("name", f"model-{resolved_version_id}.safetensors")
dest_path = output_dir / filename
_display_download_info(version_info, filename, primary_file, dest_path)
success = download_model(resolved_version_id, dest_path, key, console, resume=not no_resume)
if not success:
raise typer.Exit(1)
def _display_download_info( def _display_download_info(
@@ -449,167 +399,13 @@ def config(
console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]") console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]")
@app.command()
def generate(
prompt: Annotated[str, typer.Argument(help="Text prompt for image generation.")],
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
model: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model (remote mode only).")] = None,
host: Annotated[str, typer.Option(help="sd-server address (local mode).")] = "127.0.0.1",
port: Annotated[int, typer.Option(help="sd-server port (local mode).")] = 8080,
output: Annotated[str, typer.Option("-o", help="Output directory (local mode).")] = ".",
negative_prompt: Annotated[str, typer.Option("-n", help="Negative prompt.")] = "",
width: Annotated[int, typer.Option("-W", help="Image width.")] = 512,
height: Annotated[int, typer.Option("-H", help="Image height.")] = 512,
steps: Annotated[int, typer.Option(help="Sampling steps.")] = 20,
cfg_scale: Annotated[float, typer.Option(help="CFG scale.")] = 7.0,
seed: Annotated[int, typer.Option("-s", help="RNG seed (-1 for random).")] = -1,
sampler: Annotated[str, typer.Option(help="Sampler name.")] = "",
scheduler: Annotated[str, typer.Option(help="Scheduler name.")] = "",
batch_size: Annotated[int, typer.Option("-b", help="Number of images.")] = 1,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON (remote mode)")] = False,
) -> None:
"""Generate images using sd-server (local or remote)."""
# Check if remote is specified or configured
remote_url = resolve_remote(remote)
if remote_url:
# Remote mode: use TsrClient API
try:
with TsrClient(remote_url) as client:
# Switch model if specified
if model:
console.print(f"[cyan]Switching to model: {model}[/cyan]")
client.switch_model(model)
console.print(f"[cyan]Generating {batch_size} image(s) on {remote_url}...[/cyan]")
result = client.generate(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
steps=steps,
cfg_scale=cfg_scale,
seed=seed,
sampler_name=sampler,
scheduler=scheduler,
batch_size=batch_size,
)
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
if json_output:
console.print_json(data=result)
return
images = result.get("images", [])
for img in images:
console.print(f"[green]Generated:[/green] {img.get('id', 'unknown')}")
else:
# Local mode: direct sd-server connection
if model:
console.print("[yellow]Warning: --model ignored in local mode (sd-server loads model at startup)[/yellow]")
from tensors.generate import SDClient, Txt2ImgParams, save_images # noqa: PLC0415
params = Txt2ImgParams(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
steps=steps,
cfg_scale=cfg_scale,
seed=seed,
batch_size=batch_size,
sampler_name=sampler,
scheduler=scheduler,
)
with SDClient(host=host, port=port) as client:
console.print(f"[cyan]Generating {batch_size} image(s)...[/cyan]")
images = client.generate.txt2img(params)
paths = save_images(images, output)
for p in paths:
console.print(f"[green]Saved:[/green] {p}")
@app.command()
def status(
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
host: Annotated[str, typer.Option(help="Wrapper API host (local mode).")] = "127.0.0.1",
port: Annotated[int, typer.Option(help="Wrapper API port (local mode).")] = 8080,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Show sd-server wrapper status."""
# Check if remote is specified or configured
remote_url = resolve_remote(remote)
if remote_url:
# Remote mode: use TsrClient API
try:
with TsrClient(remote_url) as client:
data = client.status()
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
else:
# Local mode: direct HTTP call
import httpx # noqa: PLC0415
url = f"http://{host}:{port}/status"
try:
resp = httpx.get(url, timeout=10)
resp.raise_for_status()
data = resp.json()
except httpx.HTTPError as e:
console.print(f"[red]Error: Could not reach wrapper at {url}: {e}[/red]")
raise typer.Exit(1) from e
if json_output:
console.print_json(data=data)
return
table = Table(title="Server Status", show_header=True, header_style="bold magenta")
table.add_column("Property", style="cyan")
table.add_column("Value", style="green")
for key, value in data.items():
table.add_row(key, str(value))
console.print(table)
@app.command()
def reload(
model: Annotated[str, typer.Option(help="Path to model file for sd-server.")],
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
host: Annotated[str, typer.Option(help="Wrapper API host (local mode).")] = "127.0.0.1",
port: Annotated[int, typer.Option(help="Wrapper API port (local mode).")] = 8080,
) -> None:
"""Reload sd-server with a new model."""
import httpx # noqa: PLC0415
remote_url = resolve_remote(remote)
url = f"{remote_url.rstrip('/')}/reload" if remote_url else f"http://{host}:{port}/reload"
console.print(f"[cyan]Reloading model: {model}[/cyan]")
try:
resp = httpx.post(url, json={"model": model}, timeout=300)
resp.raise_for_status()
data = resp.json()
except httpx.HTTPError as e:
console.print(f"[red]Error: Reload failed at {url}: {e}[/red]")
raise typer.Exit(1) from e
console.print(f"[green]{data.get('status', 'OK')}[/green]")
@app.command() @app.command()
def serve( def serve(
host: Annotated[str, typer.Option(help="Wrapper API listen address.")] = "127.0.0.1", host: Annotated[str, typer.Option(help="Listen address.")] = "127.0.0.1",
port: Annotated[int, typer.Option(help="Wrapper API listen port.")] = 8080, port: Annotated[int, typer.Option(help="Listen port.")] = 8080,
sd_server: Annotated[str | None, typer.Option(help="sd-server URL to proxy to.")] = None,
log_level: Annotated[str, typer.Option(help="Log level.")] = "info", log_level: Annotated[str, typer.Option(help="Log level.")] = "info",
) -> None: ) -> None:
"""Start the sd-server wrapper API (proxies to external sd-server).""" """Start the tensors server (gallery and CivitAI management)."""
try: try:
import uvicorn # noqa: PLC0415 import uvicorn # noqa: PLC0415
@@ -619,7 +415,7 @@ def serve(
console.print(" pip install tensors[server]") console.print(" pip install tensors[server]")
raise typer.Exit(1) from None raise typer.Exit(1) from None
uvicorn.run(create_app(sd_server_url=sd_server), host=host, port=port, log_level=log_level) uvicorn.run(create_app(), host=host, port=port, log_level=log_level)
# ============================================================================= # =============================================================================
@@ -858,356 +654,6 @@ def db_stats(
console.print(table) console.print(table)
# =============================================================================
# Images Commands (Remote)
# =============================================================================
images_app = typer.Typer(
name="images",
help="Manage images in remote gallery.",
no_args_is_help=True,
)
app.add_typer(images_app, name="images")
def _get_client(remote: str | None) -> TsrClient:
"""Get TsrClient for remote or raise error."""
url = resolve_remote(remote)
if not url:
console.print("[red]Error: No remote specified. Use --remote or set default_remote in config.[/red]")
raise typer.Exit(1)
return TsrClient(url)
@images_app.command("list")
def images_list(
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 50,
offset: Annotated[int, typer.Option("--offset", help="Offset for pagination")] = 0,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""List images in remote gallery."""
try:
with _get_client(remote) as client:
result = client.list_images(limit=limit, offset=offset)
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
images = result.get("images", [])
total = result.get("total", len(images))
if json_output:
console.print_json(data=result)
return
if not images:
console.print("[yellow]No images in gallery.[/yellow]")
return
table = Table(title=f"Gallery Images ({len(images)}/{total})", show_header=True, header_style="bold magenta")
table.add_column("ID", style="cyan")
table.add_column("Filename", style="green")
table.add_column("Size", style="white")
table.add_column("Created", style="dim")
for img in images:
size = f"{img.get('width', '?')}x{img.get('height', '?')}"
created = img.get("created_at", "")
if isinstance(created, (int, float)):
from datetime import datetime # noqa: PLC0415
created = datetime.fromtimestamp(created).strftime("%Y-%m-%d %H:%M")
table.add_row(img.get("id", ""), img.get("filename", ""), size, str(created))
console.print(table)
@images_app.command("show")
def images_show(
image_id: Annotated[str, typer.Argument(help="Image ID to show")],
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Show image metadata."""
try:
with _get_client(remote) as client:
meta = client.get_image_meta(image_id)
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
if json_output:
console.print_json(data=meta)
return
table = Table(title=f"Image: {image_id}", show_header=True, header_style="bold magenta")
table.add_column("Property", style="cyan")
table.add_column("Value", style="green")
for key, value in meta.items():
display_value = json.dumps(value, indent=2) if isinstance(value, dict) else str(value)
table.add_row(key, display_value)
console.print(table)
@images_app.command("delete")
def images_delete(
image_id: Annotated[str, typer.Argument(help="Image ID to delete")],
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
force: Annotated[bool, typer.Option("-f", "--force", help="Skip confirmation")] = False,
) -> None:
"""Delete an image from the gallery."""
if not force:
confirm = typer.confirm(f"Delete image {image_id}?")
if not confirm:
console.print("[yellow]Cancelled.[/yellow]")
raise typer.Exit(0)
try:
with _get_client(remote) as client:
client.delete_image(image_id)
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
console.print(f"[green]Deleted image: {image_id}[/green]")
@images_app.command("download")
def images_download(
image_id: Annotated[str, typer.Argument(help="Image ID to download")],
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file or directory")] = None,
) -> None:
"""Download an image from the remote gallery."""
try:
with _get_client(remote) as client:
content = client.download_image(image_id)
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
# Determine output path
if output is None:
dest = Path(f"{image_id}.png")
elif output.is_dir():
dest = output / f"{image_id}.png"
else:
dest = output
dest.write_bytes(content)
console.print(f"[green]Saved:[/green] {dest}")
# =============================================================================
# Models Commands (Remote)
# =============================================================================
models_app = typer.Typer(
name="models",
help="Manage models on remote server.",
no_args_is_help=True,
)
app.add_typer(models_app, name="models")
@models_app.command("list")
def models_list(
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""List available models on remote server."""
try:
with _get_client(remote) as client:
result = client.list_models()
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
if json_output:
console.print_json(data=result)
return
models = result.get("models", [])
active = result.get("active", "")
if not models:
console.print("[yellow]No models found.[/yellow]")
return
table = Table(title="Available Models", show_header=True, header_style="bold magenta")
table.add_column("ID", style="dim", width=8)
table.add_column("Name", style="cyan")
table.add_column("File", style="white")
table.add_column("Size", style="green", justify="right")
for model in models:
path = model.get("path", "")
name = model.get("name", Path(path).stem if path else "")
is_active = active in {path, name}
civitai_id = model.get("civitai_model_id")
id_str = str(civitai_id) if civitai_id else ""
display_name = model.get("display_name", name)
if is_active:
display_name = f"[green]✓[/green] {display_name}"
filename = model.get("filename", Path(path).name if path else "")
size_str = _format_size_mb(model.get("size_mb"))
table.add_row(id_str, display_name, filename, size_str)
console.print(table)
@models_app.command("active")
def models_active(
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Show currently active model."""
try:
with _get_client(remote) as client:
result = client.get_active_model()
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
if json_output:
console.print_json(data=result)
return
model = result.get("model", "None")
console.print(f"[bold]Active model:[/bold] {model}")
@models_app.command("switch")
def models_switch(
model: Annotated[str, typer.Argument(help="Model path or name to switch to")],
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
) -> None:
"""Switch to a different model on the remote server."""
console.print(f"[cyan]Switching to model: {model}[/cyan]")
try:
with _get_client(remote) as client:
result = client.switch_model(model)
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
console.print(f"[green]{result.get('status', 'OK')}[/green]")
@models_app.command("loras")
def models_loras(
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""List available LoRAs on remote server."""
try:
with _get_client(remote) as client:
result = client.list_loras()
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
if json_output:
console.print_json(data=result)
return
loras = result.get("loras", [])
if not loras:
console.print("[yellow]No LoRAs found.[/yellow]")
return
table = Table(title="Available LoRAs", show_header=True, header_style="bold magenta")
table.add_column("ID", style="dim", width=8)
table.add_column("Name", style="cyan")
table.add_column("File", style="white")
table.add_column("Size", style="green", justify="right")
for lora in loras:
path = lora.get("path", "")
name = lora.get("name", Path(path).stem if path else "")
civitai_id = lora.get("civitai_model_id")
id_str = str(civitai_id) if civitai_id else ""
display_name = lora.get("display_name", name)
filename = lora.get("filename", Path(path).name if path else "")
size_str = _format_size_mb(lora.get("size_mb"))
table.add_row(id_str, display_name, filename, size_str)
console.print(table)
# =============================================================================
# Remote Configuration Commands
# =============================================================================
remote_app = typer.Typer(
name="remote",
help="Manage remote server configuration.",
no_args_is_help=True,
)
app.add_typer(remote_app, name="remote")
@remote_app.command("list")
def remote_list(
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""List configured remotes."""
from tensors.config import get_default_remote # noqa: PLC0415
remotes = get_remotes()
default = get_default_remote()
if json_output:
console.print_json(data={"remotes": remotes, "default": default})
return
if not remotes:
console.print("[yellow]No remotes configured.[/yellow]")
console.print("[dim]Add one with: tsr remote add NAME URL[/dim]")
return
table = Table(title="Configured Remotes", show_header=True, header_style="bold magenta")
table.add_column("Default", style="dim", width=3)
table.add_column("Name", style="cyan")
table.add_column("URL", style="green")
for name, url in remotes.items():
is_default = name == default
status = "[green]✓[/green]" if is_default else ""
table.add_row(status, name, url)
console.print(table)
@remote_app.command("add")
def remote_add(
name: Annotated[str, typer.Argument(help="Remote name")],
url: Annotated[str, typer.Argument(help="Remote URL (e.g., http://host:8080)")],
) -> None:
"""Add a remote server."""
save_remote(name, url)
console.print(f"[green]Added remote:[/green] {name}{url}")
@remote_app.command("default")
def remote_default(
name: Annotated[str | None, typer.Argument(help="Remote name to set as default (omit to clear)")] = None,
) -> None:
"""Set or clear the default remote."""
set_default_remote(name)
if name:
console.print(f"[green]Default remote set to:[/green] {name}")
else:
console.print("[green]Default remote cleared.[/green]")
# ============================================================================= # =============================================================================
# ComfyUI Commands # ComfyUI Commands
# ============================================================================= # =============================================================================
@@ -1424,14 +870,8 @@ def main() -> int:
"dl", "dl",
"download", "download",
"config", "config",
"generate",
"serve", "serve",
"status",
"reload",
"db", "db",
"images",
"models",
"remote",
"comfy", "comfy",
) )
if len(sys.argv) > 1 and not sys.argv[1].startswith("-"): if len(sys.argv) > 1 and not sys.argv[1].startswith("-"):
+14 -31
View File
@@ -1,4 +1,4 @@
"""sd-server wrapper — FastAPI app for proxying to an external sd-server.""" """Tensors server — FastAPI app for gallery and CivitAI management."""
from __future__ import annotations from __future__ import annotations
@@ -7,19 +7,14 @@ from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import httpx
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from tensors.config import get_sd_server_api_key, get_sd_server_url
from tensors.server.civitai_routes import create_civitai_router from tensors.server.civitai_routes import create_civitai_router
from tensors.server.db_routes import create_db_router from tensors.server.db_routes import create_db_router
from tensors.server.download_routes import create_download_router from tensors.server.download_routes import create_download_router
from tensors.server.gallery_routes import create_gallery_router from tensors.server.gallery_routes import create_gallery_router
from tensors.server.generate_routes import create_generate_router
from tensors.server.models_routes import create_models_router
from tensors.server.routes import create_router
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
@@ -29,28 +24,15 @@ __all__ = ["app", "create_app"]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def create_app(sd_server_url: str | None = None) -> FastAPI: def create_app() -> FastAPI:
"""Build the FastAPI application that proxies to an external sd-server. """Build the FastAPI application for gallery and model management."""
Args:
sd_server_url: URL of the sd-server to proxy to. If None, uses
get_sd_server_url() to resolve from env/config.
"""
backend_url = sd_server_url or get_sd_server_url()
api_key = get_sd_server_api_key()
@asynccontextmanager @asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncIterator[None]: async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
_app.state.sd_server_url = backend_url logger.info("Tensors server starting")
_app.state.sd_server_api_key = api_key yield
logger.info(f"Proxying to sd-server at: {backend_url}")
if api_key:
logger.info("Using API key authentication for sd-server")
async with httpx.AsyncClient(timeout=300) as client:
_app.state.client = client
yield
app = FastAPI(title="sd-server wrapper", lifespan=lifespan) app = FastAPI(title="tensors", lifespan=lifespan)
# Serve Vue UI static files # Serve Vue UI static files
static_dir = Path(__file__).parent / "static" static_dir = Path(__file__).parent / "static"
@@ -66,13 +48,14 @@ def create_app(sd_server_url: str | None = None) -> FastAPI:
async def vite_icon() -> FileResponse: async def vite_icon() -> FileResponse:
return FileResponse(static_dir / "vite.svg") return FileResponse(static_dir / "vite.svg")
app.include_router(create_civitai_router()) # Must be before catch-all proxy @app.get("/status")
app.include_router(create_db_router()) # Must be before catch-all proxy async def status() -> dict[str, str]:
app.include_router(create_gallery_router()) # Must be before catch-all proxy return {"status": "ok"}
app.include_router(create_models_router()) # Must be before catch-all proxy
app.include_router(create_download_router()) # Must be before catch-all proxy app.include_router(create_civitai_router())
app.include_router(create_generate_router()) # Must be before catch-all proxy app.include_router(create_db_router())
app.include_router(create_router()) app.include_router(create_gallery_router())
app.include_router(create_download_router())
return app return app
-20
View File
@@ -7,11 +7,6 @@ import json
import struct import struct
import pytest import pytest
import respx
from tensors.generate import SDClient
BASE_URL = "http://127.0.0.1:1234"
# 1x1 red PNG for image response stubs # 1x1 red PNG for image response stubs
TINY_PNG = ( TINY_PNG = (
@@ -43,18 +38,3 @@ def temp_safetensor(tmp_path):
f.write(header_bytes) f.write(header_bytes)
return file_path return file_path
@pytest.fixture()
def mock_api():
"""Activate respx mock for the sd-server base URL."""
with respx.mock(base_url=BASE_URL, assert_all_called=False) as rsps:
yield rsps
@pytest.fixture()
def client(mock_api: respx.MockRouter) -> SDClient: # noqa: ARG001
"""SDClient wired to the mocked transport."""
c = SDClient()
yield c # type: ignore[misc]
c.close()
-481
View File
@@ -1,481 +0,0 @@
"""Tests for the TsrClient HTTP client module."""
from __future__ import annotations
import pytest
import respx
from httpx import Response
from tensors.client import TsrClient, TsrClientError
BASE_URL = "http://test-server:8080"
@pytest.fixture
def mock_server():
"""Activate respx mock for the test server."""
with respx.mock(base_url=BASE_URL, assert_all_called=False) as rsps:
yield rsps
@pytest.fixture
def client(mock_server) -> TsrClient: # noqa: ARG001 - mock_server activates respx
"""TsrClient connected to mock server."""
return TsrClient(BASE_URL)
# =============================================================================
# Status Tests
# =============================================================================
class TestStatus:
"""Tests for server status endpoint."""
def test_status_success(self, client: TsrClient, mock_server) -> None:
"""Test getting server status."""
mock_server.get("/status").mock(return_value=Response(200, json={"running": True, "pid": 12345, "model": "/test.gguf"}))
with client:
result = client.status()
assert result["running"] is True
assert result["pid"] == 12345
def test_status_error(self, client: TsrClient, mock_server) -> None:
"""Test handling status error."""
mock_server.get("/status").mock(return_value=Response(503, text="Service unavailable"))
with client, pytest.raises(TsrClientError, match="HTTP 503"):
client.status()
# =============================================================================
# Gallery Tests
# =============================================================================
class TestGalleryImages:
"""Tests for gallery image operations."""
def test_list_images(self, client: TsrClient, mock_server) -> None:
"""Test listing gallery images."""
mock_server.get("/api/images").mock(
return_value=Response(
200,
json={
"images": [
{"id": "123_0", "filename": "123_0.png", "width": 512, "height": 512},
{"id": "124_1", "filename": "124_1.png", "width": 1024, "height": 1024},
],
"total": 2,
},
)
)
with client:
result = client.list_images()
assert len(result["images"]) == 2
assert result["total"] == 2
def test_list_images_with_pagination(self, client: TsrClient, mock_server) -> None:
"""Test listing images with pagination."""
mock_server.get("/api/images", params={"limit": 10, "offset": 5}).mock(
return_value=Response(200, json={"images": [], "total": 100})
)
with client:
result = client.list_images(limit=10, offset=5)
assert result["total"] == 100
def test_get_image_meta(self, client: TsrClient, mock_server) -> None:
"""Test getting image metadata."""
mock_server.get("/api/images/123_0/meta").mock(
return_value=Response(
200,
json={
"id": "123_0",
"path": "/gallery/123_0.png",
"metadata": {"prompt": "test prompt", "seed": 42},
},
)
)
with client:
result = client.get_image_meta("123_0")
assert result["id"] == "123_0"
assert result["metadata"]["prompt"] == "test prompt"
def test_delete_image(self, client: TsrClient, mock_server) -> None:
"""Test deleting an image."""
mock_server.delete("/api/images/123_0").mock(return_value=Response(200, json={"deleted": True, "id": "123_0"}))
with client:
result = client.delete_image("123_0")
assert result["deleted"] is True
def test_edit_image(self, client: TsrClient, mock_server) -> None:
"""Test editing image metadata."""
mock_server.post("/api/images/123_0/edit").mock(
return_value=Response(200, json={"id": "123_0", "metadata": {"tags": ["favorite"], "rating": 5}})
)
with client:
result = client.edit_image("123_0", {"tags": ["favorite"], "rating": 5})
assert result["metadata"]["tags"] == ["favorite"]
def test_download_image(self, client: TsrClient, mock_server) -> None:
"""Test downloading image bytes."""
image_bytes = b"\x89PNG test image data"
mock_server.get("/api/images/123_0").mock(return_value=Response(200, content=image_bytes))
with client:
result = client.download_image("123_0")
assert result == image_bytes
# =============================================================================
# Models Tests
# =============================================================================
class TestModels:
"""Tests for model management operations."""
def test_list_models(self, client: TsrClient, mock_server) -> None:
"""Test listing available models."""
mock_server.get("/api/models").mock(
return_value=Response(
200,
json={
"models": [
{"name": "sdxl_base", "path": "/models/sdxl_base.safetensors"},
{"name": "pony_v6", "path": "/models/pony_v6.safetensors"},
],
"active": "/models/sdxl_base.safetensors",
},
)
)
with client:
result = client.list_models()
assert len(result["models"]) == 2
assert result["active"] == "/models/sdxl_base.safetensors"
def test_get_active_model(self, client: TsrClient, mock_server) -> None:
"""Test getting active model."""
mock_server.get("/api/models/active").mock(return_value=Response(200, json={"model": "/models/sdxl_base.safetensors"}))
with client:
result = client.get_active_model()
assert result["model"] == "/models/sdxl_base.safetensors"
def test_switch_model(self, client: TsrClient, mock_server) -> None:
"""Test switching model."""
mock_server.post("/api/models/switch").mock(
return_value=Response(200, json={"status": "ok", "model": "/models/pony_v6.safetensors"})
)
with client:
result = client.switch_model("/models/pony_v6.safetensors")
assert result["status"] == "ok"
def test_list_loras(self, client: TsrClient, mock_server) -> None:
"""Test listing LoRAs."""
mock_server.get("/api/models/loras").mock(
return_value=Response(
200,
json={
"loras": [
{"name": "detail_tweaker", "path": "/loras/detail_tweaker.safetensors"},
]
},
)
)
with client:
result = client.list_loras()
assert len(result["loras"]) == 1
def test_scan_models(self, client: TsrClient, mock_server) -> None:
"""Test scanning models."""
mock_server.get("/api/models/scan").mock(return_value=Response(200, json={"scanned": 5}))
with client:
result = client.scan_models()
assert result["scanned"] == 5
# =============================================================================
# Generation Tests
# =============================================================================
class TestGeneration:
"""Tests for image generation."""
def test_generate(self, client: TsrClient, mock_server) -> None:
"""Test generating an image."""
mock_server.post("/api/generate").mock(
return_value=Response(
200,
json={
"images": [{"id": "999_42", "seed": 42}],
"parameters": {"prompt": "test prompt", "seed": 42},
},
)
)
with client:
result = client.generate(
prompt="test prompt",
width=512,
height=512,
seed=42,
)
assert len(result["images"]) == 1
assert result["images"][0]["seed"] == 42
def test_generate_with_all_params(self, client: TsrClient, mock_server) -> None:
"""Test generation with all parameters."""
mock_server.post("/api/generate").mock(return_value=Response(200, json={"images": []}))
with client:
result = client.generate(
prompt="detailed test prompt",
negative_prompt="bad quality",
width=1024,
height=1024,
steps=30,
cfg_scale=5.5,
seed=12345,
sampler_name="DPM++ 2M",
scheduler="karras",
batch_size=2,
save_to_gallery=False,
return_base64=True,
)
assert "images" in result
def test_list_samplers(self, client: TsrClient, mock_server) -> None:
"""Test listing samplers."""
mock_server.get("/api/samplers").mock(return_value=Response(200, json={"samplers": ["Euler", "DPM++ 2M", "Euler a"]}))
with client:
result = client.list_samplers()
assert "samplers" in result
def test_list_schedulers(self, client: TsrClient, mock_server) -> None:
"""Test listing schedulers."""
mock_server.get("/api/schedulers").mock(
return_value=Response(200, json={"schedulers": ["simple", "karras", "sgm_uniform"]})
)
with client:
result = client.list_schedulers()
assert "schedulers" in result
# =============================================================================
# Download Tests
# =============================================================================
class TestDownload:
"""Tests for CivitAI download operations."""
def test_start_download_by_version(self, client: TsrClient, mock_server) -> None:
"""Test starting download by version ID."""
mock_server.post("/api/download").mock(
return_value=Response(200, json={"download_id": "abc123", "status": "started", "version_id": 12345})
)
with client:
result = client.start_download(version_id=12345)
assert result["download_id"] == "abc123"
def test_start_download_by_hash(self, client: TsrClient, mock_server) -> None:
"""Test starting download by hash."""
mock_server.post("/api/download").mock(return_value=Response(200, json={"download_id": "def456", "status": "started"}))
with client:
result = client.start_download(hash_val="ABC123DEF456")
assert result["status"] == "started"
def test_get_download_status(self, client: TsrClient, mock_server) -> None:
"""Test getting download status."""
mock_server.get("/api/download/status/abc123").mock(
return_value=Response(200, json={"download_id": "abc123", "status": "downloading", "progress": 0.5})
)
with client:
result = client.get_download_status("abc123")
assert result["progress"] == 0.5
def test_list_downloads(self, client: TsrClient, mock_server) -> None:
"""Test listing active downloads."""
mock_server.get("/api/download/active").mock(
return_value=Response(200, json={"downloads": [{"id": "abc123", "progress": 0.75}]})
)
with client:
result = client.list_downloads()
assert len(result["downloads"]) == 1
# =============================================================================
# Database Tests
# =============================================================================
class TestDatabase:
"""Tests for database operations."""
def test_db_list_files(self, client: TsrClient, mock_server) -> None:
"""Test listing local files."""
mock_server.get("/api/db/files").mock(
return_value=Response(200, json=[{"id": 1, "file_path": "/models/test.safetensors", "sha256": "abc123"}])
)
with client:
result = client.db_list_files()
assert len(result) == 1
assert result[0]["sha256"] == "abc123"
def test_db_search_models(self, client: TsrClient, mock_server) -> None:
"""Test searching cached models."""
mock_server.get("/api/db/models").mock(
return_value=Response(200, json=[{"civitai_id": 12345, "name": "Test Model", "type": "LORA"}])
)
with client:
result = client.db_search_models(query="Test", model_type="LORA")
assert len(result) == 1
assert result[0]["name"] == "Test Model"
def test_db_get_model(self, client: TsrClient, mock_server) -> None:
"""Test getting cached model."""
mock_server.get("/api/db/models/12345").mock(
return_value=Response(200, json={"civitai_id": 12345, "name": "Test Model", "type": "Checkpoint"})
)
with client:
result = client.db_get_model(12345)
assert result["name"] == "Test Model"
def test_db_get_triggers(self, client: TsrClient, mock_server) -> None:
"""Test getting trigger words."""
mock_server.get("/api/db/triggers/12345").mock(return_value=Response(200, json=["trigger1", "trigger2"]))
with client:
result = client.db_get_triggers(version_id=12345)
assert result == ["trigger1", "trigger2"]
def test_db_stats(self, client: TsrClient, mock_server) -> None:
"""Test getting database stats."""
mock_server.get("/api/db/stats").mock(
return_value=Response(200, json={"local_files": 10, "models": 5, "model_versions": 15})
)
with client:
result = client.db_stats()
assert result["local_files"] == 10
def test_db_scan(self, client: TsrClient, mock_server) -> None:
"""Test scanning directory."""
mock_server.post("/api/db/scan").mock(return_value=Response(200, json={"scanned": 3, "files": []}))
with client:
result = client.db_scan("/models")
assert result["scanned"] == 3
def test_db_link(self, client: TsrClient, mock_server) -> None:
"""Test linking files to CivitAI."""
mock_server.post("/api/db/link").mock(return_value=Response(200, json={"linked": 2}))
with client:
result = client.db_link()
assert result["linked"] == 2
def test_db_cache(self, client: TsrClient, mock_server) -> None:
"""Test caching model data."""
mock_server.post("/api/db/cache").mock(return_value=Response(200, json={"model_id": 12345, "cached": True}))
with client:
result = client.db_cache(12345)
assert result["cached"] is True
# =============================================================================
# Error Handling Tests
# =============================================================================
class TestErrorHandling:
"""Tests for error handling."""
def test_http_error(self, client: TsrClient, mock_server) -> None:
"""Test HTTP error handling."""
mock_server.get("/api/images").mock(return_value=Response(500, text="Internal server error"))
with client, pytest.raises(TsrClientError, match="HTTP 500"):
client.list_images()
def test_not_found_error(self, client: TsrClient, mock_server) -> None:
"""Test 404 error handling."""
mock_server.get("/api/images/nonexistent/meta").mock(return_value=Response(404, json={"detail": "Image not found"}))
with client, pytest.raises(TsrClientError, match="HTTP 404"):
client.get_image_meta("nonexistent")
# =============================================================================
# Context Manager Tests
# =============================================================================
class TestContextManager:
"""Tests for context manager usage."""
def test_context_manager(self, mock_server) -> None:
"""Test client works as context manager."""
mock_server.get("/status").mock(return_value=Response(200, json={"running": True}))
with TsrClient(BASE_URL) as client:
result = client.status()
assert result["running"] is True
def test_client_without_context(self, mock_server) -> None:
"""Test client works without context manager."""
mock_server.get("/status").mock(return_value=Response(200, json={"running": True}))
client = TsrClient(BASE_URL)
result = client.status()
assert result["running"] is True
-303
View File
@@ -1,303 +0,0 @@
"""Tests for tensors.generate package."""
from __future__ import annotations
import base64
import json
from pathlib import Path
import httpx
import pytest
import respx
from tensors.generate import SDClient
from tensors.generate._http import HttpTransport
from tensors.generate.params import Img2ImgParams, Txt2ImgParams
from tensors.generate.util import save_images, to_b64
from tests.conftest import BASE_URL, TINY_PNG, TINY_PNG_B64
# ── util ──────────────────────────────────────────────────────────────
class TestToB64:
def test_bytes_input(self):
raw = b"hello"
assert to_b64(raw) == base64.b64encode(raw).decode()
def test_file_path(self, tmp_path: Path):
f = tmp_path / "img.png"
f.write_bytes(b"\x89PNG")
result = to_b64(str(f))
assert base64.b64decode(result) == b"\x89PNG"
def test_pathlib_path(self, tmp_path: Path):
f = tmp_path / "img.png"
f.write_bytes(b"data")
result = to_b64(f)
assert base64.b64decode(result) == b"data"
def test_passthrough_string(self):
b64 = base64.b64encode(b"already").decode()
assert to_b64(b64) == b64
def test_unsupported_type(self):
with pytest.raises(TypeError, match="unsupported image type"):
to_b64(12345) # type: ignore[arg-type]
class TestSaveImages:
def test_saves_files(self, tmp_path: Path):
images = [b"img0", b"img1", b"img2"]
paths = save_images(images, str(tmp_path), prefix="test")
assert len(paths) == 3
for i, p in enumerate(paths):
assert p.name == f"test_{i:04d}.png"
assert p.read_bytes() == images[i]
def test_creates_directory(self, tmp_path: Path):
out = tmp_path / "sub" / "dir"
save_images([b"x"], str(out))
assert (out / "output_0000.png").exists()
# ── params ────────────────────────────────────────────────────────────
class TestTxt2ImgParams:
def test_minimal_body(self):
p = Txt2ImgParams(prompt="a cat")
body = p.to_body()
assert body["prompt"] == "a cat"
assert body["width"] == 512
assert body["height"] == 512
assert body["steps"] == 20
assert body["seed"] == -1
assert "sampler_name" not in body
assert "scheduler" not in body
assert "clip_skip" not in body
assert "lora" not in body
def test_optional_fields_included(self):
p = Txt2ImgParams(
prompt="test",
sampler_name="euler_a",
scheduler="karras",
clip_skip=2,
lora=[{"path": "x.safetensors", "multiplier": 0.5}],
)
body = p.to_body()
assert body["sampler_name"] == "euler_a"
assert body["scheduler"] == "karras"
assert body["clip_skip"] == 2
assert len(body["lora"]) == 1
class TestImg2ImgParams:
def test_minimal_body(self, tmp_path: Path):
img = tmp_path / "init.png"
img.write_bytes(b"\x89PNG")
p = Img2ImgParams(prompt="paint it", init_image=str(img))
body = p.to_body()
assert body["prompt"] == "paint it"
assert body["denoising_strength"] == 0.75
decoded = base64.b64decode(body["init_images"][0])
assert decoded == b"\x89PNG"
assert "width" not in body
assert "height" not in body
assert "mask" not in body
def test_all_optional_fields(self, tmp_path: Path):
img = tmp_path / "init.png"
img.write_bytes(b"img")
mask = tmp_path / "mask.png"
mask.write_bytes(b"mask")
extra = tmp_path / "extra.png"
extra.write_bytes(b"extra")
p = Img2ImgParams(
prompt="test",
init_image=str(img),
mask=str(mask),
width=768,
height=768,
inpainting_mask_invert=True,
sampler_name="euler",
scheduler="simple",
clip_skip=1,
lora=[{"path": "a.gguf", "multiplier": 1.0}],
extra_images=[str(extra)],
)
body = p.to_body()
assert body["width"] == 768
assert body["mask"]
assert body["inpainting_mask_invert"] == 1
assert body["sampler_name"] == "euler"
assert len(body["extra_images"]) == 1
# ── _http ─────────────────────────────────────────────────────────────
class TestHttpTransport:
def test_get_success(self):
with respx.mock(base_url=BASE_URL) as rsps:
rsps.get("/test").respond(json={"ok": True})
t = HttpTransport(BASE_URL)
assert t.get("/test") == {"ok": True}
t.close()
def test_post_success(self):
with respx.mock(base_url=BASE_URL) as rsps:
rsps.post("/gen").respond(json={"images": []})
t = HttpTransport(BASE_URL)
assert t.post("/gen", {"prompt": "x"}) == {"images": []}
t.close()
def test_get_http_error(self):
with respx.mock(base_url=BASE_URL) as rsps:
rsps.get("/bad").respond(status_code=404, text="not found")
t = HttpTransport(BASE_URL)
with pytest.raises(httpx.HTTPStatusError):
t.get("/bad")
t.close()
def test_post_http_error(self):
with respx.mock(base_url=BASE_URL) as rsps:
rsps.post("/bad").respond(status_code=500, text="error")
t = HttpTransport(BASE_URL)
with pytest.raises(httpx.HTTPStatusError):
t.post("/bad", {})
t.close()
def test_get_connection_error(self):
with respx.mock(base_url=BASE_URL) as rsps:
rsps.get("/fail").mock(side_effect=httpx.ConnectError("refused"))
t = HttpTransport(BASE_URL)
with pytest.raises(httpx.ConnectError):
t.get("/fail")
t.close()
# ── info ──────────────────────────────────────────────────────────────
class TestInfoAPI:
def test_models(self, mock_api: respx.MockRouter, client: SDClient):
mock_api.get("/v1/models").respond(json={"data": [{"id": "sd-cpp-local", "object": "model", "owned_by": "local"}]})
result = client.info.models()
assert len(result) == 1
assert result[0]["id"] == "sd-cpp-local"
def test_sd_models(self, mock_api: respx.MockRouter, client: SDClient):
mock_api.get("/sdapi/v1/sd-models").respond(
json=[{"title": "sdxl", "model_name": "sdxl", "filename": "sdxl.safetensors"}]
)
result = client.info.sd_models()
assert result[0]["title"] == "sdxl"
def test_options(self, mock_api: respx.MockRouter, client: SDClient):
mock_api.get("/sdapi/v1/options").respond(
json={
"samples_format": "png",
"sd_model_checkpoint": "v1-5",
}
)
result = client.info.options()
assert result["sd_model_checkpoint"] == "v1-5"
def test_loras(self, mock_api: respx.MockRouter, client: SDClient):
mock_api.get("/sdapi/v1/loras").respond(
json=[
{"name": "style", "path": "style.safetensors"},
]
)
result = client.info.loras()
assert len(result) == 1
assert result[0]["name"] == "style"
def test_samplers(self, mock_api: respx.MockRouter, client: SDClient):
mock_api.get("/sdapi/v1/samplers").respond(
json=[
{"name": "euler", "aliases": ["euler"], "options": {}},
{"name": "euler_a", "aliases": ["euler_a"], "options": {}},
]
)
result = client.info.samplers()
assert result == ["euler", "euler_a"]
def test_schedulers(self, mock_api: respx.MockRouter, client: SDClient):
mock_api.get("/sdapi/v1/schedulers").respond(
json=[
{"name": "discrete", "label": "discrete"},
{"name": "karras", "label": "karras"},
]
)
result = client.info.schedulers()
assert result == ["discrete", "karras"]
# ── generation ────────────────────────────────────────────────────────
class TestTxt2Img:
def test_returns_decoded_images(self, mock_api: respx.MockRouter, client: SDClient):
mock_api.post("/sdapi/v1/txt2img").respond(
json={
"images": [TINY_PNG_B64],
"parameters": {},
"info": "",
}
)
images = client.generate.txt2img(Txt2ImgParams(prompt="a cat"))
assert len(images) == 1
assert images[0] == TINY_PNG
def test_multiple_images(self, mock_api: respx.MockRouter, client: SDClient):
mock_api.post("/sdapi/v1/txt2img").respond(
json={
"images": [TINY_PNG_B64, TINY_PNG_B64, TINY_PNG_B64],
"parameters": {},
"info": "",
}
)
params = Txt2ImgParams(prompt="cats", batch_size=3)
images = client.generate.txt2img(params)
assert len(images) == 3
def test_sends_correct_body(self, mock_api: respx.MockRouter, client: SDClient):
route = mock_api.post("/sdapi/v1/txt2img").respond(
json={
"images": [TINY_PNG_B64],
"parameters": {},
"info": "",
}
)
params = Txt2ImgParams(
prompt="hello",
width=768,
height=768,
steps=30,
sampler_name="euler_a",
)
client.generate.txt2img(params)
sent = json.loads(route.calls[0].request.content)
assert sent["prompt"] == "hello"
assert sent["width"] == 768
assert sent["sampler_name"] == "euler_a"
class TestImg2Img:
def test_returns_decoded_images(self, mock_api: respx.MockRouter, client: SDClient, tmp_path: Path):
mock_api.post("/sdapi/v1/img2img").respond(
json={
"images": [TINY_PNG_B64],
"parameters": {},
"info": "",
}
)
img = tmp_path / "init.png"
img.write_bytes(TINY_PNG)
params = Img2ImgParams(prompt="paint", init_image=str(img))
images = client.generate.img2img(params)
assert len(images) == 1
assert images[0] == TINY_PNG
+8 -82
View File
@@ -1,12 +1,8 @@
"""Tests for tensors.server package (FastAPI sd-server proxy wrapper).""" """Tests for tensors.server package (gallery and CivitAI management)."""
from __future__ import annotations from __future__ import annotations
from unittest.mock import AsyncMock
import httpx
import pytest import pytest
import respx
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from tensors.server import create_app from tensors.server import create_app
@@ -14,87 +10,17 @@ from tensors.server import create_app
@pytest.fixture() @pytest.fixture()
def api() -> TestClient: def api() -> TestClient:
"""Create test client with mock sd-server URL.""" """Create test client."""
return TestClient(create_app(sd_server_url="http://mock-sd-server:1234")) return TestClient(create_app())
class TestStatus: class TestStatus:
@respx.mock def test_status_ok(self, api: TestClient) -> None:
def test_status_when_backend_reachable(self) -> None: """Test status endpoint returns ok."""
"""Test status endpoint when sd-server is reachable.""" r = api.get("/status")
respx.get("http://mock-sd-server:1234/").mock(return_value=httpx.Response(200))
with TestClient(create_app(sd_server_url="http://mock-sd-server:1234")) as client:
r = client.get("/status")
assert r.status_code == 200
data = r.json()
assert data["status"] == "ok"
assert data["sd_server_url"] == "http://mock-sd-server:1234"
@respx.mock
def test_status_when_backend_unreachable(self) -> None:
"""Test status endpoint when sd-server is not reachable."""
respx.get("http://mock-sd-server:1234/").mock(side_effect=httpx.ConnectError("Connection refused"))
with TestClient(create_app(sd_server_url="http://mock-sd-server:1234")) as client:
r = client.get("/status")
assert r.status_code == 200
data = r.json()
assert data["status"] == "error"
assert "Connection refused" in data["error"]
class TestProxy:
def test_proxy_forwards_request(self, api: TestClient) -> None:
"""Test proxy forwards GET requests to backend."""
upstream_response = httpx.Response(
200,
json={"data": [{"id": "model-1"}]},
headers={"content-type": "application/json"},
)
mock_client = AsyncMock()
mock_client.request.return_value = upstream_response
api.app.state.client = mock_client # type: ignore[attr-defined]
api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined]
r = api.get("/v1/models")
assert r.status_code == 200 assert r.status_code == 200
assert r.json() == {"data": [{"id": "model-1"}]} data = r.json()
mock_client.request.assert_called_once() assert data["status"] == "ok"
def test_proxy_forwards_post_with_body(self, api: TestClient) -> None:
"""Test proxy forwards POST requests with body."""
upstream_response = httpx.Response(200, json={"ok": True})
mock_client = AsyncMock()
mock_client.request.return_value = upstream_response
api.app.state.client = mock_client # type: ignore[attr-defined]
api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined]
r = api.post("/sdapi/v1/txt2img", json={"prompt": "hello"})
assert r.status_code == 200
mock_client.request.assert_called_once()
def test_proxy_503_on_connect_error(self, api: TestClient) -> None:
"""Test proxy returns 503 when backend is unreachable."""
mock_client = AsyncMock()
mock_client.request.side_effect = httpx.ConnectError("Connection refused")
api.app.state.client = mock_client # type: ignore[attr-defined]
api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined]
r = api.get("/v1/models")
assert r.status_code == 503
assert "Cannot connect" in r.json()["error"]
def test_proxy_504_on_timeout(self, api: TestClient) -> None:
"""Test proxy returns 504 on timeout."""
mock_client = AsyncMock()
mock_client.request.side_effect = httpx.TimeoutException("Timeout")
api.app.state.client = mock_client # type: ignore[attr-defined]
api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined]
r = api.get("/v1/models")
assert r.status_code == 504
assert "Timeout" in r.json()["error"]
# ============================================================================= # =============================================================================