Turn server wrapper into transparent proxy with /reload
Replace management endpoints (/start, /stop, /restart) with a transparent reverse proxy and hot-reload architecture. The wrapper now sits in front of sd-server, forwarding all requests and adding a /reload endpoint for model swapping without restarting the wrapper itself. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+6
-3
@@ -439,21 +439,24 @@ def generate(
|
|||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def serve(
|
def serve(
|
||||||
|
model: Annotated[str, typer.Option(help="Path to model file for sd-server.")],
|
||||||
host: Annotated[str, typer.Option(help="Wrapper API listen address.")] = "127.0.0.1",
|
host: Annotated[str, typer.Option(help="Wrapper API listen address.")] = "127.0.0.1",
|
||||||
port: Annotated[int, typer.Option(help="Wrapper API listen port.")] = 8080,
|
port: Annotated[int, typer.Option(help="Wrapper API listen port.")] = 8080,
|
||||||
|
sd_port: Annotated[int, typer.Option(help="sd-server listen port.")] = 1234,
|
||||||
log_level: Annotated[str, typer.Option(help="Log level.")] = "info",
|
log_level: Annotated[str, typer.Option(help="Log level.")] = "info",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start the sd-server wrapper API."""
|
"""Start the sd-server wrapper API (transparent proxy with hot reload)."""
|
||||||
try:
|
try:
|
||||||
import uvicorn # noqa: PLC0415
|
import uvicorn # noqa: PLC0415
|
||||||
|
|
||||||
from tensors.server import create_app # noqa: PLC0415
|
from tensors.server import ServerConfig, create_app # noqa: PLC0415
|
||||||
except ImportError:
|
except ImportError:
|
||||||
console.print("[red]Missing server dependencies. Install with:[/red]")
|
console.print("[red]Missing server dependencies. Install with:[/red]")
|
||||||
console.print(" pip install tensors[server]")
|
console.print(" pip install tensors[server]")
|
||||||
raise typer.Exit(1) from None
|
raise typer.Exit(1) from None
|
||||||
|
|
||||||
uvicorn.run(create_app(), host=host, port=port, log_level=log_level)
|
config = ServerConfig(model=model, port=sd_port)
|
||||||
|
uvicorn.run(create_app(config), host=host, port=port, log_level=log_level)
|
||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
def main() -> int:
|
||||||
|
|||||||
@@ -1,28 +1,43 @@
|
|||||||
"""sd-server wrapper — FastAPI app for managing sd-server process."""
|
"""sd-server wrapper — FastAPI app for managing and proxying to sd-server."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import httpx
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from tensors.server.models import ServerConfig
|
||||||
from tensors.server.process import ProcessManager
|
from tensors.server.process import ProcessManager
|
||||||
from tensors.server.routes import create_router
|
from tensors.server.routes import create_router
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
__all__ = ["ProcessManager", "create_app"]
|
__all__ = ["ProcessManager", "ServerConfig", "create_app"]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
def create_app(config: ServerConfig | None = None) -> FastAPI:
|
||||||
"""Build the FastAPI application with process manager."""
|
"""Build the FastAPI application with process manager and proxy client."""
|
||||||
pm = ProcessManager()
|
pm = ProcessManager()
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
||||||
yield
|
async with httpx.AsyncClient(timeout=300) as client:
|
||||||
|
_app.state.client = client
|
||||||
|
if config is not None:
|
||||||
|
pm.start(config)
|
||||||
|
logger.info("waiting for sd-server to become ready...")
|
||||||
|
ready = await pm.wait_ready()
|
||||||
|
if ready:
|
||||||
|
logger.info("sd-server is ready")
|
||||||
|
else:
|
||||||
|
logger.warning("sd-server did not become ready in time")
|
||||||
|
yield
|
||||||
pm.stop()
|
pm.stop()
|
||||||
|
|
||||||
app = FastAPI(title="sd-server wrapper", lifespan=lifespan)
|
app = FastAPI(title="sd-server wrapper", lifespan=lifespan)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Pydantic request models for the wrapper API."""
|
"""Pydantic models for the sd-server wrapper API."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -7,13 +7,11 @@ from pydantic import BaseModel
|
|||||||
DEFAULT_PORT = 1234
|
DEFAULT_PORT = 1234
|
||||||
|
|
||||||
|
|
||||||
class StartRequest(BaseModel):
|
class ReloadRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
|
||||||
|
|
||||||
|
class ServerConfig(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
port: int = DEFAULT_PORT
|
port: int = DEFAULT_PORT
|
||||||
args: list[str] = []
|
args: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
class RestartRequest(BaseModel):
|
|
||||||
model: str | None = None
|
|
||||||
port: int | None = None
|
|
||||||
args: list[str] | None = None
|
|
||||||
|
|||||||
@@ -2,33 +2,42 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
import signal
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from tensors.server.models import ServerConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_HTTP_OK = 200
|
||||||
|
|
||||||
SD_SERVER_BIN = shutil.which("sd-server") or "sd-server"
|
SD_SERVER_BIN = shutil.which("sd-server") or "sd-server"
|
||||||
|
|
||||||
|
|
||||||
class ProcessManager:
|
class ProcessManager:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.proc: subprocess.Popen[bytes] | None = None
|
self.proc: subprocess.Popen[bytes] | None = None
|
||||||
self.config: dict[str, Any] = {}
|
self.config: ServerConfig | None = None
|
||||||
|
|
||||||
def build_cmd(self, config: dict[str, Any] | None = None) -> list[str]:
|
def build_cmd(self) -> list[str]:
|
||||||
cfg = config or self.config
|
if self.config is None:
|
||||||
cmd = [SD_SERVER_BIN, "-m", cfg["model"], "--listen-port", str(cfg["port"])]
|
raise RuntimeError("No config set")
|
||||||
cmd.extend(cfg.get("args", []))
|
cmd = [SD_SERVER_BIN, "-m", self.config.model, "--port", str(self.config.port)]
|
||||||
|
cmd.extend(self.config.args)
|
||||||
return cmd
|
return cmd
|
||||||
|
|
||||||
def start(self, config: dict[str, Any]) -> None:
|
def start(self, config: ServerConfig) -> None:
|
||||||
if self.proc is not None and self.proc.poll() is None:
|
if self.proc is not None and self.proc.poll() is None:
|
||||||
raise RuntimeError("Server already running — stop it first")
|
raise RuntimeError("Server already running — stop it first")
|
||||||
self.config = config
|
self.config = config
|
||||||
cmd = self.build_cmd(config)
|
cmd = self.build_cmd()
|
||||||
self.proc = subprocess.Popen(cmd)
|
self.proc = subprocess.Popen(cmd)
|
||||||
logger.info("started sd-server pid=%d cmd=%s", self.proc.pid, cmd)
|
logger.info("started sd-server pid=%d cmd=%s", self.proc.pid, cmd)
|
||||||
|
|
||||||
@@ -55,6 +64,25 @@ class ProcessManager:
|
|||||||
return {
|
return {
|
||||||
"running": True,
|
"running": True,
|
||||||
"pid": self.proc.pid,
|
"pid": self.proc.pid,
|
||||||
"config": self.config,
|
"model": self.config.model if self.config else None,
|
||||||
"cmd": self.build_cmd(),
|
"cmd": self.build_cmd(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def wait_ready(self, timeout: float = 120) -> bool:
|
||||||
|
"""Poll sd-server /health until it responds or timeout."""
|
||||||
|
if self.config is None:
|
||||||
|
return False
|
||||||
|
url = f"http://127.0.0.1:{self.config.port}/health"
|
||||||
|
deadline = asyncio.get_event_loop().time() + timeout
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
while asyncio.get_event_loop().time() < deadline:
|
||||||
|
if self.proc is not None and self.proc.poll() is not None:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
r = await client.get(url, timeout=2)
|
||||||
|
if r.status_code == _HTTP_OK:
|
||||||
|
return True
|
||||||
|
except httpx.ConnectError:
|
||||||
|
pass
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
return False
|
||||||
|
|||||||
+45
-38
@@ -1,59 +1,66 @@
|
|||||||
"""FastAPI route handlers for the wrapper API."""
|
"""FastAPI route handlers for the sd-server wrapper API."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, Request, Response
|
||||||
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
|
||||||
from tensors.server.models import RestartRequest, StartRequest # noqa: TC001
|
from tensors.server.models import ReloadRequest, ServerConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tensors.server.process import ProcessManager
|
from tensors.server.process import ProcessManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def create_router(pm: ProcessManager) -> APIRouter:
|
def create_router(pm: ProcessManager) -> APIRouter:
|
||||||
"""Build a new router bound to the given ProcessManager."""
|
"""Build a router with /status, /reload, and catch-all proxy."""
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@router.get("/status")
|
@router.get("/status")
|
||||||
def status() -> dict[str, Any]:
|
def status() -> dict[str, Any]:
|
||||||
return pm.status()
|
return pm.status()
|
||||||
|
|
||||||
@router.post("/start")
|
@router.post("/reload")
|
||||||
def start(req: StartRequest) -> dict[str, Any]:
|
async def reload(req: ReloadRequest) -> Response:
|
||||||
if pm.proc is not None and pm.proc.poll() is None:
|
new_config = ServerConfig(
|
||||||
raise HTTPException(409, "Server already running — use /restart or /stop first")
|
model=req.model,
|
||||||
config = {"model": req.model, "port": req.port, "args": req.args}
|
port=pm.config.port if pm.config else 1234,
|
||||||
pm.start(config)
|
args=pm.config.args if pm.config else [],
|
||||||
assert pm.proc is not None
|
)
|
||||||
return {"started": True, "pid": pm.proc.pid, "cmd": pm.build_cmd(config)}
|
pm.stop()
|
||||||
|
pm.start(new_config)
|
||||||
|
ready = await pm.wait_ready()
|
||||||
|
if not ready:
|
||||||
|
return JSONResponse({"error": "sd-server failed to become ready", "model": req.model}, status_code=503)
|
||||||
|
return JSONResponse({"ok": True, "model": req.model, "pid": pm.proc.pid if pm.proc else None})
|
||||||
|
|
||||||
@router.post("/stop")
|
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"])
|
||||||
def stop() -> dict[str, Any]:
|
async def proxy(request: Request, path: str) -> Response:
|
||||||
if not pm.stop():
|
if pm.proc is None or pm.proc.poll() is not None:
|
||||||
raise HTTPException(409, "Server is not running")
|
return JSONResponse({"error": "sd-server is not running"}, status_code=503)
|
||||||
return {"stopped": True}
|
assert pm.config is not None
|
||||||
|
url = f"http://127.0.0.1:{pm.config.port}/{path}"
|
||||||
@router.post("/restart")
|
if request.url.query:
|
||||||
def restart(req: RestartRequest) -> dict[str, Any]:
|
url = f"{url}?{request.url.query}"
|
||||||
if not pm.config and req.model is None:
|
body = await request.body()
|
||||||
raise HTTPException(400, "No previous config — provide at least 'model'")
|
headers = dict(request.headers)
|
||||||
config = dict(pm.config)
|
headers.pop("host", None)
|
||||||
if req.model is not None:
|
client = request.app.state.client
|
||||||
config["model"] = req.model
|
upstream = await client.request(
|
||||||
if req.port is not None:
|
method=request.method,
|
||||||
config["port"] = req.port
|
url=url,
|
||||||
if req.args is not None:
|
headers=headers,
|
||||||
config["args"] = req.args
|
content=body,
|
||||||
was_running = pm.stop()
|
timeout=300,
|
||||||
pm.start(config)
|
)
|
||||||
assert pm.proc is not None
|
return StreamingResponse(
|
||||||
return {
|
content=upstream.iter_bytes(),
|
||||||
"restarted": True,
|
status_code=upstream.status_code,
|
||||||
"was_running": was_running,
|
headers=dict(upstream.headers),
|
||||||
"pid": pm.proc.pid,
|
)
|
||||||
"cmd": pm.build_cmd(config),
|
|
||||||
}
|
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|||||||
+89
-61
@@ -1,13 +1,15 @@
|
|||||||
"""Tests for tensors.server package (FastAPI sd-server manager)."""
|
"""Tests for tensors.server package (FastAPI sd-server proxy wrapper)."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from tensors.server import create_app
|
from tensors.server import create_app
|
||||||
|
from tensors.server.models import ServerConfig
|
||||||
from tensors.server.process import ProcessManager
|
from tensors.server.process import ProcessManager
|
||||||
|
|
||||||
|
|
||||||
@@ -22,7 +24,7 @@ def api() -> TestClient:
|
|||||||
|
|
||||||
|
|
||||||
def _get_pm(api: TestClient) -> ProcessManager:
|
def _get_pm(api: TestClient) -> ProcessManager:
|
||||||
return api.app.state.pm # type: ignore[union-attr]
|
return api.app.state.pm # type: ignore[no-any-return, attr-defined]
|
||||||
|
|
||||||
|
|
||||||
class TestStatus:
|
class TestStatus:
|
||||||
@@ -37,7 +39,7 @@ class TestStatus:
|
|||||||
mock_proc.poll.return_value = None
|
mock_proc.poll.return_value = None
|
||||||
mock_proc.pid = 999
|
mock_proc.pid = 999
|
||||||
pm.proc = mock_proc
|
pm.proc = mock_proc
|
||||||
pm.config = {"model": "/m.safetensors", "port": 1234, "args": []}
|
pm.config = ServerConfig(model="/m.safetensors")
|
||||||
r = api.get("/status")
|
r = api.get("/status")
|
||||||
data = r.json()
|
data = r.json()
|
||||||
assert data["running"] is True
|
assert data["running"] is True
|
||||||
@@ -54,69 +56,86 @@ class TestStatus:
|
|||||||
assert data["exit_code"] == 1
|
assert data["exit_code"] == 1
|
||||||
|
|
||||||
|
|
||||||
class TestStart:
|
class TestReload:
|
||||||
|
@patch.object(ProcessManager, "wait_ready", new_callable=AsyncMock, return_value=True)
|
||||||
@patch("tensors.server.process.subprocess.Popen")
|
@patch("tensors.server.process.subprocess.Popen")
|
||||||
def test_start_success(self, mock_popen: MagicMock, api: TestClient) -> None:
|
def test_reload_swaps_model(self, mock_popen: MagicMock, mock_ready: AsyncMock, api: TestClient) -> None:
|
||||||
|
pm = _get_pm(api)
|
||||||
|
pm.config = ServerConfig(model="/old.gguf", port=5555, args=["--fa"])
|
||||||
mock_popen.return_value.pid = 42
|
mock_popen.return_value.pid = 42
|
||||||
mock_popen.return_value.poll.return_value = None
|
mock_popen.return_value.poll.return_value = None
|
||||||
r = api.post("/start", json={"model": "/m.safetensors"})
|
r = api.post("/reload", json={"model": "/new.gguf"})
|
||||||
assert r.status_code == 200
|
|
||||||
assert r.json()["started"] is True
|
|
||||||
assert r.json()["pid"] == 42
|
|
||||||
|
|
||||||
@patch("tensors.server.process.subprocess.Popen")
|
|
||||||
def test_start_already_running(self, mock_popen: MagicMock, api: TestClient) -> None:
|
|
||||||
pm = _get_pm(api)
|
|
||||||
mock_proc = MagicMock()
|
|
||||||
mock_proc.poll.return_value = None
|
|
||||||
pm.proc = mock_proc
|
|
||||||
r = api.post("/start", json={"model": "/m.safetensors"})
|
|
||||||
assert r.status_code == 409
|
|
||||||
|
|
||||||
|
|
||||||
class TestStop:
|
|
||||||
def test_stop_not_running(self, api: TestClient) -> None:
|
|
||||||
r = api.post("/stop")
|
|
||||||
assert r.status_code == 409
|
|
||||||
|
|
||||||
def test_stop_running(self, api: TestClient) -> None:
|
|
||||||
pm = _get_pm(api)
|
|
||||||
mock_proc = MagicMock()
|
|
||||||
mock_proc.poll.return_value = None
|
|
||||||
mock_proc.wait.return_value = 0
|
|
||||||
pm.proc = mock_proc
|
|
||||||
r = api.post("/stop")
|
|
||||||
assert r.status_code == 200
|
|
||||||
assert r.json()["stopped"] is True
|
|
||||||
mock_proc.send_signal.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
class TestRestart:
|
|
||||||
def test_restart_no_config_no_model(self, api: TestClient) -> None:
|
|
||||||
r = api.post("/restart", json={})
|
|
||||||
assert r.status_code == 400
|
|
||||||
|
|
||||||
@patch("tensors.server.process.subprocess.Popen")
|
|
||||||
def test_restart_with_new_model(self, mock_popen: MagicMock, api: TestClient) -> None:
|
|
||||||
mock_popen.return_value.pid = 100
|
|
||||||
mock_popen.return_value.poll.return_value = None
|
|
||||||
pm = _get_pm(api)
|
|
||||||
pm.config = {"model": "/old.safetensors", "port": 1234, "args": []}
|
|
||||||
r = api.post("/restart", json={"model": "/new.safetensors"})
|
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
data = r.json()
|
data = r.json()
|
||||||
assert data["restarted"] is True
|
assert data["ok"] is True
|
||||||
assert "/new.safetensors" in str(data["cmd"])
|
assert data["model"] == "/new.gguf"
|
||||||
|
assert data["pid"] == 42
|
||||||
|
# Verify new config preserved port and args from previous config
|
||||||
|
assert pm.config is not None
|
||||||
|
assert pm.config.port == 5555
|
||||||
|
assert pm.config.args == ["--fa"]
|
||||||
|
assert pm.config.model == "/new.gguf"
|
||||||
|
|
||||||
|
@patch.object(ProcessManager, "wait_ready", new_callable=AsyncMock, return_value=False)
|
||||||
@patch("tensors.server.process.subprocess.Popen")
|
@patch("tensors.server.process.subprocess.Popen")
|
||||||
def test_restart_keeps_previous_config(self, mock_popen: MagicMock, api: TestClient) -> None:
|
def test_reload_fails_when_not_ready(self, mock_popen: MagicMock, mock_ready: AsyncMock, api: TestClient) -> None:
|
||||||
mock_popen.return_value.pid = 101
|
|
||||||
mock_popen.return_value.poll.return_value = None
|
|
||||||
pm = _get_pm(api)
|
pm = _get_pm(api)
|
||||||
pm.config = {"model": "/m.safetensors", "port": 5555, "args": ["--fa"]}
|
pm.config = ServerConfig(model="/old.gguf")
|
||||||
r = api.post("/restart", json={})
|
mock_popen.return_value.pid = 43
|
||||||
|
mock_popen.return_value.poll.return_value = None
|
||||||
|
r = api.post("/reload", json={"model": "/bad.gguf"})
|
||||||
|
assert r.status_code == 503
|
||||||
|
assert "failed" in r.json()["error"]
|
||||||
|
|
||||||
|
def test_reload_requires_model(self, api: TestClient) -> None:
|
||||||
|
r = api.post("/reload", json={})
|
||||||
|
assert r.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
class TestProxy:
|
||||||
|
def test_proxy_503_when_not_running(self, api: TestClient) -> None:
|
||||||
|
r = api.get("/v1/models")
|
||||||
|
assert r.status_code == 503
|
||||||
|
assert "not running" in r.json()["error"]
|
||||||
|
|
||||||
|
def test_proxy_forwards_request(self, api: TestClient) -> None:
|
||||||
|
pm = _get_pm(api)
|
||||||
|
mock_proc = MagicMock()
|
||||||
|
mock_proc.poll.return_value = None
|
||||||
|
mock_proc.pid = 100
|
||||||
|
pm.proc = mock_proc
|
||||||
|
pm.config = ServerConfig(model="/m.gguf", port=1234)
|
||||||
|
|
||||||
|
upstream_response = httpx.Response(
|
||||||
|
200,
|
||||||
|
json={"data": [{"id": "model-1"}]},
|
||||||
|
headers={"content-type": "application/json"},
|
||||||
|
)
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.request.return_value = upstream_response
|
||||||
|
api.app.state.client = mock_client # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
r = api.get("/v1/models")
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
assert "5555" in str(r.json()["cmd"])
|
assert r.json() == {"data": [{"id": "model-1"}]}
|
||||||
|
mock_client.request.assert_called_once()
|
||||||
|
|
||||||
|
def test_proxy_forwards_post_with_body(self, api: TestClient) -> None:
|
||||||
|
pm = _get_pm(api)
|
||||||
|
mock_proc = MagicMock()
|
||||||
|
mock_proc.poll.return_value = None
|
||||||
|
mock_proc.pid = 100
|
||||||
|
pm.proc = mock_proc
|
||||||
|
pm.config = ServerConfig(model="/m.gguf", port=1234)
|
||||||
|
|
||||||
|
upstream_response = httpx.Response(200, json={"ok": True})
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.request.return_value = upstream_response
|
||||||
|
api.app.state.client = mock_client # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
r = api.post("/v1/chat/completions", json={"prompt": "hello"})
|
||||||
|
assert r.status_code == 200
|
||||||
|
mock_client.request.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
class TestProcessManager:
|
class TestProcessManager:
|
||||||
@@ -124,18 +143,27 @@ class TestProcessManager:
|
|||||||
assert pm.status() == {"running": False}
|
assert pm.status() == {"running": False}
|
||||||
|
|
||||||
def test_build_cmd(self, pm: ProcessManager) -> None:
|
def test_build_cmd(self, pm: ProcessManager) -> None:
|
||||||
config = {"model": "/m.gguf", "port": 1234, "args": ["--fa"]}
|
pm.config = ServerConfig(model="/m.gguf", port=1234, args=["--fa"])
|
||||||
cmd = pm.build_cmd(config)
|
cmd = pm.build_cmd()
|
||||||
assert "/m.gguf" in cmd
|
assert "/m.gguf" in cmd
|
||||||
assert "--fa" in cmd
|
assert "--fa" in cmd
|
||||||
assert "1234" in cmd
|
assert "1234" in cmd
|
||||||
|
|
||||||
|
def test_build_cmd_no_config(self, pm: ProcessManager) -> None:
|
||||||
|
with pytest.raises(RuntimeError, match="No config"):
|
||||||
|
pm.build_cmd()
|
||||||
|
|
||||||
@patch("tensors.server.process.subprocess.Popen")
|
@patch("tensors.server.process.subprocess.Popen")
|
||||||
def test_start_and_stop(self, mock_popen: MagicMock, pm: ProcessManager) -> None:
|
def test_start_and_stop(self, mock_popen: MagicMock, pm: ProcessManager) -> None:
|
||||||
mock_popen.return_value.pid = 77
|
mock_popen.return_value.pid = 77
|
||||||
mock_popen.return_value.poll.return_value = None
|
mock_popen.return_value.poll.return_value = None
|
||||||
mock_popen.return_value.wait.return_value = 0
|
mock_popen.return_value.wait.return_value = 0
|
||||||
pm.start({"model": "/m.gguf", "port": 1234, "args": []})
|
pm.start(ServerConfig(model="/m.gguf"))
|
||||||
assert pm.proc is not None
|
assert pm.proc is not None
|
||||||
assert pm.stop() is True
|
assert pm.stop() is True
|
||||||
assert pm.proc is None
|
assert pm.proc is None
|
||||||
|
|
||||||
|
def test_server_config_defaults(self) -> None:
|
||||||
|
cfg = ServerConfig(model="/m.gguf")
|
||||||
|
assert cfg.port == 1234
|
||||||
|
assert cfg.args == []
|
||||||
|
|||||||
Reference in New Issue
Block a user