From 61e68fee15605274d47e03e19fba71db0237974e Mon Sep 17 00:00:00 2001 From: Adam Ladachowski Date: Thu, 12 Feb 2026 20:51:16 +0000 Subject: [PATCH] Turn server wrapper into transparent proxy with /reload Replace management endpoints (/start, /stop, /restart) with a transparent reverse proxy and hot-reload architecture. The wrapper now sits in front of sd-server, forwarding all requests and adding a /reload endpoint for model swapping without restarting the wrapper itself. Co-Authored-By: Claude Opus 4.6 --- tensors/cli.py | 9 ++- tensors/server/__init__.py | 25 +++++-- tensors/server/models.py | 14 ++-- tensors/server/process.py | 46 +++++++++--- tensors/server/routes.py | 83 ++++++++++---------- tests/test_server.py | 150 ++++++++++++++++++++++--------------- 6 files changed, 203 insertions(+), 124 deletions(-) diff --git a/tensors/cli.py b/tensors/cli.py index 30e1a68..52fb815 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -439,21 +439,24 @@ def generate( @app.command() def serve( + model: Annotated[str, typer.Option(help="Path to model file for sd-server.")], host: Annotated[str, typer.Option(help="Wrapper API listen address.")] = "127.0.0.1", port: Annotated[int, typer.Option(help="Wrapper API listen port.")] = 8080, + sd_port: Annotated[int, typer.Option(help="sd-server listen port.")] = 1234, log_level: Annotated[str, typer.Option(help="Log level.")] = "info", ) -> None: - """Start the sd-server wrapper API.""" + """Start the sd-server wrapper API (transparent proxy with hot reload).""" try: import uvicorn # noqa: PLC0415 - from tensors.server import create_app # noqa: PLC0415 + from tensors.server import ServerConfig, create_app # noqa: PLC0415 except ImportError: console.print("[red]Missing server dependencies. Install with:[/red]") console.print(" pip install tensors[server]") raise typer.Exit(1) from None - uvicorn.run(create_app(), host=host, port=port, log_level=log_level) + config = ServerConfig(model=model, port=sd_port) + uvicorn.run(create_app(config), host=host, port=port, log_level=log_level) def main() -> int: diff --git a/tensors/server/__init__.py b/tensors/server/__init__.py index d539965..a197b1e 100644 --- a/tensors/server/__init__.py +++ b/tensors/server/__init__.py @@ -1,28 +1,43 @@ -"""sd-server wrapper — FastAPI app for managing sd-server process.""" +"""sd-server wrapper — FastAPI app for managing and proxying to sd-server.""" from __future__ import annotations +import logging from contextlib import asynccontextmanager from typing import TYPE_CHECKING +import httpx from fastapi import FastAPI +from tensors.server.models import ServerConfig from tensors.server.process import ProcessManager from tensors.server.routes import create_router if TYPE_CHECKING: from collections.abc import AsyncIterator -__all__ = ["ProcessManager", "create_app"] +__all__ = ["ProcessManager", "ServerConfig", "create_app"] + +logger = logging.getLogger(__name__) -def create_app() -> FastAPI: - """Build the FastAPI application with process manager.""" +def create_app(config: ServerConfig | None = None) -> FastAPI: + """Build the FastAPI application with process manager and proxy client.""" pm = ProcessManager() @asynccontextmanager async def lifespan(_app: FastAPI) -> AsyncIterator[None]: - yield + async with httpx.AsyncClient(timeout=300) as client: + _app.state.client = client + if config is not None: + pm.start(config) + logger.info("waiting for sd-server to become ready...") + ready = await pm.wait_ready() + if ready: + logger.info("sd-server is ready") + else: + logger.warning("sd-server did not become ready in time") + yield pm.stop() app = FastAPI(title="sd-server wrapper", lifespan=lifespan) diff --git a/tensors/server/models.py b/tensors/server/models.py index 330bae1..b06f0b1 100644 --- a/tensors/server/models.py +++ b/tensors/server/models.py @@ -1,4 +1,4 @@ -"""Pydantic request models for the wrapper API.""" +"""Pydantic models for the sd-server wrapper API.""" from __future__ import annotations @@ -7,13 +7,11 @@ from pydantic import BaseModel DEFAULT_PORT = 1234 -class StartRequest(BaseModel): +class ReloadRequest(BaseModel): + model: str + + +class ServerConfig(BaseModel): model: str port: int = DEFAULT_PORT args: list[str] = [] - - -class RestartRequest(BaseModel): - model: str | None = None - port: int | None = None - args: list[str] | None = None diff --git a/tensors/server/process.py b/tensors/server/process.py index 1898616..7d0cb51 100644 --- a/tensors/server/process.py +++ b/tensors/server/process.py @@ -2,33 +2,42 @@ from __future__ import annotations +import asyncio import logging import shutil import signal import subprocess -from typing import Any +from typing import TYPE_CHECKING, Any + +import httpx + +if TYPE_CHECKING: + from tensors.server.models import ServerConfig logger = logging.getLogger(__name__) +_HTTP_OK = 200 + SD_SERVER_BIN = shutil.which("sd-server") or "sd-server" class ProcessManager: def __init__(self) -> None: self.proc: subprocess.Popen[bytes] | None = None - self.config: dict[str, Any] = {} + self.config: ServerConfig | None = None - def build_cmd(self, config: dict[str, Any] | None = None) -> list[str]: - cfg = config or self.config - cmd = [SD_SERVER_BIN, "-m", cfg["model"], "--listen-port", str(cfg["port"])] - cmd.extend(cfg.get("args", [])) + def build_cmd(self) -> list[str]: + if self.config is None: + raise RuntimeError("No config set") + cmd = [SD_SERVER_BIN, "-m", self.config.model, "--port", str(self.config.port)] + cmd.extend(self.config.args) return cmd - def start(self, config: dict[str, Any]) -> None: + def start(self, config: ServerConfig) -> None: if self.proc is not None and self.proc.poll() is None: raise RuntimeError("Server already running — stop it first") self.config = config - cmd = self.build_cmd(config) + cmd = self.build_cmd() self.proc = subprocess.Popen(cmd) logger.info("started sd-server pid=%d cmd=%s", self.proc.pid, cmd) @@ -55,6 +64,25 @@ class ProcessManager: return { "running": True, "pid": self.proc.pid, - "config": self.config, + "model": self.config.model if self.config else None, "cmd": self.build_cmd(), } + + async def wait_ready(self, timeout: float = 120) -> bool: + """Poll sd-server /health until it responds or timeout.""" + if self.config is None: + return False + url = f"http://127.0.0.1:{self.config.port}/health" + deadline = asyncio.get_event_loop().time() + timeout + async with httpx.AsyncClient() as client: + while asyncio.get_event_loop().time() < deadline: + if self.proc is not None and self.proc.poll() is not None: + return False + try: + r = await client.get(url, timeout=2) + if r.status_code == _HTTP_OK: + return True + except httpx.ConnectError: + pass + await asyncio.sleep(1) + return False diff --git a/tensors/server/routes.py b/tensors/server/routes.py index 6074929..214ab62 100644 --- a/tensors/server/routes.py +++ b/tensors/server/routes.py @@ -1,59 +1,66 @@ -"""FastAPI route handlers for the wrapper API.""" +"""FastAPI route handlers for the sd-server wrapper API.""" from __future__ import annotations +import logging from typing import TYPE_CHECKING, Any -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, Request, Response +from fastapi.responses import JSONResponse, StreamingResponse -from tensors.server.models import RestartRequest, StartRequest # noqa: TC001 +from tensors.server.models import ReloadRequest, ServerConfig if TYPE_CHECKING: from tensors.server.process import ProcessManager +logger = logging.getLogger(__name__) + def create_router(pm: ProcessManager) -> APIRouter: - """Build a new router bound to the given ProcessManager.""" + """Build a router with /status, /reload, and catch-all proxy.""" router = APIRouter() @router.get("/status") def status() -> dict[str, Any]: return pm.status() - @router.post("/start") - def start(req: StartRequest) -> dict[str, Any]: - if pm.proc is not None and pm.proc.poll() is None: - raise HTTPException(409, "Server already running — use /restart or /stop first") - config = {"model": req.model, "port": req.port, "args": req.args} - pm.start(config) - assert pm.proc is not None - return {"started": True, "pid": pm.proc.pid, "cmd": pm.build_cmd(config)} + @router.post("/reload") + async def reload(req: ReloadRequest) -> Response: + new_config = ServerConfig( + model=req.model, + port=pm.config.port if pm.config else 1234, + args=pm.config.args if pm.config else [], + ) + pm.stop() + pm.start(new_config) + ready = await pm.wait_ready() + if not ready: + return JSONResponse({"error": "sd-server failed to become ready", "model": req.model}, status_code=503) + return JSONResponse({"ok": True, "model": req.model, "pid": pm.proc.pid if pm.proc else None}) - @router.post("/stop") - def stop() -> dict[str, Any]: - if not pm.stop(): - raise HTTPException(409, "Server is not running") - return {"stopped": True} - - @router.post("/restart") - def restart(req: RestartRequest) -> dict[str, Any]: - if not pm.config and req.model is None: - raise HTTPException(400, "No previous config — provide at least 'model'") - config = dict(pm.config) - if req.model is not None: - config["model"] = req.model - if req.port is not None: - config["port"] = req.port - if req.args is not None: - config["args"] = req.args - was_running = pm.stop() - pm.start(config) - assert pm.proc is not None - return { - "restarted": True, - "was_running": was_running, - "pid": pm.proc.pid, - "cmd": pm.build_cmd(config), - } + @router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]) + async def proxy(request: Request, path: str) -> Response: + if pm.proc is None or pm.proc.poll() is not None: + return JSONResponse({"error": "sd-server is not running"}, status_code=503) + assert pm.config is not None + url = f"http://127.0.0.1:{pm.config.port}/{path}" + if request.url.query: + url = f"{url}?{request.url.query}" + body = await request.body() + headers = dict(request.headers) + headers.pop("host", None) + client = request.app.state.client + upstream = await client.request( + method=request.method, + url=url, + headers=headers, + content=body, + timeout=300, + ) + return StreamingResponse( + content=upstream.iter_bytes(), + status_code=upstream.status_code, + headers=dict(upstream.headers), + ) return router diff --git a/tests/test_server.py b/tests/test_server.py index 090c649..f9f8db4 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,13 +1,15 @@ -"""Tests for tensors.server package (FastAPI sd-server manager).""" +"""Tests for tensors.server package (FastAPI sd-server proxy wrapper).""" from __future__ import annotations -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch +import httpx import pytest from fastapi.testclient import TestClient from tensors.server import create_app +from tensors.server.models import ServerConfig from tensors.server.process import ProcessManager @@ -22,7 +24,7 @@ def api() -> TestClient: def _get_pm(api: TestClient) -> ProcessManager: - return api.app.state.pm # type: ignore[union-attr] + return api.app.state.pm # type: ignore[no-any-return, attr-defined] class TestStatus: @@ -37,7 +39,7 @@ class TestStatus: mock_proc.poll.return_value = None mock_proc.pid = 999 pm.proc = mock_proc - pm.config = {"model": "/m.safetensors", "port": 1234, "args": []} + pm.config = ServerConfig(model="/m.safetensors") r = api.get("/status") data = r.json() assert data["running"] is True @@ -54,69 +56,86 @@ class TestStatus: assert data["exit_code"] == 1 -class TestStart: +class TestReload: + @patch.object(ProcessManager, "wait_ready", new_callable=AsyncMock, return_value=True) @patch("tensors.server.process.subprocess.Popen") - def test_start_success(self, mock_popen: MagicMock, api: TestClient) -> None: + def test_reload_swaps_model(self, mock_popen: MagicMock, mock_ready: AsyncMock, api: TestClient) -> None: + pm = _get_pm(api) + pm.config = ServerConfig(model="/old.gguf", port=5555, args=["--fa"]) mock_popen.return_value.pid = 42 mock_popen.return_value.poll.return_value = None - r = api.post("/start", json={"model": "/m.safetensors"}) - assert r.status_code == 200 - assert r.json()["started"] is True - assert r.json()["pid"] == 42 - - @patch("tensors.server.process.subprocess.Popen") - def test_start_already_running(self, mock_popen: MagicMock, api: TestClient) -> None: - pm = _get_pm(api) - mock_proc = MagicMock() - mock_proc.poll.return_value = None - pm.proc = mock_proc - r = api.post("/start", json={"model": "/m.safetensors"}) - assert r.status_code == 409 - - -class TestStop: - def test_stop_not_running(self, api: TestClient) -> None: - r = api.post("/stop") - assert r.status_code == 409 - - def test_stop_running(self, api: TestClient) -> None: - pm = _get_pm(api) - mock_proc = MagicMock() - mock_proc.poll.return_value = None - mock_proc.wait.return_value = 0 - pm.proc = mock_proc - r = api.post("/stop") - assert r.status_code == 200 - assert r.json()["stopped"] is True - mock_proc.send_signal.assert_called_once() - - -class TestRestart: - def test_restart_no_config_no_model(self, api: TestClient) -> None: - r = api.post("/restart", json={}) - assert r.status_code == 400 - - @patch("tensors.server.process.subprocess.Popen") - def test_restart_with_new_model(self, mock_popen: MagicMock, api: TestClient) -> None: - mock_popen.return_value.pid = 100 - mock_popen.return_value.poll.return_value = None - pm = _get_pm(api) - pm.config = {"model": "/old.safetensors", "port": 1234, "args": []} - r = api.post("/restart", json={"model": "/new.safetensors"}) + r = api.post("/reload", json={"model": "/new.gguf"}) assert r.status_code == 200 data = r.json() - assert data["restarted"] is True - assert "/new.safetensors" in str(data["cmd"]) + assert data["ok"] is True + assert data["model"] == "/new.gguf" + assert data["pid"] == 42 + # Verify new config preserved port and args from previous config + assert pm.config is not None + assert pm.config.port == 5555 + assert pm.config.args == ["--fa"] + assert pm.config.model == "/new.gguf" + @patch.object(ProcessManager, "wait_ready", new_callable=AsyncMock, return_value=False) @patch("tensors.server.process.subprocess.Popen") - def test_restart_keeps_previous_config(self, mock_popen: MagicMock, api: TestClient) -> None: - mock_popen.return_value.pid = 101 - mock_popen.return_value.poll.return_value = None + def test_reload_fails_when_not_ready(self, mock_popen: MagicMock, mock_ready: AsyncMock, api: TestClient) -> None: pm = _get_pm(api) - pm.config = {"model": "/m.safetensors", "port": 5555, "args": ["--fa"]} - r = api.post("/restart", json={}) + pm.config = ServerConfig(model="/old.gguf") + mock_popen.return_value.pid = 43 + mock_popen.return_value.poll.return_value = None + r = api.post("/reload", json={"model": "/bad.gguf"}) + assert r.status_code == 503 + assert "failed" in r.json()["error"] + + def test_reload_requires_model(self, api: TestClient) -> None: + r = api.post("/reload", json={}) + assert r.status_code == 422 + + +class TestProxy: + def test_proxy_503_when_not_running(self, api: TestClient) -> None: + r = api.get("/v1/models") + assert r.status_code == 503 + assert "not running" in r.json()["error"] + + def test_proxy_forwards_request(self, api: TestClient) -> None: + pm = _get_pm(api) + mock_proc = MagicMock() + mock_proc.poll.return_value = None + mock_proc.pid = 100 + pm.proc = mock_proc + pm.config = ServerConfig(model="/m.gguf", port=1234) + + upstream_response = httpx.Response( + 200, + json={"data": [{"id": "model-1"}]}, + headers={"content-type": "application/json"}, + ) + mock_client = AsyncMock() + mock_client.request.return_value = upstream_response + api.app.state.client = mock_client # type: ignore[attr-defined] + + r = api.get("/v1/models") assert r.status_code == 200 - assert "5555" in str(r.json()["cmd"]) + assert r.json() == {"data": [{"id": "model-1"}]} + mock_client.request.assert_called_once() + + def test_proxy_forwards_post_with_body(self, api: TestClient) -> None: + pm = _get_pm(api) + mock_proc = MagicMock() + mock_proc.poll.return_value = None + mock_proc.pid = 100 + pm.proc = mock_proc + pm.config = ServerConfig(model="/m.gguf", port=1234) + + upstream_response = httpx.Response(200, json={"ok": True}) + mock_client = AsyncMock() + mock_client.request.return_value = upstream_response + api.app.state.client = mock_client # type: ignore[attr-defined] + + r = api.post("/v1/chat/completions", json={"prompt": "hello"}) + assert r.status_code == 200 + mock_client.request.assert_called_once() class TestProcessManager: @@ -124,18 +143,27 @@ class TestProcessManager: assert pm.status() == {"running": False} def test_build_cmd(self, pm: ProcessManager) -> None: - config = {"model": "/m.gguf", "port": 1234, "args": ["--fa"]} - cmd = pm.build_cmd(config) + pm.config = ServerConfig(model="/m.gguf", port=1234, args=["--fa"]) + cmd = pm.build_cmd() assert "/m.gguf" in cmd assert "--fa" in cmd assert "1234" in cmd + def test_build_cmd_no_config(self, pm: ProcessManager) -> None: + with pytest.raises(RuntimeError, match="No config"): + pm.build_cmd() + @patch("tensors.server.process.subprocess.Popen") def test_start_and_stop(self, mock_popen: MagicMock, pm: ProcessManager) -> None: mock_popen.return_value.pid = 77 mock_popen.return_value.poll.return_value = None mock_popen.return_value.wait.return_value = 0 - pm.start({"model": "/m.gguf", "port": 1234, "args": []}) + pm.start(ServerConfig(model="/m.gguf")) assert pm.proc is not None assert pm.stop() is True assert pm.proc is None + + def test_server_config_defaults(self) -> None: + cfg = ServerConfig(model="/m.gguf") + assert cfg.port == 1234 + assert cfg.args == []