diff --git a/pyproject.toml b/pyproject.toml
index 31c1ffe..6fed577 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -15,7 +15,7 @@ dependencies = [
]
[project.optional-dependencies]
-server = ["fastapi>=0.115", "uvicorn>=0.30", "scalar-fastapi>=1.6"]
+server = ["fastapi>=0.115", "uvicorn>=0.30", "scalar-fastapi>=1.6", "websockets>=12.0", "python-multipart>=0.0.9"]
[project.scripts]
tsr = "tensors:main"
@@ -39,6 +39,8 @@ dev = [
"fastapi>=0.115",
"uvicorn>=0.30",
"scalar-fastapi>=1.6",
+ "websockets>=12.0",
+ "python-multipart>=0.0.9",
]
[tool.ruff]
@@ -104,6 +106,10 @@ ignore_missing_imports = true
module = ["huggingface_hub.*"]
ignore_missing_imports = true
+[[tool.mypy.overrides]]
+module = ["websockets.*"]
+ignore_missing_imports = true
+
[tool.pytest.ini_options]
testpaths = ["tests"]
addopts = "-v --cov=tensors --cov-report=term-missing"
diff --git a/tensors/server/__init__.py b/tensors/server/__init__.py
index e5d3537..ccf56c0 100644
--- a/tensors/server/__init__.py
+++ b/tensors/server/__init__.py
@@ -12,6 +12,7 @@ from scalar_fastapi import get_scalar_api_reference
from tensors.config import get_server_api_key
from tensors.server.civitai_routes import create_civitai_router
+from tensors.server.comfyui_routes import create_comfyui_router
from tensors.server.db_routes import create_db_router
from tensors.server.download_routes import create_download_router
from tensors.server.gallery_routes import create_gallery_router
@@ -69,7 +70,10 @@ def create_app() -> FastAPI:
title="tensors API",
)
- # Protected routers (auth required if configured)
+ # ComfyUI proxy (handles its own session auth)
+ app.include_router(create_comfyui_router())
+
+ # Protected routers (API key auth)
from tensors.server.auth import verify_api_key # noqa: PLC0415
app.include_router(create_search_router(), dependencies=[Depends(verify_api_key)])
diff --git a/tensors/server/comfyui_routes.py b/tensors/server/comfyui_routes.py
new file mode 100644
index 0000000..bf4f121
--- /dev/null
+++ b/tensors/server/comfyui_routes.py
@@ -0,0 +1,340 @@
+"""ComfyUI reverse proxy with session authentication."""
+
+from __future__ import annotations
+
+import asyncio
+import hashlib
+import hmac
+import os
+import time
+
+import httpx
+import websockets
+from fastapi import APIRouter, Cookie, Form, HTTPException, Request, Response, WebSocket, status
+from fastapi.responses import HTMLResponse, RedirectResponse
+
+router = APIRouter(tags=["ComfyUI"])
+
+# Number of parts in session token
+_SESSION_TOKEN_PARTS = 3
+
+# Config from environment
+COMFYUI_URL = os.environ.get("COMFYUI_URL", "http://127.0.0.1:8188")
+COMFYUI_USER = os.environ.get("COMFYUI_USER", "")
+COMFYUI_PASS = os.environ.get("COMFYUI_PASS", "")
+SESSION_SECRET = os.environ.get("SESSION_SECRET", "tensors-comfyui-secret-change-me")
+SESSION_MAX_AGE = 86400 * 7 # 7 days
+
+
+def _create_session_token(username: str) -> str:
+ """Create a signed session token."""
+ expires = int(time.time()) + SESSION_MAX_AGE
+ data = f"{username}:{expires}"
+ signature = hmac.new(SESSION_SECRET.encode(), data.encode(), hashlib.sha256).hexdigest()[:32]
+ return f"{data}:{signature}"
+
+
+def _verify_session_token(token: str | None) -> bool:
+ """Verify a session token."""
+ if not token:
+ return False
+ try:
+ parts = token.split(":")
+ if len(parts) != _SESSION_TOKEN_PARTS:
+ return False
+ username, expires_str, signature = parts
+ expires = int(expires_str)
+ if time.time() > expires:
+ return False
+ data = f"{username}:{expires_str}"
+ expected = hmac.new(SESSION_SECRET.encode(), data.encode(), hashlib.sha256).hexdigest()[:32]
+ return hmac.compare_digest(signature, expected)
+ except (ValueError, TypeError):
+ return False
+
+
+LOGIN_PAGE_HTML = """
+
+
+
+
+
+ ComfyUI - Login
+
+
+
+
+
+
ComfyUI
+
Stable Diffusion GUI
+
+ {{ERROR}}
+
+
+
+
+
+"""
+
+
+@router.get("/comfy/login")
+async def login_page(error: str | None = None) -> HTMLResponse:
+ """Show login page."""
+ error_html = ""
+ if error:
+ error_html = f'{error}
'
+ html = LOGIN_PAGE_HTML.replace("{{ERROR}}", error_html)
+ return HTMLResponse(content=html)
+
+
+@router.post("/comfy/login")
+async def login_submit(username: str = Form(...), password: str = Form(...)) -> Response:
+ """Handle login form submission."""
+ if not COMFYUI_USER or not COMFYUI_PASS:
+ return RedirectResponse(
+ url="/comfy/login?error=Authentication+not+configured",
+ status_code=status.HTTP_303_SEE_OTHER,
+ )
+
+ if username == COMFYUI_USER and password == COMFYUI_PASS:
+ token = _create_session_token(username)
+ response = RedirectResponse(url="/comfy/", status_code=status.HTTP_303_SEE_OTHER)
+ response.set_cookie(
+ key="comfy_session",
+ value=token,
+ max_age=SESSION_MAX_AGE,
+ httponly=True,
+ samesite="lax",
+ )
+ return response
+
+ return RedirectResponse(
+ url="/comfy/login?error=Invalid+username+or+password",
+ status_code=status.HTTP_303_SEE_OTHER,
+ )
+
+
+@router.get("/comfy/logout")
+async def logout() -> Response:
+ """Clear session and redirect to login."""
+ response = RedirectResponse(url="/comfy/login", status_code=status.HTTP_303_SEE_OTHER)
+ response.delete_cookie("comfy_session")
+ return response
+
+
+def _check_auth(comfy_session: str | None) -> None:
+ """Check if user is authenticated, raise 401 if not."""
+ if not COMFYUI_USER:
+ # Auth not configured, allow access
+ return
+ if not _verify_session_token(comfy_session):
+ raise HTTPException(
+ status_code=status.HTTP_307_TEMPORARY_REDIRECT,
+ headers={"Location": "/comfy/login"},
+ )
+
+
+@router.api_route("/comfy/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
+async def proxy_comfyui(request: Request, path: str, comfy_session: str | None = Cookie(default=None)) -> Response:
+ """Proxy all HTTP requests to ComfyUI."""
+ _check_auth(comfy_session)
+
+ # Build target URL
+ target_url = f"{COMFYUI_URL}/{path}"
+ if request.url.query:
+ target_url += f"?{request.url.query}"
+
+ # Forward headers (excluding host)
+ headers = dict(request.headers)
+ headers.pop("host", None)
+ headers.pop("cookie", None)
+
+ # Get request body
+ body = await request.body()
+
+ async with httpx.AsyncClient(timeout=300.0) as client:
+ try:
+ response = await client.request(
+ method=request.method,
+ url=target_url,
+ headers=headers,
+ content=body,
+ )
+ except httpx.ConnectError as e:
+ raise HTTPException(status_code=502, detail="ComfyUI is not running") from e
+ except httpx.TimeoutException as e:
+ raise HTTPException(status_code=504, detail="ComfyUI request timed out") from e
+
+ # Build response
+ excluded_headers = {"content-encoding", "content-length", "transfer-encoding", "connection"}
+ response_headers = {k: v for k, v in response.headers.items() if k.lower() not in excluded_headers}
+
+ return Response(
+ content=response.content,
+ status_code=response.status_code,
+ headers=response_headers,
+ media_type=response.headers.get("content-type"),
+ )
+
+
+@router.websocket("/comfy/ws")
+async def proxy_websocket(websocket: WebSocket, comfy_session: str | None = Cookie(default=None)) -> None:
+ """Proxy WebSocket connections to ComfyUI."""
+ # Check auth via cookie
+ if COMFYUI_USER and not _verify_session_token(comfy_session):
+ await websocket.close(code=4001, reason="Unauthorized")
+ return
+
+ await websocket.accept()
+
+ # Connect to ComfyUI WebSocket
+ comfy_ws_url = COMFYUI_URL.replace("http://", "ws://").replace("https://", "wss://") + "/ws"
+ if websocket.url.query:
+ comfy_ws_url += f"?{websocket.url.query}"
+
+ try:
+ async with websockets.connect(comfy_ws_url) as comfy_ws:
+
+ async def client_to_comfy() -> None:
+ try:
+ while True:
+ data = await websocket.receive_text()
+ await comfy_ws.send(data)
+ except Exception:
+ pass
+
+ async def comfy_to_client() -> None:
+ try:
+ async for message in comfy_ws:
+ if isinstance(message, bytes):
+ await websocket.send_bytes(message)
+ else:
+ await websocket.send_text(message)
+ except Exception:
+ pass
+
+ # Run both directions concurrently
+ await asyncio.gather(client_to_comfy(), comfy_to_client(), return_exceptions=True)
+
+ except Exception as e:
+ await websocket.close(code=1011, reason=str(e))
+
+
+def create_comfyui_router() -> APIRouter:
+ """Return the ComfyUI proxy router."""
+ return router