diff --git a/pyproject.toml b/pyproject.toml index 31c1ffe..6fed577 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ ] [project.optional-dependencies] -server = ["fastapi>=0.115", "uvicorn>=0.30", "scalar-fastapi>=1.6"] +server = ["fastapi>=0.115", "uvicorn>=0.30", "scalar-fastapi>=1.6", "websockets>=12.0", "python-multipart>=0.0.9"] [project.scripts] tsr = "tensors:main" @@ -39,6 +39,8 @@ dev = [ "fastapi>=0.115", "uvicorn>=0.30", "scalar-fastapi>=1.6", + "websockets>=12.0", + "python-multipart>=0.0.9", ] [tool.ruff] @@ -104,6 +106,10 @@ ignore_missing_imports = true module = ["huggingface_hub.*"] ignore_missing_imports = true +[[tool.mypy.overrides]] +module = ["websockets.*"] +ignore_missing_imports = true + [tool.pytest.ini_options] testpaths = ["tests"] addopts = "-v --cov=tensors --cov-report=term-missing" diff --git a/tensors/server/__init__.py b/tensors/server/__init__.py index e5d3537..ccf56c0 100644 --- a/tensors/server/__init__.py +++ b/tensors/server/__init__.py @@ -12,6 +12,7 @@ from scalar_fastapi import get_scalar_api_reference from tensors.config import get_server_api_key from tensors.server.civitai_routes import create_civitai_router +from tensors.server.comfyui_routes import create_comfyui_router from tensors.server.db_routes import create_db_router from tensors.server.download_routes import create_download_router from tensors.server.gallery_routes import create_gallery_router @@ -69,7 +70,10 @@ def create_app() -> FastAPI: title="tensors API", ) - # Protected routers (auth required if configured) + # ComfyUI proxy (handles its own session auth) + app.include_router(create_comfyui_router()) + + # Protected routers (API key auth) from tensors.server.auth import verify_api_key # noqa: PLC0415 app.include_router(create_search_router(), dependencies=[Depends(verify_api_key)]) diff --git a/tensors/server/comfyui_routes.py b/tensors/server/comfyui_routes.py new file mode 100644 index 0000000..bf4f121 --- /dev/null +++ b/tensors/server/comfyui_routes.py @@ -0,0 +1,340 @@ +"""ComfyUI reverse proxy with session authentication.""" + +from __future__ import annotations + +import asyncio +import hashlib +import hmac +import os +import time + +import httpx +import websockets +from fastapi import APIRouter, Cookie, Form, HTTPException, Request, Response, WebSocket, status +from fastapi.responses import HTMLResponse, RedirectResponse + +router = APIRouter(tags=["ComfyUI"]) + +# Number of parts in session token +_SESSION_TOKEN_PARTS = 3 + +# Config from environment +COMFYUI_URL = os.environ.get("COMFYUI_URL", "http://127.0.0.1:8188") +COMFYUI_USER = os.environ.get("COMFYUI_USER", "") +COMFYUI_PASS = os.environ.get("COMFYUI_PASS", "") +SESSION_SECRET = os.environ.get("SESSION_SECRET", "tensors-comfyui-secret-change-me") +SESSION_MAX_AGE = 86400 * 7 # 7 days + + +def _create_session_token(username: str) -> str: + """Create a signed session token.""" + expires = int(time.time()) + SESSION_MAX_AGE + data = f"{username}:{expires}" + signature = hmac.new(SESSION_SECRET.encode(), data.encode(), hashlib.sha256).hexdigest()[:32] + return f"{data}:{signature}" + + +def _verify_session_token(token: str | None) -> bool: + """Verify a session token.""" + if not token: + return False + try: + parts = token.split(":") + if len(parts) != _SESSION_TOKEN_PARTS: + return False + username, expires_str, signature = parts + expires = int(expires_str) + if time.time() > expires: + return False + data = f"{username}:{expires_str}" + expected = hmac.new(SESSION_SECRET.encode(), data.encode(), hashlib.sha256).hexdigest()[:32] + return hmac.compare_digest(signature, expected) + except (ValueError, TypeError): + return False + + +LOGIN_PAGE_HTML = """ + + + + + + ComfyUI - Login + + + +
+ + {{ERROR}} +
+
+ + +
+
+ + +
+ +
+ +
+ + +""" + + +@router.get("/comfy/login") +async def login_page(error: str | None = None) -> HTMLResponse: + """Show login page.""" + error_html = "" + if error: + error_html = f'
{error}
' + html = LOGIN_PAGE_HTML.replace("{{ERROR}}", error_html) + return HTMLResponse(content=html) + + +@router.post("/comfy/login") +async def login_submit(username: str = Form(...), password: str = Form(...)) -> Response: + """Handle login form submission.""" + if not COMFYUI_USER or not COMFYUI_PASS: + return RedirectResponse( + url="/comfy/login?error=Authentication+not+configured", + status_code=status.HTTP_303_SEE_OTHER, + ) + + if username == COMFYUI_USER and password == COMFYUI_PASS: + token = _create_session_token(username) + response = RedirectResponse(url="/comfy/", status_code=status.HTTP_303_SEE_OTHER) + response.set_cookie( + key="comfy_session", + value=token, + max_age=SESSION_MAX_AGE, + httponly=True, + samesite="lax", + ) + return response + + return RedirectResponse( + url="/comfy/login?error=Invalid+username+or+password", + status_code=status.HTTP_303_SEE_OTHER, + ) + + +@router.get("/comfy/logout") +async def logout() -> Response: + """Clear session and redirect to login.""" + response = RedirectResponse(url="/comfy/login", status_code=status.HTTP_303_SEE_OTHER) + response.delete_cookie("comfy_session") + return response + + +def _check_auth(comfy_session: str | None) -> None: + """Check if user is authenticated, raise 401 if not.""" + if not COMFYUI_USER: + # Auth not configured, allow access + return + if not _verify_session_token(comfy_session): + raise HTTPException( + status_code=status.HTTP_307_TEMPORARY_REDIRECT, + headers={"Location": "/comfy/login"}, + ) + + +@router.api_route("/comfy/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"]) +async def proxy_comfyui(request: Request, path: str, comfy_session: str | None = Cookie(default=None)) -> Response: + """Proxy all HTTP requests to ComfyUI.""" + _check_auth(comfy_session) + + # Build target URL + target_url = f"{COMFYUI_URL}/{path}" + if request.url.query: + target_url += f"?{request.url.query}" + + # Forward headers (excluding host) + headers = dict(request.headers) + headers.pop("host", None) + headers.pop("cookie", None) + + # Get request body + body = await request.body() + + async with httpx.AsyncClient(timeout=300.0) as client: + try: + response = await client.request( + method=request.method, + url=target_url, + headers=headers, + content=body, + ) + except httpx.ConnectError as e: + raise HTTPException(status_code=502, detail="ComfyUI is not running") from e + except httpx.TimeoutException as e: + raise HTTPException(status_code=504, detail="ComfyUI request timed out") from e + + # Build response + excluded_headers = {"content-encoding", "content-length", "transfer-encoding", "connection"} + response_headers = {k: v for k, v in response.headers.items() if k.lower() not in excluded_headers} + + return Response( + content=response.content, + status_code=response.status_code, + headers=response_headers, + media_type=response.headers.get("content-type"), + ) + + +@router.websocket("/comfy/ws") +async def proxy_websocket(websocket: WebSocket, comfy_session: str | None = Cookie(default=None)) -> None: + """Proxy WebSocket connections to ComfyUI.""" + # Check auth via cookie + if COMFYUI_USER and not _verify_session_token(comfy_session): + await websocket.close(code=4001, reason="Unauthorized") + return + + await websocket.accept() + + # Connect to ComfyUI WebSocket + comfy_ws_url = COMFYUI_URL.replace("http://", "ws://").replace("https://", "wss://") + "/ws" + if websocket.url.query: + comfy_ws_url += f"?{websocket.url.query}" + + try: + async with websockets.connect(comfy_ws_url) as comfy_ws: + + async def client_to_comfy() -> None: + try: + while True: + data = await websocket.receive_text() + await comfy_ws.send(data) + except Exception: + pass + + async def comfy_to_client() -> None: + try: + async for message in comfy_ws: + if isinstance(message, bytes): + await websocket.send_bytes(message) + else: + await websocket.send_text(message) + except Exception: + pass + + # Run both directions concurrently + await asyncio.gather(client_to_comfy(), comfy_to_client(), return_exceptions=True) + + except Exception as e: + await websocket.close(code=1011, reason=str(e)) + + +def create_comfyui_router() -> APIRouter: + """Return the ComfyUI proxy router.""" + return router