Turn server wrapper into transparent proxy with /reload

Replace management endpoints (/start, /stop, /restart) with a transparent
reverse proxy and hot-reload architecture. The wrapper now sits in front of
sd-server, forwarding all requests and adding a /reload endpoint for model
swapping without restarting the wrapper itself.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Adam Ladachowski
2026-02-12 20:51:16 +00:00
parent b33fe120fa
commit 61e68fee15
6 changed files with 203 additions and 124 deletions
+6 -3
View File
@@ -439,21 +439,24 @@ def generate(
@app.command() @app.command()
def serve( def serve(
model: Annotated[str, typer.Option(help="Path to model file for sd-server.")],
host: Annotated[str, typer.Option(help="Wrapper API listen address.")] = "127.0.0.1", host: Annotated[str, typer.Option(help="Wrapper API listen address.")] = "127.0.0.1",
port: Annotated[int, typer.Option(help="Wrapper API listen port.")] = 8080, port: Annotated[int, typer.Option(help="Wrapper API listen port.")] = 8080,
sd_port: Annotated[int, typer.Option(help="sd-server listen port.")] = 1234,
log_level: Annotated[str, typer.Option(help="Log level.")] = "info", log_level: Annotated[str, typer.Option(help="Log level.")] = "info",
) -> None: ) -> None:
"""Start the sd-server wrapper API.""" """Start the sd-server wrapper API (transparent proxy with hot reload)."""
try: try:
import uvicorn # noqa: PLC0415 import uvicorn # noqa: PLC0415
from tensors.server import create_app # noqa: PLC0415 from tensors.server import ServerConfig, create_app # noqa: PLC0415
except ImportError: except ImportError:
console.print("[red]Missing server dependencies. Install with:[/red]") console.print("[red]Missing server dependencies. Install with:[/red]")
console.print(" pip install tensors[server]") console.print(" pip install tensors[server]")
raise typer.Exit(1) from None raise typer.Exit(1) from None
uvicorn.run(create_app(), host=host, port=port, log_level=log_level) config = ServerConfig(model=model, port=sd_port)
uvicorn.run(create_app(config), host=host, port=port, log_level=log_level)
def main() -> int: def main() -> int:
+19 -4
View File
@@ -1,27 +1,42 @@
"""sd-server wrapper — FastAPI app for managing sd-server process.""" """sd-server wrapper — FastAPI app for managing and proxying to sd-server."""
from __future__ import annotations from __future__ import annotations
import logging
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import httpx
from fastapi import FastAPI from fastapi import FastAPI
from tensors.server.models import ServerConfig
from tensors.server.process import ProcessManager from tensors.server.process import ProcessManager
from tensors.server.routes import create_router from tensors.server.routes import create_router
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
__all__ = ["ProcessManager", "create_app"] __all__ = ["ProcessManager", "ServerConfig", "create_app"]
logger = logging.getLogger(__name__)
def create_app() -> FastAPI: def create_app(config: ServerConfig | None = None) -> FastAPI:
"""Build the FastAPI application with process manager.""" """Build the FastAPI application with process manager and proxy client."""
pm = ProcessManager() pm = ProcessManager()
@asynccontextmanager @asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncIterator[None]: async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
async with httpx.AsyncClient(timeout=300) as client:
_app.state.client = client
if config is not None:
pm.start(config)
logger.info("waiting for sd-server to become ready...")
ready = await pm.wait_ready()
if ready:
logger.info("sd-server is ready")
else:
logger.warning("sd-server did not become ready in time")
yield yield
pm.stop() pm.stop()
+6 -8
View File
@@ -1,4 +1,4 @@
"""Pydantic request models for the wrapper API.""" """Pydantic models for the sd-server wrapper API."""
from __future__ import annotations from __future__ import annotations
@@ -7,13 +7,11 @@ from pydantic import BaseModel
DEFAULT_PORT = 1234 DEFAULT_PORT = 1234
class StartRequest(BaseModel): class ReloadRequest(BaseModel):
model: str
class ServerConfig(BaseModel):
model: str model: str
port: int = DEFAULT_PORT port: int = DEFAULT_PORT
args: list[str] = [] args: list[str] = []
class RestartRequest(BaseModel):
model: str | None = None
port: int | None = None
args: list[str] | None = None
+37 -9
View File
@@ -2,33 +2,42 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
import shutil import shutil
import signal import signal
import subprocess import subprocess
from typing import Any from typing import TYPE_CHECKING, Any
import httpx
if TYPE_CHECKING:
from tensors.server.models import ServerConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_HTTP_OK = 200
SD_SERVER_BIN = shutil.which("sd-server") or "sd-server" SD_SERVER_BIN = shutil.which("sd-server") or "sd-server"
class ProcessManager: class ProcessManager:
def __init__(self) -> None: def __init__(self) -> None:
self.proc: subprocess.Popen[bytes] | None = None self.proc: subprocess.Popen[bytes] | None = None
self.config: dict[str, Any] = {} self.config: ServerConfig | None = None
def build_cmd(self, config: dict[str, Any] | None = None) -> list[str]: def build_cmd(self) -> list[str]:
cfg = config or self.config if self.config is None:
cmd = [SD_SERVER_BIN, "-m", cfg["model"], "--listen-port", str(cfg["port"])] raise RuntimeError("No config set")
cmd.extend(cfg.get("args", [])) cmd = [SD_SERVER_BIN, "-m", self.config.model, "--port", str(self.config.port)]
cmd.extend(self.config.args)
return cmd return cmd
def start(self, config: dict[str, Any]) -> None: def start(self, config: ServerConfig) -> None:
if self.proc is not None and self.proc.poll() is None: if self.proc is not None and self.proc.poll() is None:
raise RuntimeError("Server already running — stop it first") raise RuntimeError("Server already running — stop it first")
self.config = config self.config = config
cmd = self.build_cmd(config) cmd = self.build_cmd()
self.proc = subprocess.Popen(cmd) self.proc = subprocess.Popen(cmd)
logger.info("started sd-server pid=%d cmd=%s", self.proc.pid, cmd) logger.info("started sd-server pid=%d cmd=%s", self.proc.pid, cmd)
@@ -55,6 +64,25 @@ class ProcessManager:
return { return {
"running": True, "running": True,
"pid": self.proc.pid, "pid": self.proc.pid,
"config": self.config, "model": self.config.model if self.config else None,
"cmd": self.build_cmd(), "cmd": self.build_cmd(),
} }
async def wait_ready(self, timeout: float = 120) -> bool:
"""Poll sd-server /health until it responds or timeout."""
if self.config is None:
return False
url = f"http://127.0.0.1:{self.config.port}/health"
deadline = asyncio.get_event_loop().time() + timeout
async with httpx.AsyncClient() as client:
while asyncio.get_event_loop().time() < deadline:
if self.proc is not None and self.proc.poll() is not None:
return False
try:
r = await client.get(url, timeout=2)
if r.status_code == _HTTP_OK:
return True
except httpx.ConnectError:
pass
await asyncio.sleep(1)
return False
+45 -38
View File
@@ -1,59 +1,66 @@
"""FastAPI route handlers for the wrapper API.""" """FastAPI route handlers for the sd-server wrapper API."""
from __future__ import annotations from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, Request, Response
from fastapi.responses import JSONResponse, StreamingResponse
from tensors.server.models import RestartRequest, StartRequest # noqa: TC001 from tensors.server.models import ReloadRequest, ServerConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from tensors.server.process import ProcessManager from tensors.server.process import ProcessManager
logger = logging.getLogger(__name__)
def create_router(pm: ProcessManager) -> APIRouter: def create_router(pm: ProcessManager) -> APIRouter:
"""Build a new router bound to the given ProcessManager.""" """Build a router with /status, /reload, and catch-all proxy."""
router = APIRouter() router = APIRouter()
@router.get("/status") @router.get("/status")
def status() -> dict[str, Any]: def status() -> dict[str, Any]:
return pm.status() return pm.status()
@router.post("/start") @router.post("/reload")
def start(req: StartRequest) -> dict[str, Any]: async def reload(req: ReloadRequest) -> Response:
if pm.proc is not None and pm.proc.poll() is None: new_config = ServerConfig(
raise HTTPException(409, "Server already running — use /restart or /stop first") model=req.model,
config = {"model": req.model, "port": req.port, "args": req.args} port=pm.config.port if pm.config else 1234,
pm.start(config) args=pm.config.args if pm.config else [],
assert pm.proc is not None )
return {"started": True, "pid": pm.proc.pid, "cmd": pm.build_cmd(config)} pm.stop()
pm.start(new_config)
ready = await pm.wait_ready()
if not ready:
return JSONResponse({"error": "sd-server failed to become ready", "model": req.model}, status_code=503)
return JSONResponse({"ok": True, "model": req.model, "pid": pm.proc.pid if pm.proc else None})
@router.post("/stop") @router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"])
def stop() -> dict[str, Any]: async def proxy(request: Request, path: str) -> Response:
if not pm.stop(): if pm.proc is None or pm.proc.poll() is not None:
raise HTTPException(409, "Server is not running") return JSONResponse({"error": "sd-server is not running"}, status_code=503)
return {"stopped": True} assert pm.config is not None
url = f"http://127.0.0.1:{pm.config.port}/{path}"
@router.post("/restart") if request.url.query:
def restart(req: RestartRequest) -> dict[str, Any]: url = f"{url}?{request.url.query}"
if not pm.config and req.model is None: body = await request.body()
raise HTTPException(400, "No previous config — provide at least 'model'") headers = dict(request.headers)
config = dict(pm.config) headers.pop("host", None)
if req.model is not None: client = request.app.state.client
config["model"] = req.model upstream = await client.request(
if req.port is not None: method=request.method,
config["port"] = req.port url=url,
if req.args is not None: headers=headers,
config["args"] = req.args content=body,
was_running = pm.stop() timeout=300,
pm.start(config) )
assert pm.proc is not None return StreamingResponse(
return { content=upstream.iter_bytes(),
"restarted": True, status_code=upstream.status_code,
"was_running": was_running, headers=dict(upstream.headers),
"pid": pm.proc.pid, )
"cmd": pm.build_cmd(config),
}
return router return router
+89 -61
View File
@@ -1,13 +1,15 @@
"""Tests for tensors.server package (FastAPI sd-server manager).""" """Tests for tensors.server package (FastAPI sd-server proxy wrapper)."""
from __future__ import annotations from __future__ import annotations
from unittest.mock import MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from tensors.server import create_app from tensors.server import create_app
from tensors.server.models import ServerConfig
from tensors.server.process import ProcessManager from tensors.server.process import ProcessManager
@@ -22,7 +24,7 @@ def api() -> TestClient:
def _get_pm(api: TestClient) -> ProcessManager: def _get_pm(api: TestClient) -> ProcessManager:
return api.app.state.pm # type: ignore[union-attr] return api.app.state.pm # type: ignore[no-any-return, attr-defined]
class TestStatus: class TestStatus:
@@ -37,7 +39,7 @@ class TestStatus:
mock_proc.poll.return_value = None mock_proc.poll.return_value = None
mock_proc.pid = 999 mock_proc.pid = 999
pm.proc = mock_proc pm.proc = mock_proc
pm.config = {"model": "/m.safetensors", "port": 1234, "args": []} pm.config = ServerConfig(model="/m.safetensors")
r = api.get("/status") r = api.get("/status")
data = r.json() data = r.json()
assert data["running"] is True assert data["running"] is True
@@ -54,69 +56,86 @@ class TestStatus:
assert data["exit_code"] == 1 assert data["exit_code"] == 1
class TestStart: class TestReload:
@patch.object(ProcessManager, "wait_ready", new_callable=AsyncMock, return_value=True)
@patch("tensors.server.process.subprocess.Popen") @patch("tensors.server.process.subprocess.Popen")
def test_start_success(self, mock_popen: MagicMock, api: TestClient) -> None: def test_reload_swaps_model(self, mock_popen: MagicMock, mock_ready: AsyncMock, api: TestClient) -> None:
pm = _get_pm(api)
pm.config = ServerConfig(model="/old.gguf", port=5555, args=["--fa"])
mock_popen.return_value.pid = 42 mock_popen.return_value.pid = 42
mock_popen.return_value.poll.return_value = None mock_popen.return_value.poll.return_value = None
r = api.post("/start", json={"model": "/m.safetensors"}) r = api.post("/reload", json={"model": "/new.gguf"})
assert r.status_code == 200
assert r.json()["started"] is True
assert r.json()["pid"] == 42
@patch("tensors.server.process.subprocess.Popen")
def test_start_already_running(self, mock_popen: MagicMock, api: TestClient) -> None:
pm = _get_pm(api)
mock_proc = MagicMock()
mock_proc.poll.return_value = None
pm.proc = mock_proc
r = api.post("/start", json={"model": "/m.safetensors"})
assert r.status_code == 409
class TestStop:
def test_stop_not_running(self, api: TestClient) -> None:
r = api.post("/stop")
assert r.status_code == 409
def test_stop_running(self, api: TestClient) -> None:
pm = _get_pm(api)
mock_proc = MagicMock()
mock_proc.poll.return_value = None
mock_proc.wait.return_value = 0
pm.proc = mock_proc
r = api.post("/stop")
assert r.status_code == 200
assert r.json()["stopped"] is True
mock_proc.send_signal.assert_called_once()
class TestRestart:
def test_restart_no_config_no_model(self, api: TestClient) -> None:
r = api.post("/restart", json={})
assert r.status_code == 400
@patch("tensors.server.process.subprocess.Popen")
def test_restart_with_new_model(self, mock_popen: MagicMock, api: TestClient) -> None:
mock_popen.return_value.pid = 100
mock_popen.return_value.poll.return_value = None
pm = _get_pm(api)
pm.config = {"model": "/old.safetensors", "port": 1234, "args": []}
r = api.post("/restart", json={"model": "/new.safetensors"})
assert r.status_code == 200 assert r.status_code == 200
data = r.json() data = r.json()
assert data["restarted"] is True assert data["ok"] is True
assert "/new.safetensors" in str(data["cmd"]) assert data["model"] == "/new.gguf"
assert data["pid"] == 42
# Verify new config preserved port and args from previous config
assert pm.config is not None
assert pm.config.port == 5555
assert pm.config.args == ["--fa"]
assert pm.config.model == "/new.gguf"
@patch.object(ProcessManager, "wait_ready", new_callable=AsyncMock, return_value=False)
@patch("tensors.server.process.subprocess.Popen") @patch("tensors.server.process.subprocess.Popen")
def test_restart_keeps_previous_config(self, mock_popen: MagicMock, api: TestClient) -> None: def test_reload_fails_when_not_ready(self, mock_popen: MagicMock, mock_ready: AsyncMock, api: TestClient) -> None:
mock_popen.return_value.pid = 101
mock_popen.return_value.poll.return_value = None
pm = _get_pm(api) pm = _get_pm(api)
pm.config = {"model": "/m.safetensors", "port": 5555, "args": ["--fa"]} pm.config = ServerConfig(model="/old.gguf")
r = api.post("/restart", json={}) mock_popen.return_value.pid = 43
mock_popen.return_value.poll.return_value = None
r = api.post("/reload", json={"model": "/bad.gguf"})
assert r.status_code == 503
assert "failed" in r.json()["error"]
def test_reload_requires_model(self, api: TestClient) -> None:
r = api.post("/reload", json={})
assert r.status_code == 422
class TestProxy:
def test_proxy_503_when_not_running(self, api: TestClient) -> None:
r = api.get("/v1/models")
assert r.status_code == 503
assert "not running" in r.json()["error"]
def test_proxy_forwards_request(self, api: TestClient) -> None:
pm = _get_pm(api)
mock_proc = MagicMock()
mock_proc.poll.return_value = None
mock_proc.pid = 100
pm.proc = mock_proc
pm.config = ServerConfig(model="/m.gguf", port=1234)
upstream_response = httpx.Response(
200,
json={"data": [{"id": "model-1"}]},
headers={"content-type": "application/json"},
)
mock_client = AsyncMock()
mock_client.request.return_value = upstream_response
api.app.state.client = mock_client # type: ignore[attr-defined]
r = api.get("/v1/models")
assert r.status_code == 200 assert r.status_code == 200
assert "5555" in str(r.json()["cmd"]) assert r.json() == {"data": [{"id": "model-1"}]}
mock_client.request.assert_called_once()
def test_proxy_forwards_post_with_body(self, api: TestClient) -> None:
pm = _get_pm(api)
mock_proc = MagicMock()
mock_proc.poll.return_value = None
mock_proc.pid = 100
pm.proc = mock_proc
pm.config = ServerConfig(model="/m.gguf", port=1234)
upstream_response = httpx.Response(200, json={"ok": True})
mock_client = AsyncMock()
mock_client.request.return_value = upstream_response
api.app.state.client = mock_client # type: ignore[attr-defined]
r = api.post("/v1/chat/completions", json={"prompt": "hello"})
assert r.status_code == 200
mock_client.request.assert_called_once()
class TestProcessManager: class TestProcessManager:
@@ -124,18 +143,27 @@ class TestProcessManager:
assert pm.status() == {"running": False} assert pm.status() == {"running": False}
def test_build_cmd(self, pm: ProcessManager) -> None: def test_build_cmd(self, pm: ProcessManager) -> None:
config = {"model": "/m.gguf", "port": 1234, "args": ["--fa"]} pm.config = ServerConfig(model="/m.gguf", port=1234, args=["--fa"])
cmd = pm.build_cmd(config) cmd = pm.build_cmd()
assert "/m.gguf" in cmd assert "/m.gguf" in cmd
assert "--fa" in cmd assert "--fa" in cmd
assert "1234" in cmd assert "1234" in cmd
def test_build_cmd_no_config(self, pm: ProcessManager) -> None:
with pytest.raises(RuntimeError, match="No config"):
pm.build_cmd()
@patch("tensors.server.process.subprocess.Popen") @patch("tensors.server.process.subprocess.Popen")
def test_start_and_stop(self, mock_popen: MagicMock, pm: ProcessManager) -> None: def test_start_and_stop(self, mock_popen: MagicMock, pm: ProcessManager) -> None:
mock_popen.return_value.pid = 77 mock_popen.return_value.pid = 77
mock_popen.return_value.poll.return_value = None mock_popen.return_value.poll.return_value = None
mock_popen.return_value.wait.return_value = 0 mock_popen.return_value.wait.return_value = 0
pm.start({"model": "/m.gguf", "port": 1234, "args": []}) pm.start(ServerConfig(model="/m.gguf"))
assert pm.proc is not None assert pm.proc is not None
assert pm.stop() is True assert pm.stop() is True
assert pm.proc is None assert pm.proc is None
def test_server_config_defaults(self) -> None:
cfg = ServerConfig(model="/m.gguf")
assert cfg.port == 1234
assert cfg.args == []