Remove internal sd-server management, proxy to external sd-server

- Remove ProcessManager and process.py
- Add get_sd_server_url() config (env/config/default)
- Update routes to proxy to external sd-server URL
- Remove model switching (handled by external sd-server)
- Update CLI serve command
- Update tests for new architecture

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Adam Ladachowski
2026-02-14 06:39:35 +01:00
parent be7cf0b6e7
commit e9480a18c2
16 changed files with 211 additions and 390 deletions
+4 -6
View File
@@ -582,24 +582,22 @@ def reload(
@app.command()
def serve(
model: Annotated[str, typer.Option(help="Path to model file for sd-server.")],
host: Annotated[str, typer.Option(help="Wrapper API listen address.")] = "127.0.0.1",
port: Annotated[int, typer.Option(help="Wrapper API listen port.")] = 8080,
sd_port: Annotated[int, typer.Option(help="sd-server listen port.")] = 1234,
sd_server: Annotated[str | None, typer.Option(help="sd-server URL to proxy to.")] = None,
log_level: Annotated[str, typer.Option(help="Log level.")] = "info",
) -> None:
"""Start the sd-server wrapper API (transparent proxy with hot reload)."""
"""Start the sd-server wrapper API (proxies to external sd-server)."""
try:
import uvicorn # noqa: PLC0415
from tensors.server import ServerConfig, create_app # noqa: PLC0415
from tensors.server import create_app # noqa: PLC0415
except ImportError:
console.print("[red]Missing server dependencies. Install with:[/red]")
console.print(" pip install tensors[server]")
raise typer.Exit(1) from None
config = ServerConfig(model=model, port=sd_port)
uvicorn.run(create_app(config), host=host, port=port, log_level=log_level)
uvicorn.run(create_app(sd_server_url=sd_server), host=host, port=port, log_level=log_level)
# =============================================================================
+31
View File
@@ -243,3 +243,34 @@ def set_default_remote(name: str | None) -> None:
else:
config["default_remote"] = name
save_config(config)
# ============================================================================
# SD Server Configuration
# ============================================================================
SD_SERVER_DEFAULT_URL = "http://localhost:1234"
def get_sd_server_url() -> str:
"""Get the sd-server URL.
Resolution order:
1. SD_SERVER_URL environment variable
2. config.toml [server].sd_server_url
3. Default: http://localhost:1234
"""
# Check environment variable first
env_url = os.environ.get("SD_SERVER_URL")
if env_url:
return env_url
# Check config file
config = load_config()
server_config = config.get("server", {})
if isinstance(server_config, dict):
url = server_config.get("sd_server_url")
if url:
return str(url)
return SD_SERVER_DEFAULT_URL
+16 -20
View File
@@ -1,4 +1,4 @@
"""sd-server wrapper — FastAPI app for managing and proxying to sd-server."""
"""sd-server wrapper — FastAPI app for proxying to an external sd-server."""
from __future__ import annotations
@@ -12,42 +12,39 @@ from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from tensors.config import get_sd_server_url
from tensors.server.civitai_routes import create_civitai_router
from tensors.server.db_routes import create_db_router
from tensors.server.download_routes import create_download_router
from tensors.server.gallery_routes import create_gallery_router
from tensors.server.generate_routes import create_generate_router
from tensors.server.models import ServerConfig
from tensors.server.models_routes import create_models_router
from tensors.server.process import ProcessManager
from tensors.server.routes import create_router
if TYPE_CHECKING:
from collections.abc import AsyncIterator
__all__ = ["ProcessManager", "ServerConfig", "app", "create_app"]
__all__ = ["app", "create_app"]
logger = logging.getLogger(__name__)
def create_app(config: ServerConfig | None = None) -> FastAPI:
"""Build the FastAPI application with process manager and proxy client."""
pm = ProcessManager()
def create_app(sd_server_url: str | None = None) -> FastAPI:
"""Build the FastAPI application that proxies to an external sd-server.
Args:
sd_server_url: URL of the sd-server to proxy to. If None, uses
get_sd_server_url() to resolve from env/config.
"""
backend_url = sd_server_url or get_sd_server_url()
@asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
_app.state.sd_server_url = backend_url
logger.info(f"Proxying to sd-server at: {backend_url}")
async with httpx.AsyncClient(timeout=300) as client:
_app.state.client = client
if config is not None:
pm.start(config)
logger.info("waiting for sd-server to become ready...")
ready = await pm.wait_ready()
if ready:
logger.info("sd-server is ready")
else:
logger.warning("sd-server did not become ready in time")
yield
pm.stop()
app = FastAPI(title="sd-server wrapper", lifespan=lifespan)
@@ -68,11 +65,10 @@ def create_app(config: ServerConfig | None = None) -> FastAPI:
app.include_router(create_civitai_router()) # Must be before catch-all proxy
app.include_router(create_db_router()) # Must be before catch-all proxy
app.include_router(create_gallery_router()) # Must be before catch-all proxy
app.include_router(create_models_router(pm)) # Must be before catch-all proxy
app.include_router(create_models_router()) # Must be before catch-all proxy
app.include_router(create_download_router()) # Must be before catch-all proxy
app.include_router(create_generate_router(pm)) # Must be before catch-all proxy
app.include_router(create_router(pm))
app.state.pm = pm
app.include_router(create_generate_router()) # Must be before catch-all proxy
app.include_router(create_router())
return app
+22 -28
View File
@@ -5,18 +5,15 @@ from __future__ import annotations
import base64
import logging
import time
from typing import TYPE_CHECKING, Any
from typing import Any
import httpx
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel as PydanticBaseModel
from pydantic import Field
from tensors.server.gallery import Gallery
if TYPE_CHECKING:
from tensors.server.process import ProcessManager
logger = logging.getLogger(__name__)
@@ -113,38 +110,30 @@ def _process_image(
return image_info
def _check_server_running(pm: ProcessManager) -> None:
"""Check if sd-server is running, raise HTTPException if not."""
if pm.proc is None or pm.proc.poll() is not None:
raise HTTPException(status_code=503, detail="sd-server is not running")
if pm.config is None:
raise HTTPException(status_code=503, detail="sd-server not configured")
# =============================================================================
# Router Factory
# =============================================================================
def create_generate_router(pm: ProcessManager) -> APIRouter:
def create_generate_router() -> APIRouter:
"""Build a router with /api/generate endpoint."""
router = APIRouter(prefix="/api", tags=["generate"])
gallery = Gallery()
@router.post("/generate")
async def generate(req: GenerateRequest) -> dict[str, Any]:
async def generate(request: Request, req: GenerateRequest) -> dict[str, Any]:
"""Generate images with gallery integration."""
_check_server_running(pm)
assert pm.config is not None # Verified by _check_server_running
sd_server_url = request.app.state.sd_server_url
body = _build_sd_request(req)
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/txt2img"
url = f"{sd_server_url}/sdapi/v1/txt2img"
try:
async with httpx.AsyncClient(timeout=300) as client:
response = await client.post(url, json=body)
response.raise_for_status()
result = response.json()
except httpx.ConnectError as e:
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
except httpx.HTTPError as e:
logger.exception("Generation failed")
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
@@ -153,8 +142,11 @@ def create_generate_router(pm: ProcessManager) -> APIRouter:
info = _parse_info(result.get("info", {}))
all_seeds = info.get("all_seeds", [req.seed] * len(images_data))
# Get model info from sd-server response if available
model_name = info.get("sd_model_name") or info.get("model")
output_images = [
_process_image(img_b64, i, all_seeds[i] if i < len(all_seeds) else req.seed + i, req, gallery, pm.config.model)
_process_image(img_b64, i, all_seeds[i] if i < len(all_seeds) else req.seed + i, req, gallery, model_name)
for i, img_b64 in enumerate(images_data)
]
@@ -167,32 +159,34 @@ def create_generate_router(pm: ProcessManager) -> APIRouter:
}
@router.get("/samplers")
async def list_samplers() -> dict[str, Any]:
async def list_samplers(request: Request) -> dict[str, Any]:
"""List available samplers from sd-server."""
_check_server_running(pm)
assert pm.config is not None
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/samplers"
sd_server_url = request.app.state.sd_server_url
url = f"{sd_server_url}/sdapi/v1/samplers"
try:
async with httpx.AsyncClient(timeout=30) as client:
response = await client.get(url)
response.raise_for_status()
return {"samplers": response.json()}
except httpx.ConnectError as e:
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
except httpx.HTTPError as e:
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
@router.get("/schedulers")
async def list_schedulers() -> dict[str, Any]:
async def list_schedulers(request: Request) -> dict[str, Any]:
"""List available schedulers from sd-server."""
_check_server_running(pm)
assert pm.config is not None
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/schedulers"
sd_server_url = request.app.state.sd_server_url
url = f"{sd_server_url}/sdapi/v1/schedulers"
try:
async with httpx.AsyncClient(timeout=30) as client:
response = await client.get(url)
response.raise_for_status()
return {"schedulers": response.json()}
except httpx.ConnectError as e:
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
except httpx.HTTPError as e:
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
+2 -13
View File
@@ -2,16 +2,5 @@
from __future__ import annotations
from pydantic import BaseModel
DEFAULT_PORT = 1234
class ReloadRequest(BaseModel):
model: str
class ServerConfig(BaseModel):
model: str
port: int = DEFAULT_PORT
args: list[str] = []
# Note: ServerConfig and ReloadRequest were removed since we no longer manage
# sd-server processes internally. The wrapper now proxies to an external sd-server.
+27 -67
View File
@@ -3,30 +3,18 @@
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any
from fastapi import APIRouter, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel as PydanticBaseModel
from fastapi import APIRouter, Request
from tensors.config import MODELS_DIR
if TYPE_CHECKING:
from tensors.server.process import ProcessManager
from pathlib import Path
logger = logging.getLogger(__name__)
# =============================================================================
# Request/Response Models
# =============================================================================
class SwitchModelRequest(PydanticBaseModel):
"""Request body for switching models."""
model: str # Path to model file
_HTTP_OK = 200
# =============================================================================
@@ -76,7 +64,7 @@ def scan_checkpoints(directory: Path | None = None) -> list[dict[str, Any]]:
# =============================================================================
def create_models_router(pm: ProcessManager) -> APIRouter:
def create_models_router() -> APIRouter:
"""Build a router with /api/models/* endpoints."""
router = APIRouter(prefix="/api/models", tags=["models"])
@@ -90,62 +78,34 @@ def create_models_router(pm: ProcessManager) -> APIRouter:
}
@router.get("/active")
def get_active_model() -> dict[str, Any]:
"""Get information about the currently loaded model."""
status = pm.status()
config = pm.config
async def get_active_model(request: Request) -> dict[str, Any]:
"""Get information about the currently loaded model from sd-server."""
import httpx # noqa: PLC0415
if config is None:
return {
"loaded": False,
"model": None,
"status": status.get("status"),
}
sd_server_url = request.app.state.sd_server_url
# Try to get current model from sd-server's options endpoint
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get(f"{sd_server_url}/sdapi/v1/options")
if response.status_code == _HTTP_OK:
options = response.json()
model_name = options.get("sd_model_checkpoint")
return {
"loaded": True,
"model": model_name,
"sd_server_url": sd_server_url,
}
except httpx.HTTPError:
pass
return {
"loaded": status.get("status") == "running",
"model": config.model,
"pid": status.get("pid"),
"port": config.port,
"status": status.get("status"),
"loaded": False,
"model": None,
"sd_server_url": sd_server_url,
"error": "Cannot connect to sd-server",
}
@router.post("/switch")
async def switch_model(req: SwitchModelRequest) -> JSONResponse:
"""Switch to a different model (hot reload)."""
model_path = Path(req.model)
# Validate model exists
if not model_path.exists():
raise HTTPException(status_code=400, detail=f"Model not found: {req.model}")
# Use existing reload logic
from tensors.server.models import ServerConfig # noqa: PLC0415
new_config = ServerConfig(
model=req.model,
port=pm.config.port if pm.config else 1234,
args=pm.config.args if pm.config else [],
)
pm.stop()
pm.start(new_config)
ready = await pm.wait_ready()
if not ready:
return JSONResponse(
{"error": "sd-server failed to become ready", "model": req.model},
status_code=503,
)
return JSONResponse(
{
"ok": True,
"model": req.model,
"pid": pm.proc.pid if pm.proc else None,
}
)
@router.get("/loras")
def list_loras() -> dict[str, Any]:
"""List available LoRA files."""
-92
View File
@@ -1,92 +0,0 @@
"""sd-server process lifecycle management."""
from __future__ import annotations
import asyncio
import logging
import os
import shutil
import signal
import subprocess
from typing import TYPE_CHECKING, Any
import httpx
if TYPE_CHECKING:
from tensors.server.models import ServerConfig
logger = logging.getLogger(__name__)
_HTTP_OK = 200
SD_SERVER_BIN = shutil.which("sd-server") or "sd-server"
class ProcessManager:
def __init__(self) -> None:
self.proc: subprocess.Popen[bytes] | None = None
self.config: ServerConfig | None = None
def build_cmd(self) -> list[str]:
if self.config is None:
raise RuntimeError("No config set")
cmd = [SD_SERVER_BIN, "-m", self.config.model, "--listen-port", str(self.config.port)]
cmd.extend(self.config.args)
return cmd
def start(self, config: ServerConfig) -> None:
if self.proc is not None and self.proc.poll() is None:
raise RuntimeError("Server already running — stop it first")
self.config = config
cmd = self.build_cmd()
# Inherit environment (important for HSA_OVERRIDE_GFX_VERSION on ROCm)
self.proc = subprocess.Popen(cmd, env=os.environ.copy())
logger.info("started sd-server pid=%d cmd=%s", self.proc.pid, cmd)
def stop(self) -> bool:
if self.proc is None or self.proc.poll() is not None:
self.proc = None
return False
self.proc.send_signal(signal.SIGTERM)
try:
self.proc.wait(timeout=10)
except subprocess.TimeoutExpired:
self.proc.kill()
self.proc.wait(timeout=5)
logger.info("stopped sd-server")
self.proc = None
return True
def status(self) -> dict[str, Any]:
if self.proc is None:
return {"running": False}
rc = self.proc.poll()
if rc is not None:
return {"running": False, "exit_code": rc}
return {
"running": True,
"pid": self.proc.pid,
"model": self.config.model if self.config else None,
"port": self.config.port if self.config else None,
"bind": f"http://127.0.0.1:{self.config.port}" if self.config else None,
"cmd": self.build_cmd(),
}
async def wait_ready(self, timeout: float = 120) -> bool:
"""Poll sd-server root endpoint until it responds or timeout."""
if self.config is None:
return False
url = f"http://127.0.0.1:{self.config.port}/"
deadline = asyncio.get_event_loop().time() + timeout
async with httpx.AsyncClient() as client:
while asyncio.get_event_loop().time() < deadline:
if self.proc is not None and self.proc.poll() is not None:
return False
try:
r = await client.get(url, timeout=2)
if r.status_code == _HTTP_OK:
return True
except (httpx.ConnectError, httpx.TimeoutException, httpx.ReadError):
pass
await asyncio.sleep(1)
return False
+49 -40
View File
@@ -3,64 +3,73 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
from typing import Any
import httpx
from fastapi import APIRouter, Request, Response
from fastapi.responses import JSONResponse, StreamingResponse
from tensors.server.models import ReloadRequest, ServerConfig
if TYPE_CHECKING:
from tensors.server.process import ProcessManager
logger = logging.getLogger(__name__)
def create_router(pm: ProcessManager) -> APIRouter:
"""Build a router with /status, /reload, and catch-all proxy."""
def create_router() -> APIRouter:
"""Build a router with /status and catch-all proxy."""
router = APIRouter()
@router.get("/status")
def status() -> dict[str, Any]:
return pm.status()
@router.post("/reload")
async def reload(req: ReloadRequest) -> Response:
new_config = ServerConfig(
model=req.model,
port=pm.config.port if pm.config else 1234,
args=pm.config.args if pm.config else [],
)
pm.stop()
pm.start(new_config)
ready = await pm.wait_ready()
if not ready:
return JSONResponse({"error": "sd-server failed to become ready", "model": req.model}, status_code=503)
return JSONResponse({"ok": True, "model": req.model, "pid": pm.proc.pid if pm.proc else None})
async def status(request: Request) -> dict[str, Any]:
"""Check if the external sd-server is reachable."""
sd_server_url = request.app.state.sd_server_url
try:
async with httpx.AsyncClient(timeout=5) as client:
r = await client.get(sd_server_url)
return {
"status": "ok",
"sd_server_url": sd_server_url,
"sd_server_status": r.status_code,
}
except httpx.HTTPError as e:
return {
"status": "error",
"sd_server_url": sd_server_url,
"error": str(e),
}
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"])
async def proxy(request: Request, path: str) -> Response:
if pm.proc is None or pm.proc.poll() is not None:
return JSONResponse({"error": "sd-server is not running"}, status_code=503)
assert pm.config is not None
url = f"http://127.0.0.1:{pm.config.port}/{path}"
"""Proxy all requests to the external sd-server."""
sd_server_url = request.app.state.sd_server_url
url = f"{sd_server_url}/{path}"
if request.url.query:
url = f"{url}?{request.url.query}"
body = await request.body()
headers = dict(request.headers)
headers.pop("host", None)
client = request.app.state.client
upstream = await client.request(
method=request.method,
url=url,
headers=headers,
content=body,
timeout=300,
)
return StreamingResponse(
content=upstream.iter_bytes(),
status_code=upstream.status_code,
headers=dict(upstream.headers),
)
try:
upstream = await client.request(
method=request.method,
url=url,
headers=headers,
content=body,
timeout=300,
)
return StreamingResponse(
content=upstream.iter_bytes(),
status_code=upstream.status_code,
headers=dict(upstream.headers),
)
except httpx.ConnectError:
return JSONResponse(
{"error": f"Cannot connect to sd-server at {sd_server_url}"},
status_code=503,
)
except httpx.TimeoutException:
return JSONResponse(
{"error": f"Timeout connecting to sd-server at {sd_server_url}"},
status_code=504,
)
return router