Add WebSocket-based progress tracking for ComfyUI generation
- Replace polling with WebSocket connection for real-time progress - Show step-by-step progress during sampling (Step 1/20, etc.) - Display progress bar with actual completion percentage - Fall back to polling if WebSocket connection fails - Import websocket-client for sync WebSocket support Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
+159
-32
@@ -9,20 +9,20 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from http import HTTPStatus
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
import websocket
|
||||||
|
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
|
||||||
|
|
||||||
from tensors.config import get_comfyui_url
|
from tensors.config import get_comfyui_url
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
# Progress update throttle interval (seconds)
|
# WebSocket timeout for receiving messages (seconds)
|
||||||
_PROGRESS_UPDATE_INTERVAL = 0.25
|
_WS_RECV_TIMEOUT = 1.0
|
||||||
|
|
||||||
|
|
||||||
def _get_comfyui_url() -> str:
|
def _get_comfyui_url() -> str:
|
||||||
@@ -384,27 +384,153 @@ def queue_prompt(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _poll_for_completion(
|
def _wait_for_completion_ws(
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
url: str,
|
url: str,
|
||||||
|
client_id: str,
|
||||||
timeout: float = 600.0,
|
timeout: float = 600.0,
|
||||||
poll_interval: float = 0.5,
|
|
||||||
on_progress: ProgressCallback | None = None,
|
on_progress: ProgressCallback | None = None,
|
||||||
) -> WorkflowResult:
|
) -> WorkflowResult:
|
||||||
"""Poll history endpoint for workflow completion.
|
"""Wait for workflow completion using WebSocket for real-time progress.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt_id: The prompt ID to track
|
prompt_id: The prompt ID to track
|
||||||
url: ComfyUI base URL
|
url: ComfyUI base URL (http://...)
|
||||||
|
client_id: Client ID used when queueing the prompt
|
||||||
timeout: Maximum wait time in seconds
|
timeout: Maximum wait time in seconds
|
||||||
poll_interval: Time between polls in seconds
|
on_progress: Optional callback for progress updates (step, total, status)
|
||||||
on_progress: Optional callback for progress updates
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
WorkflowResult with outputs or errors
|
WorkflowResult with outputs or errors
|
||||||
"""
|
"""
|
||||||
|
# Convert http(s) URL to ws(s) URL
|
||||||
|
ws_url = url.replace("http://", "ws://").replace("https://", "wss://")
|
||||||
|
ws_url = f"{ws_url}/ws?clientId={client_id}"
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
last_progress_time = 0.0
|
outputs: dict[str, Any] = {}
|
||||||
|
node_errors: dict[str, Any] = {}
|
||||||
|
current_node: str | None = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
ws = websocket.create_connection(ws_url, timeout=timeout)
|
||||||
|
except Exception as e:
|
||||||
|
# Fall back to polling if WebSocket fails
|
||||||
|
return _poll_for_completion_fallback(prompt_id, url, timeout, on_progress)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
try:
|
||||||
|
ws.settimeout(_WS_RECV_TIMEOUT)
|
||||||
|
msg = ws.recv()
|
||||||
|
if not msg:
|
||||||
|
continue
|
||||||
|
|
||||||
|
data = json.loads(msg)
|
||||||
|
msg_type = data.get("type", "")
|
||||||
|
msg_data = data.get("data", {})
|
||||||
|
|
||||||
|
# Only process messages for our prompt
|
||||||
|
if msg_data.get("prompt_id") and msg_data.get("prompt_id") != prompt_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if msg_type == "execution_start":
|
||||||
|
if on_progress:
|
||||||
|
on_progress(0, 0, "Starting...")
|
||||||
|
|
||||||
|
elif msg_type == "execution_cached":
|
||||||
|
# Some nodes were cached
|
||||||
|
cached_nodes = msg_data.get("nodes", [])
|
||||||
|
if on_progress and cached_nodes:
|
||||||
|
on_progress(0, 0, f"Cached {len(cached_nodes)} node(s)")
|
||||||
|
|
||||||
|
elif msg_type == "executing":
|
||||||
|
# A node is being executed
|
||||||
|
current_node = msg_data.get("node")
|
||||||
|
if current_node is None:
|
||||||
|
# Execution finished (node=None means done)
|
||||||
|
break
|
||||||
|
# Don't update progress for non-sampler nodes to preserve step display
|
||||||
|
|
||||||
|
elif msg_type == "progress":
|
||||||
|
# Sampling progress: {"value": 5, "max": 20}
|
||||||
|
value = msg_data.get("value", 0)
|
||||||
|
max_val = msg_data.get("max", 0)
|
||||||
|
if on_progress and max_val > 0:
|
||||||
|
on_progress(value, max_val, f"Step {value}/{max_val}")
|
||||||
|
|
||||||
|
elif msg_type == "executed":
|
||||||
|
# A node finished, may have output
|
||||||
|
node_id = msg_data.get("node")
|
||||||
|
output = msg_data.get("output", {})
|
||||||
|
if node_id and output:
|
||||||
|
outputs[node_id] = output
|
||||||
|
|
||||||
|
elif msg_type == "execution_error":
|
||||||
|
# Execution failed
|
||||||
|
node_id = msg_data.get("node_id", "unknown")
|
||||||
|
error_msg = msg_data.get("exception_message", "Unknown error")
|
||||||
|
node_errors[node_id] = error_msg
|
||||||
|
ws.close()
|
||||||
|
return WorkflowResult(
|
||||||
|
prompt_id=prompt_id,
|
||||||
|
outputs=outputs,
|
||||||
|
node_errors=node_errors,
|
||||||
|
success=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif msg_type == "execution_success":
|
||||||
|
# Explicitly done
|
||||||
|
break
|
||||||
|
|
||||||
|
except websocket.WebSocketTimeoutException:
|
||||||
|
# No message received, continue waiting
|
||||||
|
continue
|
||||||
|
except websocket.WebSocketConnectionClosedException:
|
||||||
|
break
|
||||||
|
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
ws.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fetch final outputs from history to ensure we have everything
|
||||||
|
try:
|
||||||
|
response = httpx.get(f"{url}/history/{prompt_id}", timeout=10.0)
|
||||||
|
response.raise_for_status()
|
||||||
|
history = response.json()
|
||||||
|
if prompt_id in history:
|
||||||
|
entry = history[prompt_id]
|
||||||
|
outputs = entry.get("outputs", outputs)
|
||||||
|
status_info = entry.get("status", {})
|
||||||
|
if status_info.get("status_str") == "error":
|
||||||
|
return WorkflowResult(
|
||||||
|
prompt_id=prompt_id,
|
||||||
|
outputs=outputs,
|
||||||
|
node_errors=status_info.get("messages", {}),
|
||||||
|
success=False,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return WorkflowResult(
|
||||||
|
prompt_id=prompt_id,
|
||||||
|
outputs=outputs,
|
||||||
|
node_errors=node_errors,
|
||||||
|
success=len(node_errors) == 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _poll_for_completion_fallback(
|
||||||
|
prompt_id: str,
|
||||||
|
url: str,
|
||||||
|
timeout: float = 600.0,
|
||||||
|
on_progress: ProgressCallback | None = None,
|
||||||
|
) -> WorkflowResult:
|
||||||
|
"""Fallback polling method when WebSocket is unavailable."""
|
||||||
|
start_time = time.time()
|
||||||
|
poll_interval = 0.5
|
||||||
|
|
||||||
while time.time() - start_time < timeout:
|
while time.time() - start_time < timeout:
|
||||||
try:
|
try:
|
||||||
@@ -417,7 +543,6 @@ def _poll_for_completion(
|
|||||||
outputs = entry.get("outputs", {})
|
outputs = entry.get("outputs", {})
|
||||||
status_info = entry.get("status", {})
|
status_info = entry.get("status", {})
|
||||||
|
|
||||||
# Check for errors
|
|
||||||
if status_info.get("status_str") == "error":
|
if status_info.get("status_str") == "error":
|
||||||
return WorkflowResult(
|
return WorkflowResult(
|
||||||
prompt_id=prompt_id,
|
prompt_id=prompt_id,
|
||||||
@@ -426,32 +551,20 @@ def _poll_for_completion(
|
|||||||
success=False,
|
success=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Success - return outputs
|
|
||||||
return WorkflowResult(
|
return WorkflowResult(
|
||||||
prompt_id=prompt_id,
|
prompt_id=prompt_id,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
success=True,
|
success=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Still running - check queue for progress
|
|
||||||
if on_progress:
|
if on_progress:
|
||||||
now = time.time()
|
on_progress(0, 0, "Running...")
|
||||||
if now - last_progress_time >= _PROGRESS_UPDATE_INTERVAL:
|
|
||||||
queue_response = httpx.get(f"{url}/queue", timeout=5.0)
|
|
||||||
if queue_response.status_code == HTTPStatus.OK:
|
|
||||||
queue_data = queue_response.json()
|
|
||||||
running = queue_data.get("queue_running", [])
|
|
||||||
pending = queue_data.get("queue_pending", [])
|
|
||||||
total = len(running) + len(pending)
|
|
||||||
on_progress(0, total, f"Queued ({len(pending)} pending)")
|
|
||||||
last_progress_time = now
|
|
||||||
|
|
||||||
except httpx.RequestError:
|
except httpx.RequestError:
|
||||||
pass # Connection error, keep polling
|
pass
|
||||||
|
|
||||||
time.sleep(poll_interval)
|
time.sleep(poll_interval)
|
||||||
|
|
||||||
# Timeout
|
|
||||||
return WorkflowResult(
|
return WorkflowResult(
|
||||||
prompt_id=prompt_id,
|
prompt_id=prompt_id,
|
||||||
node_errors={"timeout": f"Workflow did not complete within {timeout}s"},
|
node_errors={"timeout": f"Workflow did not complete within {timeout}s"},
|
||||||
@@ -491,11 +604,14 @@ def run_workflow(
|
|||||||
else:
|
else:
|
||||||
workflow_dict = workflow
|
workflow_dict = workflow
|
||||||
|
|
||||||
|
# Generate client_id for WebSocket tracking
|
||||||
|
client_id = str(uuid.uuid4())
|
||||||
|
|
||||||
# Queue the workflow
|
# Queue the workflow
|
||||||
if console:
|
if console:
|
||||||
console.print("[cyan]Queueing workflow...[/cyan]")
|
console.print("[cyan]Queueing workflow...[/cyan]")
|
||||||
|
|
||||||
result = queue_prompt(workflow_dict, url=base_url, console=console)
|
result = queue_prompt(workflow_dict, url=base_url, client_id=client_id, console=console)
|
||||||
if not result:
|
if not result:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -503,23 +619,34 @@ def run_workflow(
|
|||||||
if console:
|
if console:
|
||||||
console.print(f"[dim]Prompt ID: {prompt_id}[/dim]")
|
console.print(f"[dim]Prompt ID: {prompt_id}[/dim]")
|
||||||
|
|
||||||
# Poll for completion with progress
|
# Wait for completion with WebSocket progress
|
||||||
if console:
|
if console:
|
||||||
with Progress(
|
with Progress(
|
||||||
SpinnerColumn(),
|
SpinnerColumn(),
|
||||||
TextColumn("[progress.description]{task.description}"),
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
BarColumn(bar_width=20),
|
||||||
|
TaskProgressColumn(),
|
||||||
console=console,
|
console=console,
|
||||||
) as progress:
|
) as progress:
|
||||||
task = progress.add_task("[cyan]Running workflow...", total=None)
|
task = progress.add_task("[cyan]Starting...", total=None)
|
||||||
|
|
||||||
def _console_progress(step: int, total: int, status: str) -> None:
|
def _console_progress(step: int, total: int, status: str) -> None:
|
||||||
progress.update(task, description=f"[cyan]{status}[/cyan]")
|
if total > 0:
|
||||||
|
# Update to determinate progress bar
|
||||||
|
progress.update(task, completed=step, total=total, description=f"[cyan]{status}[/cyan]")
|
||||||
|
else:
|
||||||
|
# Indeterminate spinner
|
||||||
|
progress.update(task, description=f"[cyan]{status}[/cyan]")
|
||||||
if on_progress:
|
if on_progress:
|
||||||
on_progress(step, total, status)
|
on_progress(step, total, status)
|
||||||
|
|
||||||
return _poll_for_completion(prompt_id, base_url, timeout, on_progress=_console_progress)
|
return _wait_for_completion_ws(
|
||||||
|
prompt_id, base_url, client_id, timeout, on_progress=_console_progress
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return _poll_for_completion(prompt_id, base_url, timeout, on_progress=on_progress)
|
return _wait_for_completion_ws(
|
||||||
|
prompt_id, base_url, client_id, timeout, on_progress=on_progress
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|||||||
Reference in New Issue
Block a user