Remove internal sd-server management, proxy to external sd-server
- Remove ProcessManager and process.py - Add get_sd_server_url() config (env/config/default) - Update routes to proxy to external sd-server URL - Remove model switching (handled by external sd-server) - Update CLI serve command - Update tests for new architecture Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
+1
-1
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "tensors"
|
name = "tensors"
|
||||||
version = "0.1.11"
|
version = "0.1.12"
|
||||||
description = "Read safetensor metadata and fetch CivitAI model information"
|
description = "Read safetensor metadata and fetch CivitAI model information"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
|
|||||||
Binary file not shown.
|
After Width: | Height: | Size: 23 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 54 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 486 KiB |
+11
-3
@@ -55,9 +55,17 @@ def sync_to_junkpile() -> None:
|
|||||||
"""Sync project to junkpile."""
|
"""Sync project to junkpile."""
|
||||||
print("\n[3/4] Syncing to junkpile...")
|
print("\n[3/4] Syncing to junkpile...")
|
||||||
excludes = [
|
excludes = [
|
||||||
".git", ".venv", "__pycache__", "*.pyc", ".mypy_cache",
|
".git",
|
||||||
".pytest_cache", ".ruff_cache", ".coverage", "*.egg-info",
|
".venv",
|
||||||
"node_modules", ".tmp",
|
"__pycache__",
|
||||||
|
"*.pyc",
|
||||||
|
".mypy_cache",
|
||||||
|
".pytest_cache",
|
||||||
|
".ruff_cache",
|
||||||
|
".coverage",
|
||||||
|
"*.egg-info",
|
||||||
|
"node_modules",
|
||||||
|
".tmp",
|
||||||
]
|
]
|
||||||
cmd = ["rsync", "-avz", "--delete"]
|
cmd = ["rsync", "-avz", "--delete"]
|
||||||
for exc in excludes:
|
for exc in excludes:
|
||||||
|
|||||||
+4
-6
@@ -582,24 +582,22 @@ def reload(
|
|||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def serve(
|
def serve(
|
||||||
model: Annotated[str, typer.Option(help="Path to model file for sd-server.")],
|
|
||||||
host: Annotated[str, typer.Option(help="Wrapper API listen address.")] = "127.0.0.1",
|
host: Annotated[str, typer.Option(help="Wrapper API listen address.")] = "127.0.0.1",
|
||||||
port: Annotated[int, typer.Option(help="Wrapper API listen port.")] = 8080,
|
port: Annotated[int, typer.Option(help="Wrapper API listen port.")] = 8080,
|
||||||
sd_port: Annotated[int, typer.Option(help="sd-server listen port.")] = 1234,
|
sd_server: Annotated[str | None, typer.Option(help="sd-server URL to proxy to.")] = None,
|
||||||
log_level: Annotated[str, typer.Option(help="Log level.")] = "info",
|
log_level: Annotated[str, typer.Option(help="Log level.")] = "info",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start the sd-server wrapper API (transparent proxy with hot reload)."""
|
"""Start the sd-server wrapper API (proxies to external sd-server)."""
|
||||||
try:
|
try:
|
||||||
import uvicorn # noqa: PLC0415
|
import uvicorn # noqa: PLC0415
|
||||||
|
|
||||||
from tensors.server import ServerConfig, create_app # noqa: PLC0415
|
from tensors.server import create_app # noqa: PLC0415
|
||||||
except ImportError:
|
except ImportError:
|
||||||
console.print("[red]Missing server dependencies. Install with:[/red]")
|
console.print("[red]Missing server dependencies. Install with:[/red]")
|
||||||
console.print(" pip install tensors[server]")
|
console.print(" pip install tensors[server]")
|
||||||
raise typer.Exit(1) from None
|
raise typer.Exit(1) from None
|
||||||
|
|
||||||
config = ServerConfig(model=model, port=sd_port)
|
uvicorn.run(create_app(sd_server_url=sd_server), host=host, port=port, log_level=log_level)
|
||||||
uvicorn.run(create_app(config), host=host, port=port, log_level=log_level)
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|||||||
@@ -243,3 +243,34 @@ def set_default_remote(name: str | None) -> None:
|
|||||||
else:
|
else:
|
||||||
config["default_remote"] = name
|
config["default_remote"] = name
|
||||||
save_config(config)
|
save_config(config)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# SD Server Configuration
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
SD_SERVER_DEFAULT_URL = "http://localhost:1234"
|
||||||
|
|
||||||
|
|
||||||
|
def get_sd_server_url() -> str:
|
||||||
|
"""Get the sd-server URL.
|
||||||
|
|
||||||
|
Resolution order:
|
||||||
|
1. SD_SERVER_URL environment variable
|
||||||
|
2. config.toml [server].sd_server_url
|
||||||
|
3. Default: http://localhost:1234
|
||||||
|
"""
|
||||||
|
# Check environment variable first
|
||||||
|
env_url = os.environ.get("SD_SERVER_URL")
|
||||||
|
if env_url:
|
||||||
|
return env_url
|
||||||
|
|
||||||
|
# Check config file
|
||||||
|
config = load_config()
|
||||||
|
server_config = config.get("server", {})
|
||||||
|
if isinstance(server_config, dict):
|
||||||
|
url = server_config.get("sd_server_url")
|
||||||
|
if url:
|
||||||
|
return str(url)
|
||||||
|
|
||||||
|
return SD_SERVER_DEFAULT_URL
|
||||||
|
|||||||
+16
-20
@@ -1,4 +1,4 @@
|
|||||||
"""sd-server wrapper — FastAPI app for managing and proxying to sd-server."""
|
"""sd-server wrapper — FastAPI app for proxying to an external sd-server."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -12,42 +12,39 @@ from fastapi import FastAPI
|
|||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
|
from tensors.config import get_sd_server_url
|
||||||
from tensors.server.civitai_routes import create_civitai_router
|
from tensors.server.civitai_routes import create_civitai_router
|
||||||
from tensors.server.db_routes import create_db_router
|
from tensors.server.db_routes import create_db_router
|
||||||
from tensors.server.download_routes import create_download_router
|
from tensors.server.download_routes import create_download_router
|
||||||
from tensors.server.gallery_routes import create_gallery_router
|
from tensors.server.gallery_routes import create_gallery_router
|
||||||
from tensors.server.generate_routes import create_generate_router
|
from tensors.server.generate_routes import create_generate_router
|
||||||
from tensors.server.models import ServerConfig
|
|
||||||
from tensors.server.models_routes import create_models_router
|
from tensors.server.models_routes import create_models_router
|
||||||
from tensors.server.process import ProcessManager
|
|
||||||
from tensors.server.routes import create_router
|
from tensors.server.routes import create_router
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
__all__ = ["ProcessManager", "ServerConfig", "app", "create_app"]
|
__all__ = ["app", "create_app"]
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def create_app(config: ServerConfig | None = None) -> FastAPI:
|
def create_app(sd_server_url: str | None = None) -> FastAPI:
|
||||||
"""Build the FastAPI application with process manager and proxy client."""
|
"""Build the FastAPI application that proxies to an external sd-server.
|
||||||
pm = ProcessManager()
|
|
||||||
|
Args:
|
||||||
|
sd_server_url: URL of the sd-server to proxy to. If None, uses
|
||||||
|
get_sd_server_url() to resolve from env/config.
|
||||||
|
"""
|
||||||
|
backend_url = sd_server_url or get_sd_server_url()
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
||||||
|
_app.state.sd_server_url = backend_url
|
||||||
|
logger.info(f"Proxying to sd-server at: {backend_url}")
|
||||||
async with httpx.AsyncClient(timeout=300) as client:
|
async with httpx.AsyncClient(timeout=300) as client:
|
||||||
_app.state.client = client
|
_app.state.client = client
|
||||||
if config is not None:
|
|
||||||
pm.start(config)
|
|
||||||
logger.info("waiting for sd-server to become ready...")
|
|
||||||
ready = await pm.wait_ready()
|
|
||||||
if ready:
|
|
||||||
logger.info("sd-server is ready")
|
|
||||||
else:
|
|
||||||
logger.warning("sd-server did not become ready in time")
|
|
||||||
yield
|
yield
|
||||||
pm.stop()
|
|
||||||
|
|
||||||
app = FastAPI(title="sd-server wrapper", lifespan=lifespan)
|
app = FastAPI(title="sd-server wrapper", lifespan=lifespan)
|
||||||
|
|
||||||
@@ -68,11 +65,10 @@ def create_app(config: ServerConfig | None = None) -> FastAPI:
|
|||||||
app.include_router(create_civitai_router()) # Must be before catch-all proxy
|
app.include_router(create_civitai_router()) # Must be before catch-all proxy
|
||||||
app.include_router(create_db_router()) # Must be before catch-all proxy
|
app.include_router(create_db_router()) # Must be before catch-all proxy
|
||||||
app.include_router(create_gallery_router()) # Must be before catch-all proxy
|
app.include_router(create_gallery_router()) # Must be before catch-all proxy
|
||||||
app.include_router(create_models_router(pm)) # Must be before catch-all proxy
|
app.include_router(create_models_router()) # Must be before catch-all proxy
|
||||||
app.include_router(create_download_router()) # Must be before catch-all proxy
|
app.include_router(create_download_router()) # Must be before catch-all proxy
|
||||||
app.include_router(create_generate_router(pm)) # Must be before catch-all proxy
|
app.include_router(create_generate_router()) # Must be before catch-all proxy
|
||||||
app.include_router(create_router(pm))
|
app.include_router(create_router())
|
||||||
app.state.pm = pm
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,18 +5,15 @@ from __future__ import annotations
|
|||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
from pydantic import BaseModel as PydanticBaseModel
|
from pydantic import BaseModel as PydanticBaseModel
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from tensors.server.gallery import Gallery
|
from tensors.server.gallery import Gallery
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from tensors.server.process import ProcessManager
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -113,38 +110,30 @@ def _process_image(
|
|||||||
return image_info
|
return image_info
|
||||||
|
|
||||||
|
|
||||||
def _check_server_running(pm: ProcessManager) -> None:
|
|
||||||
"""Check if sd-server is running, raise HTTPException if not."""
|
|
||||||
if pm.proc is None or pm.proc.poll() is not None:
|
|
||||||
raise HTTPException(status_code=503, detail="sd-server is not running")
|
|
||||||
if pm.config is None:
|
|
||||||
raise HTTPException(status_code=503, detail="sd-server not configured")
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Router Factory
|
# Router Factory
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def create_generate_router(pm: ProcessManager) -> APIRouter:
|
def create_generate_router() -> APIRouter:
|
||||||
"""Build a router with /api/generate endpoint."""
|
"""Build a router with /api/generate endpoint."""
|
||||||
router = APIRouter(prefix="/api", tags=["generate"])
|
router = APIRouter(prefix="/api", tags=["generate"])
|
||||||
gallery = Gallery()
|
gallery = Gallery()
|
||||||
|
|
||||||
@router.post("/generate")
|
@router.post("/generate")
|
||||||
async def generate(req: GenerateRequest) -> dict[str, Any]:
|
async def generate(request: Request, req: GenerateRequest) -> dict[str, Any]:
|
||||||
"""Generate images with gallery integration."""
|
"""Generate images with gallery integration."""
|
||||||
_check_server_running(pm)
|
sd_server_url = request.app.state.sd_server_url
|
||||||
assert pm.config is not None # Verified by _check_server_running
|
|
||||||
|
|
||||||
body = _build_sd_request(req)
|
body = _build_sd_request(req)
|
||||||
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/txt2img"
|
url = f"{sd_server_url}/sdapi/v1/txt2img"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=300) as client:
|
async with httpx.AsyncClient(timeout=300) as client:
|
||||||
response = await client.post(url, json=body)
|
response = await client.post(url, json=body)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
except httpx.ConnectError as e:
|
||||||
|
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
|
||||||
except httpx.HTTPError as e:
|
except httpx.HTTPError as e:
|
||||||
logger.exception("Generation failed")
|
logger.exception("Generation failed")
|
||||||
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
||||||
@@ -153,8 +142,11 @@ def create_generate_router(pm: ProcessManager) -> APIRouter:
|
|||||||
info = _parse_info(result.get("info", {}))
|
info = _parse_info(result.get("info", {}))
|
||||||
all_seeds = info.get("all_seeds", [req.seed] * len(images_data))
|
all_seeds = info.get("all_seeds", [req.seed] * len(images_data))
|
||||||
|
|
||||||
|
# Get model info from sd-server response if available
|
||||||
|
model_name = info.get("sd_model_name") or info.get("model")
|
||||||
|
|
||||||
output_images = [
|
output_images = [
|
||||||
_process_image(img_b64, i, all_seeds[i] if i < len(all_seeds) else req.seed + i, req, gallery, pm.config.model)
|
_process_image(img_b64, i, all_seeds[i] if i < len(all_seeds) else req.seed + i, req, gallery, model_name)
|
||||||
for i, img_b64 in enumerate(images_data)
|
for i, img_b64 in enumerate(images_data)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -167,32 +159,34 @@ def create_generate_router(pm: ProcessManager) -> APIRouter:
|
|||||||
}
|
}
|
||||||
|
|
||||||
@router.get("/samplers")
|
@router.get("/samplers")
|
||||||
async def list_samplers() -> dict[str, Any]:
|
async def list_samplers(request: Request) -> dict[str, Any]:
|
||||||
"""List available samplers from sd-server."""
|
"""List available samplers from sd-server."""
|
||||||
_check_server_running(pm)
|
sd_server_url = request.app.state.sd_server_url
|
||||||
assert pm.config is not None
|
url = f"{sd_server_url}/sdapi/v1/samplers"
|
||||||
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/samplers"
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=30) as client:
|
async with httpx.AsyncClient(timeout=30) as client:
|
||||||
response = await client.get(url)
|
response = await client.get(url)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return {"samplers": response.json()}
|
return {"samplers": response.json()}
|
||||||
|
except httpx.ConnectError as e:
|
||||||
|
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
|
||||||
except httpx.HTTPError as e:
|
except httpx.HTTPError as e:
|
||||||
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
||||||
|
|
||||||
@router.get("/schedulers")
|
@router.get("/schedulers")
|
||||||
async def list_schedulers() -> dict[str, Any]:
|
async def list_schedulers(request: Request) -> dict[str, Any]:
|
||||||
"""List available schedulers from sd-server."""
|
"""List available schedulers from sd-server."""
|
||||||
_check_server_running(pm)
|
sd_server_url = request.app.state.sd_server_url
|
||||||
assert pm.config is not None
|
url = f"{sd_server_url}/sdapi/v1/schedulers"
|
||||||
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/schedulers"
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=30) as client:
|
async with httpx.AsyncClient(timeout=30) as client:
|
||||||
response = await client.get(url)
|
response = await client.get(url)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return {"schedulers": response.json()}
|
return {"schedulers": response.json()}
|
||||||
|
except httpx.ConnectError as e:
|
||||||
|
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
|
||||||
except httpx.HTTPError as e:
|
except httpx.HTTPError as e:
|
||||||
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
||||||
|
|
||||||
|
|||||||
@@ -2,16 +2,5 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from pydantic import BaseModel
|
# Note: ServerConfig and ReloadRequest were removed since we no longer manage
|
||||||
|
# sd-server processes internally. The wrapper now proxies to an external sd-server.
|
||||||
DEFAULT_PORT = 1234
|
|
||||||
|
|
||||||
|
|
||||||
class ReloadRequest(BaseModel):
|
|
||||||
model: str
|
|
||||||
|
|
||||||
|
|
||||||
class ServerConfig(BaseModel):
|
|
||||||
model: str
|
|
||||||
port: int = DEFAULT_PORT
|
|
||||||
args: list[str] = []
|
|
||||||
|
|||||||
@@ -3,30 +3,18 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, Request
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
from pydantic import BaseModel as PydanticBaseModel
|
|
||||||
|
|
||||||
from tensors.config import MODELS_DIR
|
from tensors.config import MODELS_DIR
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tensors.server.process import ProcessManager
|
from pathlib import Path
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_HTTP_OK = 200
|
||||||
# =============================================================================
|
|
||||||
# Request/Response Models
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class SwitchModelRequest(PydanticBaseModel):
|
|
||||||
"""Request body for switching models."""
|
|
||||||
|
|
||||||
model: str # Path to model file
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -76,7 +64,7 @@ def scan_checkpoints(directory: Path | None = None) -> list[dict[str, Any]]:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def create_models_router(pm: ProcessManager) -> APIRouter:
|
def create_models_router() -> APIRouter:
|
||||||
"""Build a router with /api/models/* endpoints."""
|
"""Build a router with /api/models/* endpoints."""
|
||||||
router = APIRouter(prefix="/api/models", tags=["models"])
|
router = APIRouter(prefix="/api/models", tags=["models"])
|
||||||
|
|
||||||
@@ -90,62 +78,34 @@ def create_models_router(pm: ProcessManager) -> APIRouter:
|
|||||||
}
|
}
|
||||||
|
|
||||||
@router.get("/active")
|
@router.get("/active")
|
||||||
def get_active_model() -> dict[str, Any]:
|
async def get_active_model(request: Request) -> dict[str, Any]:
|
||||||
"""Get information about the currently loaded model."""
|
"""Get information about the currently loaded model from sd-server."""
|
||||||
status = pm.status()
|
import httpx # noqa: PLC0415
|
||||||
config = pm.config
|
|
||||||
|
sd_server_url = request.app.state.sd_server_url
|
||||||
|
|
||||||
|
# Try to get current model from sd-server's options endpoint
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=10) as client:
|
||||||
|
response = await client.get(f"{sd_server_url}/sdapi/v1/options")
|
||||||
|
if response.status_code == _HTTP_OK:
|
||||||
|
options = response.json()
|
||||||
|
model_name = options.get("sd_model_checkpoint")
|
||||||
|
return {
|
||||||
|
"loaded": True,
|
||||||
|
"model": model_name,
|
||||||
|
"sd_server_url": sd_server_url,
|
||||||
|
}
|
||||||
|
except httpx.HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
if config is None:
|
|
||||||
return {
|
return {
|
||||||
"loaded": False,
|
"loaded": False,
|
||||||
"model": None,
|
"model": None,
|
||||||
"status": status.get("status"),
|
"sd_server_url": sd_server_url,
|
||||||
|
"error": "Cannot connect to sd-server",
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
|
||||||
"loaded": status.get("status") == "running",
|
|
||||||
"model": config.model,
|
|
||||||
"pid": status.get("pid"),
|
|
||||||
"port": config.port,
|
|
||||||
"status": status.get("status"),
|
|
||||||
}
|
|
||||||
|
|
||||||
@router.post("/switch")
|
|
||||||
async def switch_model(req: SwitchModelRequest) -> JSONResponse:
|
|
||||||
"""Switch to a different model (hot reload)."""
|
|
||||||
model_path = Path(req.model)
|
|
||||||
|
|
||||||
# Validate model exists
|
|
||||||
if not model_path.exists():
|
|
||||||
raise HTTPException(status_code=400, detail=f"Model not found: {req.model}")
|
|
||||||
|
|
||||||
# Use existing reload logic
|
|
||||||
from tensors.server.models import ServerConfig # noqa: PLC0415
|
|
||||||
|
|
||||||
new_config = ServerConfig(
|
|
||||||
model=req.model,
|
|
||||||
port=pm.config.port if pm.config else 1234,
|
|
||||||
args=pm.config.args if pm.config else [],
|
|
||||||
)
|
|
||||||
|
|
||||||
pm.stop()
|
|
||||||
pm.start(new_config)
|
|
||||||
ready = await pm.wait_ready()
|
|
||||||
|
|
||||||
if not ready:
|
|
||||||
return JSONResponse(
|
|
||||||
{"error": "sd-server failed to become ready", "model": req.model},
|
|
||||||
status_code=503,
|
|
||||||
)
|
|
||||||
|
|
||||||
return JSONResponse(
|
|
||||||
{
|
|
||||||
"ok": True,
|
|
||||||
"model": req.model,
|
|
||||||
"pid": pm.proc.pid if pm.proc else None,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@router.get("/loras")
|
@router.get("/loras")
|
||||||
def list_loras() -> dict[str, Any]:
|
def list_loras() -> dict[str, Any]:
|
||||||
"""List available LoRA files."""
|
"""List available LoRA files."""
|
||||||
|
|||||||
@@ -1,92 +0,0 @@
|
|||||||
"""sd-server process lifecycle management."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import signal
|
|
||||||
import subprocess
|
|
||||||
from typing import TYPE_CHECKING, Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from tensors.server.models import ServerConfig
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_HTTP_OK = 200
|
|
||||||
|
|
||||||
SD_SERVER_BIN = shutil.which("sd-server") or "sd-server"
|
|
||||||
|
|
||||||
|
|
||||||
class ProcessManager:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.proc: subprocess.Popen[bytes] | None = None
|
|
||||||
self.config: ServerConfig | None = None
|
|
||||||
|
|
||||||
def build_cmd(self) -> list[str]:
|
|
||||||
if self.config is None:
|
|
||||||
raise RuntimeError("No config set")
|
|
||||||
cmd = [SD_SERVER_BIN, "-m", self.config.model, "--listen-port", str(self.config.port)]
|
|
||||||
cmd.extend(self.config.args)
|
|
||||||
return cmd
|
|
||||||
|
|
||||||
def start(self, config: ServerConfig) -> None:
|
|
||||||
if self.proc is not None and self.proc.poll() is None:
|
|
||||||
raise RuntimeError("Server already running — stop it first")
|
|
||||||
self.config = config
|
|
||||||
cmd = self.build_cmd()
|
|
||||||
# Inherit environment (important for HSA_OVERRIDE_GFX_VERSION on ROCm)
|
|
||||||
self.proc = subprocess.Popen(cmd, env=os.environ.copy())
|
|
||||||
logger.info("started sd-server pid=%d cmd=%s", self.proc.pid, cmd)
|
|
||||||
|
|
||||||
def stop(self) -> bool:
|
|
||||||
if self.proc is None or self.proc.poll() is not None:
|
|
||||||
self.proc = None
|
|
||||||
return False
|
|
||||||
self.proc.send_signal(signal.SIGTERM)
|
|
||||||
try:
|
|
||||||
self.proc.wait(timeout=10)
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
self.proc.kill()
|
|
||||||
self.proc.wait(timeout=5)
|
|
||||||
logger.info("stopped sd-server")
|
|
||||||
self.proc = None
|
|
||||||
return True
|
|
||||||
|
|
||||||
def status(self) -> dict[str, Any]:
|
|
||||||
if self.proc is None:
|
|
||||||
return {"running": False}
|
|
||||||
rc = self.proc.poll()
|
|
||||||
if rc is not None:
|
|
||||||
return {"running": False, "exit_code": rc}
|
|
||||||
return {
|
|
||||||
"running": True,
|
|
||||||
"pid": self.proc.pid,
|
|
||||||
"model": self.config.model if self.config else None,
|
|
||||||
"port": self.config.port if self.config else None,
|
|
||||||
"bind": f"http://127.0.0.1:{self.config.port}" if self.config else None,
|
|
||||||
"cmd": self.build_cmd(),
|
|
||||||
}
|
|
||||||
|
|
||||||
async def wait_ready(self, timeout: float = 120) -> bool:
|
|
||||||
"""Poll sd-server root endpoint until it responds or timeout."""
|
|
||||||
if self.config is None:
|
|
||||||
return False
|
|
||||||
url = f"http://127.0.0.1:{self.config.port}/"
|
|
||||||
deadline = asyncio.get_event_loop().time() + timeout
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
while asyncio.get_event_loop().time() < deadline:
|
|
||||||
if self.proc is not None and self.proc.poll() is not None:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
r = await client.get(url, timeout=2)
|
|
||||||
if r.status_code == _HTTP_OK:
|
|
||||||
return True
|
|
||||||
except (httpx.ConnectError, httpx.TimeoutException, httpx.ReadError):
|
|
||||||
pass
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
return False
|
|
||||||
+37
-28
@@ -3,53 +3,52 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
from fastapi import APIRouter, Request, Response
|
from fastapi import APIRouter, Request, Response
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
|
||||||
from tensors.server.models import ReloadRequest, ServerConfig
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from tensors.server.process import ProcessManager
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def create_router(pm: ProcessManager) -> APIRouter:
|
def create_router() -> APIRouter:
|
||||||
"""Build a router with /status, /reload, and catch-all proxy."""
|
"""Build a router with /status and catch-all proxy."""
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@router.get("/status")
|
@router.get("/status")
|
||||||
def status() -> dict[str, Any]:
|
async def status(request: Request) -> dict[str, Any]:
|
||||||
return pm.status()
|
"""Check if the external sd-server is reachable."""
|
||||||
|
sd_server_url = request.app.state.sd_server_url
|
||||||
@router.post("/reload")
|
try:
|
||||||
async def reload(req: ReloadRequest) -> Response:
|
async with httpx.AsyncClient(timeout=5) as client:
|
||||||
new_config = ServerConfig(
|
r = await client.get(sd_server_url)
|
||||||
model=req.model,
|
return {
|
||||||
port=pm.config.port if pm.config else 1234,
|
"status": "ok",
|
||||||
args=pm.config.args if pm.config else [],
|
"sd_server_url": sd_server_url,
|
||||||
)
|
"sd_server_status": r.status_code,
|
||||||
pm.stop()
|
}
|
||||||
pm.start(new_config)
|
except httpx.HTTPError as e:
|
||||||
ready = await pm.wait_ready()
|
return {
|
||||||
if not ready:
|
"status": "error",
|
||||||
return JSONResponse({"error": "sd-server failed to become ready", "model": req.model}, status_code=503)
|
"sd_server_url": sd_server_url,
|
||||||
return JSONResponse({"ok": True, "model": req.model, "pid": pm.proc.pid if pm.proc else None})
|
"error": str(e),
|
||||||
|
}
|
||||||
|
|
||||||
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"])
|
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"])
|
||||||
async def proxy(request: Request, path: str) -> Response:
|
async def proxy(request: Request, path: str) -> Response:
|
||||||
if pm.proc is None or pm.proc.poll() is not None:
|
"""Proxy all requests to the external sd-server."""
|
||||||
return JSONResponse({"error": "sd-server is not running"}, status_code=503)
|
sd_server_url = request.app.state.sd_server_url
|
||||||
assert pm.config is not None
|
url = f"{sd_server_url}/{path}"
|
||||||
url = f"http://127.0.0.1:{pm.config.port}/{path}"
|
|
||||||
if request.url.query:
|
if request.url.query:
|
||||||
url = f"{url}?{request.url.query}"
|
url = f"{url}?{request.url.query}"
|
||||||
|
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
headers = dict(request.headers)
|
headers = dict(request.headers)
|
||||||
headers.pop("host", None)
|
headers.pop("host", None)
|
||||||
client = request.app.state.client
|
client = request.app.state.client
|
||||||
|
|
||||||
|
try:
|
||||||
upstream = await client.request(
|
upstream = await client.request(
|
||||||
method=request.method,
|
method=request.method,
|
||||||
url=url,
|
url=url,
|
||||||
@@ -62,5 +61,15 @@ def create_router(pm: ProcessManager) -> APIRouter:
|
|||||||
status_code=upstream.status_code,
|
status_code=upstream.status_code,
|
||||||
headers=dict(upstream.headers),
|
headers=dict(upstream.headers),
|
||||||
)
|
)
|
||||||
|
except httpx.ConnectError:
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": f"Cannot connect to sd-server at {sd_server_url}"},
|
||||||
|
status_code=503,
|
||||||
|
)
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": f"Timeout connecting to sd-server at {sd_server_url}"},
|
||||||
|
status_code=504,
|
||||||
|
)
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|||||||
+45
-117
@@ -2,110 +2,51 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
import respx
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from tensors.server import create_app
|
from tensors.server import create_app
|
||||||
from tensors.server.models import ServerConfig
|
|
||||||
from tensors.server.process import ProcessManager
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def pm() -> ProcessManager:
|
|
||||||
return ProcessManager()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def api() -> TestClient:
|
def api() -> TestClient:
|
||||||
return TestClient(create_app())
|
"""Create test client with mock sd-server URL."""
|
||||||
|
return TestClient(create_app(sd_server_url="http://mock-sd-server:1234"))
|
||||||
|
|
||||||
def _get_pm(api: TestClient) -> ProcessManager:
|
|
||||||
return api.app.state.pm # type: ignore[no-any-return, attr-defined]
|
|
||||||
|
|
||||||
|
|
||||||
class TestStatus:
|
class TestStatus:
|
||||||
def test_not_running(self, api: TestClient) -> None:
|
@respx.mock
|
||||||
r = api.get("/status")
|
def test_status_when_backend_reachable(self) -> None:
|
||||||
assert r.status_code == 200
|
"""Test status endpoint when sd-server is reachable."""
|
||||||
assert r.json()["running"] is False
|
respx.get("http://mock-sd-server:1234/").mock(return_value=httpx.Response(200))
|
||||||
|
|
||||||
def test_running(self, api: TestClient) -> None:
|
with TestClient(create_app(sd_server_url="http://mock-sd-server:1234")) as client:
|
||||||
pm = _get_pm(api)
|
r = client.get("/status")
|
||||||
mock_proc = MagicMock()
|
|
||||||
mock_proc.poll.return_value = None
|
|
||||||
mock_proc.pid = 999
|
|
||||||
pm.proc = mock_proc
|
|
||||||
pm.config = ServerConfig(model="/m.safetensors")
|
|
||||||
r = api.get("/status")
|
|
||||||
data = r.json()
|
|
||||||
assert data["running"] is True
|
|
||||||
assert data["pid"] == 999
|
|
||||||
|
|
||||||
def test_exited(self, api: TestClient) -> None:
|
|
||||||
pm = _get_pm(api)
|
|
||||||
mock_proc = MagicMock()
|
|
||||||
mock_proc.poll.return_value = 1
|
|
||||||
pm.proc = mock_proc
|
|
||||||
r = api.get("/status")
|
|
||||||
data = r.json()
|
|
||||||
assert data["running"] is False
|
|
||||||
assert data["exit_code"] == 1
|
|
||||||
|
|
||||||
|
|
||||||
class TestReload:
|
|
||||||
@patch.object(ProcessManager, "wait_ready", new_callable=AsyncMock, return_value=True)
|
|
||||||
@patch("tensors.server.process.subprocess.Popen")
|
|
||||||
def test_reload_swaps_model(self, mock_popen: MagicMock, mock_ready: AsyncMock, api: TestClient) -> None:
|
|
||||||
pm = _get_pm(api)
|
|
||||||
pm.config = ServerConfig(model="/old.gguf", port=5555, args=["--fa"])
|
|
||||||
mock_popen.return_value.pid = 42
|
|
||||||
mock_popen.return_value.poll.return_value = None
|
|
||||||
r = api.post("/reload", json={"model": "/new.gguf"})
|
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
data = r.json()
|
data = r.json()
|
||||||
assert data["ok"] is True
|
assert data["status"] == "ok"
|
||||||
assert data["model"] == "/new.gguf"
|
assert data["sd_server_url"] == "http://mock-sd-server:1234"
|
||||||
assert data["pid"] == 42
|
|
||||||
# Verify new config preserved port and args from previous config
|
|
||||||
assert pm.config is not None
|
|
||||||
assert pm.config.port == 5555
|
|
||||||
assert pm.config.args == ["--fa"]
|
|
||||||
assert pm.config.model == "/new.gguf"
|
|
||||||
|
|
||||||
@patch.object(ProcessManager, "wait_ready", new_callable=AsyncMock, return_value=False)
|
@respx.mock
|
||||||
@patch("tensors.server.process.subprocess.Popen")
|
def test_status_when_backend_unreachable(self) -> None:
|
||||||
def test_reload_fails_when_not_ready(self, mock_popen: MagicMock, mock_ready: AsyncMock, api: TestClient) -> None:
|
"""Test status endpoint when sd-server is not reachable."""
|
||||||
pm = _get_pm(api)
|
respx.get("http://mock-sd-server:1234/").mock(side_effect=httpx.ConnectError("Connection refused"))
|
||||||
pm.config = ServerConfig(model="/old.gguf")
|
|
||||||
mock_popen.return_value.pid = 43
|
|
||||||
mock_popen.return_value.poll.return_value = None
|
|
||||||
r = api.post("/reload", json={"model": "/bad.gguf"})
|
|
||||||
assert r.status_code == 503
|
|
||||||
assert "failed" in r.json()["error"]
|
|
||||||
|
|
||||||
def test_reload_requires_model(self, api: TestClient) -> None:
|
with TestClient(create_app(sd_server_url="http://mock-sd-server:1234")) as client:
|
||||||
r = api.post("/reload", json={})
|
r = client.get("/status")
|
||||||
assert r.status_code == 422
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert data["status"] == "error"
|
||||||
|
assert "Connection refused" in data["error"]
|
||||||
|
|
||||||
|
|
||||||
class TestProxy:
|
class TestProxy:
|
||||||
def test_proxy_503_when_not_running(self, api: TestClient) -> None:
|
|
||||||
r = api.get("/v1/models")
|
|
||||||
assert r.status_code == 503
|
|
||||||
assert "not running" in r.json()["error"]
|
|
||||||
|
|
||||||
def test_proxy_forwards_request(self, api: TestClient) -> None:
|
def test_proxy_forwards_request(self, api: TestClient) -> None:
|
||||||
pm = _get_pm(api)
|
"""Test proxy forwards GET requests to backend."""
|
||||||
mock_proc = MagicMock()
|
|
||||||
mock_proc.poll.return_value = None
|
|
||||||
mock_proc.pid = 100
|
|
||||||
pm.proc = mock_proc
|
|
||||||
pm.config = ServerConfig(model="/m.gguf", port=1234)
|
|
||||||
|
|
||||||
upstream_response = httpx.Response(
|
upstream_response = httpx.Response(
|
||||||
200,
|
200,
|
||||||
json={"data": [{"id": "model-1"}]},
|
json={"data": [{"id": "model-1"}]},
|
||||||
@@ -114,6 +55,7 @@ class TestProxy:
|
|||||||
mock_client = AsyncMock()
|
mock_client = AsyncMock()
|
||||||
mock_client.request.return_value = upstream_response
|
mock_client.request.return_value = upstream_response
|
||||||
api.app.state.client = mock_client # type: ignore[attr-defined]
|
api.app.state.client = mock_client # type: ignore[attr-defined]
|
||||||
|
api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined]
|
||||||
|
|
||||||
r = api.get("/v1/models")
|
r = api.get("/v1/models")
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
@@ -121,52 +63,38 @@ class TestProxy:
|
|||||||
mock_client.request.assert_called_once()
|
mock_client.request.assert_called_once()
|
||||||
|
|
||||||
def test_proxy_forwards_post_with_body(self, api: TestClient) -> None:
|
def test_proxy_forwards_post_with_body(self, api: TestClient) -> None:
|
||||||
pm = _get_pm(api)
|
"""Test proxy forwards POST requests with body."""
|
||||||
mock_proc = MagicMock()
|
|
||||||
mock_proc.poll.return_value = None
|
|
||||||
mock_proc.pid = 100
|
|
||||||
pm.proc = mock_proc
|
|
||||||
pm.config = ServerConfig(model="/m.gguf", port=1234)
|
|
||||||
|
|
||||||
upstream_response = httpx.Response(200, json={"ok": True})
|
upstream_response = httpx.Response(200, json={"ok": True})
|
||||||
mock_client = AsyncMock()
|
mock_client = AsyncMock()
|
||||||
mock_client.request.return_value = upstream_response
|
mock_client.request.return_value = upstream_response
|
||||||
api.app.state.client = mock_client # type: ignore[attr-defined]
|
api.app.state.client = mock_client # type: ignore[attr-defined]
|
||||||
|
api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined]
|
||||||
|
|
||||||
r = api.post("/v1/chat/completions", json={"prompt": "hello"})
|
r = api.post("/sdapi/v1/txt2img", json={"prompt": "hello"})
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
mock_client.request.assert_called_once()
|
mock_client.request.assert_called_once()
|
||||||
|
|
||||||
|
def test_proxy_503_on_connect_error(self, api: TestClient) -> None:
|
||||||
|
"""Test proxy returns 503 when backend is unreachable."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.request.side_effect = httpx.ConnectError("Connection refused")
|
||||||
|
api.app.state.client = mock_client # type: ignore[attr-defined]
|
||||||
|
api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined]
|
||||||
|
|
||||||
class TestProcessManager:
|
r = api.get("/v1/models")
|
||||||
def test_status_not_running(self, pm: ProcessManager) -> None:
|
assert r.status_code == 503
|
||||||
assert pm.status() == {"running": False}
|
assert "Cannot connect" in r.json()["error"]
|
||||||
|
|
||||||
def test_build_cmd(self, pm: ProcessManager) -> None:
|
def test_proxy_504_on_timeout(self, api: TestClient) -> None:
|
||||||
pm.config = ServerConfig(model="/m.gguf", port=1234, args=["--fa"])
|
"""Test proxy returns 504 on timeout."""
|
||||||
cmd = pm.build_cmd()
|
mock_client = AsyncMock()
|
||||||
assert "/m.gguf" in cmd
|
mock_client.request.side_effect = httpx.TimeoutException("Timeout")
|
||||||
assert "--fa" in cmd
|
api.app.state.client = mock_client # type: ignore[attr-defined]
|
||||||
assert "1234" in cmd
|
api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined]
|
||||||
|
|
||||||
def test_build_cmd_no_config(self, pm: ProcessManager) -> None:
|
r = api.get("/v1/models")
|
||||||
with pytest.raises(RuntimeError, match="No config"):
|
assert r.status_code == 504
|
||||||
pm.build_cmd()
|
assert "Timeout" in r.json()["error"]
|
||||||
|
|
||||||
@patch("tensors.server.process.subprocess.Popen")
|
|
||||||
def test_start_and_stop(self, mock_popen: MagicMock, pm: ProcessManager) -> None:
|
|
||||||
mock_popen.return_value.pid = 77
|
|
||||||
mock_popen.return_value.poll.return_value = None
|
|
||||||
mock_popen.return_value.wait.return_value = 0
|
|
||||||
pm.start(ServerConfig(model="/m.gguf"))
|
|
||||||
assert pm.proc is not None
|
|
||||||
assert pm.stop() is True
|
|
||||||
assert pm.proc is None
|
|
||||||
|
|
||||||
def test_server_config_defaults(self) -> None:
|
|
||||||
cfg = ServerConfig(model="/m.gguf")
|
|
||||||
assert cfg.port == 1234
|
|
||||||
assert cfg.args == []
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|||||||
Reference in New Issue
Block a user