Add WebSocket-based progress tracking for ComfyUI generation

- Replace polling with WebSocket connection for real-time progress
- Show step-by-step progress during sampling (Step 1/20, etc.)
- Display progress bar with actual completion percentage
- Fall back to polling if WebSocket connection fails
- Import websocket-client for sync WebSocket support

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Adam Ladachowski
2026-02-22 04:07:07 +00:00
parent 3888558214
commit 5ddfb07448
+158 -31
View File
@@ -9,20 +9,20 @@ import time
import uuid import uuid
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from http import HTTPStatus
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import httpx import httpx
from rich.progress import Progress, SpinnerColumn, TextColumn import websocket
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
from tensors.config import get_comfyui_url from tensors.config import get_comfyui_url
if TYPE_CHECKING: if TYPE_CHECKING:
from rich.console import Console from rich.console import Console
# Progress update throttle interval (seconds) # WebSocket timeout for receiving messages (seconds)
_PROGRESS_UPDATE_INTERVAL = 0.25 _WS_RECV_TIMEOUT = 1.0
def _get_comfyui_url() -> str: def _get_comfyui_url() -> str:
@@ -384,27 +384,153 @@ def queue_prompt(
return None return None
def _poll_for_completion( def _wait_for_completion_ws(
prompt_id: str, prompt_id: str,
url: str, url: str,
client_id: str,
timeout: float = 600.0, timeout: float = 600.0,
poll_interval: float = 0.5,
on_progress: ProgressCallback | None = None, on_progress: ProgressCallback | None = None,
) -> WorkflowResult: ) -> WorkflowResult:
"""Poll history endpoint for workflow completion. """Wait for workflow completion using WebSocket for real-time progress.
Args: Args:
prompt_id: The prompt ID to track prompt_id: The prompt ID to track
url: ComfyUI base URL url: ComfyUI base URL (http://...)
client_id: Client ID used when queueing the prompt
timeout: Maximum wait time in seconds timeout: Maximum wait time in seconds
poll_interval: Time between polls in seconds on_progress: Optional callback for progress updates (step, total, status)
on_progress: Optional callback for progress updates
Returns: Returns:
WorkflowResult with outputs or errors WorkflowResult with outputs or errors
""" """
# Convert http(s) URL to ws(s) URL
ws_url = url.replace("http://", "ws://").replace("https://", "wss://")
ws_url = f"{ws_url}/ws?clientId={client_id}"
start_time = time.time() start_time = time.time()
last_progress_time = 0.0 outputs: dict[str, Any] = {}
node_errors: dict[str, Any] = {}
current_node: str | None = None
try:
ws = websocket.create_connection(ws_url, timeout=timeout)
except Exception as e:
# Fall back to polling if WebSocket fails
return _poll_for_completion_fallback(prompt_id, url, timeout, on_progress)
try:
while time.time() - start_time < timeout:
try:
ws.settimeout(_WS_RECV_TIMEOUT)
msg = ws.recv()
if not msg:
continue
data = json.loads(msg)
msg_type = data.get("type", "")
msg_data = data.get("data", {})
# Only process messages for our prompt
if msg_data.get("prompt_id") and msg_data.get("prompt_id") != prompt_id:
continue
if msg_type == "execution_start":
if on_progress:
on_progress(0, 0, "Starting...")
elif msg_type == "execution_cached":
# Some nodes were cached
cached_nodes = msg_data.get("nodes", [])
if on_progress and cached_nodes:
on_progress(0, 0, f"Cached {len(cached_nodes)} node(s)")
elif msg_type == "executing":
# A node is being executed
current_node = msg_data.get("node")
if current_node is None:
# Execution finished (node=None means done)
break
# Don't update progress for non-sampler nodes to preserve step display
elif msg_type == "progress":
# Sampling progress: {"value": 5, "max": 20}
value = msg_data.get("value", 0)
max_val = msg_data.get("max", 0)
if on_progress and max_val > 0:
on_progress(value, max_val, f"Step {value}/{max_val}")
elif msg_type == "executed":
# A node finished, may have output
node_id = msg_data.get("node")
output = msg_data.get("output", {})
if node_id and output:
outputs[node_id] = output
elif msg_type == "execution_error":
# Execution failed
node_id = msg_data.get("node_id", "unknown")
error_msg = msg_data.get("exception_message", "Unknown error")
node_errors[node_id] = error_msg
ws.close()
return WorkflowResult(
prompt_id=prompt_id,
outputs=outputs,
node_errors=node_errors,
success=False,
)
elif msg_type == "execution_success":
# Explicitly done
break
except websocket.WebSocketTimeoutException:
# No message received, continue waiting
continue
except websocket.WebSocketConnectionClosedException:
break
finally:
try:
ws.close()
except Exception:
pass
# Fetch final outputs from history to ensure we have everything
try:
response = httpx.get(f"{url}/history/{prompt_id}", timeout=10.0)
response.raise_for_status()
history = response.json()
if prompt_id in history:
entry = history[prompt_id]
outputs = entry.get("outputs", outputs)
status_info = entry.get("status", {})
if status_info.get("status_str") == "error":
return WorkflowResult(
prompt_id=prompt_id,
outputs=outputs,
node_errors=status_info.get("messages", {}),
success=False,
)
except Exception:
pass
return WorkflowResult(
prompt_id=prompt_id,
outputs=outputs,
node_errors=node_errors,
success=len(node_errors) == 0,
)
def _poll_for_completion_fallback(
prompt_id: str,
url: str,
timeout: float = 600.0,
on_progress: ProgressCallback | None = None,
) -> WorkflowResult:
"""Fallback polling method when WebSocket is unavailable."""
start_time = time.time()
poll_interval = 0.5
while time.time() - start_time < timeout: while time.time() - start_time < timeout:
try: try:
@@ -417,7 +543,6 @@ def _poll_for_completion(
outputs = entry.get("outputs", {}) outputs = entry.get("outputs", {})
status_info = entry.get("status", {}) status_info = entry.get("status", {})
# Check for errors
if status_info.get("status_str") == "error": if status_info.get("status_str") == "error":
return WorkflowResult( return WorkflowResult(
prompt_id=prompt_id, prompt_id=prompt_id,
@@ -426,32 +551,20 @@ def _poll_for_completion(
success=False, success=False,
) )
# Success - return outputs
return WorkflowResult( return WorkflowResult(
prompt_id=prompt_id, prompt_id=prompt_id,
outputs=outputs, outputs=outputs,
success=True, success=True,
) )
# Still running - check queue for progress
if on_progress: if on_progress:
now = time.time() on_progress(0, 0, "Running...")
if now - last_progress_time >= _PROGRESS_UPDATE_INTERVAL:
queue_response = httpx.get(f"{url}/queue", timeout=5.0)
if queue_response.status_code == HTTPStatus.OK:
queue_data = queue_response.json()
running = queue_data.get("queue_running", [])
pending = queue_data.get("queue_pending", [])
total = len(running) + len(pending)
on_progress(0, total, f"Queued ({len(pending)} pending)")
last_progress_time = now
except httpx.RequestError: except httpx.RequestError:
pass # Connection error, keep polling pass
time.sleep(poll_interval) time.sleep(poll_interval)
# Timeout
return WorkflowResult( return WorkflowResult(
prompt_id=prompt_id, prompt_id=prompt_id,
node_errors={"timeout": f"Workflow did not complete within {timeout}s"}, node_errors={"timeout": f"Workflow did not complete within {timeout}s"},
@@ -491,11 +604,14 @@ def run_workflow(
else: else:
workflow_dict = workflow workflow_dict = workflow
# Generate client_id for WebSocket tracking
client_id = str(uuid.uuid4())
# Queue the workflow # Queue the workflow
if console: if console:
console.print("[cyan]Queueing workflow...[/cyan]") console.print("[cyan]Queueing workflow...[/cyan]")
result = queue_prompt(workflow_dict, url=base_url, console=console) result = queue_prompt(workflow_dict, url=base_url, client_id=client_id, console=console)
if not result: if not result:
return None return None
@@ -503,23 +619,34 @@ def run_workflow(
if console: if console:
console.print(f"[dim]Prompt ID: {prompt_id}[/dim]") console.print(f"[dim]Prompt ID: {prompt_id}[/dim]")
# Poll for completion with progress # Wait for completion with WebSocket progress
if console: if console:
with Progress( with Progress(
SpinnerColumn(), SpinnerColumn(),
TextColumn("[progress.description]{task.description}"), TextColumn("[progress.description]{task.description}"),
BarColumn(bar_width=20),
TaskProgressColumn(),
console=console, console=console,
) as progress: ) as progress:
task = progress.add_task("[cyan]Running workflow...", total=None) task = progress.add_task("[cyan]Starting...", total=None)
def _console_progress(step: int, total: int, status: str) -> None: def _console_progress(step: int, total: int, status: str) -> None:
if total > 0:
# Update to determinate progress bar
progress.update(task, completed=step, total=total, description=f"[cyan]{status}[/cyan]")
else:
# Indeterminate spinner
progress.update(task, description=f"[cyan]{status}[/cyan]") progress.update(task, description=f"[cyan]{status}[/cyan]")
if on_progress: if on_progress:
on_progress(step, total, status) on_progress(step, total, status)
return _poll_for_completion(prompt_id, base_url, timeout, on_progress=_console_progress) return _wait_for_completion_ws(
prompt_id, base_url, client_id, timeout, on_progress=_console_progress
)
else: else:
return _poll_for_completion(prompt_id, base_url, timeout, on_progress=on_progress) return _wait_for_completion_ws(
prompt_id, base_url, client_id, timeout, on_progress=on_progress
)
# ============================================================================ # ============================================================================