diff --git a/tensors/comfyui.py b/tensors/comfyui.py index 9db3f32..f72945c 100644 --- a/tensors/comfyui.py +++ b/tensors/comfyui.py @@ -9,20 +9,20 @@ import time import uuid from collections.abc import Callable from dataclasses import dataclass, field -from http import HTTPStatus from pathlib import Path from typing import TYPE_CHECKING, Any import httpx -from rich.progress import Progress, SpinnerColumn, TextColumn +import websocket +from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn from tensors.config import get_comfyui_url if TYPE_CHECKING: from rich.console import Console -# Progress update throttle interval (seconds) -_PROGRESS_UPDATE_INTERVAL = 0.25 +# WebSocket timeout for receiving messages (seconds) +_WS_RECV_TIMEOUT = 1.0 def _get_comfyui_url() -> str: @@ -384,27 +384,153 @@ def queue_prompt( return None -def _poll_for_completion( +def _wait_for_completion_ws( prompt_id: str, url: str, + client_id: str, timeout: float = 600.0, - poll_interval: float = 0.5, on_progress: ProgressCallback | None = None, ) -> WorkflowResult: - """Poll history endpoint for workflow completion. + """Wait for workflow completion using WebSocket for real-time progress. Args: prompt_id: The prompt ID to track - url: ComfyUI base URL + url: ComfyUI base URL (http://...) + client_id: Client ID used when queueing the prompt timeout: Maximum wait time in seconds - poll_interval: Time between polls in seconds - on_progress: Optional callback for progress updates + on_progress: Optional callback for progress updates (step, total, status) Returns: WorkflowResult with outputs or errors """ + # Convert http(s) URL to ws(s) URL + ws_url = url.replace("http://", "ws://").replace("https://", "wss://") + ws_url = f"{ws_url}/ws?clientId={client_id}" + start_time = time.time() - last_progress_time = 0.0 + outputs: dict[str, Any] = {} + node_errors: dict[str, Any] = {} + current_node: str | None = None + + try: + ws = websocket.create_connection(ws_url, timeout=timeout) + except Exception as e: + # Fall back to polling if WebSocket fails + return _poll_for_completion_fallback(prompt_id, url, timeout, on_progress) + + try: + while time.time() - start_time < timeout: + try: + ws.settimeout(_WS_RECV_TIMEOUT) + msg = ws.recv() + if not msg: + continue + + data = json.loads(msg) + msg_type = data.get("type", "") + msg_data = data.get("data", {}) + + # Only process messages for our prompt + if msg_data.get("prompt_id") and msg_data.get("prompt_id") != prompt_id: + continue + + if msg_type == "execution_start": + if on_progress: + on_progress(0, 0, "Starting...") + + elif msg_type == "execution_cached": + # Some nodes were cached + cached_nodes = msg_data.get("nodes", []) + if on_progress and cached_nodes: + on_progress(0, 0, f"Cached {len(cached_nodes)} node(s)") + + elif msg_type == "executing": + # A node is being executed + current_node = msg_data.get("node") + if current_node is None: + # Execution finished (node=None means done) + break + # Don't update progress for non-sampler nodes to preserve step display + + elif msg_type == "progress": + # Sampling progress: {"value": 5, "max": 20} + value = msg_data.get("value", 0) + max_val = msg_data.get("max", 0) + if on_progress and max_val > 0: + on_progress(value, max_val, f"Step {value}/{max_val}") + + elif msg_type == "executed": + # A node finished, may have output + node_id = msg_data.get("node") + output = msg_data.get("output", {}) + if node_id and output: + outputs[node_id] = output + + elif msg_type == "execution_error": + # Execution failed + node_id = msg_data.get("node_id", "unknown") + error_msg = msg_data.get("exception_message", "Unknown error") + node_errors[node_id] = error_msg + ws.close() + return WorkflowResult( + prompt_id=prompt_id, + outputs=outputs, + node_errors=node_errors, + success=False, + ) + + elif msg_type == "execution_success": + # Explicitly done + break + + except websocket.WebSocketTimeoutException: + # No message received, continue waiting + continue + except websocket.WebSocketConnectionClosedException: + break + + finally: + try: + ws.close() + except Exception: + pass + + # Fetch final outputs from history to ensure we have everything + try: + response = httpx.get(f"{url}/history/{prompt_id}", timeout=10.0) + response.raise_for_status() + history = response.json() + if prompt_id in history: + entry = history[prompt_id] + outputs = entry.get("outputs", outputs) + status_info = entry.get("status", {}) + if status_info.get("status_str") == "error": + return WorkflowResult( + prompt_id=prompt_id, + outputs=outputs, + node_errors=status_info.get("messages", {}), + success=False, + ) + except Exception: + pass + + return WorkflowResult( + prompt_id=prompt_id, + outputs=outputs, + node_errors=node_errors, + success=len(node_errors) == 0, + ) + + +def _poll_for_completion_fallback( + prompt_id: str, + url: str, + timeout: float = 600.0, + on_progress: ProgressCallback | None = None, +) -> WorkflowResult: + """Fallback polling method when WebSocket is unavailable.""" + start_time = time.time() + poll_interval = 0.5 while time.time() - start_time < timeout: try: @@ -417,7 +543,6 @@ def _poll_for_completion( outputs = entry.get("outputs", {}) status_info = entry.get("status", {}) - # Check for errors if status_info.get("status_str") == "error": return WorkflowResult( prompt_id=prompt_id, @@ -426,32 +551,20 @@ def _poll_for_completion( success=False, ) - # Success - return outputs return WorkflowResult( prompt_id=prompt_id, outputs=outputs, success=True, ) - # Still running - check queue for progress if on_progress: - now = time.time() - if now - last_progress_time >= _PROGRESS_UPDATE_INTERVAL: - queue_response = httpx.get(f"{url}/queue", timeout=5.0) - if queue_response.status_code == HTTPStatus.OK: - queue_data = queue_response.json() - running = queue_data.get("queue_running", []) - pending = queue_data.get("queue_pending", []) - total = len(running) + len(pending) - on_progress(0, total, f"Queued ({len(pending)} pending)") - last_progress_time = now + on_progress(0, 0, "Running...") except httpx.RequestError: - pass # Connection error, keep polling + pass time.sleep(poll_interval) - # Timeout return WorkflowResult( prompt_id=prompt_id, node_errors={"timeout": f"Workflow did not complete within {timeout}s"}, @@ -491,11 +604,14 @@ def run_workflow( else: workflow_dict = workflow + # Generate client_id for WebSocket tracking + client_id = str(uuid.uuid4()) + # Queue the workflow if console: console.print("[cyan]Queueing workflow...[/cyan]") - result = queue_prompt(workflow_dict, url=base_url, console=console) + result = queue_prompt(workflow_dict, url=base_url, client_id=client_id, console=console) if not result: return None @@ -503,23 +619,34 @@ def run_workflow( if console: console.print(f"[dim]Prompt ID: {prompt_id}[/dim]") - # Poll for completion with progress + # Wait for completion with WebSocket progress if console: with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), + BarColumn(bar_width=20), + TaskProgressColumn(), console=console, ) as progress: - task = progress.add_task("[cyan]Running workflow...", total=None) + task = progress.add_task("[cyan]Starting...", total=None) def _console_progress(step: int, total: int, status: str) -> None: - progress.update(task, description=f"[cyan]{status}[/cyan]") + if total > 0: + # Update to determinate progress bar + progress.update(task, completed=step, total=total, description=f"[cyan]{status}[/cyan]") + else: + # Indeterminate spinner + progress.update(task, description=f"[cyan]{status}[/cyan]") if on_progress: on_progress(step, total, status) - return _poll_for_completion(prompt_id, base_url, timeout, on_progress=_console_progress) + return _wait_for_completion_ws( + prompt_id, base_url, client_id, timeout, on_progress=_console_progress + ) else: - return _poll_for_completion(prompt_id, base_url, timeout, on_progress=on_progress) + return _wait_for_completion_ws( + prompt_id, base_url, client_id, timeout, on_progress=on_progress + ) # ============================================================================