diff --git a/.coverage b/.coverage index c3b6e90..3e5f05a 100644 Binary files a/.coverage and b/.coverage differ diff --git a/tensors/cli.py b/tensors/cli.py index acead72..f6681b5 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -20,20 +20,15 @@ from tensors.api import ( fetch_civitai_model_version, search_civitai, ) -from tensors.client import TsrClient, TsrClientError from tensors.config import ( CONFIG_FILE, BaseModel, ModelType, SortOrder, get_default_output_path, - get_remotes, load_api_key, load_config, - resolve_remote, save_config, - save_remote, - set_default_remote, ) from tensors.db import DB_PATH, Database from tensors.display import ( @@ -49,18 +44,6 @@ from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_me # Key masking threshold MIN_KEY_LENGTH_FOR_MASKING = 8 -# Size threshold for GB display -_MB_PER_GB = 1024 - - -def _format_size_mb(size_mb: float | None) -> str: - """Format size in MB to human-readable string.""" - if not size_mb: - return "" - if size_mb >= _MB_PER_GB: - return f"{size_mb / _MB_PER_GB:.1f} GB" - return f"{size_mb:.0f} MB" - def _version_callback(value: bool) -> None: if value: @@ -329,74 +312,41 @@ def download( output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None, no_resume: Annotated[bool, typer.Option("--no-resume", help="Don't resume partial downloads")] = False, api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, - remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, - json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON (remote mode)")] = False, ) -> None: - """Download a model from CivitAI (locally or to remote server).""" - # Check if remote is specified or configured - remote_url = resolve_remote(remote) + """Download a model from CivitAI.""" + key = api_key or load_api_key() - if remote_url: - # Remote mode: use TsrClient API - if not version_id and not model_id and not hash_val: + resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key) + if not resolved_version_id: + if not version_id and not hash_val and not model_id: console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]") - raise typer.Exit(1) + raise typer.Exit(1) - try: - with TsrClient(remote_url) as client: - console.print(f"[cyan]Starting download on {remote_url}...[/cyan]") - result = client.start_download( - version_id=version_id, - model_id=model_id, - hash_val=hash_val, - output_dir=str(output) if output else None, - ) - except TsrClientError as e: - console.print(f"[red]Error: {e}[/red]") - raise typer.Exit(1) from e + console.print(f"[cyan]Fetching version info for {resolved_version_id}...[/cyan]") + version_info = fetch_civitai_model_version(resolved_version_id, key, console) + if not version_info: + console.print("[red]Error: Could not fetch model version info.[/red]") + raise typer.Exit(1) - if json_output: - console.print_json(data=result) - return + model_type_str: str | None = version_info.get("model", {}).get("type") + output_dir = _prepare_download_dir(output, model_type_str) + if not output_dir: + raise typer.Exit(1) - download_id = result.get("download_id") - console.print(f"[green]Download started:[/green] {download_id}") - console.print(f"[dim]Check status with: tsr images download-status {download_id} --remote {remote or 'default'}[/dim]") - else: - # Local mode: direct download - key = api_key or load_api_key() + files: list[dict[str, Any]] = version_info.get("files", []) + primary_file = next((f for f in files if f.get("primary")), files[0] if files else None) + if not primary_file: + console.print("[red]Error: No files found for this version.[/red]") + raise typer.Exit(1) - resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key) - if not resolved_version_id: - if not version_id and not hash_val and not model_id: - console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]") - raise typer.Exit(1) + filename = primary_file.get("name", f"model-{resolved_version_id}.safetensors") + dest_path = output_dir / filename - console.print(f"[cyan]Fetching version info for {resolved_version_id}...[/cyan]") - version_info = fetch_civitai_model_version(resolved_version_id, key, console) - if not version_info: - console.print("[red]Error: Could not fetch model version info.[/red]") - raise typer.Exit(1) + _display_download_info(version_info, filename, primary_file, dest_path) - model_type_str: str | None = version_info.get("model", {}).get("type") - output_dir = _prepare_download_dir(output, model_type_str) - if not output_dir: - raise typer.Exit(1) - - files: list[dict[str, Any]] = version_info.get("files", []) - primary_file = next((f for f in files if f.get("primary")), files[0] if files else None) - if not primary_file: - console.print("[red]Error: No files found for this version.[/red]") - raise typer.Exit(1) - - filename = primary_file.get("name", f"model-{resolved_version_id}.safetensors") - dest_path = output_dir / filename - - _display_download_info(version_info, filename, primary_file, dest_path) - - success = download_model(resolved_version_id, dest_path, key, console, resume=not no_resume) - if not success: - raise typer.Exit(1) + success = download_model(resolved_version_id, dest_path, key, console, resume=not no_resume) + if not success: + raise typer.Exit(1) def _display_download_info( @@ -449,167 +399,13 @@ def config( console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]") -@app.command() -def generate( - prompt: Annotated[str, typer.Argument(help="Text prompt for image generation.")], - remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, - model: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model (remote mode only).")] = None, - host: Annotated[str, typer.Option(help="sd-server address (local mode).")] = "127.0.0.1", - port: Annotated[int, typer.Option(help="sd-server port (local mode).")] = 8080, - output: Annotated[str, typer.Option("-o", help="Output directory (local mode).")] = ".", - negative_prompt: Annotated[str, typer.Option("-n", help="Negative prompt.")] = "", - width: Annotated[int, typer.Option("-W", help="Image width.")] = 512, - height: Annotated[int, typer.Option("-H", help="Image height.")] = 512, - steps: Annotated[int, typer.Option(help="Sampling steps.")] = 20, - cfg_scale: Annotated[float, typer.Option(help="CFG scale.")] = 7.0, - seed: Annotated[int, typer.Option("-s", help="RNG seed (-1 for random).")] = -1, - sampler: Annotated[str, typer.Option(help="Sampler name.")] = "", - scheduler: Annotated[str, typer.Option(help="Scheduler name.")] = "", - batch_size: Annotated[int, typer.Option("-b", help="Number of images.")] = 1, - json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON (remote mode)")] = False, -) -> None: - """Generate images using sd-server (local or remote).""" - # Check if remote is specified or configured - remote_url = resolve_remote(remote) - - if remote_url: - # Remote mode: use TsrClient API - try: - with TsrClient(remote_url) as client: - # Switch model if specified - if model: - console.print(f"[cyan]Switching to model: {model}[/cyan]") - client.switch_model(model) - - console.print(f"[cyan]Generating {batch_size} image(s) on {remote_url}...[/cyan]") - result = client.generate( - prompt=prompt, - negative_prompt=negative_prompt, - width=width, - height=height, - steps=steps, - cfg_scale=cfg_scale, - seed=seed, - sampler_name=sampler, - scheduler=scheduler, - batch_size=batch_size, - ) - except TsrClientError as e: - console.print(f"[red]Error: {e}[/red]") - raise typer.Exit(1) from e - - if json_output: - console.print_json(data=result) - return - - images = result.get("images", []) - for img in images: - console.print(f"[green]Generated:[/green] {img.get('id', 'unknown')}") - else: - # Local mode: direct sd-server connection - if model: - console.print("[yellow]Warning: --model ignored in local mode (sd-server loads model at startup)[/yellow]") - - from tensors.generate import SDClient, Txt2ImgParams, save_images # noqa: PLC0415 - - params = Txt2ImgParams( - prompt=prompt, - negative_prompt=negative_prompt, - width=width, - height=height, - steps=steps, - cfg_scale=cfg_scale, - seed=seed, - batch_size=batch_size, - sampler_name=sampler, - scheduler=scheduler, - ) - - with SDClient(host=host, port=port) as client: - console.print(f"[cyan]Generating {batch_size} image(s)...[/cyan]") - images = client.generate.txt2img(params) - paths = save_images(images, output) - for p in paths: - console.print(f"[green]Saved:[/green] {p}") - - -@app.command() -def status( - remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, - host: Annotated[str, typer.Option(help="Wrapper API host (local mode).")] = "127.0.0.1", - port: Annotated[int, typer.Option(help="Wrapper API port (local mode).")] = 8080, - json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, -) -> None: - """Show sd-server wrapper status.""" - # Check if remote is specified or configured - remote_url = resolve_remote(remote) - - if remote_url: - # Remote mode: use TsrClient API - try: - with TsrClient(remote_url) as client: - data = client.status() - except TsrClientError as e: - console.print(f"[red]Error: {e}[/red]") - raise typer.Exit(1) from e - else: - # Local mode: direct HTTP call - import httpx # noqa: PLC0415 - - url = f"http://{host}:{port}/status" - try: - resp = httpx.get(url, timeout=10) - resp.raise_for_status() - data = resp.json() - except httpx.HTTPError as e: - console.print(f"[red]Error: Could not reach wrapper at {url}: {e}[/red]") - raise typer.Exit(1) from e - - if json_output: - console.print_json(data=data) - return - - table = Table(title="Server Status", show_header=True, header_style="bold magenta") - table.add_column("Property", style="cyan") - table.add_column("Value", style="green") - for key, value in data.items(): - table.add_row(key, str(value)) - console.print(table) - - -@app.command() -def reload( - model: Annotated[str, typer.Option(help="Path to model file for sd-server.")], - remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, - host: Annotated[str, typer.Option(help="Wrapper API host (local mode).")] = "127.0.0.1", - port: Annotated[int, typer.Option(help="Wrapper API port (local mode).")] = 8080, -) -> None: - """Reload sd-server with a new model.""" - import httpx # noqa: PLC0415 - - remote_url = resolve_remote(remote) - url = f"{remote_url.rstrip('/')}/reload" if remote_url else f"http://{host}:{port}/reload" - - console.print(f"[cyan]Reloading model: {model}[/cyan]") - try: - resp = httpx.post(url, json={"model": model}, timeout=300) - resp.raise_for_status() - data = resp.json() - except httpx.HTTPError as e: - console.print(f"[red]Error: Reload failed at {url}: {e}[/red]") - raise typer.Exit(1) from e - - console.print(f"[green]{data.get('status', 'OK')}[/green]") - - @app.command() def serve( - host: Annotated[str, typer.Option(help="Wrapper API listen address.")] = "127.0.0.1", - port: Annotated[int, typer.Option(help="Wrapper API listen port.")] = 8080, - sd_server: Annotated[str | None, typer.Option(help="sd-server URL to proxy to.")] = None, + host: Annotated[str, typer.Option(help="Listen address.")] = "127.0.0.1", + port: Annotated[int, typer.Option(help="Listen port.")] = 8080, log_level: Annotated[str, typer.Option(help="Log level.")] = "info", ) -> None: - """Start the sd-server wrapper API (proxies to external sd-server).""" + """Start the tensors server (gallery and CivitAI management).""" try: import uvicorn # noqa: PLC0415 @@ -619,7 +415,7 @@ def serve( console.print(" pip install tensors[server]") raise typer.Exit(1) from None - uvicorn.run(create_app(sd_server_url=sd_server), host=host, port=port, log_level=log_level) + uvicorn.run(create_app(), host=host, port=port, log_level=log_level) # ============================================================================= @@ -858,356 +654,6 @@ def db_stats( console.print(table) -# ============================================================================= -# Images Commands (Remote) -# ============================================================================= - -images_app = typer.Typer( - name="images", - help="Manage images in remote gallery.", - no_args_is_help=True, -) -app.add_typer(images_app, name="images") - - -def _get_client(remote: str | None) -> TsrClient: - """Get TsrClient for remote or raise error.""" - url = resolve_remote(remote) - if not url: - console.print("[red]Error: No remote specified. Use --remote or set default_remote in config.[/red]") - raise typer.Exit(1) - return TsrClient(url) - - -@images_app.command("list") -def images_list( - remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, - limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 50, - offset: Annotated[int, typer.Option("--offset", help="Offset for pagination")] = 0, - json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, -) -> None: - """List images in remote gallery.""" - try: - with _get_client(remote) as client: - result = client.list_images(limit=limit, offset=offset) - except TsrClientError as e: - console.print(f"[red]Error: {e}[/red]") - raise typer.Exit(1) from e - - images = result.get("images", []) - total = result.get("total", len(images)) - - if json_output: - console.print_json(data=result) - return - - if not images: - console.print("[yellow]No images in gallery.[/yellow]") - return - - table = Table(title=f"Gallery Images ({len(images)}/{total})", show_header=True, header_style="bold magenta") - table.add_column("ID", style="cyan") - table.add_column("Filename", style="green") - table.add_column("Size", style="white") - table.add_column("Created", style="dim") - - for img in images: - size = f"{img.get('width', '?')}x{img.get('height', '?')}" - created = img.get("created_at", "") - if isinstance(created, (int, float)): - from datetime import datetime # noqa: PLC0415 - - created = datetime.fromtimestamp(created).strftime("%Y-%m-%d %H:%M") - table.add_row(img.get("id", ""), img.get("filename", ""), size, str(created)) - - console.print(table) - - -@images_app.command("show") -def images_show( - image_id: Annotated[str, typer.Argument(help="Image ID to show")], - remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, - json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, -) -> None: - """Show image metadata.""" - try: - with _get_client(remote) as client: - meta = client.get_image_meta(image_id) - except TsrClientError as e: - console.print(f"[red]Error: {e}[/red]") - raise typer.Exit(1) from e - - if json_output: - console.print_json(data=meta) - return - - table = Table(title=f"Image: {image_id}", show_header=True, header_style="bold magenta") - table.add_column("Property", style="cyan") - table.add_column("Value", style="green") - - for key, value in meta.items(): - display_value = json.dumps(value, indent=2) if isinstance(value, dict) else str(value) - table.add_row(key, display_value) - - console.print(table) - - -@images_app.command("delete") -def images_delete( - image_id: Annotated[str, typer.Argument(help="Image ID to delete")], - remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, - force: Annotated[bool, typer.Option("-f", "--force", help="Skip confirmation")] = False, -) -> None: - """Delete an image from the gallery.""" - if not force: - confirm = typer.confirm(f"Delete image {image_id}?") - if not confirm: - console.print("[yellow]Cancelled.[/yellow]") - raise typer.Exit(0) - - try: - with _get_client(remote) as client: - client.delete_image(image_id) - except TsrClientError as e: - console.print(f"[red]Error: {e}[/red]") - raise typer.Exit(1) from e - - console.print(f"[green]Deleted image: {image_id}[/green]") - - -@images_app.command("download") -def images_download( - image_id: Annotated[str, typer.Argument(help="Image ID to download")], - remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, - output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file or directory")] = None, -) -> None: - """Download an image from the remote gallery.""" - try: - with _get_client(remote) as client: - content = client.download_image(image_id) - except TsrClientError as e: - console.print(f"[red]Error: {e}[/red]") - raise typer.Exit(1) from e - - # Determine output path - if output is None: - dest = Path(f"{image_id}.png") - elif output.is_dir(): - dest = output / f"{image_id}.png" - else: - dest = output - - dest.write_bytes(content) - console.print(f"[green]Saved:[/green] {dest}") - - -# ============================================================================= -# Models Commands (Remote) -# ============================================================================= - -models_app = typer.Typer( - name="models", - help="Manage models on remote server.", - no_args_is_help=True, -) -app.add_typer(models_app, name="models") - - -@models_app.command("list") -def models_list( - remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, - json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, -) -> None: - """List available models on remote server.""" - try: - with _get_client(remote) as client: - result = client.list_models() - except TsrClientError as e: - console.print(f"[red]Error: {e}[/red]") - raise typer.Exit(1) from e - - if json_output: - console.print_json(data=result) - return - - models = result.get("models", []) - active = result.get("active", "") - - if not models: - console.print("[yellow]No models found.[/yellow]") - return - - table = Table(title="Available Models", show_header=True, header_style="bold magenta") - table.add_column("ID", style="dim", width=8) - table.add_column("Name", style="cyan") - table.add_column("File", style="white") - table.add_column("Size", style="green", justify="right") - - for model in models: - path = model.get("path", "") - name = model.get("name", Path(path).stem if path else "") - is_active = active in {path, name} - - civitai_id = model.get("civitai_model_id") - id_str = str(civitai_id) if civitai_id else "" - display_name = model.get("display_name", name) - if is_active: - display_name = f"[green]✓[/green] {display_name}" - filename = model.get("filename", Path(path).name if path else "") - size_str = _format_size_mb(model.get("size_mb")) - - table.add_row(id_str, display_name, filename, size_str) - - console.print(table) - - -@models_app.command("active") -def models_active( - remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, - json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, -) -> None: - """Show currently active model.""" - try: - with _get_client(remote) as client: - result = client.get_active_model() - except TsrClientError as e: - console.print(f"[red]Error: {e}[/red]") - raise typer.Exit(1) from e - - if json_output: - console.print_json(data=result) - return - - model = result.get("model", "None") - console.print(f"[bold]Active model:[/bold] {model}") - - -@models_app.command("switch") -def models_switch( - model: Annotated[str, typer.Argument(help="Model path or name to switch to")], - remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, -) -> None: - """Switch to a different model on the remote server.""" - console.print(f"[cyan]Switching to model: {model}[/cyan]") - try: - with _get_client(remote) as client: - result = client.switch_model(model) - except TsrClientError as e: - console.print(f"[red]Error: {e}[/red]") - raise typer.Exit(1) from e - - console.print(f"[green]{result.get('status', 'OK')}[/green]") - - -@models_app.command("loras") -def models_loras( - remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, - json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, -) -> None: - """List available LoRAs on remote server.""" - try: - with _get_client(remote) as client: - result = client.list_loras() - except TsrClientError as e: - console.print(f"[red]Error: {e}[/red]") - raise typer.Exit(1) from e - - if json_output: - console.print_json(data=result) - return - - loras = result.get("loras", []) - if not loras: - console.print("[yellow]No LoRAs found.[/yellow]") - return - - table = Table(title="Available LoRAs", show_header=True, header_style="bold magenta") - table.add_column("ID", style="dim", width=8) - table.add_column("Name", style="cyan") - table.add_column("File", style="white") - table.add_column("Size", style="green", justify="right") - - for lora in loras: - path = lora.get("path", "") - name = lora.get("name", Path(path).stem if path else "") - - civitai_id = lora.get("civitai_model_id") - id_str = str(civitai_id) if civitai_id else "" - display_name = lora.get("display_name", name) - filename = lora.get("filename", Path(path).name if path else "") - size_str = _format_size_mb(lora.get("size_mb")) - - table.add_row(id_str, display_name, filename, size_str) - - console.print(table) - - -# ============================================================================= -# Remote Configuration Commands -# ============================================================================= - -remote_app = typer.Typer( - name="remote", - help="Manage remote server configuration.", - no_args_is_help=True, -) -app.add_typer(remote_app, name="remote") - - -@remote_app.command("list") -def remote_list( - json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, -) -> None: - """List configured remotes.""" - from tensors.config import get_default_remote # noqa: PLC0415 - - remotes = get_remotes() - default = get_default_remote() - - if json_output: - console.print_json(data={"remotes": remotes, "default": default}) - return - - if not remotes: - console.print("[yellow]No remotes configured.[/yellow]") - console.print("[dim]Add one with: tsr remote add NAME URL[/dim]") - return - - table = Table(title="Configured Remotes", show_header=True, header_style="bold magenta") - table.add_column("Default", style="dim", width=3) - table.add_column("Name", style="cyan") - table.add_column("URL", style="green") - - for name, url in remotes.items(): - is_default = name == default - status = "[green]✓[/green]" if is_default else "" - table.add_row(status, name, url) - - console.print(table) - - -@remote_app.command("add") -def remote_add( - name: Annotated[str, typer.Argument(help="Remote name")], - url: Annotated[str, typer.Argument(help="Remote URL (e.g., http://host:8080)")], -) -> None: - """Add a remote server.""" - save_remote(name, url) - console.print(f"[green]Added remote:[/green] {name} → {url}") - - -@remote_app.command("default") -def remote_default( - name: Annotated[str | None, typer.Argument(help="Remote name to set as default (omit to clear)")] = None, -) -> None: - """Set or clear the default remote.""" - set_default_remote(name) - if name: - console.print(f"[green]Default remote set to:[/green] {name}") - else: - console.print("[green]Default remote cleared.[/green]") - - # ============================================================================= # ComfyUI Commands # ============================================================================= @@ -1424,14 +870,8 @@ def main() -> int: "dl", "download", "config", - "generate", "serve", - "status", - "reload", "db", - "images", - "models", - "remote", "comfy", ) if len(sys.argv) > 1 and not sys.argv[1].startswith("-"): diff --git a/tensors/server/__init__.py b/tensors/server/__init__.py index c67a23b..1c8ac6e 100644 --- a/tensors/server/__init__.py +++ b/tensors/server/__init__.py @@ -1,4 +1,4 @@ -"""sd-server wrapper — FastAPI app for proxying to an external sd-server.""" +"""Tensors server — FastAPI app for gallery and CivitAI management.""" from __future__ import annotations @@ -7,19 +7,14 @@ from contextlib import asynccontextmanager from pathlib import Path from typing import TYPE_CHECKING -import httpx from fastapi import FastAPI from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles -from tensors.config import get_sd_server_api_key, get_sd_server_url from tensors.server.civitai_routes import create_civitai_router from tensors.server.db_routes import create_db_router from tensors.server.download_routes import create_download_router from tensors.server.gallery_routes import create_gallery_router -from tensors.server.generate_routes import create_generate_router -from tensors.server.models_routes import create_models_router -from tensors.server.routes import create_router if TYPE_CHECKING: from collections.abc import AsyncIterator @@ -29,28 +24,15 @@ __all__ = ["app", "create_app"] logger = logging.getLogger(__name__) -def create_app(sd_server_url: str | None = None) -> FastAPI: - """Build the FastAPI application that proxies to an external sd-server. - - Args: - sd_server_url: URL of the sd-server to proxy to. If None, uses - get_sd_server_url() to resolve from env/config. - """ - backend_url = sd_server_url or get_sd_server_url() - api_key = get_sd_server_api_key() +def create_app() -> FastAPI: + """Build the FastAPI application for gallery and model management.""" @asynccontextmanager async def lifespan(_app: FastAPI) -> AsyncIterator[None]: - _app.state.sd_server_url = backend_url - _app.state.sd_server_api_key = api_key - logger.info(f"Proxying to sd-server at: {backend_url}") - if api_key: - logger.info("Using API key authentication for sd-server") - async with httpx.AsyncClient(timeout=300) as client: - _app.state.client = client - yield + logger.info("Tensors server starting") + yield - app = FastAPI(title="sd-server wrapper", lifespan=lifespan) + app = FastAPI(title="tensors", lifespan=lifespan) # Serve Vue UI static files static_dir = Path(__file__).parent / "static" @@ -66,13 +48,14 @@ def create_app(sd_server_url: str | None = None) -> FastAPI: async def vite_icon() -> FileResponse: return FileResponse(static_dir / "vite.svg") - app.include_router(create_civitai_router()) # Must be before catch-all proxy - app.include_router(create_db_router()) # Must be before catch-all proxy - app.include_router(create_gallery_router()) # Must be before catch-all proxy - app.include_router(create_models_router()) # Must be before catch-all proxy - app.include_router(create_download_router()) # Must be before catch-all proxy - app.include_router(create_generate_router()) # Must be before catch-all proxy - app.include_router(create_router()) + @app.get("/status") + async def status() -> dict[str, str]: + return {"status": "ok"} + + app.include_router(create_civitai_router()) + app.include_router(create_db_router()) + app.include_router(create_gallery_router()) + app.include_router(create_download_router()) return app diff --git a/tests/conftest.py b/tests/conftest.py index 6cfc14e..6e27fc4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,11 +7,6 @@ import json import struct import pytest -import respx - -from tensors.generate import SDClient - -BASE_URL = "http://127.0.0.1:1234" # 1x1 red PNG for image response stubs TINY_PNG = ( @@ -43,18 +38,3 @@ def temp_safetensor(tmp_path): f.write(header_bytes) return file_path - - -@pytest.fixture() -def mock_api(): - """Activate respx mock for the sd-server base URL.""" - with respx.mock(base_url=BASE_URL, assert_all_called=False) as rsps: - yield rsps - - -@pytest.fixture() -def client(mock_api: respx.MockRouter) -> SDClient: # noqa: ARG001 - """SDClient wired to the mocked transport.""" - c = SDClient() - yield c # type: ignore[misc] - c.close() diff --git a/tests/test_client.py b/tests/test_client.py deleted file mode 100644 index ba2ff49..0000000 --- a/tests/test_client.py +++ /dev/null @@ -1,481 +0,0 @@ -"""Tests for the TsrClient HTTP client module.""" - -from __future__ import annotations - -import pytest -import respx -from httpx import Response - -from tensors.client import TsrClient, TsrClientError - -BASE_URL = "http://test-server:8080" - - -@pytest.fixture -def mock_server(): - """Activate respx mock for the test server.""" - with respx.mock(base_url=BASE_URL, assert_all_called=False) as rsps: - yield rsps - - -@pytest.fixture -def client(mock_server) -> TsrClient: # noqa: ARG001 - mock_server activates respx - """TsrClient connected to mock server.""" - return TsrClient(BASE_URL) - - -# ============================================================================= -# Status Tests -# ============================================================================= - - -class TestStatus: - """Tests for server status endpoint.""" - - def test_status_success(self, client: TsrClient, mock_server) -> None: - """Test getting server status.""" - mock_server.get("/status").mock(return_value=Response(200, json={"running": True, "pid": 12345, "model": "/test.gguf"})) - - with client: - result = client.status() - - assert result["running"] is True - assert result["pid"] == 12345 - - def test_status_error(self, client: TsrClient, mock_server) -> None: - """Test handling status error.""" - mock_server.get("/status").mock(return_value=Response(503, text="Service unavailable")) - - with client, pytest.raises(TsrClientError, match="HTTP 503"): - client.status() - - -# ============================================================================= -# Gallery Tests -# ============================================================================= - - -class TestGalleryImages: - """Tests for gallery image operations.""" - - def test_list_images(self, client: TsrClient, mock_server) -> None: - """Test listing gallery images.""" - mock_server.get("/api/images").mock( - return_value=Response( - 200, - json={ - "images": [ - {"id": "123_0", "filename": "123_0.png", "width": 512, "height": 512}, - {"id": "124_1", "filename": "124_1.png", "width": 1024, "height": 1024}, - ], - "total": 2, - }, - ) - ) - - with client: - result = client.list_images() - - assert len(result["images"]) == 2 - assert result["total"] == 2 - - def test_list_images_with_pagination(self, client: TsrClient, mock_server) -> None: - """Test listing images with pagination.""" - mock_server.get("/api/images", params={"limit": 10, "offset": 5}).mock( - return_value=Response(200, json={"images": [], "total": 100}) - ) - - with client: - result = client.list_images(limit=10, offset=5) - - assert result["total"] == 100 - - def test_get_image_meta(self, client: TsrClient, mock_server) -> None: - """Test getting image metadata.""" - mock_server.get("/api/images/123_0/meta").mock( - return_value=Response( - 200, - json={ - "id": "123_0", - "path": "/gallery/123_0.png", - "metadata": {"prompt": "test prompt", "seed": 42}, - }, - ) - ) - - with client: - result = client.get_image_meta("123_0") - - assert result["id"] == "123_0" - assert result["metadata"]["prompt"] == "test prompt" - - def test_delete_image(self, client: TsrClient, mock_server) -> None: - """Test deleting an image.""" - mock_server.delete("/api/images/123_0").mock(return_value=Response(200, json={"deleted": True, "id": "123_0"})) - - with client: - result = client.delete_image("123_0") - - assert result["deleted"] is True - - def test_edit_image(self, client: TsrClient, mock_server) -> None: - """Test editing image metadata.""" - mock_server.post("/api/images/123_0/edit").mock( - return_value=Response(200, json={"id": "123_0", "metadata": {"tags": ["favorite"], "rating": 5}}) - ) - - with client: - result = client.edit_image("123_0", {"tags": ["favorite"], "rating": 5}) - - assert result["metadata"]["tags"] == ["favorite"] - - def test_download_image(self, client: TsrClient, mock_server) -> None: - """Test downloading image bytes.""" - image_bytes = b"\x89PNG test image data" - mock_server.get("/api/images/123_0").mock(return_value=Response(200, content=image_bytes)) - - with client: - result = client.download_image("123_0") - - assert result == image_bytes - - -# ============================================================================= -# Models Tests -# ============================================================================= - - -class TestModels: - """Tests for model management operations.""" - - def test_list_models(self, client: TsrClient, mock_server) -> None: - """Test listing available models.""" - mock_server.get("/api/models").mock( - return_value=Response( - 200, - json={ - "models": [ - {"name": "sdxl_base", "path": "/models/sdxl_base.safetensors"}, - {"name": "pony_v6", "path": "/models/pony_v6.safetensors"}, - ], - "active": "/models/sdxl_base.safetensors", - }, - ) - ) - - with client: - result = client.list_models() - - assert len(result["models"]) == 2 - assert result["active"] == "/models/sdxl_base.safetensors" - - def test_get_active_model(self, client: TsrClient, mock_server) -> None: - """Test getting active model.""" - mock_server.get("/api/models/active").mock(return_value=Response(200, json={"model": "/models/sdxl_base.safetensors"})) - - with client: - result = client.get_active_model() - - assert result["model"] == "/models/sdxl_base.safetensors" - - def test_switch_model(self, client: TsrClient, mock_server) -> None: - """Test switching model.""" - mock_server.post("/api/models/switch").mock( - return_value=Response(200, json={"status": "ok", "model": "/models/pony_v6.safetensors"}) - ) - - with client: - result = client.switch_model("/models/pony_v6.safetensors") - - assert result["status"] == "ok" - - def test_list_loras(self, client: TsrClient, mock_server) -> None: - """Test listing LoRAs.""" - mock_server.get("/api/models/loras").mock( - return_value=Response( - 200, - json={ - "loras": [ - {"name": "detail_tweaker", "path": "/loras/detail_tweaker.safetensors"}, - ] - }, - ) - ) - - with client: - result = client.list_loras() - - assert len(result["loras"]) == 1 - - def test_scan_models(self, client: TsrClient, mock_server) -> None: - """Test scanning models.""" - mock_server.get("/api/models/scan").mock(return_value=Response(200, json={"scanned": 5})) - - with client: - result = client.scan_models() - - assert result["scanned"] == 5 - - -# ============================================================================= -# Generation Tests -# ============================================================================= - - -class TestGeneration: - """Tests for image generation.""" - - def test_generate(self, client: TsrClient, mock_server) -> None: - """Test generating an image.""" - mock_server.post("/api/generate").mock( - return_value=Response( - 200, - json={ - "images": [{"id": "999_42", "seed": 42}], - "parameters": {"prompt": "test prompt", "seed": 42}, - }, - ) - ) - - with client: - result = client.generate( - prompt="test prompt", - width=512, - height=512, - seed=42, - ) - - assert len(result["images"]) == 1 - assert result["images"][0]["seed"] == 42 - - def test_generate_with_all_params(self, client: TsrClient, mock_server) -> None: - """Test generation with all parameters.""" - mock_server.post("/api/generate").mock(return_value=Response(200, json={"images": []})) - - with client: - result = client.generate( - prompt="detailed test prompt", - negative_prompt="bad quality", - width=1024, - height=1024, - steps=30, - cfg_scale=5.5, - seed=12345, - sampler_name="DPM++ 2M", - scheduler="karras", - batch_size=2, - save_to_gallery=False, - return_base64=True, - ) - - assert "images" in result - - def test_list_samplers(self, client: TsrClient, mock_server) -> None: - """Test listing samplers.""" - mock_server.get("/api/samplers").mock(return_value=Response(200, json={"samplers": ["Euler", "DPM++ 2M", "Euler a"]})) - - with client: - result = client.list_samplers() - - assert "samplers" in result - - def test_list_schedulers(self, client: TsrClient, mock_server) -> None: - """Test listing schedulers.""" - mock_server.get("/api/schedulers").mock( - return_value=Response(200, json={"schedulers": ["simple", "karras", "sgm_uniform"]}) - ) - - with client: - result = client.list_schedulers() - - assert "schedulers" in result - - -# ============================================================================= -# Download Tests -# ============================================================================= - - -class TestDownload: - """Tests for CivitAI download operations.""" - - def test_start_download_by_version(self, client: TsrClient, mock_server) -> None: - """Test starting download by version ID.""" - mock_server.post("/api/download").mock( - return_value=Response(200, json={"download_id": "abc123", "status": "started", "version_id": 12345}) - ) - - with client: - result = client.start_download(version_id=12345) - - assert result["download_id"] == "abc123" - - def test_start_download_by_hash(self, client: TsrClient, mock_server) -> None: - """Test starting download by hash.""" - mock_server.post("/api/download").mock(return_value=Response(200, json={"download_id": "def456", "status": "started"})) - - with client: - result = client.start_download(hash_val="ABC123DEF456") - - assert result["status"] == "started" - - def test_get_download_status(self, client: TsrClient, mock_server) -> None: - """Test getting download status.""" - mock_server.get("/api/download/status/abc123").mock( - return_value=Response(200, json={"download_id": "abc123", "status": "downloading", "progress": 0.5}) - ) - - with client: - result = client.get_download_status("abc123") - - assert result["progress"] == 0.5 - - def test_list_downloads(self, client: TsrClient, mock_server) -> None: - """Test listing active downloads.""" - mock_server.get("/api/download/active").mock( - return_value=Response(200, json={"downloads": [{"id": "abc123", "progress": 0.75}]}) - ) - - with client: - result = client.list_downloads() - - assert len(result["downloads"]) == 1 - - -# ============================================================================= -# Database Tests -# ============================================================================= - - -class TestDatabase: - """Tests for database operations.""" - - def test_db_list_files(self, client: TsrClient, mock_server) -> None: - """Test listing local files.""" - mock_server.get("/api/db/files").mock( - return_value=Response(200, json=[{"id": 1, "file_path": "/models/test.safetensors", "sha256": "abc123"}]) - ) - - with client: - result = client.db_list_files() - - assert len(result) == 1 - assert result[0]["sha256"] == "abc123" - - def test_db_search_models(self, client: TsrClient, mock_server) -> None: - """Test searching cached models.""" - mock_server.get("/api/db/models").mock( - return_value=Response(200, json=[{"civitai_id": 12345, "name": "Test Model", "type": "LORA"}]) - ) - - with client: - result = client.db_search_models(query="Test", model_type="LORA") - - assert len(result) == 1 - assert result[0]["name"] == "Test Model" - - def test_db_get_model(self, client: TsrClient, mock_server) -> None: - """Test getting cached model.""" - mock_server.get("/api/db/models/12345").mock( - return_value=Response(200, json={"civitai_id": 12345, "name": "Test Model", "type": "Checkpoint"}) - ) - - with client: - result = client.db_get_model(12345) - - assert result["name"] == "Test Model" - - def test_db_get_triggers(self, client: TsrClient, mock_server) -> None: - """Test getting trigger words.""" - mock_server.get("/api/db/triggers/12345").mock(return_value=Response(200, json=["trigger1", "trigger2"])) - - with client: - result = client.db_get_triggers(version_id=12345) - - assert result == ["trigger1", "trigger2"] - - def test_db_stats(self, client: TsrClient, mock_server) -> None: - """Test getting database stats.""" - mock_server.get("/api/db/stats").mock( - return_value=Response(200, json={"local_files": 10, "models": 5, "model_versions": 15}) - ) - - with client: - result = client.db_stats() - - assert result["local_files"] == 10 - - def test_db_scan(self, client: TsrClient, mock_server) -> None: - """Test scanning directory.""" - mock_server.post("/api/db/scan").mock(return_value=Response(200, json={"scanned": 3, "files": []})) - - with client: - result = client.db_scan("/models") - - assert result["scanned"] == 3 - - def test_db_link(self, client: TsrClient, mock_server) -> None: - """Test linking files to CivitAI.""" - mock_server.post("/api/db/link").mock(return_value=Response(200, json={"linked": 2})) - - with client: - result = client.db_link() - - assert result["linked"] == 2 - - def test_db_cache(self, client: TsrClient, mock_server) -> None: - """Test caching model data.""" - mock_server.post("/api/db/cache").mock(return_value=Response(200, json={"model_id": 12345, "cached": True})) - - with client: - result = client.db_cache(12345) - - assert result["cached"] is True - - -# ============================================================================= -# Error Handling Tests -# ============================================================================= - - -class TestErrorHandling: - """Tests for error handling.""" - - def test_http_error(self, client: TsrClient, mock_server) -> None: - """Test HTTP error handling.""" - mock_server.get("/api/images").mock(return_value=Response(500, text="Internal server error")) - - with client, pytest.raises(TsrClientError, match="HTTP 500"): - client.list_images() - - def test_not_found_error(self, client: TsrClient, mock_server) -> None: - """Test 404 error handling.""" - mock_server.get("/api/images/nonexistent/meta").mock(return_value=Response(404, json={"detail": "Image not found"})) - - with client, pytest.raises(TsrClientError, match="HTTP 404"): - client.get_image_meta("nonexistent") - - -# ============================================================================= -# Context Manager Tests -# ============================================================================= - - -class TestContextManager: - """Tests for context manager usage.""" - - def test_context_manager(self, mock_server) -> None: - """Test client works as context manager.""" - mock_server.get("/status").mock(return_value=Response(200, json={"running": True})) - - with TsrClient(BASE_URL) as client: - result = client.status() - assert result["running"] is True - - def test_client_without_context(self, mock_server) -> None: - """Test client works without context manager.""" - mock_server.get("/status").mock(return_value=Response(200, json={"running": True})) - - client = TsrClient(BASE_URL) - result = client.status() - assert result["running"] is True diff --git a/tests/test_generate.py b/tests/test_generate.py deleted file mode 100644 index fdfe63d..0000000 --- a/tests/test_generate.py +++ /dev/null @@ -1,303 +0,0 @@ -"""Tests for tensors.generate package.""" - -from __future__ import annotations - -import base64 -import json -from pathlib import Path - -import httpx -import pytest -import respx - -from tensors.generate import SDClient -from tensors.generate._http import HttpTransport -from tensors.generate.params import Img2ImgParams, Txt2ImgParams -from tensors.generate.util import save_images, to_b64 -from tests.conftest import BASE_URL, TINY_PNG, TINY_PNG_B64 - -# ── util ────────────────────────────────────────────────────────────── - - -class TestToB64: - def test_bytes_input(self): - raw = b"hello" - assert to_b64(raw) == base64.b64encode(raw).decode() - - def test_file_path(self, tmp_path: Path): - f = tmp_path / "img.png" - f.write_bytes(b"\x89PNG") - result = to_b64(str(f)) - assert base64.b64decode(result) == b"\x89PNG" - - def test_pathlib_path(self, tmp_path: Path): - f = tmp_path / "img.png" - f.write_bytes(b"data") - result = to_b64(f) - assert base64.b64decode(result) == b"data" - - def test_passthrough_string(self): - b64 = base64.b64encode(b"already").decode() - assert to_b64(b64) == b64 - - def test_unsupported_type(self): - with pytest.raises(TypeError, match="unsupported image type"): - to_b64(12345) # type: ignore[arg-type] - - -class TestSaveImages: - def test_saves_files(self, tmp_path: Path): - images = [b"img0", b"img1", b"img2"] - paths = save_images(images, str(tmp_path), prefix="test") - assert len(paths) == 3 - for i, p in enumerate(paths): - assert p.name == f"test_{i:04d}.png" - assert p.read_bytes() == images[i] - - def test_creates_directory(self, tmp_path: Path): - out = tmp_path / "sub" / "dir" - save_images([b"x"], str(out)) - assert (out / "output_0000.png").exists() - - -# ── params ──────────────────────────────────────────────────────────── - - -class TestTxt2ImgParams: - def test_minimal_body(self): - p = Txt2ImgParams(prompt="a cat") - body = p.to_body() - assert body["prompt"] == "a cat" - assert body["width"] == 512 - assert body["height"] == 512 - assert body["steps"] == 20 - assert body["seed"] == -1 - assert "sampler_name" not in body - assert "scheduler" not in body - assert "clip_skip" not in body - assert "lora" not in body - - def test_optional_fields_included(self): - p = Txt2ImgParams( - prompt="test", - sampler_name="euler_a", - scheduler="karras", - clip_skip=2, - lora=[{"path": "x.safetensors", "multiplier": 0.5}], - ) - body = p.to_body() - assert body["sampler_name"] == "euler_a" - assert body["scheduler"] == "karras" - assert body["clip_skip"] == 2 - assert len(body["lora"]) == 1 - - -class TestImg2ImgParams: - def test_minimal_body(self, tmp_path: Path): - img = tmp_path / "init.png" - img.write_bytes(b"\x89PNG") - p = Img2ImgParams(prompt="paint it", init_image=str(img)) - body = p.to_body() - assert body["prompt"] == "paint it" - assert body["denoising_strength"] == 0.75 - decoded = base64.b64decode(body["init_images"][0]) - assert decoded == b"\x89PNG" - assert "width" not in body - assert "height" not in body - assert "mask" not in body - - def test_all_optional_fields(self, tmp_path: Path): - img = tmp_path / "init.png" - img.write_bytes(b"img") - mask = tmp_path / "mask.png" - mask.write_bytes(b"mask") - extra = tmp_path / "extra.png" - extra.write_bytes(b"extra") - - p = Img2ImgParams( - prompt="test", - init_image=str(img), - mask=str(mask), - width=768, - height=768, - inpainting_mask_invert=True, - sampler_name="euler", - scheduler="simple", - clip_skip=1, - lora=[{"path": "a.gguf", "multiplier": 1.0}], - extra_images=[str(extra)], - ) - body = p.to_body() - assert body["width"] == 768 - assert body["mask"] - assert body["inpainting_mask_invert"] == 1 - assert body["sampler_name"] == "euler" - assert len(body["extra_images"]) == 1 - - -# ── _http ───────────────────────────────────────────────────────────── - - -class TestHttpTransport: - def test_get_success(self): - with respx.mock(base_url=BASE_URL) as rsps: - rsps.get("/test").respond(json={"ok": True}) - t = HttpTransport(BASE_URL) - assert t.get("/test") == {"ok": True} - t.close() - - def test_post_success(self): - with respx.mock(base_url=BASE_URL) as rsps: - rsps.post("/gen").respond(json={"images": []}) - t = HttpTransport(BASE_URL) - assert t.post("/gen", {"prompt": "x"}) == {"images": []} - t.close() - - def test_get_http_error(self): - with respx.mock(base_url=BASE_URL) as rsps: - rsps.get("/bad").respond(status_code=404, text="not found") - t = HttpTransport(BASE_URL) - with pytest.raises(httpx.HTTPStatusError): - t.get("/bad") - t.close() - - def test_post_http_error(self): - with respx.mock(base_url=BASE_URL) as rsps: - rsps.post("/bad").respond(status_code=500, text="error") - t = HttpTransport(BASE_URL) - with pytest.raises(httpx.HTTPStatusError): - t.post("/bad", {}) - t.close() - - def test_get_connection_error(self): - with respx.mock(base_url=BASE_URL) as rsps: - rsps.get("/fail").mock(side_effect=httpx.ConnectError("refused")) - t = HttpTransport(BASE_URL) - with pytest.raises(httpx.ConnectError): - t.get("/fail") - t.close() - - -# ── info ────────────────────────────────────────────────────────────── - - -class TestInfoAPI: - def test_models(self, mock_api: respx.MockRouter, client: SDClient): - mock_api.get("/v1/models").respond(json={"data": [{"id": "sd-cpp-local", "object": "model", "owned_by": "local"}]}) - result = client.info.models() - assert len(result) == 1 - assert result[0]["id"] == "sd-cpp-local" - - def test_sd_models(self, mock_api: respx.MockRouter, client: SDClient): - mock_api.get("/sdapi/v1/sd-models").respond( - json=[{"title": "sdxl", "model_name": "sdxl", "filename": "sdxl.safetensors"}] - ) - result = client.info.sd_models() - assert result[0]["title"] == "sdxl" - - def test_options(self, mock_api: respx.MockRouter, client: SDClient): - mock_api.get("/sdapi/v1/options").respond( - json={ - "samples_format": "png", - "sd_model_checkpoint": "v1-5", - } - ) - result = client.info.options() - assert result["sd_model_checkpoint"] == "v1-5" - - def test_loras(self, mock_api: respx.MockRouter, client: SDClient): - mock_api.get("/sdapi/v1/loras").respond( - json=[ - {"name": "style", "path": "style.safetensors"}, - ] - ) - result = client.info.loras() - assert len(result) == 1 - assert result[0]["name"] == "style" - - def test_samplers(self, mock_api: respx.MockRouter, client: SDClient): - mock_api.get("/sdapi/v1/samplers").respond( - json=[ - {"name": "euler", "aliases": ["euler"], "options": {}}, - {"name": "euler_a", "aliases": ["euler_a"], "options": {}}, - ] - ) - result = client.info.samplers() - assert result == ["euler", "euler_a"] - - def test_schedulers(self, mock_api: respx.MockRouter, client: SDClient): - mock_api.get("/sdapi/v1/schedulers").respond( - json=[ - {"name": "discrete", "label": "discrete"}, - {"name": "karras", "label": "karras"}, - ] - ) - result = client.info.schedulers() - assert result == ["discrete", "karras"] - - -# ── generation ──────────────────────────────────────────────────────── - - -class TestTxt2Img: - def test_returns_decoded_images(self, mock_api: respx.MockRouter, client: SDClient): - mock_api.post("/sdapi/v1/txt2img").respond( - json={ - "images": [TINY_PNG_B64], - "parameters": {}, - "info": "", - } - ) - images = client.generate.txt2img(Txt2ImgParams(prompt="a cat")) - assert len(images) == 1 - assert images[0] == TINY_PNG - - def test_multiple_images(self, mock_api: respx.MockRouter, client: SDClient): - mock_api.post("/sdapi/v1/txt2img").respond( - json={ - "images": [TINY_PNG_B64, TINY_PNG_B64, TINY_PNG_B64], - "parameters": {}, - "info": "", - } - ) - params = Txt2ImgParams(prompt="cats", batch_size=3) - images = client.generate.txt2img(params) - assert len(images) == 3 - - def test_sends_correct_body(self, mock_api: respx.MockRouter, client: SDClient): - route = mock_api.post("/sdapi/v1/txt2img").respond( - json={ - "images": [TINY_PNG_B64], - "parameters": {}, - "info": "", - } - ) - params = Txt2ImgParams( - prompt="hello", - width=768, - height=768, - steps=30, - sampler_name="euler_a", - ) - client.generate.txt2img(params) - sent = json.loads(route.calls[0].request.content) - assert sent["prompt"] == "hello" - assert sent["width"] == 768 - assert sent["sampler_name"] == "euler_a" - - -class TestImg2Img: - def test_returns_decoded_images(self, mock_api: respx.MockRouter, client: SDClient, tmp_path: Path): - mock_api.post("/sdapi/v1/img2img").respond( - json={ - "images": [TINY_PNG_B64], - "parameters": {}, - "info": "", - } - ) - img = tmp_path / "init.png" - img.write_bytes(TINY_PNG) - params = Img2ImgParams(prompt="paint", init_image=str(img)) - images = client.generate.img2img(params) - assert len(images) == 1 - assert images[0] == TINY_PNG diff --git a/tests/test_server.py b/tests/test_server.py index 8050c49..907a9d2 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,12 +1,8 @@ -"""Tests for tensors.server package (FastAPI sd-server proxy wrapper).""" +"""Tests for tensors.server package (gallery and CivitAI management).""" from __future__ import annotations -from unittest.mock import AsyncMock - -import httpx import pytest -import respx from fastapi.testclient import TestClient from tensors.server import create_app @@ -14,87 +10,17 @@ from tensors.server import create_app @pytest.fixture() def api() -> TestClient: - """Create test client with mock sd-server URL.""" - return TestClient(create_app(sd_server_url="http://mock-sd-server:1234")) + """Create test client.""" + return TestClient(create_app()) class TestStatus: - @respx.mock - def test_status_when_backend_reachable(self) -> None: - """Test status endpoint when sd-server is reachable.""" - respx.get("http://mock-sd-server:1234/").mock(return_value=httpx.Response(200)) - - with TestClient(create_app(sd_server_url="http://mock-sd-server:1234")) as client: - r = client.get("/status") - assert r.status_code == 200 - data = r.json() - assert data["status"] == "ok" - assert data["sd_server_url"] == "http://mock-sd-server:1234" - - @respx.mock - def test_status_when_backend_unreachable(self) -> None: - """Test status endpoint when sd-server is not reachable.""" - respx.get("http://mock-sd-server:1234/").mock(side_effect=httpx.ConnectError("Connection refused")) - - with TestClient(create_app(sd_server_url="http://mock-sd-server:1234")) as client: - r = client.get("/status") - assert r.status_code == 200 - data = r.json() - assert data["status"] == "error" - assert "Connection refused" in data["error"] - - -class TestProxy: - def test_proxy_forwards_request(self, api: TestClient) -> None: - """Test proxy forwards GET requests to backend.""" - upstream_response = httpx.Response( - 200, - json={"data": [{"id": "model-1"}]}, - headers={"content-type": "application/json"}, - ) - mock_client = AsyncMock() - mock_client.request.return_value = upstream_response - api.app.state.client = mock_client # type: ignore[attr-defined] - api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined] - - r = api.get("/v1/models") + def test_status_ok(self, api: TestClient) -> None: + """Test status endpoint returns ok.""" + r = api.get("/status") assert r.status_code == 200 - assert r.json() == {"data": [{"id": "model-1"}]} - mock_client.request.assert_called_once() - - def test_proxy_forwards_post_with_body(self, api: TestClient) -> None: - """Test proxy forwards POST requests with body.""" - upstream_response = httpx.Response(200, json={"ok": True}) - mock_client = AsyncMock() - mock_client.request.return_value = upstream_response - api.app.state.client = mock_client # type: ignore[attr-defined] - api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined] - - r = api.post("/sdapi/v1/txt2img", json={"prompt": "hello"}) - assert r.status_code == 200 - mock_client.request.assert_called_once() - - def test_proxy_503_on_connect_error(self, api: TestClient) -> None: - """Test proxy returns 503 when backend is unreachable.""" - mock_client = AsyncMock() - mock_client.request.side_effect = httpx.ConnectError("Connection refused") - api.app.state.client = mock_client # type: ignore[attr-defined] - api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined] - - r = api.get("/v1/models") - assert r.status_code == 503 - assert "Cannot connect" in r.json()["error"] - - def test_proxy_504_on_timeout(self, api: TestClient) -> None: - """Test proxy returns 504 on timeout.""" - mock_client = AsyncMock() - mock_client.request.side_effect = httpx.TimeoutException("Timeout") - api.app.state.client = mock_client # type: ignore[attr-defined] - api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined] - - r = api.get("/v1/models") - assert r.status_code == 504 - assert "Timeout" in r.json()["error"] + data = r.json() + assert data["status"] == "ok" # =============================================================================