diff --git a/.coverage b/.coverage index 5b6e50b..f0c0133 100644 Binary files a/.coverage and b/.coverage differ diff --git a/pyproject.toml b/pyproject.toml index ccd880d..332efba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "tensors" -version = "0.1.11" +version = "0.1.12" description = "Read safetensor metadata and fetch CivitAI model information" readme = "README.md" requires-python = ">=3.12" diff --git a/screenshots/sd-server-raw.png b/screenshots/sd-server-raw.png new file mode 100644 index 0000000..847c6ca Binary files /dev/null and b/screenshots/sd-server-raw.png differ diff --git a/screenshots/tensors-debug.png b/screenshots/tensors-debug.png new file mode 100644 index 0000000..0810be1 Binary files /dev/null and b/screenshots/tensors-debug.png differ diff --git a/screenshots/test-cat.png b/screenshots/test-cat.png new file mode 100644 index 0000000..882b5ea Binary files /dev/null and b/screenshots/test-cat.png differ diff --git a/scripts/reinstall.py b/scripts/reinstall.py index ed9b908..7c6b26f 100755 --- a/scripts/reinstall.py +++ b/scripts/reinstall.py @@ -55,9 +55,17 @@ def sync_to_junkpile() -> None: """Sync project to junkpile.""" print("\n[3/4] Syncing to junkpile...") excludes = [ - ".git", ".venv", "__pycache__", "*.pyc", ".mypy_cache", - ".pytest_cache", ".ruff_cache", ".coverage", "*.egg-info", - "node_modules", ".tmp", + ".git", + ".venv", + "__pycache__", + "*.pyc", + ".mypy_cache", + ".pytest_cache", + ".ruff_cache", + ".coverage", + "*.egg-info", + "node_modules", + ".tmp", ] cmd = ["rsync", "-avz", "--delete"] for exc in excludes: diff --git a/tensors/cli.py b/tensors/cli.py index a09c881..d0b3b1c 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -582,24 +582,22 @@ def reload( @app.command() def serve( - model: Annotated[str, typer.Option(help="Path to model file for sd-server.")], host: Annotated[str, typer.Option(help="Wrapper API listen address.")] = "127.0.0.1", port: Annotated[int, typer.Option(help="Wrapper API listen port.")] = 8080, - sd_port: Annotated[int, typer.Option(help="sd-server listen port.")] = 1234, + sd_server: Annotated[str | None, typer.Option(help="sd-server URL to proxy to.")] = None, log_level: Annotated[str, typer.Option(help="Log level.")] = "info", ) -> None: - """Start the sd-server wrapper API (transparent proxy with hot reload).""" + """Start the sd-server wrapper API (proxies to external sd-server).""" try: import uvicorn # noqa: PLC0415 - from tensors.server import ServerConfig, create_app # noqa: PLC0415 + from tensors.server import create_app # noqa: PLC0415 except ImportError: console.print("[red]Missing server dependencies. Install with:[/red]") console.print(" pip install tensors[server]") raise typer.Exit(1) from None - config = ServerConfig(model=model, port=sd_port) - uvicorn.run(create_app(config), host=host, port=port, log_level=log_level) + uvicorn.run(create_app(sd_server_url=sd_server), host=host, port=port, log_level=log_level) # ============================================================================= diff --git a/tensors/config.py b/tensors/config.py index b7d4f9b..e8a56f7 100644 --- a/tensors/config.py +++ b/tensors/config.py @@ -243,3 +243,34 @@ def set_default_remote(name: str | None) -> None: else: config["default_remote"] = name save_config(config) + + +# ============================================================================ +# SD Server Configuration +# ============================================================================ + +SD_SERVER_DEFAULT_URL = "http://localhost:1234" + + +def get_sd_server_url() -> str: + """Get the sd-server URL. + + Resolution order: + 1. SD_SERVER_URL environment variable + 2. config.toml [server].sd_server_url + 3. Default: http://localhost:1234 + """ + # Check environment variable first + env_url = os.environ.get("SD_SERVER_URL") + if env_url: + return env_url + + # Check config file + config = load_config() + server_config = config.get("server", {}) + if isinstance(server_config, dict): + url = server_config.get("sd_server_url") + if url: + return str(url) + + return SD_SERVER_DEFAULT_URL diff --git a/tensors/server/__init__.py b/tensors/server/__init__.py index 01e8778..141e510 100644 --- a/tensors/server/__init__.py +++ b/tensors/server/__init__.py @@ -1,4 +1,4 @@ -"""sd-server wrapper — FastAPI app for managing and proxying to sd-server.""" +"""sd-server wrapper — FastAPI app for proxying to an external sd-server.""" from __future__ import annotations @@ -12,42 +12,39 @@ from fastapi import FastAPI from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles +from tensors.config import get_sd_server_url from tensors.server.civitai_routes import create_civitai_router from tensors.server.db_routes import create_db_router from tensors.server.download_routes import create_download_router from tensors.server.gallery_routes import create_gallery_router from tensors.server.generate_routes import create_generate_router -from tensors.server.models import ServerConfig from tensors.server.models_routes import create_models_router -from tensors.server.process import ProcessManager from tensors.server.routes import create_router if TYPE_CHECKING: from collections.abc import AsyncIterator -__all__ = ["ProcessManager", "ServerConfig", "app", "create_app"] +__all__ = ["app", "create_app"] logger = logging.getLogger(__name__) -def create_app(config: ServerConfig | None = None) -> FastAPI: - """Build the FastAPI application with process manager and proxy client.""" - pm = ProcessManager() +def create_app(sd_server_url: str | None = None) -> FastAPI: + """Build the FastAPI application that proxies to an external sd-server. + + Args: + sd_server_url: URL of the sd-server to proxy to. If None, uses + get_sd_server_url() to resolve from env/config. + """ + backend_url = sd_server_url or get_sd_server_url() @asynccontextmanager async def lifespan(_app: FastAPI) -> AsyncIterator[None]: + _app.state.sd_server_url = backend_url + logger.info(f"Proxying to sd-server at: {backend_url}") async with httpx.AsyncClient(timeout=300) as client: _app.state.client = client - if config is not None: - pm.start(config) - logger.info("waiting for sd-server to become ready...") - ready = await pm.wait_ready() - if ready: - logger.info("sd-server is ready") - else: - logger.warning("sd-server did not become ready in time") yield - pm.stop() app = FastAPI(title="sd-server wrapper", lifespan=lifespan) @@ -68,11 +65,10 @@ def create_app(config: ServerConfig | None = None) -> FastAPI: app.include_router(create_civitai_router()) # Must be before catch-all proxy app.include_router(create_db_router()) # Must be before catch-all proxy app.include_router(create_gallery_router()) # Must be before catch-all proxy - app.include_router(create_models_router(pm)) # Must be before catch-all proxy + app.include_router(create_models_router()) # Must be before catch-all proxy app.include_router(create_download_router()) # Must be before catch-all proxy - app.include_router(create_generate_router(pm)) # Must be before catch-all proxy - app.include_router(create_router(pm)) - app.state.pm = pm + app.include_router(create_generate_router()) # Must be before catch-all proxy + app.include_router(create_router()) return app diff --git a/tensors/server/generate_routes.py b/tensors/server/generate_routes.py index 863c8ed..a9f5be4 100644 --- a/tensors/server/generate_routes.py +++ b/tensors/server/generate_routes.py @@ -5,18 +5,15 @@ from __future__ import annotations import base64 import logging import time -from typing import TYPE_CHECKING, Any +from typing import Any import httpx -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, Request from pydantic import BaseModel as PydanticBaseModel from pydantic import Field from tensors.server.gallery import Gallery -if TYPE_CHECKING: - from tensors.server.process import ProcessManager - logger = logging.getLogger(__name__) @@ -113,38 +110,30 @@ def _process_image( return image_info -def _check_server_running(pm: ProcessManager) -> None: - """Check if sd-server is running, raise HTTPException if not.""" - if pm.proc is None or pm.proc.poll() is not None: - raise HTTPException(status_code=503, detail="sd-server is not running") - if pm.config is None: - raise HTTPException(status_code=503, detail="sd-server not configured") - - # ============================================================================= # Router Factory # ============================================================================= -def create_generate_router(pm: ProcessManager) -> APIRouter: +def create_generate_router() -> APIRouter: """Build a router with /api/generate endpoint.""" router = APIRouter(prefix="/api", tags=["generate"]) gallery = Gallery() @router.post("/generate") - async def generate(req: GenerateRequest) -> dict[str, Any]: + async def generate(request: Request, req: GenerateRequest) -> dict[str, Any]: """Generate images with gallery integration.""" - _check_server_running(pm) - assert pm.config is not None # Verified by _check_server_running - + sd_server_url = request.app.state.sd_server_url body = _build_sd_request(req) - url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/txt2img" + url = f"{sd_server_url}/sdapi/v1/txt2img" try: async with httpx.AsyncClient(timeout=300) as client: response = await client.post(url, json=body) response.raise_for_status() result = response.json() + except httpx.ConnectError as e: + raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e except httpx.HTTPError as e: logger.exception("Generation failed") raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e @@ -153,8 +142,11 @@ def create_generate_router(pm: ProcessManager) -> APIRouter: info = _parse_info(result.get("info", {})) all_seeds = info.get("all_seeds", [req.seed] * len(images_data)) + # Get model info from sd-server response if available + model_name = info.get("sd_model_name") or info.get("model") + output_images = [ - _process_image(img_b64, i, all_seeds[i] if i < len(all_seeds) else req.seed + i, req, gallery, pm.config.model) + _process_image(img_b64, i, all_seeds[i] if i < len(all_seeds) else req.seed + i, req, gallery, model_name) for i, img_b64 in enumerate(images_data) ] @@ -167,32 +159,34 @@ def create_generate_router(pm: ProcessManager) -> APIRouter: } @router.get("/samplers") - async def list_samplers() -> dict[str, Any]: + async def list_samplers(request: Request) -> dict[str, Any]: """List available samplers from sd-server.""" - _check_server_running(pm) - assert pm.config is not None - url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/samplers" + sd_server_url = request.app.state.sd_server_url + url = f"{sd_server_url}/sdapi/v1/samplers" try: async with httpx.AsyncClient(timeout=30) as client: response = await client.get(url) response.raise_for_status() return {"samplers": response.json()} + except httpx.ConnectError as e: + raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e except httpx.HTTPError as e: raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e @router.get("/schedulers") - async def list_schedulers() -> dict[str, Any]: + async def list_schedulers(request: Request) -> dict[str, Any]: """List available schedulers from sd-server.""" - _check_server_running(pm) - assert pm.config is not None - url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/schedulers" + sd_server_url = request.app.state.sd_server_url + url = f"{sd_server_url}/sdapi/v1/schedulers" try: async with httpx.AsyncClient(timeout=30) as client: response = await client.get(url) response.raise_for_status() return {"schedulers": response.json()} + except httpx.ConnectError as e: + raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e except httpx.HTTPError as e: raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e diff --git a/tensors/server/models.py b/tensors/server/models.py index b06f0b1..204d7ca 100644 --- a/tensors/server/models.py +++ b/tensors/server/models.py @@ -2,16 +2,5 @@ from __future__ import annotations -from pydantic import BaseModel - -DEFAULT_PORT = 1234 - - -class ReloadRequest(BaseModel): - model: str - - -class ServerConfig(BaseModel): - model: str - port: int = DEFAULT_PORT - args: list[str] = [] +# Note: ServerConfig and ReloadRequest were removed since we no longer manage +# sd-server processes internally. The wrapper now proxies to an external sd-server. diff --git a/tensors/server/models_routes.py b/tensors/server/models_routes.py index 47251c5..f11ba9f 100644 --- a/tensors/server/models_routes.py +++ b/tensors/server/models_routes.py @@ -3,30 +3,18 @@ from __future__ import annotations import logging -from pathlib import Path from typing import TYPE_CHECKING, Any -from fastapi import APIRouter, HTTPException -from fastapi.responses import JSONResponse -from pydantic import BaseModel as PydanticBaseModel +from fastapi import APIRouter, Request from tensors.config import MODELS_DIR if TYPE_CHECKING: - from tensors.server.process import ProcessManager + from pathlib import Path logger = logging.getLogger(__name__) - -# ============================================================================= -# Request/Response Models -# ============================================================================= - - -class SwitchModelRequest(PydanticBaseModel): - """Request body for switching models.""" - - model: str # Path to model file +_HTTP_OK = 200 # ============================================================================= @@ -76,7 +64,7 @@ def scan_checkpoints(directory: Path | None = None) -> list[dict[str, Any]]: # ============================================================================= -def create_models_router(pm: ProcessManager) -> APIRouter: +def create_models_router() -> APIRouter: """Build a router with /api/models/* endpoints.""" router = APIRouter(prefix="/api/models", tags=["models"]) @@ -90,62 +78,34 @@ def create_models_router(pm: ProcessManager) -> APIRouter: } @router.get("/active") - def get_active_model() -> dict[str, Any]: - """Get information about the currently loaded model.""" - status = pm.status() - config = pm.config + async def get_active_model(request: Request) -> dict[str, Any]: + """Get information about the currently loaded model from sd-server.""" + import httpx # noqa: PLC0415 - if config is None: - return { - "loaded": False, - "model": None, - "status": status.get("status"), - } + sd_server_url = request.app.state.sd_server_url + + # Try to get current model from sd-server's options endpoint + try: + async with httpx.AsyncClient(timeout=10) as client: + response = await client.get(f"{sd_server_url}/sdapi/v1/options") + if response.status_code == _HTTP_OK: + options = response.json() + model_name = options.get("sd_model_checkpoint") + return { + "loaded": True, + "model": model_name, + "sd_server_url": sd_server_url, + } + except httpx.HTTPError: + pass return { - "loaded": status.get("status") == "running", - "model": config.model, - "pid": status.get("pid"), - "port": config.port, - "status": status.get("status"), + "loaded": False, + "model": None, + "sd_server_url": sd_server_url, + "error": "Cannot connect to sd-server", } - @router.post("/switch") - async def switch_model(req: SwitchModelRequest) -> JSONResponse: - """Switch to a different model (hot reload).""" - model_path = Path(req.model) - - # Validate model exists - if not model_path.exists(): - raise HTTPException(status_code=400, detail=f"Model not found: {req.model}") - - # Use existing reload logic - from tensors.server.models import ServerConfig # noqa: PLC0415 - - new_config = ServerConfig( - model=req.model, - port=pm.config.port if pm.config else 1234, - args=pm.config.args if pm.config else [], - ) - - pm.stop() - pm.start(new_config) - ready = await pm.wait_ready() - - if not ready: - return JSONResponse( - {"error": "sd-server failed to become ready", "model": req.model}, - status_code=503, - ) - - return JSONResponse( - { - "ok": True, - "model": req.model, - "pid": pm.proc.pid if pm.proc else None, - } - ) - @router.get("/loras") def list_loras() -> dict[str, Any]: """List available LoRA files.""" diff --git a/tensors/server/process.py b/tensors/server/process.py deleted file mode 100644 index 6574c8c..0000000 --- a/tensors/server/process.py +++ /dev/null @@ -1,92 +0,0 @@ -"""sd-server process lifecycle management.""" - -from __future__ import annotations - -import asyncio -import logging -import os -import shutil -import signal -import subprocess -from typing import TYPE_CHECKING, Any - -import httpx - -if TYPE_CHECKING: - from tensors.server.models import ServerConfig - -logger = logging.getLogger(__name__) - -_HTTP_OK = 200 - -SD_SERVER_BIN = shutil.which("sd-server") or "sd-server" - - -class ProcessManager: - def __init__(self) -> None: - self.proc: subprocess.Popen[bytes] | None = None - self.config: ServerConfig | None = None - - def build_cmd(self) -> list[str]: - if self.config is None: - raise RuntimeError("No config set") - cmd = [SD_SERVER_BIN, "-m", self.config.model, "--listen-port", str(self.config.port)] - cmd.extend(self.config.args) - return cmd - - def start(self, config: ServerConfig) -> None: - if self.proc is not None and self.proc.poll() is None: - raise RuntimeError("Server already running — stop it first") - self.config = config - cmd = self.build_cmd() - # Inherit environment (important for HSA_OVERRIDE_GFX_VERSION on ROCm) - self.proc = subprocess.Popen(cmd, env=os.environ.copy()) - logger.info("started sd-server pid=%d cmd=%s", self.proc.pid, cmd) - - def stop(self) -> bool: - if self.proc is None or self.proc.poll() is not None: - self.proc = None - return False - self.proc.send_signal(signal.SIGTERM) - try: - self.proc.wait(timeout=10) - except subprocess.TimeoutExpired: - self.proc.kill() - self.proc.wait(timeout=5) - logger.info("stopped sd-server") - self.proc = None - return True - - def status(self) -> dict[str, Any]: - if self.proc is None: - return {"running": False} - rc = self.proc.poll() - if rc is not None: - return {"running": False, "exit_code": rc} - return { - "running": True, - "pid": self.proc.pid, - "model": self.config.model if self.config else None, - "port": self.config.port if self.config else None, - "bind": f"http://127.0.0.1:{self.config.port}" if self.config else None, - "cmd": self.build_cmd(), - } - - async def wait_ready(self, timeout: float = 120) -> bool: - """Poll sd-server root endpoint until it responds or timeout.""" - if self.config is None: - return False - url = f"http://127.0.0.1:{self.config.port}/" - deadline = asyncio.get_event_loop().time() + timeout - async with httpx.AsyncClient() as client: - while asyncio.get_event_loop().time() < deadline: - if self.proc is not None and self.proc.poll() is not None: - return False - try: - r = await client.get(url, timeout=2) - if r.status_code == _HTTP_OK: - return True - except (httpx.ConnectError, httpx.TimeoutException, httpx.ReadError): - pass - await asyncio.sleep(1) - return False diff --git a/tensors/server/routes.py b/tensors/server/routes.py index 214ab62..c37aa6b 100644 --- a/tensors/server/routes.py +++ b/tensors/server/routes.py @@ -3,64 +3,73 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from typing import Any +import httpx from fastapi import APIRouter, Request, Response from fastapi.responses import JSONResponse, StreamingResponse -from tensors.server.models import ReloadRequest, ServerConfig - -if TYPE_CHECKING: - from tensors.server.process import ProcessManager - logger = logging.getLogger(__name__) -def create_router(pm: ProcessManager) -> APIRouter: - """Build a router with /status, /reload, and catch-all proxy.""" +def create_router() -> APIRouter: + """Build a router with /status and catch-all proxy.""" router = APIRouter() @router.get("/status") - def status() -> dict[str, Any]: - return pm.status() - - @router.post("/reload") - async def reload(req: ReloadRequest) -> Response: - new_config = ServerConfig( - model=req.model, - port=pm.config.port if pm.config else 1234, - args=pm.config.args if pm.config else [], - ) - pm.stop() - pm.start(new_config) - ready = await pm.wait_ready() - if not ready: - return JSONResponse({"error": "sd-server failed to become ready", "model": req.model}, status_code=503) - return JSONResponse({"ok": True, "model": req.model, "pid": pm.proc.pid if pm.proc else None}) + async def status(request: Request) -> dict[str, Any]: + """Check if the external sd-server is reachable.""" + sd_server_url = request.app.state.sd_server_url + try: + async with httpx.AsyncClient(timeout=5) as client: + r = await client.get(sd_server_url) + return { + "status": "ok", + "sd_server_url": sd_server_url, + "sd_server_status": r.status_code, + } + except httpx.HTTPError as e: + return { + "status": "error", + "sd_server_url": sd_server_url, + "error": str(e), + } @router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]) async def proxy(request: Request, path: str) -> Response: - if pm.proc is None or pm.proc.poll() is not None: - return JSONResponse({"error": "sd-server is not running"}, status_code=503) - assert pm.config is not None - url = f"http://127.0.0.1:{pm.config.port}/{path}" + """Proxy all requests to the external sd-server.""" + sd_server_url = request.app.state.sd_server_url + url = f"{sd_server_url}/{path}" if request.url.query: url = f"{url}?{request.url.query}" + body = await request.body() headers = dict(request.headers) headers.pop("host", None) client = request.app.state.client - upstream = await client.request( - method=request.method, - url=url, - headers=headers, - content=body, - timeout=300, - ) - return StreamingResponse( - content=upstream.iter_bytes(), - status_code=upstream.status_code, - headers=dict(upstream.headers), - ) + + try: + upstream = await client.request( + method=request.method, + url=url, + headers=headers, + content=body, + timeout=300, + ) + return StreamingResponse( + content=upstream.iter_bytes(), + status_code=upstream.status_code, + headers=dict(upstream.headers), + ) + except httpx.ConnectError: + return JSONResponse( + {"error": f"Cannot connect to sd-server at {sd_server_url}"}, + status_code=503, + ) + except httpx.TimeoutException: + return JSONResponse( + {"error": f"Timeout connecting to sd-server at {sd_server_url}"}, + status_code=504, + ) return router diff --git a/tests/test_server.py b/tests/test_server.py index 17d98cb..8050c49 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -2,110 +2,51 @@ from __future__ import annotations -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock import httpx import pytest +import respx from fastapi.testclient import TestClient from tensors.server import create_app -from tensors.server.models import ServerConfig -from tensors.server.process import ProcessManager - - -@pytest.fixture() -def pm() -> ProcessManager: - return ProcessManager() @pytest.fixture() def api() -> TestClient: - return TestClient(create_app()) - - -def _get_pm(api: TestClient) -> ProcessManager: - return api.app.state.pm # type: ignore[no-any-return, attr-defined] + """Create test client with mock sd-server URL.""" + return TestClient(create_app(sd_server_url="http://mock-sd-server:1234")) class TestStatus: - def test_not_running(self, api: TestClient) -> None: - r = api.get("/status") - assert r.status_code == 200 - assert r.json()["running"] is False + @respx.mock + def test_status_when_backend_reachable(self) -> None: + """Test status endpoint when sd-server is reachable.""" + respx.get("http://mock-sd-server:1234/").mock(return_value=httpx.Response(200)) - def test_running(self, api: TestClient) -> None: - pm = _get_pm(api) - mock_proc = MagicMock() - mock_proc.poll.return_value = None - mock_proc.pid = 999 - pm.proc = mock_proc - pm.config = ServerConfig(model="/m.safetensors") - r = api.get("/status") - data = r.json() - assert data["running"] is True - assert data["pid"] == 999 + with TestClient(create_app(sd_server_url="http://mock-sd-server:1234")) as client: + r = client.get("/status") + assert r.status_code == 200 + data = r.json() + assert data["status"] == "ok" + assert data["sd_server_url"] == "http://mock-sd-server:1234" - def test_exited(self, api: TestClient) -> None: - pm = _get_pm(api) - mock_proc = MagicMock() - mock_proc.poll.return_value = 1 - pm.proc = mock_proc - r = api.get("/status") - data = r.json() - assert data["running"] is False - assert data["exit_code"] == 1 + @respx.mock + def test_status_when_backend_unreachable(self) -> None: + """Test status endpoint when sd-server is not reachable.""" + respx.get("http://mock-sd-server:1234/").mock(side_effect=httpx.ConnectError("Connection refused")) - -class TestReload: - @patch.object(ProcessManager, "wait_ready", new_callable=AsyncMock, return_value=True) - @patch("tensors.server.process.subprocess.Popen") - def test_reload_swaps_model(self, mock_popen: MagicMock, mock_ready: AsyncMock, api: TestClient) -> None: - pm = _get_pm(api) - pm.config = ServerConfig(model="/old.gguf", port=5555, args=["--fa"]) - mock_popen.return_value.pid = 42 - mock_popen.return_value.poll.return_value = None - r = api.post("/reload", json={"model": "/new.gguf"}) - assert r.status_code == 200 - data = r.json() - assert data["ok"] is True - assert data["model"] == "/new.gguf" - assert data["pid"] == 42 - # Verify new config preserved port and args from previous config - assert pm.config is not None - assert pm.config.port == 5555 - assert pm.config.args == ["--fa"] - assert pm.config.model == "/new.gguf" - - @patch.object(ProcessManager, "wait_ready", new_callable=AsyncMock, return_value=False) - @patch("tensors.server.process.subprocess.Popen") - def test_reload_fails_when_not_ready(self, mock_popen: MagicMock, mock_ready: AsyncMock, api: TestClient) -> None: - pm = _get_pm(api) - pm.config = ServerConfig(model="/old.gguf") - mock_popen.return_value.pid = 43 - mock_popen.return_value.poll.return_value = None - r = api.post("/reload", json={"model": "/bad.gguf"}) - assert r.status_code == 503 - assert "failed" in r.json()["error"] - - def test_reload_requires_model(self, api: TestClient) -> None: - r = api.post("/reload", json={}) - assert r.status_code == 422 + with TestClient(create_app(sd_server_url="http://mock-sd-server:1234")) as client: + r = client.get("/status") + assert r.status_code == 200 + data = r.json() + assert data["status"] == "error" + assert "Connection refused" in data["error"] class TestProxy: - def test_proxy_503_when_not_running(self, api: TestClient) -> None: - r = api.get("/v1/models") - assert r.status_code == 503 - assert "not running" in r.json()["error"] - def test_proxy_forwards_request(self, api: TestClient) -> None: - pm = _get_pm(api) - mock_proc = MagicMock() - mock_proc.poll.return_value = None - mock_proc.pid = 100 - pm.proc = mock_proc - pm.config = ServerConfig(model="/m.gguf", port=1234) - + """Test proxy forwards GET requests to backend.""" upstream_response = httpx.Response( 200, json={"data": [{"id": "model-1"}]}, @@ -114,6 +55,7 @@ class TestProxy: mock_client = AsyncMock() mock_client.request.return_value = upstream_response api.app.state.client = mock_client # type: ignore[attr-defined] + api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined] r = api.get("/v1/models") assert r.status_code == 200 @@ -121,52 +63,38 @@ class TestProxy: mock_client.request.assert_called_once() def test_proxy_forwards_post_with_body(self, api: TestClient) -> None: - pm = _get_pm(api) - mock_proc = MagicMock() - mock_proc.poll.return_value = None - mock_proc.pid = 100 - pm.proc = mock_proc - pm.config = ServerConfig(model="/m.gguf", port=1234) - + """Test proxy forwards POST requests with body.""" upstream_response = httpx.Response(200, json={"ok": True}) mock_client = AsyncMock() mock_client.request.return_value = upstream_response api.app.state.client = mock_client # type: ignore[attr-defined] + api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined] - r = api.post("/v1/chat/completions", json={"prompt": "hello"}) + r = api.post("/sdapi/v1/txt2img", json={"prompt": "hello"}) assert r.status_code == 200 mock_client.request.assert_called_once() + def test_proxy_503_on_connect_error(self, api: TestClient) -> None: + """Test proxy returns 503 when backend is unreachable.""" + mock_client = AsyncMock() + mock_client.request.side_effect = httpx.ConnectError("Connection refused") + api.app.state.client = mock_client # type: ignore[attr-defined] + api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined] -class TestProcessManager: - def test_status_not_running(self, pm: ProcessManager) -> None: - assert pm.status() == {"running": False} + r = api.get("/v1/models") + assert r.status_code == 503 + assert "Cannot connect" in r.json()["error"] - def test_build_cmd(self, pm: ProcessManager) -> None: - pm.config = ServerConfig(model="/m.gguf", port=1234, args=["--fa"]) - cmd = pm.build_cmd() - assert "/m.gguf" in cmd - assert "--fa" in cmd - assert "1234" in cmd + def test_proxy_504_on_timeout(self, api: TestClient) -> None: + """Test proxy returns 504 on timeout.""" + mock_client = AsyncMock() + mock_client.request.side_effect = httpx.TimeoutException("Timeout") + api.app.state.client = mock_client # type: ignore[attr-defined] + api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined] - def test_build_cmd_no_config(self, pm: ProcessManager) -> None: - with pytest.raises(RuntimeError, match="No config"): - pm.build_cmd() - - @patch("tensors.server.process.subprocess.Popen") - def test_start_and_stop(self, mock_popen: MagicMock, pm: ProcessManager) -> None: - mock_popen.return_value.pid = 77 - mock_popen.return_value.poll.return_value = None - mock_popen.return_value.wait.return_value = 0 - pm.start(ServerConfig(model="/m.gguf")) - assert pm.proc is not None - assert pm.stop() is True - assert pm.proc is None - - def test_server_config_defaults(self) -> None: - cfg = ServerConfig(model="/m.gguf") - assert cfg.port == 1234 - assert cfg.args == [] + r = api.get("/v1/models") + assert r.status_code == 504 + assert "Timeout" in r.json()["error"] # ============================================================================= diff --git a/uv.lock b/uv.lock index 4c71c4a..eaed257 100644 --- a/uv.lock +++ b/uv.lock @@ -707,7 +707,7 @@ wheels = [ [[package]] name = "tensors" -version = "0.1.9" +version = "0.1.12" source = { editable = "." } dependencies = [ { name = "httpx" },