Remove internal sd-server management, proxy to external sd-server

- Remove ProcessManager and process.py
- Add get_sd_server_url() config (env/config/default)
- Update routes to proxy to external sd-server URL
- Remove model switching (handled by external sd-server)
- Update CLI serve command
- Update tests for new architecture

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Adam Ladachowski
2026-02-14 06:39:35 +01:00
parent be7cf0b6e7
commit e9480a18c2
16 changed files with 211 additions and 390 deletions
BIN
View File
Binary file not shown.
+1 -1
View File
@@ -1,6 +1,6 @@
[project] [project]
name = "tensors" name = "tensors"
version = "0.1.11" version = "0.1.12"
description = "Read safetensor metadata and fetch CivitAI model information" description = "Read safetensor metadata and fetch CivitAI model information"
readme = "README.md" readme = "README.md"
requires-python = ">=3.12" requires-python = ">=3.12"
Binary file not shown.

After

Width:  |  Height:  |  Size: 23 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 486 KiB

+11 -3
View File
@@ -55,9 +55,17 @@ def sync_to_junkpile() -> None:
"""Sync project to junkpile.""" """Sync project to junkpile."""
print("\n[3/4] Syncing to junkpile...") print("\n[3/4] Syncing to junkpile...")
excludes = [ excludes = [
".git", ".venv", "__pycache__", "*.pyc", ".mypy_cache", ".git",
".pytest_cache", ".ruff_cache", ".coverage", "*.egg-info", ".venv",
"node_modules", ".tmp", "__pycache__",
"*.pyc",
".mypy_cache",
".pytest_cache",
".ruff_cache",
".coverage",
"*.egg-info",
"node_modules",
".tmp",
] ]
cmd = ["rsync", "-avz", "--delete"] cmd = ["rsync", "-avz", "--delete"]
for exc in excludes: for exc in excludes:
+4 -6
View File
@@ -582,24 +582,22 @@ def reload(
@app.command() @app.command()
def serve( def serve(
model: Annotated[str, typer.Option(help="Path to model file for sd-server.")],
host: Annotated[str, typer.Option(help="Wrapper API listen address.")] = "127.0.0.1", host: Annotated[str, typer.Option(help="Wrapper API listen address.")] = "127.0.0.1",
port: Annotated[int, typer.Option(help="Wrapper API listen port.")] = 8080, port: Annotated[int, typer.Option(help="Wrapper API listen port.")] = 8080,
sd_port: Annotated[int, typer.Option(help="sd-server listen port.")] = 1234, sd_server: Annotated[str | None, typer.Option(help="sd-server URL to proxy to.")] = None,
log_level: Annotated[str, typer.Option(help="Log level.")] = "info", log_level: Annotated[str, typer.Option(help="Log level.")] = "info",
) -> None: ) -> None:
"""Start the sd-server wrapper API (transparent proxy with hot reload).""" """Start the sd-server wrapper API (proxies to external sd-server)."""
try: try:
import uvicorn # noqa: PLC0415 import uvicorn # noqa: PLC0415
from tensors.server import ServerConfig, create_app # noqa: PLC0415 from tensors.server import create_app # noqa: PLC0415
except ImportError: except ImportError:
console.print("[red]Missing server dependencies. Install with:[/red]") console.print("[red]Missing server dependencies. Install with:[/red]")
console.print(" pip install tensors[server]") console.print(" pip install tensors[server]")
raise typer.Exit(1) from None raise typer.Exit(1) from None
config = ServerConfig(model=model, port=sd_port) uvicorn.run(create_app(sd_server_url=sd_server), host=host, port=port, log_level=log_level)
uvicorn.run(create_app(config), host=host, port=port, log_level=log_level)
# ============================================================================= # =============================================================================
+31
View File
@@ -243,3 +243,34 @@ def set_default_remote(name: str | None) -> None:
else: else:
config["default_remote"] = name config["default_remote"] = name
save_config(config) save_config(config)
# ============================================================================
# SD Server Configuration
# ============================================================================
SD_SERVER_DEFAULT_URL = "http://localhost:1234"
def get_sd_server_url() -> str:
"""Get the sd-server URL.
Resolution order:
1. SD_SERVER_URL environment variable
2. config.toml [server].sd_server_url
3. Default: http://localhost:1234
"""
# Check environment variable first
env_url = os.environ.get("SD_SERVER_URL")
if env_url:
return env_url
# Check config file
config = load_config()
server_config = config.get("server", {})
if isinstance(server_config, dict):
url = server_config.get("sd_server_url")
if url:
return str(url)
return SD_SERVER_DEFAULT_URL
+16 -20
View File
@@ -1,4 +1,4 @@
"""sd-server wrapper — FastAPI app for managing and proxying to sd-server.""" """sd-server wrapper — FastAPI app for proxying to an external sd-server."""
from __future__ import annotations from __future__ import annotations
@@ -12,42 +12,39 @@ from fastapi import FastAPI
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from tensors.config import get_sd_server_url
from tensors.server.civitai_routes import create_civitai_router from tensors.server.civitai_routes import create_civitai_router
from tensors.server.db_routes import create_db_router from tensors.server.db_routes import create_db_router
from tensors.server.download_routes import create_download_router from tensors.server.download_routes import create_download_router
from tensors.server.gallery_routes import create_gallery_router from tensors.server.gallery_routes import create_gallery_router
from tensors.server.generate_routes import create_generate_router from tensors.server.generate_routes import create_generate_router
from tensors.server.models import ServerConfig
from tensors.server.models_routes import create_models_router from tensors.server.models_routes import create_models_router
from tensors.server.process import ProcessManager
from tensors.server.routes import create_router from tensors.server.routes import create_router
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
__all__ = ["ProcessManager", "ServerConfig", "app", "create_app"] __all__ = ["app", "create_app"]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def create_app(config: ServerConfig | None = None) -> FastAPI: def create_app(sd_server_url: str | None = None) -> FastAPI:
"""Build the FastAPI application with process manager and proxy client.""" """Build the FastAPI application that proxies to an external sd-server.
pm = ProcessManager()
Args:
sd_server_url: URL of the sd-server to proxy to. If None, uses
get_sd_server_url() to resolve from env/config.
"""
backend_url = sd_server_url or get_sd_server_url()
@asynccontextmanager @asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncIterator[None]: async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
_app.state.sd_server_url = backend_url
logger.info(f"Proxying to sd-server at: {backend_url}")
async with httpx.AsyncClient(timeout=300) as client: async with httpx.AsyncClient(timeout=300) as client:
_app.state.client = client _app.state.client = client
if config is not None:
pm.start(config)
logger.info("waiting for sd-server to become ready...")
ready = await pm.wait_ready()
if ready:
logger.info("sd-server is ready")
else:
logger.warning("sd-server did not become ready in time")
yield yield
pm.stop()
app = FastAPI(title="sd-server wrapper", lifespan=lifespan) app = FastAPI(title="sd-server wrapper", lifespan=lifespan)
@@ -68,11 +65,10 @@ def create_app(config: ServerConfig | None = None) -> FastAPI:
app.include_router(create_civitai_router()) # Must be before catch-all proxy app.include_router(create_civitai_router()) # Must be before catch-all proxy
app.include_router(create_db_router()) # Must be before catch-all proxy app.include_router(create_db_router()) # Must be before catch-all proxy
app.include_router(create_gallery_router()) # Must be before catch-all proxy app.include_router(create_gallery_router()) # Must be before catch-all proxy
app.include_router(create_models_router(pm)) # Must be before catch-all proxy app.include_router(create_models_router()) # Must be before catch-all proxy
app.include_router(create_download_router()) # Must be before catch-all proxy app.include_router(create_download_router()) # Must be before catch-all proxy
app.include_router(create_generate_router(pm)) # Must be before catch-all proxy app.include_router(create_generate_router()) # Must be before catch-all proxy
app.include_router(create_router(pm)) app.include_router(create_router())
app.state.pm = pm
return app return app
+22 -28
View File
@@ -5,18 +5,15 @@ from __future__ import annotations
import base64 import base64
import logging import logging
import time import time
from typing import TYPE_CHECKING, Any from typing import Any
import httpx import httpx
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel as PydanticBaseModel from pydantic import BaseModel as PydanticBaseModel
from pydantic import Field from pydantic import Field
from tensors.server.gallery import Gallery from tensors.server.gallery import Gallery
if TYPE_CHECKING:
from tensors.server.process import ProcessManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -113,38 +110,30 @@ def _process_image(
return image_info return image_info
def _check_server_running(pm: ProcessManager) -> None:
"""Check if sd-server is running, raise HTTPException if not."""
if pm.proc is None or pm.proc.poll() is not None:
raise HTTPException(status_code=503, detail="sd-server is not running")
if pm.config is None:
raise HTTPException(status_code=503, detail="sd-server not configured")
# ============================================================================= # =============================================================================
# Router Factory # Router Factory
# ============================================================================= # =============================================================================
def create_generate_router(pm: ProcessManager) -> APIRouter: def create_generate_router() -> APIRouter:
"""Build a router with /api/generate endpoint.""" """Build a router with /api/generate endpoint."""
router = APIRouter(prefix="/api", tags=["generate"]) router = APIRouter(prefix="/api", tags=["generate"])
gallery = Gallery() gallery = Gallery()
@router.post("/generate") @router.post("/generate")
async def generate(req: GenerateRequest) -> dict[str, Any]: async def generate(request: Request, req: GenerateRequest) -> dict[str, Any]:
"""Generate images with gallery integration.""" """Generate images with gallery integration."""
_check_server_running(pm) sd_server_url = request.app.state.sd_server_url
assert pm.config is not None # Verified by _check_server_running
body = _build_sd_request(req) body = _build_sd_request(req)
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/txt2img" url = f"{sd_server_url}/sdapi/v1/txt2img"
try: try:
async with httpx.AsyncClient(timeout=300) as client: async with httpx.AsyncClient(timeout=300) as client:
response = await client.post(url, json=body) response = await client.post(url, json=body)
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
except httpx.ConnectError as e:
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
except httpx.HTTPError as e: except httpx.HTTPError as e:
logger.exception("Generation failed") logger.exception("Generation failed")
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
@@ -153,8 +142,11 @@ def create_generate_router(pm: ProcessManager) -> APIRouter:
info = _parse_info(result.get("info", {})) info = _parse_info(result.get("info", {}))
all_seeds = info.get("all_seeds", [req.seed] * len(images_data)) all_seeds = info.get("all_seeds", [req.seed] * len(images_data))
# Get model info from sd-server response if available
model_name = info.get("sd_model_name") or info.get("model")
output_images = [ output_images = [
_process_image(img_b64, i, all_seeds[i] if i < len(all_seeds) else req.seed + i, req, gallery, pm.config.model) _process_image(img_b64, i, all_seeds[i] if i < len(all_seeds) else req.seed + i, req, gallery, model_name)
for i, img_b64 in enumerate(images_data) for i, img_b64 in enumerate(images_data)
] ]
@@ -167,32 +159,34 @@ def create_generate_router(pm: ProcessManager) -> APIRouter:
} }
@router.get("/samplers") @router.get("/samplers")
async def list_samplers() -> dict[str, Any]: async def list_samplers(request: Request) -> dict[str, Any]:
"""List available samplers from sd-server.""" """List available samplers from sd-server."""
_check_server_running(pm) sd_server_url = request.app.state.sd_server_url
assert pm.config is not None url = f"{sd_server_url}/sdapi/v1/samplers"
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/samplers"
try: try:
async with httpx.AsyncClient(timeout=30) as client: async with httpx.AsyncClient(timeout=30) as client:
response = await client.get(url) response = await client.get(url)
response.raise_for_status() response.raise_for_status()
return {"samplers": response.json()} return {"samplers": response.json()}
except httpx.ConnectError as e:
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
except httpx.HTTPError as e: except httpx.HTTPError as e:
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
@router.get("/schedulers") @router.get("/schedulers")
async def list_schedulers() -> dict[str, Any]: async def list_schedulers(request: Request) -> dict[str, Any]:
"""List available schedulers from sd-server.""" """List available schedulers from sd-server."""
_check_server_running(pm) sd_server_url = request.app.state.sd_server_url
assert pm.config is not None url = f"{sd_server_url}/sdapi/v1/schedulers"
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/schedulers"
try: try:
async with httpx.AsyncClient(timeout=30) as client: async with httpx.AsyncClient(timeout=30) as client:
response = await client.get(url) response = await client.get(url)
response.raise_for_status() response.raise_for_status()
return {"schedulers": response.json()} return {"schedulers": response.json()}
except httpx.ConnectError as e:
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
except httpx.HTTPError as e: except httpx.HTTPError as e:
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
+2 -13
View File
@@ -2,16 +2,5 @@
from __future__ import annotations from __future__ import annotations
from pydantic import BaseModel # Note: ServerConfig and ReloadRequest were removed since we no longer manage
# sd-server processes internally. The wrapper now proxies to an external sd-server.
DEFAULT_PORT = 1234
class ReloadRequest(BaseModel):
model: str
class ServerConfig(BaseModel):
model: str
port: int = DEFAULT_PORT
args: list[str] = []
+27 -67
View File
@@ -3,30 +3,18 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel as PydanticBaseModel
from tensors.config import MODELS_DIR from tensors.config import MODELS_DIR
if TYPE_CHECKING: if TYPE_CHECKING:
from tensors.server.process import ProcessManager from pathlib import Path
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_HTTP_OK = 200
# =============================================================================
# Request/Response Models
# =============================================================================
class SwitchModelRequest(PydanticBaseModel):
"""Request body for switching models."""
model: str # Path to model file
# ============================================================================= # =============================================================================
@@ -76,7 +64,7 @@ def scan_checkpoints(directory: Path | None = None) -> list[dict[str, Any]]:
# ============================================================================= # =============================================================================
def create_models_router(pm: ProcessManager) -> APIRouter: def create_models_router() -> APIRouter:
"""Build a router with /api/models/* endpoints.""" """Build a router with /api/models/* endpoints."""
router = APIRouter(prefix="/api/models", tags=["models"]) router = APIRouter(prefix="/api/models", tags=["models"])
@@ -90,62 +78,34 @@ def create_models_router(pm: ProcessManager) -> APIRouter:
} }
@router.get("/active") @router.get("/active")
def get_active_model() -> dict[str, Any]: async def get_active_model(request: Request) -> dict[str, Any]:
"""Get information about the currently loaded model.""" """Get information about the currently loaded model from sd-server."""
status = pm.status() import httpx # noqa: PLC0415
config = pm.config
if config is None: sd_server_url = request.app.state.sd_server_url
return {
"loaded": False, # Try to get current model from sd-server's options endpoint
"model": None, try:
"status": status.get("status"), async with httpx.AsyncClient(timeout=10) as client:
} response = await client.get(f"{sd_server_url}/sdapi/v1/options")
if response.status_code == _HTTP_OK:
options = response.json()
model_name = options.get("sd_model_checkpoint")
return {
"loaded": True,
"model": model_name,
"sd_server_url": sd_server_url,
}
except httpx.HTTPError:
pass
return { return {
"loaded": status.get("status") == "running", "loaded": False,
"model": config.model, "model": None,
"pid": status.get("pid"), "sd_server_url": sd_server_url,
"port": config.port, "error": "Cannot connect to sd-server",
"status": status.get("status"),
} }
@router.post("/switch")
async def switch_model(req: SwitchModelRequest) -> JSONResponse:
"""Switch to a different model (hot reload)."""
model_path = Path(req.model)
# Validate model exists
if not model_path.exists():
raise HTTPException(status_code=400, detail=f"Model not found: {req.model}")
# Use existing reload logic
from tensors.server.models import ServerConfig # noqa: PLC0415
new_config = ServerConfig(
model=req.model,
port=pm.config.port if pm.config else 1234,
args=pm.config.args if pm.config else [],
)
pm.stop()
pm.start(new_config)
ready = await pm.wait_ready()
if not ready:
return JSONResponse(
{"error": "sd-server failed to become ready", "model": req.model},
status_code=503,
)
return JSONResponse(
{
"ok": True,
"model": req.model,
"pid": pm.proc.pid if pm.proc else None,
}
)
@router.get("/loras") @router.get("/loras")
def list_loras() -> dict[str, Any]: def list_loras() -> dict[str, Any]:
"""List available LoRA files.""" """List available LoRA files."""
-92
View File
@@ -1,92 +0,0 @@
"""sd-server process lifecycle management."""
from __future__ import annotations
import asyncio
import logging
import os
import shutil
import signal
import subprocess
from typing import TYPE_CHECKING, Any
import httpx
if TYPE_CHECKING:
from tensors.server.models import ServerConfig
logger = logging.getLogger(__name__)
_HTTP_OK = 200
SD_SERVER_BIN = shutil.which("sd-server") or "sd-server"
class ProcessManager:
def __init__(self) -> None:
self.proc: subprocess.Popen[bytes] | None = None
self.config: ServerConfig | None = None
def build_cmd(self) -> list[str]:
if self.config is None:
raise RuntimeError("No config set")
cmd = [SD_SERVER_BIN, "-m", self.config.model, "--listen-port", str(self.config.port)]
cmd.extend(self.config.args)
return cmd
def start(self, config: ServerConfig) -> None:
if self.proc is not None and self.proc.poll() is None:
raise RuntimeError("Server already running — stop it first")
self.config = config
cmd = self.build_cmd()
# Inherit environment (important for HSA_OVERRIDE_GFX_VERSION on ROCm)
self.proc = subprocess.Popen(cmd, env=os.environ.copy())
logger.info("started sd-server pid=%d cmd=%s", self.proc.pid, cmd)
def stop(self) -> bool:
if self.proc is None or self.proc.poll() is not None:
self.proc = None
return False
self.proc.send_signal(signal.SIGTERM)
try:
self.proc.wait(timeout=10)
except subprocess.TimeoutExpired:
self.proc.kill()
self.proc.wait(timeout=5)
logger.info("stopped sd-server")
self.proc = None
return True
def status(self) -> dict[str, Any]:
if self.proc is None:
return {"running": False}
rc = self.proc.poll()
if rc is not None:
return {"running": False, "exit_code": rc}
return {
"running": True,
"pid": self.proc.pid,
"model": self.config.model if self.config else None,
"port": self.config.port if self.config else None,
"bind": f"http://127.0.0.1:{self.config.port}" if self.config else None,
"cmd": self.build_cmd(),
}
async def wait_ready(self, timeout: float = 120) -> bool:
"""Poll sd-server root endpoint until it responds or timeout."""
if self.config is None:
return False
url = f"http://127.0.0.1:{self.config.port}/"
deadline = asyncio.get_event_loop().time() + timeout
async with httpx.AsyncClient() as client:
while asyncio.get_event_loop().time() < deadline:
if self.proc is not None and self.proc.poll() is not None:
return False
try:
r = await client.get(url, timeout=2)
if r.status_code == _HTTP_OK:
return True
except (httpx.ConnectError, httpx.TimeoutException, httpx.ReadError):
pass
await asyncio.sleep(1)
return False
+49 -40
View File
@@ -3,64 +3,73 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Any from typing import Any
import httpx
from fastapi import APIRouter, Request, Response from fastapi import APIRouter, Request, Response
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from tensors.server.models import ReloadRequest, ServerConfig
if TYPE_CHECKING:
from tensors.server.process import ProcessManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def create_router(pm: ProcessManager) -> APIRouter: def create_router() -> APIRouter:
"""Build a router with /status, /reload, and catch-all proxy.""" """Build a router with /status and catch-all proxy."""
router = APIRouter() router = APIRouter()
@router.get("/status") @router.get("/status")
def status() -> dict[str, Any]: async def status(request: Request) -> dict[str, Any]:
return pm.status() """Check if the external sd-server is reachable."""
sd_server_url = request.app.state.sd_server_url
@router.post("/reload") try:
async def reload(req: ReloadRequest) -> Response: async with httpx.AsyncClient(timeout=5) as client:
new_config = ServerConfig( r = await client.get(sd_server_url)
model=req.model, return {
port=pm.config.port if pm.config else 1234, "status": "ok",
args=pm.config.args if pm.config else [], "sd_server_url": sd_server_url,
) "sd_server_status": r.status_code,
pm.stop() }
pm.start(new_config) except httpx.HTTPError as e:
ready = await pm.wait_ready() return {
if not ready: "status": "error",
return JSONResponse({"error": "sd-server failed to become ready", "model": req.model}, status_code=503) "sd_server_url": sd_server_url,
return JSONResponse({"ok": True, "model": req.model, "pid": pm.proc.pid if pm.proc else None}) "error": str(e),
}
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]) @router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"])
async def proxy(request: Request, path: str) -> Response: async def proxy(request: Request, path: str) -> Response:
if pm.proc is None or pm.proc.poll() is not None: """Proxy all requests to the external sd-server."""
return JSONResponse({"error": "sd-server is not running"}, status_code=503) sd_server_url = request.app.state.sd_server_url
assert pm.config is not None url = f"{sd_server_url}/{path}"
url = f"http://127.0.0.1:{pm.config.port}/{path}"
if request.url.query: if request.url.query:
url = f"{url}?{request.url.query}" url = f"{url}?{request.url.query}"
body = await request.body() body = await request.body()
headers = dict(request.headers) headers = dict(request.headers)
headers.pop("host", None) headers.pop("host", None)
client = request.app.state.client client = request.app.state.client
upstream = await client.request(
method=request.method, try:
url=url, upstream = await client.request(
headers=headers, method=request.method,
content=body, url=url,
timeout=300, headers=headers,
) content=body,
return StreamingResponse( timeout=300,
content=upstream.iter_bytes(), )
status_code=upstream.status_code, return StreamingResponse(
headers=dict(upstream.headers), content=upstream.iter_bytes(),
) status_code=upstream.status_code,
headers=dict(upstream.headers),
)
except httpx.ConnectError:
return JSONResponse(
{"error": f"Cannot connect to sd-server at {sd_server_url}"},
status_code=503,
)
except httpx.TimeoutException:
return JSONResponse(
{"error": f"Timeout connecting to sd-server at {sd_server_url}"},
status_code=504,
)
return router return router
+47 -119
View File
@@ -2,110 +2,51 @@
from __future__ import annotations from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock
import httpx import httpx
import pytest import pytest
import respx
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from tensors.server import create_app from tensors.server import create_app
from tensors.server.models import ServerConfig
from tensors.server.process import ProcessManager
@pytest.fixture()
def pm() -> ProcessManager:
return ProcessManager()
@pytest.fixture() @pytest.fixture()
def api() -> TestClient: def api() -> TestClient:
return TestClient(create_app()) """Create test client with mock sd-server URL."""
return TestClient(create_app(sd_server_url="http://mock-sd-server:1234"))
def _get_pm(api: TestClient) -> ProcessManager:
return api.app.state.pm # type: ignore[no-any-return, attr-defined]
class TestStatus: class TestStatus:
def test_not_running(self, api: TestClient) -> None: @respx.mock
r = api.get("/status") def test_status_when_backend_reachable(self) -> None:
assert r.status_code == 200 """Test status endpoint when sd-server is reachable."""
assert r.json()["running"] is False respx.get("http://mock-sd-server:1234/").mock(return_value=httpx.Response(200))
def test_running(self, api: TestClient) -> None: with TestClient(create_app(sd_server_url="http://mock-sd-server:1234")) as client:
pm = _get_pm(api) r = client.get("/status")
mock_proc = MagicMock() assert r.status_code == 200
mock_proc.poll.return_value = None data = r.json()
mock_proc.pid = 999 assert data["status"] == "ok"
pm.proc = mock_proc assert data["sd_server_url"] == "http://mock-sd-server:1234"
pm.config = ServerConfig(model="/m.safetensors")
r = api.get("/status")
data = r.json()
assert data["running"] is True
assert data["pid"] == 999
def test_exited(self, api: TestClient) -> None: @respx.mock
pm = _get_pm(api) def test_status_when_backend_unreachable(self) -> None:
mock_proc = MagicMock() """Test status endpoint when sd-server is not reachable."""
mock_proc.poll.return_value = 1 respx.get("http://mock-sd-server:1234/").mock(side_effect=httpx.ConnectError("Connection refused"))
pm.proc = mock_proc
r = api.get("/status")
data = r.json()
assert data["running"] is False
assert data["exit_code"] == 1
with TestClient(create_app(sd_server_url="http://mock-sd-server:1234")) as client:
class TestReload: r = client.get("/status")
@patch.object(ProcessManager, "wait_ready", new_callable=AsyncMock, return_value=True) assert r.status_code == 200
@patch("tensors.server.process.subprocess.Popen") data = r.json()
def test_reload_swaps_model(self, mock_popen: MagicMock, mock_ready: AsyncMock, api: TestClient) -> None: assert data["status"] == "error"
pm = _get_pm(api) assert "Connection refused" in data["error"]
pm.config = ServerConfig(model="/old.gguf", port=5555, args=["--fa"])
mock_popen.return_value.pid = 42
mock_popen.return_value.poll.return_value = None
r = api.post("/reload", json={"model": "/new.gguf"})
assert r.status_code == 200
data = r.json()
assert data["ok"] is True
assert data["model"] == "/new.gguf"
assert data["pid"] == 42
# Verify new config preserved port and args from previous config
assert pm.config is not None
assert pm.config.port == 5555
assert pm.config.args == ["--fa"]
assert pm.config.model == "/new.gguf"
@patch.object(ProcessManager, "wait_ready", new_callable=AsyncMock, return_value=False)
@patch("tensors.server.process.subprocess.Popen")
def test_reload_fails_when_not_ready(self, mock_popen: MagicMock, mock_ready: AsyncMock, api: TestClient) -> None:
pm = _get_pm(api)
pm.config = ServerConfig(model="/old.gguf")
mock_popen.return_value.pid = 43
mock_popen.return_value.poll.return_value = None
r = api.post("/reload", json={"model": "/bad.gguf"})
assert r.status_code == 503
assert "failed" in r.json()["error"]
def test_reload_requires_model(self, api: TestClient) -> None:
r = api.post("/reload", json={})
assert r.status_code == 422
class TestProxy: class TestProxy:
def test_proxy_503_when_not_running(self, api: TestClient) -> None:
r = api.get("/v1/models")
assert r.status_code == 503
assert "not running" in r.json()["error"]
def test_proxy_forwards_request(self, api: TestClient) -> None: def test_proxy_forwards_request(self, api: TestClient) -> None:
pm = _get_pm(api) """Test proxy forwards GET requests to backend."""
mock_proc = MagicMock()
mock_proc.poll.return_value = None
mock_proc.pid = 100
pm.proc = mock_proc
pm.config = ServerConfig(model="/m.gguf", port=1234)
upstream_response = httpx.Response( upstream_response = httpx.Response(
200, 200,
json={"data": [{"id": "model-1"}]}, json={"data": [{"id": "model-1"}]},
@@ -114,6 +55,7 @@ class TestProxy:
mock_client = AsyncMock() mock_client = AsyncMock()
mock_client.request.return_value = upstream_response mock_client.request.return_value = upstream_response
api.app.state.client = mock_client # type: ignore[attr-defined] api.app.state.client = mock_client # type: ignore[attr-defined]
api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined]
r = api.get("/v1/models") r = api.get("/v1/models")
assert r.status_code == 200 assert r.status_code == 200
@@ -121,52 +63,38 @@ class TestProxy:
mock_client.request.assert_called_once() mock_client.request.assert_called_once()
def test_proxy_forwards_post_with_body(self, api: TestClient) -> None: def test_proxy_forwards_post_with_body(self, api: TestClient) -> None:
pm = _get_pm(api) """Test proxy forwards POST requests with body."""
mock_proc = MagicMock()
mock_proc.poll.return_value = None
mock_proc.pid = 100
pm.proc = mock_proc
pm.config = ServerConfig(model="/m.gguf", port=1234)
upstream_response = httpx.Response(200, json={"ok": True}) upstream_response = httpx.Response(200, json={"ok": True})
mock_client = AsyncMock() mock_client = AsyncMock()
mock_client.request.return_value = upstream_response mock_client.request.return_value = upstream_response
api.app.state.client = mock_client # type: ignore[attr-defined] api.app.state.client = mock_client # type: ignore[attr-defined]
api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined]
r = api.post("/v1/chat/completions", json={"prompt": "hello"}) r = api.post("/sdapi/v1/txt2img", json={"prompt": "hello"})
assert r.status_code == 200 assert r.status_code == 200
mock_client.request.assert_called_once() mock_client.request.assert_called_once()
def test_proxy_503_on_connect_error(self, api: TestClient) -> None:
"""Test proxy returns 503 when backend is unreachable."""
mock_client = AsyncMock()
mock_client.request.side_effect = httpx.ConnectError("Connection refused")
api.app.state.client = mock_client # type: ignore[attr-defined]
api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined]
class TestProcessManager: r = api.get("/v1/models")
def test_status_not_running(self, pm: ProcessManager) -> None: assert r.status_code == 503
assert pm.status() == {"running": False} assert "Cannot connect" in r.json()["error"]
def test_build_cmd(self, pm: ProcessManager) -> None: def test_proxy_504_on_timeout(self, api: TestClient) -> None:
pm.config = ServerConfig(model="/m.gguf", port=1234, args=["--fa"]) """Test proxy returns 504 on timeout."""
cmd = pm.build_cmd() mock_client = AsyncMock()
assert "/m.gguf" in cmd mock_client.request.side_effect = httpx.TimeoutException("Timeout")
assert "--fa" in cmd api.app.state.client = mock_client # type: ignore[attr-defined]
assert "1234" in cmd api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined]
def test_build_cmd_no_config(self, pm: ProcessManager) -> None: r = api.get("/v1/models")
with pytest.raises(RuntimeError, match="No config"): assert r.status_code == 504
pm.build_cmd() assert "Timeout" in r.json()["error"]
@patch("tensors.server.process.subprocess.Popen")
def test_start_and_stop(self, mock_popen: MagicMock, pm: ProcessManager) -> None:
mock_popen.return_value.pid = 77
mock_popen.return_value.poll.return_value = None
mock_popen.return_value.wait.return_value = 0
pm.start(ServerConfig(model="/m.gguf"))
assert pm.proc is not None
assert pm.stop() is True
assert pm.proc is None
def test_server_config_defaults(self) -> None:
cfg = ServerConfig(model="/m.gguf")
assert cfg.port == 1234
assert cfg.args == []
# ============================================================================= # =============================================================================
Generated
+1 -1
View File
@@ -707,7 +707,7 @@ wheels = [
[[package]] [[package]]
name = "tensors" name = "tensors"
version = "0.1.9" version = "0.1.12"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "httpx" }, { name = "httpx" },