💬 Commit message: Update 2026-02-15 06:21:35, 7 files, 1559 lines

📁 Files changed: 7
📝 Lines changed: 1559

  • .coverage
  • cli.py
  • __init__.py
  • conftest.py
  • test_client.py
  • test_generate.py
  • test_server.py
This commit is contained in:
Adam Ladachowski
2026-02-15 06:21:35 +01:00
parent c419e443ae
commit 356d8fd156
7 changed files with 52 additions and 1507 deletions
+30 -590
View File
@@ -20,20 +20,15 @@ from tensors.api import (
fetch_civitai_model_version,
search_civitai,
)
from tensors.client import TsrClient, TsrClientError
from tensors.config import (
CONFIG_FILE,
BaseModel,
ModelType,
SortOrder,
get_default_output_path,
get_remotes,
load_api_key,
load_config,
resolve_remote,
save_config,
save_remote,
set_default_remote,
)
from tensors.db import DB_PATH, Database
from tensors.display import (
@@ -49,18 +44,6 @@ from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_me
# Key masking threshold
MIN_KEY_LENGTH_FOR_MASKING = 8
# Size threshold for GB display
_MB_PER_GB = 1024
def _format_size_mb(size_mb: float | None) -> str:
"""Format size in MB to human-readable string."""
if not size_mb:
return ""
if size_mb >= _MB_PER_GB:
return f"{size_mb / _MB_PER_GB:.1f} GB"
return f"{size_mb:.0f} MB"
def _version_callback(value: bool) -> None:
if value:
@@ -329,74 +312,41 @@ def download(
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None,
no_resume: Annotated[bool, typer.Option("--no-resume", help="Don't resume partial downloads")] = False,
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON (remote mode)")] = False,
) -> None:
"""Download a model from CivitAI (locally or to remote server)."""
# Check if remote is specified or configured
remote_url = resolve_remote(remote)
"""Download a model from CivitAI."""
key = api_key or load_api_key()
if remote_url:
# Remote mode: use TsrClient API
if not version_id and not model_id and not hash_val:
resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key)
if not resolved_version_id:
if not version_id and not hash_val and not model_id:
console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]")
raise typer.Exit(1)
raise typer.Exit(1)
try:
with TsrClient(remote_url) as client:
console.print(f"[cyan]Starting download on {remote_url}...[/cyan]")
result = client.start_download(
version_id=version_id,
model_id=model_id,
hash_val=hash_val,
output_dir=str(output) if output else None,
)
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
console.print(f"[cyan]Fetching version info for {resolved_version_id}...[/cyan]")
version_info = fetch_civitai_model_version(resolved_version_id, key, console)
if not version_info:
console.print("[red]Error: Could not fetch model version info.[/red]")
raise typer.Exit(1)
if json_output:
console.print_json(data=result)
return
model_type_str: str | None = version_info.get("model", {}).get("type")
output_dir = _prepare_download_dir(output, model_type_str)
if not output_dir:
raise typer.Exit(1)
download_id = result.get("download_id")
console.print(f"[green]Download started:[/green] {download_id}")
console.print(f"[dim]Check status with: tsr images download-status {download_id} --remote {remote or 'default'}[/dim]")
else:
# Local mode: direct download
key = api_key or load_api_key()
files: list[dict[str, Any]] = version_info.get("files", [])
primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
if not primary_file:
console.print("[red]Error: No files found for this version.[/red]")
raise typer.Exit(1)
resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key)
if not resolved_version_id:
if not version_id and not hash_val and not model_id:
console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]")
raise typer.Exit(1)
filename = primary_file.get("name", f"model-{resolved_version_id}.safetensors")
dest_path = output_dir / filename
console.print(f"[cyan]Fetching version info for {resolved_version_id}...[/cyan]")
version_info = fetch_civitai_model_version(resolved_version_id, key, console)
if not version_info:
console.print("[red]Error: Could not fetch model version info.[/red]")
raise typer.Exit(1)
_display_download_info(version_info, filename, primary_file, dest_path)
model_type_str: str | None = version_info.get("model", {}).get("type")
output_dir = _prepare_download_dir(output, model_type_str)
if not output_dir:
raise typer.Exit(1)
files: list[dict[str, Any]] = version_info.get("files", [])
primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
if not primary_file:
console.print("[red]Error: No files found for this version.[/red]")
raise typer.Exit(1)
filename = primary_file.get("name", f"model-{resolved_version_id}.safetensors")
dest_path = output_dir / filename
_display_download_info(version_info, filename, primary_file, dest_path)
success = download_model(resolved_version_id, dest_path, key, console, resume=not no_resume)
if not success:
raise typer.Exit(1)
success = download_model(resolved_version_id, dest_path, key, console, resume=not no_resume)
if not success:
raise typer.Exit(1)
def _display_download_info(
@@ -449,167 +399,13 @@ def config(
console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]")
@app.command()
def generate(
prompt: Annotated[str, typer.Argument(help="Text prompt for image generation.")],
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
model: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model (remote mode only).")] = None,
host: Annotated[str, typer.Option(help="sd-server address (local mode).")] = "127.0.0.1",
port: Annotated[int, typer.Option(help="sd-server port (local mode).")] = 8080,
output: Annotated[str, typer.Option("-o", help="Output directory (local mode).")] = ".",
negative_prompt: Annotated[str, typer.Option("-n", help="Negative prompt.")] = "",
width: Annotated[int, typer.Option("-W", help="Image width.")] = 512,
height: Annotated[int, typer.Option("-H", help="Image height.")] = 512,
steps: Annotated[int, typer.Option(help="Sampling steps.")] = 20,
cfg_scale: Annotated[float, typer.Option(help="CFG scale.")] = 7.0,
seed: Annotated[int, typer.Option("-s", help="RNG seed (-1 for random).")] = -1,
sampler: Annotated[str, typer.Option(help="Sampler name.")] = "",
scheduler: Annotated[str, typer.Option(help="Scheduler name.")] = "",
batch_size: Annotated[int, typer.Option("-b", help="Number of images.")] = 1,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON (remote mode)")] = False,
) -> None:
"""Generate images using sd-server (local or remote)."""
# Check if remote is specified or configured
remote_url = resolve_remote(remote)
if remote_url:
# Remote mode: use TsrClient API
try:
with TsrClient(remote_url) as client:
# Switch model if specified
if model:
console.print(f"[cyan]Switching to model: {model}[/cyan]")
client.switch_model(model)
console.print(f"[cyan]Generating {batch_size} image(s) on {remote_url}...[/cyan]")
result = client.generate(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
steps=steps,
cfg_scale=cfg_scale,
seed=seed,
sampler_name=sampler,
scheduler=scheduler,
batch_size=batch_size,
)
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
if json_output:
console.print_json(data=result)
return
images = result.get("images", [])
for img in images:
console.print(f"[green]Generated:[/green] {img.get('id', 'unknown')}")
else:
# Local mode: direct sd-server connection
if model:
console.print("[yellow]Warning: --model ignored in local mode (sd-server loads model at startup)[/yellow]")
from tensors.generate import SDClient, Txt2ImgParams, save_images # noqa: PLC0415
params = Txt2ImgParams(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
steps=steps,
cfg_scale=cfg_scale,
seed=seed,
batch_size=batch_size,
sampler_name=sampler,
scheduler=scheduler,
)
with SDClient(host=host, port=port) as client:
console.print(f"[cyan]Generating {batch_size} image(s)...[/cyan]")
images = client.generate.txt2img(params)
paths = save_images(images, output)
for p in paths:
console.print(f"[green]Saved:[/green] {p}")
@app.command()
def status(
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
host: Annotated[str, typer.Option(help="Wrapper API host (local mode).")] = "127.0.0.1",
port: Annotated[int, typer.Option(help="Wrapper API port (local mode).")] = 8080,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Show sd-server wrapper status."""
# Check if remote is specified or configured
remote_url = resolve_remote(remote)
if remote_url:
# Remote mode: use TsrClient API
try:
with TsrClient(remote_url) as client:
data = client.status()
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
else:
# Local mode: direct HTTP call
import httpx # noqa: PLC0415
url = f"http://{host}:{port}/status"
try:
resp = httpx.get(url, timeout=10)
resp.raise_for_status()
data = resp.json()
except httpx.HTTPError as e:
console.print(f"[red]Error: Could not reach wrapper at {url}: {e}[/red]")
raise typer.Exit(1) from e
if json_output:
console.print_json(data=data)
return
table = Table(title="Server Status", show_header=True, header_style="bold magenta")
table.add_column("Property", style="cyan")
table.add_column("Value", style="green")
for key, value in data.items():
table.add_row(key, str(value))
console.print(table)
@app.command()
def reload(
model: Annotated[str, typer.Option(help="Path to model file for sd-server.")],
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
host: Annotated[str, typer.Option(help="Wrapper API host (local mode).")] = "127.0.0.1",
port: Annotated[int, typer.Option(help="Wrapper API port (local mode).")] = 8080,
) -> None:
"""Reload sd-server with a new model."""
import httpx # noqa: PLC0415
remote_url = resolve_remote(remote)
url = f"{remote_url.rstrip('/')}/reload" if remote_url else f"http://{host}:{port}/reload"
console.print(f"[cyan]Reloading model: {model}[/cyan]")
try:
resp = httpx.post(url, json={"model": model}, timeout=300)
resp.raise_for_status()
data = resp.json()
except httpx.HTTPError as e:
console.print(f"[red]Error: Reload failed at {url}: {e}[/red]")
raise typer.Exit(1) from e
console.print(f"[green]{data.get('status', 'OK')}[/green]")
@app.command()
def serve(
host: Annotated[str, typer.Option(help="Wrapper API listen address.")] = "127.0.0.1",
port: Annotated[int, typer.Option(help="Wrapper API listen port.")] = 8080,
sd_server: Annotated[str | None, typer.Option(help="sd-server URL to proxy to.")] = None,
host: Annotated[str, typer.Option(help="Listen address.")] = "127.0.0.1",
port: Annotated[int, typer.Option(help="Listen port.")] = 8080,
log_level: Annotated[str, typer.Option(help="Log level.")] = "info",
) -> None:
"""Start the sd-server wrapper API (proxies to external sd-server)."""
"""Start the tensors server (gallery and CivitAI management)."""
try:
import uvicorn # noqa: PLC0415
@@ -619,7 +415,7 @@ def serve(
console.print(" pip install tensors[server]")
raise typer.Exit(1) from None
uvicorn.run(create_app(sd_server_url=sd_server), host=host, port=port, log_level=log_level)
uvicorn.run(create_app(), host=host, port=port, log_level=log_level)
# =============================================================================
@@ -858,356 +654,6 @@ def db_stats(
console.print(table)
# =============================================================================
# Images Commands (Remote)
# =============================================================================
images_app = typer.Typer(
name="images",
help="Manage images in remote gallery.",
no_args_is_help=True,
)
app.add_typer(images_app, name="images")
def _get_client(remote: str | None) -> TsrClient:
"""Get TsrClient for remote or raise error."""
url = resolve_remote(remote)
if not url:
console.print("[red]Error: No remote specified. Use --remote or set default_remote in config.[/red]")
raise typer.Exit(1)
return TsrClient(url)
@images_app.command("list")
def images_list(
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 50,
offset: Annotated[int, typer.Option("--offset", help="Offset for pagination")] = 0,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""List images in remote gallery."""
try:
with _get_client(remote) as client:
result = client.list_images(limit=limit, offset=offset)
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
images = result.get("images", [])
total = result.get("total", len(images))
if json_output:
console.print_json(data=result)
return
if not images:
console.print("[yellow]No images in gallery.[/yellow]")
return
table = Table(title=f"Gallery Images ({len(images)}/{total})", show_header=True, header_style="bold magenta")
table.add_column("ID", style="cyan")
table.add_column("Filename", style="green")
table.add_column("Size", style="white")
table.add_column("Created", style="dim")
for img in images:
size = f"{img.get('width', '?')}x{img.get('height', '?')}"
created = img.get("created_at", "")
if isinstance(created, (int, float)):
from datetime import datetime # noqa: PLC0415
created = datetime.fromtimestamp(created).strftime("%Y-%m-%d %H:%M")
table.add_row(img.get("id", ""), img.get("filename", ""), size, str(created))
console.print(table)
@images_app.command("show")
def images_show(
image_id: Annotated[str, typer.Argument(help="Image ID to show")],
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Show image metadata."""
try:
with _get_client(remote) as client:
meta = client.get_image_meta(image_id)
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
if json_output:
console.print_json(data=meta)
return
table = Table(title=f"Image: {image_id}", show_header=True, header_style="bold magenta")
table.add_column("Property", style="cyan")
table.add_column("Value", style="green")
for key, value in meta.items():
display_value = json.dumps(value, indent=2) if isinstance(value, dict) else str(value)
table.add_row(key, display_value)
console.print(table)
@images_app.command("delete")
def images_delete(
image_id: Annotated[str, typer.Argument(help="Image ID to delete")],
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
force: Annotated[bool, typer.Option("-f", "--force", help="Skip confirmation")] = False,
) -> None:
"""Delete an image from the gallery."""
if not force:
confirm = typer.confirm(f"Delete image {image_id}?")
if not confirm:
console.print("[yellow]Cancelled.[/yellow]")
raise typer.Exit(0)
try:
with _get_client(remote) as client:
client.delete_image(image_id)
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
console.print(f"[green]Deleted image: {image_id}[/green]")
@images_app.command("download")
def images_download(
image_id: Annotated[str, typer.Argument(help="Image ID to download")],
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file or directory")] = None,
) -> None:
"""Download an image from the remote gallery."""
try:
with _get_client(remote) as client:
content = client.download_image(image_id)
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
# Determine output path
if output is None:
dest = Path(f"{image_id}.png")
elif output.is_dir():
dest = output / f"{image_id}.png"
else:
dest = output
dest.write_bytes(content)
console.print(f"[green]Saved:[/green] {dest}")
# =============================================================================
# Models Commands (Remote)
# =============================================================================
models_app = typer.Typer(
name="models",
help="Manage models on remote server.",
no_args_is_help=True,
)
app.add_typer(models_app, name="models")
@models_app.command("list")
def models_list(
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""List available models on remote server."""
try:
with _get_client(remote) as client:
result = client.list_models()
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
if json_output:
console.print_json(data=result)
return
models = result.get("models", [])
active = result.get("active", "")
if not models:
console.print("[yellow]No models found.[/yellow]")
return
table = Table(title="Available Models", show_header=True, header_style="bold magenta")
table.add_column("ID", style="dim", width=8)
table.add_column("Name", style="cyan")
table.add_column("File", style="white")
table.add_column("Size", style="green", justify="right")
for model in models:
path = model.get("path", "")
name = model.get("name", Path(path).stem if path else "")
is_active = active in {path, name}
civitai_id = model.get("civitai_model_id")
id_str = str(civitai_id) if civitai_id else ""
display_name = model.get("display_name", name)
if is_active:
display_name = f"[green]✓[/green] {display_name}"
filename = model.get("filename", Path(path).name if path else "")
size_str = _format_size_mb(model.get("size_mb"))
table.add_row(id_str, display_name, filename, size_str)
console.print(table)
@models_app.command("active")
def models_active(
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Show currently active model."""
try:
with _get_client(remote) as client:
result = client.get_active_model()
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
if json_output:
console.print_json(data=result)
return
model = result.get("model", "None")
console.print(f"[bold]Active model:[/bold] {model}")
@models_app.command("switch")
def models_switch(
model: Annotated[str, typer.Argument(help="Model path or name to switch to")],
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
) -> None:
"""Switch to a different model on the remote server."""
console.print(f"[cyan]Switching to model: {model}[/cyan]")
try:
with _get_client(remote) as client:
result = client.switch_model(model)
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
console.print(f"[green]{result.get('status', 'OK')}[/green]")
@models_app.command("loras")
def models_loras(
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""List available LoRAs on remote server."""
try:
with _get_client(remote) as client:
result = client.list_loras()
except TsrClientError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
if json_output:
console.print_json(data=result)
return
loras = result.get("loras", [])
if not loras:
console.print("[yellow]No LoRAs found.[/yellow]")
return
table = Table(title="Available LoRAs", show_header=True, header_style="bold magenta")
table.add_column("ID", style="dim", width=8)
table.add_column("Name", style="cyan")
table.add_column("File", style="white")
table.add_column("Size", style="green", justify="right")
for lora in loras:
path = lora.get("path", "")
name = lora.get("name", Path(path).stem if path else "")
civitai_id = lora.get("civitai_model_id")
id_str = str(civitai_id) if civitai_id else ""
display_name = lora.get("display_name", name)
filename = lora.get("filename", Path(path).name if path else "")
size_str = _format_size_mb(lora.get("size_mb"))
table.add_row(id_str, display_name, filename, size_str)
console.print(table)
# =============================================================================
# Remote Configuration Commands
# =============================================================================
remote_app = typer.Typer(
name="remote",
help="Manage remote server configuration.",
no_args_is_help=True,
)
app.add_typer(remote_app, name="remote")
@remote_app.command("list")
def remote_list(
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""List configured remotes."""
from tensors.config import get_default_remote # noqa: PLC0415
remotes = get_remotes()
default = get_default_remote()
if json_output:
console.print_json(data={"remotes": remotes, "default": default})
return
if not remotes:
console.print("[yellow]No remotes configured.[/yellow]")
console.print("[dim]Add one with: tsr remote add NAME URL[/dim]")
return
table = Table(title="Configured Remotes", show_header=True, header_style="bold magenta")
table.add_column("Default", style="dim", width=3)
table.add_column("Name", style="cyan")
table.add_column("URL", style="green")
for name, url in remotes.items():
is_default = name == default
status = "[green]✓[/green]" if is_default else ""
table.add_row(status, name, url)
console.print(table)
@remote_app.command("add")
def remote_add(
name: Annotated[str, typer.Argument(help="Remote name")],
url: Annotated[str, typer.Argument(help="Remote URL (e.g., http://host:8080)")],
) -> None:
"""Add a remote server."""
save_remote(name, url)
console.print(f"[green]Added remote:[/green] {name}{url}")
@remote_app.command("default")
def remote_default(
name: Annotated[str | None, typer.Argument(help="Remote name to set as default (omit to clear)")] = None,
) -> None:
"""Set or clear the default remote."""
set_default_remote(name)
if name:
console.print(f"[green]Default remote set to:[/green] {name}")
else:
console.print("[green]Default remote cleared.[/green]")
# =============================================================================
# ComfyUI Commands
# =============================================================================
@@ -1424,14 +870,8 @@ def main() -> int:
"dl",
"download",
"config",
"generate",
"serve",
"status",
"reload",
"db",
"images",
"models",
"remote",
"comfy",
)
if len(sys.argv) > 1 and not sys.argv[1].startswith("-"):
+14 -31
View File
@@ -1,4 +1,4 @@
"""sd-server wrapper — FastAPI app for proxying to an external sd-server."""
"""Tensors server — FastAPI app for gallery and CivitAI management."""
from __future__ import annotations
@@ -7,19 +7,14 @@ from contextlib import asynccontextmanager
from pathlib import Path
from typing import TYPE_CHECKING
import httpx
from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from tensors.config import get_sd_server_api_key, get_sd_server_url
from tensors.server.civitai_routes import create_civitai_router
from tensors.server.db_routes import create_db_router
from tensors.server.download_routes import create_download_router
from tensors.server.gallery_routes import create_gallery_router
from tensors.server.generate_routes import create_generate_router
from tensors.server.models_routes import create_models_router
from tensors.server.routes import create_router
if TYPE_CHECKING:
from collections.abc import AsyncIterator
@@ -29,28 +24,15 @@ __all__ = ["app", "create_app"]
logger = logging.getLogger(__name__)
def create_app(sd_server_url: str | None = None) -> FastAPI:
"""Build the FastAPI application that proxies to an external sd-server.
Args:
sd_server_url: URL of the sd-server to proxy to. If None, uses
get_sd_server_url() to resolve from env/config.
"""
backend_url = sd_server_url or get_sd_server_url()
api_key = get_sd_server_api_key()
def create_app() -> FastAPI:
"""Build the FastAPI application for gallery and model management."""
@asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
_app.state.sd_server_url = backend_url
_app.state.sd_server_api_key = api_key
logger.info(f"Proxying to sd-server at: {backend_url}")
if api_key:
logger.info("Using API key authentication for sd-server")
async with httpx.AsyncClient(timeout=300) as client:
_app.state.client = client
yield
logger.info("Tensors server starting")
yield
app = FastAPI(title="sd-server wrapper", lifespan=lifespan)
app = FastAPI(title="tensors", lifespan=lifespan)
# Serve Vue UI static files
static_dir = Path(__file__).parent / "static"
@@ -66,13 +48,14 @@ def create_app(sd_server_url: str | None = None) -> FastAPI:
async def vite_icon() -> FileResponse:
return FileResponse(static_dir / "vite.svg")
app.include_router(create_civitai_router()) # Must be before catch-all proxy
app.include_router(create_db_router()) # Must be before catch-all proxy
app.include_router(create_gallery_router()) # Must be before catch-all proxy
app.include_router(create_models_router()) # Must be before catch-all proxy
app.include_router(create_download_router()) # Must be before catch-all proxy
app.include_router(create_generate_router()) # Must be before catch-all proxy
app.include_router(create_router())
@app.get("/status")
async def status() -> dict[str, str]:
return {"status": "ok"}
app.include_router(create_civitai_router())
app.include_router(create_db_router())
app.include_router(create_gallery_router())
app.include_router(create_download_router())
return app