Remove internal sd-server management, proxy to external sd-server
- Remove ProcessManager and process.py - Add get_sd_server_url() config (env/config/default) - Update routes to proxy to external sd-server URL - Remove model switching (handled by external sd-server) - Update CLI serve command - Update tests for new architecture Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
+4
-6
@@ -582,24 +582,22 @@ def reload(
|
||||
|
||||
@app.command()
|
||||
def serve(
|
||||
model: Annotated[str, typer.Option(help="Path to model file for sd-server.")],
|
||||
host: Annotated[str, typer.Option(help="Wrapper API listen address.")] = "127.0.0.1",
|
||||
port: Annotated[int, typer.Option(help="Wrapper API listen port.")] = 8080,
|
||||
sd_port: Annotated[int, typer.Option(help="sd-server listen port.")] = 1234,
|
||||
sd_server: Annotated[str | None, typer.Option(help="sd-server URL to proxy to.")] = None,
|
||||
log_level: Annotated[str, typer.Option(help="Log level.")] = "info",
|
||||
) -> None:
|
||||
"""Start the sd-server wrapper API (transparent proxy with hot reload)."""
|
||||
"""Start the sd-server wrapper API (proxies to external sd-server)."""
|
||||
try:
|
||||
import uvicorn # noqa: PLC0415
|
||||
|
||||
from tensors.server import ServerConfig, create_app # noqa: PLC0415
|
||||
from tensors.server import create_app # noqa: PLC0415
|
||||
except ImportError:
|
||||
console.print("[red]Missing server dependencies. Install with:[/red]")
|
||||
console.print(" pip install tensors[server]")
|
||||
raise typer.Exit(1) from None
|
||||
|
||||
config = ServerConfig(model=model, port=sd_port)
|
||||
uvicorn.run(create_app(config), host=host, port=port, log_level=log_level)
|
||||
uvicorn.run(create_app(sd_server_url=sd_server), host=host, port=port, log_level=log_level)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
@@ -243,3 +243,34 @@ def set_default_remote(name: str | None) -> None:
|
||||
else:
|
||||
config["default_remote"] = name
|
||||
save_config(config)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SD Server Configuration
|
||||
# ============================================================================
|
||||
|
||||
SD_SERVER_DEFAULT_URL = "http://localhost:1234"
|
||||
|
||||
|
||||
def get_sd_server_url() -> str:
|
||||
"""Get the sd-server URL.
|
||||
|
||||
Resolution order:
|
||||
1. SD_SERVER_URL environment variable
|
||||
2. config.toml [server].sd_server_url
|
||||
3. Default: http://localhost:1234
|
||||
"""
|
||||
# Check environment variable first
|
||||
env_url = os.environ.get("SD_SERVER_URL")
|
||||
if env_url:
|
||||
return env_url
|
||||
|
||||
# Check config file
|
||||
config = load_config()
|
||||
server_config = config.get("server", {})
|
||||
if isinstance(server_config, dict):
|
||||
url = server_config.get("sd_server_url")
|
||||
if url:
|
||||
return str(url)
|
||||
|
||||
return SD_SERVER_DEFAULT_URL
|
||||
|
||||
+16
-20
@@ -1,4 +1,4 @@
|
||||
"""sd-server wrapper — FastAPI app for managing and proxying to sd-server."""
|
||||
"""sd-server wrapper — FastAPI app for proxying to an external sd-server."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -12,42 +12,39 @@ from fastapi import FastAPI
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from tensors.config import get_sd_server_url
|
||||
from tensors.server.civitai_routes import create_civitai_router
|
||||
from tensors.server.db_routes import create_db_router
|
||||
from tensors.server.download_routes import create_download_router
|
||||
from tensors.server.gallery_routes import create_gallery_router
|
||||
from tensors.server.generate_routes import create_generate_router
|
||||
from tensors.server.models import ServerConfig
|
||||
from tensors.server.models_routes import create_models_router
|
||||
from tensors.server.process import ProcessManager
|
||||
from tensors.server.routes import create_router
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
__all__ = ["ProcessManager", "ServerConfig", "app", "create_app"]
|
||||
__all__ = ["app", "create_app"]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_app(config: ServerConfig | None = None) -> FastAPI:
|
||||
"""Build the FastAPI application with process manager and proxy client."""
|
||||
pm = ProcessManager()
|
||||
def create_app(sd_server_url: str | None = None) -> FastAPI:
|
||||
"""Build the FastAPI application that proxies to an external sd-server.
|
||||
|
||||
Args:
|
||||
sd_server_url: URL of the sd-server to proxy to. If None, uses
|
||||
get_sd_server_url() to resolve from env/config.
|
||||
"""
|
||||
backend_url = sd_server_url or get_sd_server_url()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
||||
_app.state.sd_server_url = backend_url
|
||||
logger.info(f"Proxying to sd-server at: {backend_url}")
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
_app.state.client = client
|
||||
if config is not None:
|
||||
pm.start(config)
|
||||
logger.info("waiting for sd-server to become ready...")
|
||||
ready = await pm.wait_ready()
|
||||
if ready:
|
||||
logger.info("sd-server is ready")
|
||||
else:
|
||||
logger.warning("sd-server did not become ready in time")
|
||||
yield
|
||||
pm.stop()
|
||||
|
||||
app = FastAPI(title="sd-server wrapper", lifespan=lifespan)
|
||||
|
||||
@@ -68,11 +65,10 @@ def create_app(config: ServerConfig | None = None) -> FastAPI:
|
||||
app.include_router(create_civitai_router()) # Must be before catch-all proxy
|
||||
app.include_router(create_db_router()) # Must be before catch-all proxy
|
||||
app.include_router(create_gallery_router()) # Must be before catch-all proxy
|
||||
app.include_router(create_models_router(pm)) # Must be before catch-all proxy
|
||||
app.include_router(create_models_router()) # Must be before catch-all proxy
|
||||
app.include_router(create_download_router()) # Must be before catch-all proxy
|
||||
app.include_router(create_generate_router(pm)) # Must be before catch-all proxy
|
||||
app.include_router(create_router(pm))
|
||||
app.state.pm = pm
|
||||
app.include_router(create_generate_router()) # Must be before catch-all proxy
|
||||
app.include_router(create_router())
|
||||
return app
|
||||
|
||||
|
||||
|
||||
@@ -5,18 +5,15 @@ from __future__ import annotations
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from tensors.server.gallery import Gallery
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensors.server.process import ProcessManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -113,38 +110,30 @@ def _process_image(
|
||||
return image_info
|
||||
|
||||
|
||||
def _check_server_running(pm: ProcessManager) -> None:
|
||||
"""Check if sd-server is running, raise HTTPException if not."""
|
||||
if pm.proc is None or pm.proc.poll() is not None:
|
||||
raise HTTPException(status_code=503, detail="sd-server is not running")
|
||||
if pm.config is None:
|
||||
raise HTTPException(status_code=503, detail="sd-server not configured")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Router Factory
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_generate_router(pm: ProcessManager) -> APIRouter:
|
||||
def create_generate_router() -> APIRouter:
|
||||
"""Build a router with /api/generate endpoint."""
|
||||
router = APIRouter(prefix="/api", tags=["generate"])
|
||||
gallery = Gallery()
|
||||
|
||||
@router.post("/generate")
|
||||
async def generate(req: GenerateRequest) -> dict[str, Any]:
|
||||
async def generate(request: Request, req: GenerateRequest) -> dict[str, Any]:
|
||||
"""Generate images with gallery integration."""
|
||||
_check_server_running(pm)
|
||||
assert pm.config is not None # Verified by _check_server_running
|
||||
|
||||
sd_server_url = request.app.state.sd_server_url
|
||||
body = _build_sd_request(req)
|
||||
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/txt2img"
|
||||
url = f"{sd_server_url}/sdapi/v1/txt2img"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
response = await client.post(url, json=body)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
except httpx.ConnectError as e:
|
||||
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
|
||||
except httpx.HTTPError as e:
|
||||
logger.exception("Generation failed")
|
||||
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
||||
@@ -153,8 +142,11 @@ def create_generate_router(pm: ProcessManager) -> APIRouter:
|
||||
info = _parse_info(result.get("info", {}))
|
||||
all_seeds = info.get("all_seeds", [req.seed] * len(images_data))
|
||||
|
||||
# Get model info from sd-server response if available
|
||||
model_name = info.get("sd_model_name") or info.get("model")
|
||||
|
||||
output_images = [
|
||||
_process_image(img_b64, i, all_seeds[i] if i < len(all_seeds) else req.seed + i, req, gallery, pm.config.model)
|
||||
_process_image(img_b64, i, all_seeds[i] if i < len(all_seeds) else req.seed + i, req, gallery, model_name)
|
||||
for i, img_b64 in enumerate(images_data)
|
||||
]
|
||||
|
||||
@@ -167,32 +159,34 @@ def create_generate_router(pm: ProcessManager) -> APIRouter:
|
||||
}
|
||||
|
||||
@router.get("/samplers")
|
||||
async def list_samplers() -> dict[str, Any]:
|
||||
async def list_samplers(request: Request) -> dict[str, Any]:
|
||||
"""List available samplers from sd-server."""
|
||||
_check_server_running(pm)
|
||||
assert pm.config is not None
|
||||
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/samplers"
|
||||
sd_server_url = request.app.state.sd_server_url
|
||||
url = f"{sd_server_url}/sdapi/v1/samplers"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
return {"samplers": response.json()}
|
||||
except httpx.ConnectError as e:
|
||||
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
||||
|
||||
@router.get("/schedulers")
|
||||
async def list_schedulers() -> dict[str, Any]:
|
||||
async def list_schedulers(request: Request) -> dict[str, Any]:
|
||||
"""List available schedulers from sd-server."""
|
||||
_check_server_running(pm)
|
||||
assert pm.config is not None
|
||||
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/schedulers"
|
||||
sd_server_url = request.app.state.sd_server_url
|
||||
url = f"{sd_server_url}/sdapi/v1/schedulers"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
return {"schedulers": response.json()}
|
||||
except httpx.ConnectError as e:
|
||||
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
||||
|
||||
|
||||
@@ -2,16 +2,5 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
DEFAULT_PORT = 1234
|
||||
|
||||
|
||||
class ReloadRequest(BaseModel):
|
||||
model: str
|
||||
|
||||
|
||||
class ServerConfig(BaseModel):
|
||||
model: str
|
||||
port: int = DEFAULT_PORT
|
||||
args: list[str] = []
|
||||
# Note: ServerConfig and ReloadRequest were removed since we no longer manage
|
||||
# sd-server processes internally. The wrapper now proxies to an external sd-server.
|
||||
|
||||
@@ -3,30 +3,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
from fastapi import APIRouter, Request
|
||||
|
||||
from tensors.config import MODELS_DIR
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensors.server.process import ProcessManager
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Request/Response Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class SwitchModelRequest(PydanticBaseModel):
|
||||
"""Request body for switching models."""
|
||||
|
||||
model: str # Path to model file
|
||||
_HTTP_OK = 200
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -76,7 +64,7 @@ def scan_checkpoints(directory: Path | None = None) -> list[dict[str, Any]]:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_models_router(pm: ProcessManager) -> APIRouter:
|
||||
def create_models_router() -> APIRouter:
|
||||
"""Build a router with /api/models/* endpoints."""
|
||||
router = APIRouter(prefix="/api/models", tags=["models"])
|
||||
|
||||
@@ -90,62 +78,34 @@ def create_models_router(pm: ProcessManager) -> APIRouter:
|
||||
}
|
||||
|
||||
@router.get("/active")
|
||||
def get_active_model() -> dict[str, Any]:
|
||||
"""Get information about the currently loaded model."""
|
||||
status = pm.status()
|
||||
config = pm.config
|
||||
async def get_active_model(request: Request) -> dict[str, Any]:
|
||||
"""Get information about the currently loaded model from sd-server."""
|
||||
import httpx # noqa: PLC0415
|
||||
|
||||
if config is None:
|
||||
return {
|
||||
"loaded": False,
|
||||
"model": None,
|
||||
"status": status.get("status"),
|
||||
}
|
||||
sd_server_url = request.app.state.sd_server_url
|
||||
|
||||
# Try to get current model from sd-server's options endpoint
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.get(f"{sd_server_url}/sdapi/v1/options")
|
||||
if response.status_code == _HTTP_OK:
|
||||
options = response.json()
|
||||
model_name = options.get("sd_model_checkpoint")
|
||||
return {
|
||||
"loaded": True,
|
||||
"model": model_name,
|
||||
"sd_server_url": sd_server_url,
|
||||
}
|
||||
except httpx.HTTPError:
|
||||
pass
|
||||
|
||||
return {
|
||||
"loaded": status.get("status") == "running",
|
||||
"model": config.model,
|
||||
"pid": status.get("pid"),
|
||||
"port": config.port,
|
||||
"status": status.get("status"),
|
||||
"loaded": False,
|
||||
"model": None,
|
||||
"sd_server_url": sd_server_url,
|
||||
"error": "Cannot connect to sd-server",
|
||||
}
|
||||
|
||||
@router.post("/switch")
|
||||
async def switch_model(req: SwitchModelRequest) -> JSONResponse:
|
||||
"""Switch to a different model (hot reload)."""
|
||||
model_path = Path(req.model)
|
||||
|
||||
# Validate model exists
|
||||
if not model_path.exists():
|
||||
raise HTTPException(status_code=400, detail=f"Model not found: {req.model}")
|
||||
|
||||
# Use existing reload logic
|
||||
from tensors.server.models import ServerConfig # noqa: PLC0415
|
||||
|
||||
new_config = ServerConfig(
|
||||
model=req.model,
|
||||
port=pm.config.port if pm.config else 1234,
|
||||
args=pm.config.args if pm.config else [],
|
||||
)
|
||||
|
||||
pm.stop()
|
||||
pm.start(new_config)
|
||||
ready = await pm.wait_ready()
|
||||
|
||||
if not ready:
|
||||
return JSONResponse(
|
||||
{"error": "sd-server failed to become ready", "model": req.model},
|
||||
status_code=503,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"ok": True,
|
||||
"model": req.model,
|
||||
"pid": pm.proc.pid if pm.proc else None,
|
||||
}
|
||||
)
|
||||
|
||||
@router.get("/loras")
|
||||
def list_loras() -> dict[str, Any]:
|
||||
"""List available LoRA files."""
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
"""sd-server process lifecycle management."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensors.server.models import ServerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_HTTP_OK = 200
|
||||
|
||||
SD_SERVER_BIN = shutil.which("sd-server") or "sd-server"
|
||||
|
||||
|
||||
class ProcessManager:
|
||||
def __init__(self) -> None:
|
||||
self.proc: subprocess.Popen[bytes] | None = None
|
||||
self.config: ServerConfig | None = None
|
||||
|
||||
def build_cmd(self) -> list[str]:
|
||||
if self.config is None:
|
||||
raise RuntimeError("No config set")
|
||||
cmd = [SD_SERVER_BIN, "-m", self.config.model, "--listen-port", str(self.config.port)]
|
||||
cmd.extend(self.config.args)
|
||||
return cmd
|
||||
|
||||
def start(self, config: ServerConfig) -> None:
|
||||
if self.proc is not None and self.proc.poll() is None:
|
||||
raise RuntimeError("Server already running — stop it first")
|
||||
self.config = config
|
||||
cmd = self.build_cmd()
|
||||
# Inherit environment (important for HSA_OVERRIDE_GFX_VERSION on ROCm)
|
||||
self.proc = subprocess.Popen(cmd, env=os.environ.copy())
|
||||
logger.info("started sd-server pid=%d cmd=%s", self.proc.pid, cmd)
|
||||
|
||||
def stop(self) -> bool:
|
||||
if self.proc is None or self.proc.poll() is not None:
|
||||
self.proc = None
|
||||
return False
|
||||
self.proc.send_signal(signal.SIGTERM)
|
||||
try:
|
||||
self.proc.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.proc.kill()
|
||||
self.proc.wait(timeout=5)
|
||||
logger.info("stopped sd-server")
|
||||
self.proc = None
|
||||
return True
|
||||
|
||||
def status(self) -> dict[str, Any]:
|
||||
if self.proc is None:
|
||||
return {"running": False}
|
||||
rc = self.proc.poll()
|
||||
if rc is not None:
|
||||
return {"running": False, "exit_code": rc}
|
||||
return {
|
||||
"running": True,
|
||||
"pid": self.proc.pid,
|
||||
"model": self.config.model if self.config else None,
|
||||
"port": self.config.port if self.config else None,
|
||||
"bind": f"http://127.0.0.1:{self.config.port}" if self.config else None,
|
||||
"cmd": self.build_cmd(),
|
||||
}
|
||||
|
||||
async def wait_ready(self, timeout: float = 120) -> bool:
|
||||
"""Poll sd-server root endpoint until it responds or timeout."""
|
||||
if self.config is None:
|
||||
return False
|
||||
url = f"http://127.0.0.1:{self.config.port}/"
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
async with httpx.AsyncClient() as client:
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
if self.proc is not None and self.proc.poll() is not None:
|
||||
return False
|
||||
try:
|
||||
r = await client.get(url, timeout=2)
|
||||
if r.status_code == _HTTP_OK:
|
||||
return True
|
||||
except (httpx.ConnectError, httpx.TimeoutException, httpx.ReadError):
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
return False
|
||||
+49
-40
@@ -3,64 +3,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Request, Response
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from tensors.server.models import ReloadRequest, ServerConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensors.server.process import ProcessManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_router(pm: ProcessManager) -> APIRouter:
|
||||
"""Build a router with /status, /reload, and catch-all proxy."""
|
||||
def create_router() -> APIRouter:
|
||||
"""Build a router with /status and catch-all proxy."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/status")
|
||||
def status() -> dict[str, Any]:
|
||||
return pm.status()
|
||||
|
||||
@router.post("/reload")
|
||||
async def reload(req: ReloadRequest) -> Response:
|
||||
new_config = ServerConfig(
|
||||
model=req.model,
|
||||
port=pm.config.port if pm.config else 1234,
|
||||
args=pm.config.args if pm.config else [],
|
||||
)
|
||||
pm.stop()
|
||||
pm.start(new_config)
|
||||
ready = await pm.wait_ready()
|
||||
if not ready:
|
||||
return JSONResponse({"error": "sd-server failed to become ready", "model": req.model}, status_code=503)
|
||||
return JSONResponse({"ok": True, "model": req.model, "pid": pm.proc.pid if pm.proc else None})
|
||||
async def status(request: Request) -> dict[str, Any]:
|
||||
"""Check if the external sd-server is reachable."""
|
||||
sd_server_url = request.app.state.sd_server_url
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
r = await client.get(sd_server_url)
|
||||
return {
|
||||
"status": "ok",
|
||||
"sd_server_url": sd_server_url,
|
||||
"sd_server_status": r.status_code,
|
||||
}
|
||||
except httpx.HTTPError as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"sd_server_url": sd_server_url,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"])
|
||||
async def proxy(request: Request, path: str) -> Response:
|
||||
if pm.proc is None or pm.proc.poll() is not None:
|
||||
return JSONResponse({"error": "sd-server is not running"}, status_code=503)
|
||||
assert pm.config is not None
|
||||
url = f"http://127.0.0.1:{pm.config.port}/{path}"
|
||||
"""Proxy all requests to the external sd-server."""
|
||||
sd_server_url = request.app.state.sd_server_url
|
||||
url = f"{sd_server_url}/{path}"
|
||||
if request.url.query:
|
||||
url = f"{url}?{request.url.query}"
|
||||
|
||||
body = await request.body()
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
client = request.app.state.client
|
||||
upstream = await client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=body,
|
||||
timeout=300,
|
||||
)
|
||||
return StreamingResponse(
|
||||
content=upstream.iter_bytes(),
|
||||
status_code=upstream.status_code,
|
||||
headers=dict(upstream.headers),
|
||||
)
|
||||
|
||||
try:
|
||||
upstream = await client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=body,
|
||||
timeout=300,
|
||||
)
|
||||
return StreamingResponse(
|
||||
content=upstream.iter_bytes(),
|
||||
status_code=upstream.status_code,
|
||||
headers=dict(upstream.headers),
|
||||
)
|
||||
except httpx.ConnectError:
|
||||
return JSONResponse(
|
||||
{"error": f"Cannot connect to sd-server at {sd_server_url}"},
|
||||
status_code=503,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
return JSONResponse(
|
||||
{"error": f"Timeout connecting to sd-server at {sd_server_url}"},
|
||||
status_code=504,
|
||||
)
|
||||
|
||||
return router
|
||||
|
||||
Reference in New Issue
Block a user