Remove internal sd-server management, proxy to external sd-server
- Remove ProcessManager and process.py - Add get_sd_server_url() config (env/config/default) - Update routes to proxy to external sd-server URL - Remove model switching (handled by external sd-server) - Update CLI serve command - Update tests for new architecture Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
+1
-1
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "tensors"
|
||||
version = "0.1.11"
|
||||
version = "0.1.12"
|
||||
description = "Read safetensor metadata and fetch CivitAI model information"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 23 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 54 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 486 KiB |
+11
-3
@@ -55,9 +55,17 @@ def sync_to_junkpile() -> None:
|
||||
"""Sync project to junkpile."""
|
||||
print("\n[3/4] Syncing to junkpile...")
|
||||
excludes = [
|
||||
".git", ".venv", "__pycache__", "*.pyc", ".mypy_cache",
|
||||
".pytest_cache", ".ruff_cache", ".coverage", "*.egg-info",
|
||||
"node_modules", ".tmp",
|
||||
".git",
|
||||
".venv",
|
||||
"__pycache__",
|
||||
"*.pyc",
|
||||
".mypy_cache",
|
||||
".pytest_cache",
|
||||
".ruff_cache",
|
||||
".coverage",
|
||||
"*.egg-info",
|
||||
"node_modules",
|
||||
".tmp",
|
||||
]
|
||||
cmd = ["rsync", "-avz", "--delete"]
|
||||
for exc in excludes:
|
||||
|
||||
+4
-6
@@ -582,24 +582,22 @@ def reload(
|
||||
|
||||
@app.command()
|
||||
def serve(
|
||||
model: Annotated[str, typer.Option(help="Path to model file for sd-server.")],
|
||||
host: Annotated[str, typer.Option(help="Wrapper API listen address.")] = "127.0.0.1",
|
||||
port: Annotated[int, typer.Option(help="Wrapper API listen port.")] = 8080,
|
||||
sd_port: Annotated[int, typer.Option(help="sd-server listen port.")] = 1234,
|
||||
sd_server: Annotated[str | None, typer.Option(help="sd-server URL to proxy to.")] = None,
|
||||
log_level: Annotated[str, typer.Option(help="Log level.")] = "info",
|
||||
) -> None:
|
||||
"""Start the sd-server wrapper API (transparent proxy with hot reload)."""
|
||||
"""Start the sd-server wrapper API (proxies to external sd-server)."""
|
||||
try:
|
||||
import uvicorn # noqa: PLC0415
|
||||
|
||||
from tensors.server import ServerConfig, create_app # noqa: PLC0415
|
||||
from tensors.server import create_app # noqa: PLC0415
|
||||
except ImportError:
|
||||
console.print("[red]Missing server dependencies. Install with:[/red]")
|
||||
console.print(" pip install tensors[server]")
|
||||
raise typer.Exit(1) from None
|
||||
|
||||
config = ServerConfig(model=model, port=sd_port)
|
||||
uvicorn.run(create_app(config), host=host, port=port, log_level=log_level)
|
||||
uvicorn.run(create_app(sd_server_url=sd_server), host=host, port=port, log_level=log_level)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
@@ -243,3 +243,34 @@ def set_default_remote(name: str | None) -> None:
|
||||
else:
|
||||
config["default_remote"] = name
|
||||
save_config(config)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SD Server Configuration
|
||||
# ============================================================================
|
||||
|
||||
SD_SERVER_DEFAULT_URL = "http://localhost:1234"
|
||||
|
||||
|
||||
def get_sd_server_url() -> str:
|
||||
"""Get the sd-server URL.
|
||||
|
||||
Resolution order:
|
||||
1. SD_SERVER_URL environment variable
|
||||
2. config.toml [server].sd_server_url
|
||||
3. Default: http://localhost:1234
|
||||
"""
|
||||
# Check environment variable first
|
||||
env_url = os.environ.get("SD_SERVER_URL")
|
||||
if env_url:
|
||||
return env_url
|
||||
|
||||
# Check config file
|
||||
config = load_config()
|
||||
server_config = config.get("server", {})
|
||||
if isinstance(server_config, dict):
|
||||
url = server_config.get("sd_server_url")
|
||||
if url:
|
||||
return str(url)
|
||||
|
||||
return SD_SERVER_DEFAULT_URL
|
||||
|
||||
+16
-20
@@ -1,4 +1,4 @@
|
||||
"""sd-server wrapper — FastAPI app for managing and proxying to sd-server."""
|
||||
"""sd-server wrapper — FastAPI app for proxying to an external sd-server."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -12,42 +12,39 @@ from fastapi import FastAPI
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from tensors.config import get_sd_server_url
|
||||
from tensors.server.civitai_routes import create_civitai_router
|
||||
from tensors.server.db_routes import create_db_router
|
||||
from tensors.server.download_routes import create_download_router
|
||||
from tensors.server.gallery_routes import create_gallery_router
|
||||
from tensors.server.generate_routes import create_generate_router
|
||||
from tensors.server.models import ServerConfig
|
||||
from tensors.server.models_routes import create_models_router
|
||||
from tensors.server.process import ProcessManager
|
||||
from tensors.server.routes import create_router
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
__all__ = ["ProcessManager", "ServerConfig", "app", "create_app"]
|
||||
__all__ = ["app", "create_app"]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_app(config: ServerConfig | None = None) -> FastAPI:
|
||||
"""Build the FastAPI application with process manager and proxy client."""
|
||||
pm = ProcessManager()
|
||||
def create_app(sd_server_url: str | None = None) -> FastAPI:
|
||||
"""Build the FastAPI application that proxies to an external sd-server.
|
||||
|
||||
Args:
|
||||
sd_server_url: URL of the sd-server to proxy to. If None, uses
|
||||
get_sd_server_url() to resolve from env/config.
|
||||
"""
|
||||
backend_url = sd_server_url or get_sd_server_url()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
||||
_app.state.sd_server_url = backend_url
|
||||
logger.info(f"Proxying to sd-server at: {backend_url}")
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
_app.state.client = client
|
||||
if config is not None:
|
||||
pm.start(config)
|
||||
logger.info("waiting for sd-server to become ready...")
|
||||
ready = await pm.wait_ready()
|
||||
if ready:
|
||||
logger.info("sd-server is ready")
|
||||
else:
|
||||
logger.warning("sd-server did not become ready in time")
|
||||
yield
|
||||
pm.stop()
|
||||
|
||||
app = FastAPI(title="sd-server wrapper", lifespan=lifespan)
|
||||
|
||||
@@ -68,11 +65,10 @@ def create_app(config: ServerConfig | None = None) -> FastAPI:
|
||||
app.include_router(create_civitai_router()) # Must be before catch-all proxy
|
||||
app.include_router(create_db_router()) # Must be before catch-all proxy
|
||||
app.include_router(create_gallery_router()) # Must be before catch-all proxy
|
||||
app.include_router(create_models_router(pm)) # Must be before catch-all proxy
|
||||
app.include_router(create_models_router()) # Must be before catch-all proxy
|
||||
app.include_router(create_download_router()) # Must be before catch-all proxy
|
||||
app.include_router(create_generate_router(pm)) # Must be before catch-all proxy
|
||||
app.include_router(create_router(pm))
|
||||
app.state.pm = pm
|
||||
app.include_router(create_generate_router()) # Must be before catch-all proxy
|
||||
app.include_router(create_router())
|
||||
return app
|
||||
|
||||
|
||||
|
||||
@@ -5,18 +5,15 @@ from __future__ import annotations
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from tensors.server.gallery import Gallery
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensors.server.process import ProcessManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -113,38 +110,30 @@ def _process_image(
|
||||
return image_info
|
||||
|
||||
|
||||
def _check_server_running(pm: ProcessManager) -> None:
|
||||
"""Check if sd-server is running, raise HTTPException if not."""
|
||||
if pm.proc is None or pm.proc.poll() is not None:
|
||||
raise HTTPException(status_code=503, detail="sd-server is not running")
|
||||
if pm.config is None:
|
||||
raise HTTPException(status_code=503, detail="sd-server not configured")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Router Factory
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_generate_router(pm: ProcessManager) -> APIRouter:
|
||||
def create_generate_router() -> APIRouter:
|
||||
"""Build a router with /api/generate endpoint."""
|
||||
router = APIRouter(prefix="/api", tags=["generate"])
|
||||
gallery = Gallery()
|
||||
|
||||
@router.post("/generate")
|
||||
async def generate(req: GenerateRequest) -> dict[str, Any]:
|
||||
async def generate(request: Request, req: GenerateRequest) -> dict[str, Any]:
|
||||
"""Generate images with gallery integration."""
|
||||
_check_server_running(pm)
|
||||
assert pm.config is not None # Verified by _check_server_running
|
||||
|
||||
sd_server_url = request.app.state.sd_server_url
|
||||
body = _build_sd_request(req)
|
||||
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/txt2img"
|
||||
url = f"{sd_server_url}/sdapi/v1/txt2img"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
response = await client.post(url, json=body)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
except httpx.ConnectError as e:
|
||||
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
|
||||
except httpx.HTTPError as e:
|
||||
logger.exception("Generation failed")
|
||||
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
||||
@@ -153,8 +142,11 @@ def create_generate_router(pm: ProcessManager) -> APIRouter:
|
||||
info = _parse_info(result.get("info", {}))
|
||||
all_seeds = info.get("all_seeds", [req.seed] * len(images_data))
|
||||
|
||||
# Get model info from sd-server response if available
|
||||
model_name = info.get("sd_model_name") or info.get("model")
|
||||
|
||||
output_images = [
|
||||
_process_image(img_b64, i, all_seeds[i] if i < len(all_seeds) else req.seed + i, req, gallery, pm.config.model)
|
||||
_process_image(img_b64, i, all_seeds[i] if i < len(all_seeds) else req.seed + i, req, gallery, model_name)
|
||||
for i, img_b64 in enumerate(images_data)
|
||||
]
|
||||
|
||||
@@ -167,32 +159,34 @@ def create_generate_router(pm: ProcessManager) -> APIRouter:
|
||||
}
|
||||
|
||||
@router.get("/samplers")
|
||||
async def list_samplers() -> dict[str, Any]:
|
||||
async def list_samplers(request: Request) -> dict[str, Any]:
|
||||
"""List available samplers from sd-server."""
|
||||
_check_server_running(pm)
|
||||
assert pm.config is not None
|
||||
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/samplers"
|
||||
sd_server_url = request.app.state.sd_server_url
|
||||
url = f"{sd_server_url}/sdapi/v1/samplers"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
return {"samplers": response.json()}
|
||||
except httpx.ConnectError as e:
|
||||
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
||||
|
||||
@router.get("/schedulers")
|
||||
async def list_schedulers() -> dict[str, Any]:
|
||||
async def list_schedulers(request: Request) -> dict[str, Any]:
|
||||
"""List available schedulers from sd-server."""
|
||||
_check_server_running(pm)
|
||||
assert pm.config is not None
|
||||
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/schedulers"
|
||||
sd_server_url = request.app.state.sd_server_url
|
||||
url = f"{sd_server_url}/sdapi/v1/schedulers"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
return {"schedulers": response.json()}
|
||||
except httpx.ConnectError as e:
|
||||
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
||||
|
||||
|
||||
@@ -2,16 +2,5 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
DEFAULT_PORT = 1234
|
||||
|
||||
|
||||
class ReloadRequest(BaseModel):
|
||||
model: str
|
||||
|
||||
|
||||
class ServerConfig(BaseModel):
|
||||
model: str
|
||||
port: int = DEFAULT_PORT
|
||||
args: list[str] = []
|
||||
# Note: ServerConfig and ReloadRequest were removed since we no longer manage
|
||||
# sd-server processes internally. The wrapper now proxies to an external sd-server.
|
||||
|
||||
@@ -3,30 +3,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
from fastapi import APIRouter, Request
|
||||
|
||||
from tensors.config import MODELS_DIR
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensors.server.process import ProcessManager
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Request/Response Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class SwitchModelRequest(PydanticBaseModel):
|
||||
"""Request body for switching models."""
|
||||
|
||||
model: str # Path to model file
|
||||
_HTTP_OK = 200
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -76,7 +64,7 @@ def scan_checkpoints(directory: Path | None = None) -> list[dict[str, Any]]:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_models_router(pm: ProcessManager) -> APIRouter:
|
||||
def create_models_router() -> APIRouter:
|
||||
"""Build a router with /api/models/* endpoints."""
|
||||
router = APIRouter(prefix="/api/models", tags=["models"])
|
||||
|
||||
@@ -90,62 +78,34 @@ def create_models_router(pm: ProcessManager) -> APIRouter:
|
||||
}
|
||||
|
||||
@router.get("/active")
|
||||
def get_active_model() -> dict[str, Any]:
|
||||
"""Get information about the currently loaded model."""
|
||||
status = pm.status()
|
||||
config = pm.config
|
||||
async def get_active_model(request: Request) -> dict[str, Any]:
|
||||
"""Get information about the currently loaded model from sd-server."""
|
||||
import httpx # noqa: PLC0415
|
||||
|
||||
sd_server_url = request.app.state.sd_server_url
|
||||
|
||||
# Try to get current model from sd-server's options endpoint
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.get(f"{sd_server_url}/sdapi/v1/options")
|
||||
if response.status_code == _HTTP_OK:
|
||||
options = response.json()
|
||||
model_name = options.get("sd_model_checkpoint")
|
||||
return {
|
||||
"loaded": True,
|
||||
"model": model_name,
|
||||
"sd_server_url": sd_server_url,
|
||||
}
|
||||
except httpx.HTTPError:
|
||||
pass
|
||||
|
||||
if config is None:
|
||||
return {
|
||||
"loaded": False,
|
||||
"model": None,
|
||||
"status": status.get("status"),
|
||||
"sd_server_url": sd_server_url,
|
||||
"error": "Cannot connect to sd-server",
|
||||
}
|
||||
|
||||
return {
|
||||
"loaded": status.get("status") == "running",
|
||||
"model": config.model,
|
||||
"pid": status.get("pid"),
|
||||
"port": config.port,
|
||||
"status": status.get("status"),
|
||||
}
|
||||
|
||||
@router.post("/switch")
|
||||
async def switch_model(req: SwitchModelRequest) -> JSONResponse:
|
||||
"""Switch to a different model (hot reload)."""
|
||||
model_path = Path(req.model)
|
||||
|
||||
# Validate model exists
|
||||
if not model_path.exists():
|
||||
raise HTTPException(status_code=400, detail=f"Model not found: {req.model}")
|
||||
|
||||
# Use existing reload logic
|
||||
from tensors.server.models import ServerConfig # noqa: PLC0415
|
||||
|
||||
new_config = ServerConfig(
|
||||
model=req.model,
|
||||
port=pm.config.port if pm.config else 1234,
|
||||
args=pm.config.args if pm.config else [],
|
||||
)
|
||||
|
||||
pm.stop()
|
||||
pm.start(new_config)
|
||||
ready = await pm.wait_ready()
|
||||
|
||||
if not ready:
|
||||
return JSONResponse(
|
||||
{"error": "sd-server failed to become ready", "model": req.model},
|
||||
status_code=503,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"ok": True,
|
||||
"model": req.model,
|
||||
"pid": pm.proc.pid if pm.proc else None,
|
||||
}
|
||||
)
|
||||
|
||||
@router.get("/loras")
|
||||
def list_loras() -> dict[str, Any]:
|
||||
"""List available LoRA files."""
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
"""sd-server process lifecycle management."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensors.server.models import ServerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_HTTP_OK = 200
|
||||
|
||||
SD_SERVER_BIN = shutil.which("sd-server") or "sd-server"
|
||||
|
||||
|
||||
class ProcessManager:
|
||||
def __init__(self) -> None:
|
||||
self.proc: subprocess.Popen[bytes] | None = None
|
||||
self.config: ServerConfig | None = None
|
||||
|
||||
def build_cmd(self) -> list[str]:
|
||||
if self.config is None:
|
||||
raise RuntimeError("No config set")
|
||||
cmd = [SD_SERVER_BIN, "-m", self.config.model, "--listen-port", str(self.config.port)]
|
||||
cmd.extend(self.config.args)
|
||||
return cmd
|
||||
|
||||
def start(self, config: ServerConfig) -> None:
|
||||
if self.proc is not None and self.proc.poll() is None:
|
||||
raise RuntimeError("Server already running — stop it first")
|
||||
self.config = config
|
||||
cmd = self.build_cmd()
|
||||
# Inherit environment (important for HSA_OVERRIDE_GFX_VERSION on ROCm)
|
||||
self.proc = subprocess.Popen(cmd, env=os.environ.copy())
|
||||
logger.info("started sd-server pid=%d cmd=%s", self.proc.pid, cmd)
|
||||
|
||||
def stop(self) -> bool:
|
||||
if self.proc is None or self.proc.poll() is not None:
|
||||
self.proc = None
|
||||
return False
|
||||
self.proc.send_signal(signal.SIGTERM)
|
||||
try:
|
||||
self.proc.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.proc.kill()
|
||||
self.proc.wait(timeout=5)
|
||||
logger.info("stopped sd-server")
|
||||
self.proc = None
|
||||
return True
|
||||
|
||||
def status(self) -> dict[str, Any]:
|
||||
if self.proc is None:
|
||||
return {"running": False}
|
||||
rc = self.proc.poll()
|
||||
if rc is not None:
|
||||
return {"running": False, "exit_code": rc}
|
||||
return {
|
||||
"running": True,
|
||||
"pid": self.proc.pid,
|
||||
"model": self.config.model if self.config else None,
|
||||
"port": self.config.port if self.config else None,
|
||||
"bind": f"http://127.0.0.1:{self.config.port}" if self.config else None,
|
||||
"cmd": self.build_cmd(),
|
||||
}
|
||||
|
||||
async def wait_ready(self, timeout: float = 120) -> bool:
|
||||
"""Poll sd-server root endpoint until it responds or timeout."""
|
||||
if self.config is None:
|
||||
return False
|
||||
url = f"http://127.0.0.1:{self.config.port}/"
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
async with httpx.AsyncClient() as client:
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
if self.proc is not None and self.proc.poll() is not None:
|
||||
return False
|
||||
try:
|
||||
r = await client.get(url, timeout=2)
|
||||
if r.status_code == _HTTP_OK:
|
||||
return True
|
||||
except (httpx.ConnectError, httpx.TimeoutException, httpx.ReadError):
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
return False
|
||||
+37
-28
@@ -3,53 +3,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Request, Response
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from tensors.server.models import ReloadRequest, ServerConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensors.server.process import ProcessManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_router(pm: ProcessManager) -> APIRouter:
|
||||
"""Build a router with /status, /reload, and catch-all proxy."""
|
||||
def create_router() -> APIRouter:
|
||||
"""Build a router with /status and catch-all proxy."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/status")
|
||||
def status() -> dict[str, Any]:
|
||||
return pm.status()
|
||||
|
||||
@router.post("/reload")
|
||||
async def reload(req: ReloadRequest) -> Response:
|
||||
new_config = ServerConfig(
|
||||
model=req.model,
|
||||
port=pm.config.port if pm.config else 1234,
|
||||
args=pm.config.args if pm.config else [],
|
||||
)
|
||||
pm.stop()
|
||||
pm.start(new_config)
|
||||
ready = await pm.wait_ready()
|
||||
if not ready:
|
||||
return JSONResponse({"error": "sd-server failed to become ready", "model": req.model}, status_code=503)
|
||||
return JSONResponse({"ok": True, "model": req.model, "pid": pm.proc.pid if pm.proc else None})
|
||||
async def status(request: Request) -> dict[str, Any]:
|
||||
"""Check if the external sd-server is reachable."""
|
||||
sd_server_url = request.app.state.sd_server_url
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
r = await client.get(sd_server_url)
|
||||
return {
|
||||
"status": "ok",
|
||||
"sd_server_url": sd_server_url,
|
||||
"sd_server_status": r.status_code,
|
||||
}
|
||||
except httpx.HTTPError as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"sd_server_url": sd_server_url,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"])
|
||||
async def proxy(request: Request, path: str) -> Response:
|
||||
if pm.proc is None or pm.proc.poll() is not None:
|
||||
return JSONResponse({"error": "sd-server is not running"}, status_code=503)
|
||||
assert pm.config is not None
|
||||
url = f"http://127.0.0.1:{pm.config.port}/{path}"
|
||||
"""Proxy all requests to the external sd-server."""
|
||||
sd_server_url = request.app.state.sd_server_url
|
||||
url = f"{sd_server_url}/{path}"
|
||||
if request.url.query:
|
||||
url = f"{url}?{request.url.query}"
|
||||
|
||||
body = await request.body()
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
client = request.app.state.client
|
||||
|
||||
try:
|
||||
upstream = await client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
@@ -62,5 +61,15 @@ def create_router(pm: ProcessManager) -> APIRouter:
|
||||
status_code=upstream.status_code,
|
||||
headers=dict(upstream.headers),
|
||||
)
|
||||
except httpx.ConnectError:
|
||||
return JSONResponse(
|
||||
{"error": f"Cannot connect to sd-server at {sd_server_url}"},
|
||||
status_code=503,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
return JSONResponse(
|
||||
{"error": f"Timeout connecting to sd-server at {sd_server_url}"},
|
||||
status_code=504,
|
||||
)
|
||||
|
||||
return router
|
||||
|
||||
+45
-117
@@ -2,110 +2,51 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from tensors.server import create_app
|
||||
from tensors.server.models import ServerConfig
|
||||
from tensors.server.process import ProcessManager
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def pm() -> ProcessManager:
|
||||
return ProcessManager()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def api() -> TestClient:
|
||||
return TestClient(create_app())
|
||||
|
||||
|
||||
def _get_pm(api: TestClient) -> ProcessManager:
|
||||
return api.app.state.pm # type: ignore[no-any-return, attr-defined]
|
||||
"""Create test client with mock sd-server URL."""
|
||||
return TestClient(create_app(sd_server_url="http://mock-sd-server:1234"))
|
||||
|
||||
|
||||
class TestStatus:
|
||||
def test_not_running(self, api: TestClient) -> None:
|
||||
r = api.get("/status")
|
||||
assert r.status_code == 200
|
||||
assert r.json()["running"] is False
|
||||
@respx.mock
|
||||
def test_status_when_backend_reachable(self) -> None:
|
||||
"""Test status endpoint when sd-server is reachable."""
|
||||
respx.get("http://mock-sd-server:1234/").mock(return_value=httpx.Response(200))
|
||||
|
||||
def test_running(self, api: TestClient) -> None:
|
||||
pm = _get_pm(api)
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.poll.return_value = None
|
||||
mock_proc.pid = 999
|
||||
pm.proc = mock_proc
|
||||
pm.config = ServerConfig(model="/m.safetensors")
|
||||
r = api.get("/status")
|
||||
data = r.json()
|
||||
assert data["running"] is True
|
||||
assert data["pid"] == 999
|
||||
|
||||
def test_exited(self, api: TestClient) -> None:
|
||||
pm = _get_pm(api)
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.poll.return_value = 1
|
||||
pm.proc = mock_proc
|
||||
r = api.get("/status")
|
||||
data = r.json()
|
||||
assert data["running"] is False
|
||||
assert data["exit_code"] == 1
|
||||
|
||||
|
||||
class TestReload:
|
||||
@patch.object(ProcessManager, "wait_ready", new_callable=AsyncMock, return_value=True)
|
||||
@patch("tensors.server.process.subprocess.Popen")
|
||||
def test_reload_swaps_model(self, mock_popen: MagicMock, mock_ready: AsyncMock, api: TestClient) -> None:
|
||||
pm = _get_pm(api)
|
||||
pm.config = ServerConfig(model="/old.gguf", port=5555, args=["--fa"])
|
||||
mock_popen.return_value.pid = 42
|
||||
mock_popen.return_value.poll.return_value = None
|
||||
r = api.post("/reload", json={"model": "/new.gguf"})
|
||||
with TestClient(create_app(sd_server_url="http://mock-sd-server:1234")) as client:
|
||||
r = client.get("/status")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert data["model"] == "/new.gguf"
|
||||
assert data["pid"] == 42
|
||||
# Verify new config preserved port and args from previous config
|
||||
assert pm.config is not None
|
||||
assert pm.config.port == 5555
|
||||
assert pm.config.args == ["--fa"]
|
||||
assert pm.config.model == "/new.gguf"
|
||||
assert data["status"] == "ok"
|
||||
assert data["sd_server_url"] == "http://mock-sd-server:1234"
|
||||
|
||||
@patch.object(ProcessManager, "wait_ready", new_callable=AsyncMock, return_value=False)
|
||||
@patch("tensors.server.process.subprocess.Popen")
|
||||
def test_reload_fails_when_not_ready(self, mock_popen: MagicMock, mock_ready: AsyncMock, api: TestClient) -> None:
|
||||
pm = _get_pm(api)
|
||||
pm.config = ServerConfig(model="/old.gguf")
|
||||
mock_popen.return_value.pid = 43
|
||||
mock_popen.return_value.poll.return_value = None
|
||||
r = api.post("/reload", json={"model": "/bad.gguf"})
|
||||
assert r.status_code == 503
|
||||
assert "failed" in r.json()["error"]
|
||||
@respx.mock
|
||||
def test_status_when_backend_unreachable(self) -> None:
|
||||
"""Test status endpoint when sd-server is not reachable."""
|
||||
respx.get("http://mock-sd-server:1234/").mock(side_effect=httpx.ConnectError("Connection refused"))
|
||||
|
||||
def test_reload_requires_model(self, api: TestClient) -> None:
|
||||
r = api.post("/reload", json={})
|
||||
assert r.status_code == 422
|
||||
with TestClient(create_app(sd_server_url="http://mock-sd-server:1234")) as client:
|
||||
r = client.get("/status")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["status"] == "error"
|
||||
assert "Connection refused" in data["error"]
|
||||
|
||||
|
||||
class TestProxy:
|
||||
def test_proxy_503_when_not_running(self, api: TestClient) -> None:
|
||||
r = api.get("/v1/models")
|
||||
assert r.status_code == 503
|
||||
assert "not running" in r.json()["error"]
|
||||
|
||||
def test_proxy_forwards_request(self, api: TestClient) -> None:
|
||||
pm = _get_pm(api)
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.poll.return_value = None
|
||||
mock_proc.pid = 100
|
||||
pm.proc = mock_proc
|
||||
pm.config = ServerConfig(model="/m.gguf", port=1234)
|
||||
|
||||
"""Test proxy forwards GET requests to backend."""
|
||||
upstream_response = httpx.Response(
|
||||
200,
|
||||
json={"data": [{"id": "model-1"}]},
|
||||
@@ -114,6 +55,7 @@ class TestProxy:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = upstream_response
|
||||
api.app.state.client = mock_client # type: ignore[attr-defined]
|
||||
api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined]
|
||||
|
||||
r = api.get("/v1/models")
|
||||
assert r.status_code == 200
|
||||
@@ -121,52 +63,38 @@ class TestProxy:
|
||||
mock_client.request.assert_called_once()
|
||||
|
||||
def test_proxy_forwards_post_with_body(self, api: TestClient) -> None:
|
||||
pm = _get_pm(api)
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.poll.return_value = None
|
||||
mock_proc.pid = 100
|
||||
pm.proc = mock_proc
|
||||
pm.config = ServerConfig(model="/m.gguf", port=1234)
|
||||
|
||||
"""Test proxy forwards POST requests with body."""
|
||||
upstream_response = httpx.Response(200, json={"ok": True})
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.return_value = upstream_response
|
||||
api.app.state.client = mock_client # type: ignore[attr-defined]
|
||||
api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined]
|
||||
|
||||
r = api.post("/v1/chat/completions", json={"prompt": "hello"})
|
||||
r = api.post("/sdapi/v1/txt2img", json={"prompt": "hello"})
|
||||
assert r.status_code == 200
|
||||
mock_client.request.assert_called_once()
|
||||
|
||||
def test_proxy_503_on_connect_error(self, api: TestClient) -> None:
|
||||
"""Test proxy returns 503 when backend is unreachable."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.side_effect = httpx.ConnectError("Connection refused")
|
||||
api.app.state.client = mock_client # type: ignore[attr-defined]
|
||||
api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined]
|
||||
|
||||
class TestProcessManager:
|
||||
def test_status_not_running(self, pm: ProcessManager) -> None:
|
||||
assert pm.status() == {"running": False}
|
||||
r = api.get("/v1/models")
|
||||
assert r.status_code == 503
|
||||
assert "Cannot connect" in r.json()["error"]
|
||||
|
||||
def test_build_cmd(self, pm: ProcessManager) -> None:
|
||||
pm.config = ServerConfig(model="/m.gguf", port=1234, args=["--fa"])
|
||||
cmd = pm.build_cmd()
|
||||
assert "/m.gguf" in cmd
|
||||
assert "--fa" in cmd
|
||||
assert "1234" in cmd
|
||||
def test_proxy_504_on_timeout(self, api: TestClient) -> None:
|
||||
"""Test proxy returns 504 on timeout."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.side_effect = httpx.TimeoutException("Timeout")
|
||||
api.app.state.client = mock_client # type: ignore[attr-defined]
|
||||
api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined]
|
||||
|
||||
def test_build_cmd_no_config(self, pm: ProcessManager) -> None:
|
||||
with pytest.raises(RuntimeError, match="No config"):
|
||||
pm.build_cmd()
|
||||
|
||||
@patch("tensors.server.process.subprocess.Popen")
|
||||
def test_start_and_stop(self, mock_popen: MagicMock, pm: ProcessManager) -> None:
|
||||
mock_popen.return_value.pid = 77
|
||||
mock_popen.return_value.poll.return_value = None
|
||||
mock_popen.return_value.wait.return_value = 0
|
||||
pm.start(ServerConfig(model="/m.gguf"))
|
||||
assert pm.proc is not None
|
||||
assert pm.stop() is True
|
||||
assert pm.proc is None
|
||||
|
||||
def test_server_config_defaults(self) -> None:
|
||||
cfg = ServerConfig(model="/m.gguf")
|
||||
assert cfg.port == 1234
|
||||
assert cfg.args == []
|
||||
r = api.get("/v1/models")
|
||||
assert r.status_code == 504
|
||||
assert "Timeout" in r.json()["error"]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
Reference in New Issue
Block a user