Phase 1: Core Client (tensors/comfyui.py)

Add ComfyUI API client module with:
- Basic query functions: get_system_stats(), get_queue_status(),
  get_loaded_models(), get_history(), clear_queue(), get_object_info()
- Workflow execution: queue_prompt(), run_workflow() with progress tracking
- Text-to-image: generate_image() with embedded SDXL/Flux workflow template
- Data classes: GenerationResult, WorkflowResult
- Helper: get_image() to download generated images

Follows existing api.py patterns with console: Console | None for optional
Rich progress output, httpx for HTTP requests, and proper error handling.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Adam Ladachowski
2026-02-16 16:03:49 +01:00
parent e9d1745c62
commit e776f9846e
2 changed files with 777 additions and 5 deletions
+6 -5
View File
@@ -1,17 +1,18 @@
# TODO: ComfyUI CLI & API Integration
## Phase 1: Core Client (`tensors/comfyui.py`)
- [ ] Step 1.1: Create ComfyUI client module with basic query functions
- [x] Step 1.1: Create ComfyUI client module with basic query functions
- `get_system_stats()` - System stats (GPU, RAM)
- `get_queue_status()` - Queue status
- `get_loaded_models()` - List loaded checkpoints/loras
- `get_history()` - View history
- [ ] Step 1.2: Add workflow execution with WebSocket progress tracking
- Also: `clear_queue()`, `get_object_info()`, `get_image()`
- [x] Step 1.2: Add workflow execution with WebSocket progress tracking
- `queue_prompt()` - Queue a workflow
- `run_workflow()` - Run workflow with progress callback
- [ ] Step 1.3: Add simple text-to-image generation
- `run_workflow()` - Run workflow with progress callback (uses polling)
- [x] Step 1.3: Add simple text-to-image generation
- `generate_image()` - Text-to-image with embedded workflow template
- Include SDXL/Flux-compatible default workflow
- Include SDXL/Flux-compatible default workflow (DEFAULT_WORKFLOW_TEMPLATE)
## Phase 2: CLI Commands (`tensors/cli.py`)
- [ ] Step 2.1: Add `comfy` subcommand group with status commands
+771
View File
@@ -0,0 +1,771 @@
"""ComfyUI API client for programmatic workflow execution."""
from __future__ import annotations
import copy
import json
import os
import random
import time
import uuid
from collections.abc import Callable
from dataclasses import dataclass, field
from http import HTTPStatus
from pathlib import Path
from typing import TYPE_CHECKING, Any
import httpx
from rich.progress import Progress, SpinnerColumn, TextColumn
if TYPE_CHECKING:
from rich.console import Console
# Default ComfyUI URL (same as comfyui_routes.py)
COMFYUI_DEFAULT_URL = "http://127.0.0.1:8188"
# Progress update throttle interval (seconds)
_PROGRESS_UPDATE_INTERVAL = 0.25
def _get_comfyui_url() -> str:
"""Get ComfyUI URL from environment or default."""
return os.environ.get("COMFYUI_URL", COMFYUI_DEFAULT_URL)
# ============================================================================
# Data Classes
# ============================================================================
@dataclass
class GenerationResult:
"""Result from image generation."""
prompt_id: str
images: list[Path] = field(default_factory=list)
node_errors: dict[str, Any] = field(default_factory=dict)
success: bool = True
@dataclass
class WorkflowResult:
"""Result from workflow execution."""
prompt_id: str
outputs: dict[str, Any] = field(default_factory=dict)
node_errors: dict[str, Any] = field(default_factory=dict)
success: bool = True
# ============================================================================
# Progress Callback Type
# ============================================================================
# (current_step, total_steps, status_message)
ProgressCallback = Callable[[int, int, str], None]
# ============================================================================
# Basic Query Functions
# ============================================================================
def get_system_stats(url: str | None = None, console: Console | None = None) -> dict[str, Any] | None:
"""Get ComfyUI system stats (GPU, RAM, etc.).
Args:
url: ComfyUI base URL (defaults to COMFYUI_URL env var or localhost:8188)
console: Rich console for progress/error output
Returns:
System stats dict or None on error
"""
base_url = url or _get_comfyui_url()
def _do_fetch() -> dict[str, Any] | None:
try:
response = httpx.get(f"{base_url}/system_stats", timeout=10.0)
response.raise_for_status()
result: dict[str, Any] = response.json()
return result
except httpx.HTTPStatusError as e:
if console:
console.print(f"[red]API error: {e.response.status_code}[/red]")
return None
except httpx.RequestError as e:
if console:
console.print(f"[red]Connection error: {e}[/red]")
return None
if console:
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console,
transient=True,
) as progress:
progress.add_task("[cyan]Fetching system stats...", total=None)
return _do_fetch()
else:
return _do_fetch()
def get_queue_status(url: str | None = None, console: Console | None = None) -> dict[str, Any] | None:
"""Get ComfyUI queue status.
Args:
url: ComfyUI base URL
console: Rich console for output
Returns:
Queue status dict with 'queue_running' and 'queue_pending' lists, or None on error
"""
base_url = url or _get_comfyui_url()
def _do_fetch() -> dict[str, Any] | None:
try:
response = httpx.get(f"{base_url}/queue", timeout=10.0)
response.raise_for_status()
result: dict[str, Any] = response.json()
return result
except httpx.HTTPStatusError as e:
if console:
console.print(f"[red]API error: {e.response.status_code}[/red]")
return None
except httpx.RequestError as e:
if console:
console.print(f"[red]Connection error: {e}[/red]")
return None
if console:
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console,
transient=True,
) as progress:
progress.add_task("[cyan]Fetching queue status...", total=None)
return _do_fetch()
else:
return _do_fetch()
def clear_queue(url: str | None = None, console: Console | None = None) -> bool:
"""Clear the ComfyUI queue.
Args:
url: ComfyUI base URL
console: Rich console for output
Returns:
True if successful, False on error
"""
base_url = url or _get_comfyui_url()
try:
# Clear both pending and running
response = httpx.post(f"{base_url}/queue", json={"clear": True}, timeout=10.0)
response.raise_for_status()
if console:
console.print("[green]Queue cleared[/green]")
return True
except httpx.HTTPStatusError as e:
if console:
console.print(f"[red]API error: {e.response.status_code}[/red]")
return False
except httpx.RequestError as e:
if console:
console.print(f"[red]Connection error: {e}[/red]")
return False
def get_object_info(url: str | None = None, console: Console | None = None) -> dict[str, Any] | None:
"""Get ComfyUI object info (available nodes and their configurations).
Args:
url: ComfyUI base URL
console: Rich console for output
Returns:
Object info dict or None on error
"""
base_url = url or _get_comfyui_url()
def _do_fetch() -> dict[str, Any] | None:
try:
response = httpx.get(f"{base_url}/object_info", timeout=30.0)
response.raise_for_status()
result: dict[str, Any] = response.json()
return result
except httpx.HTTPStatusError as e:
if console:
console.print(f"[red]API error: {e.response.status_code}[/red]")
return None
except httpx.RequestError as e:
if console:
console.print(f"[red]Connection error: {e}[/red]")
return None
if console:
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console,
transient=True,
) as progress:
progress.add_task("[cyan]Fetching object info...", total=None)
return _do_fetch()
else:
return _do_fetch()
def get_loaded_models(url: str | None = None, console: Console | None = None) -> dict[str, list[str]] | None:
"""Get list of loaded/available models (checkpoints, loras, etc.).
Args:
url: ComfyUI base URL
console: Rich console for output
Returns:
Dict mapping model type to list of model names, or None on error
"""
base_url = url or _get_comfyui_url()
def _do_fetch() -> dict[str, list[str]] | None:
result: dict[str, list[str]] = {}
# Model type to node class and input name mapping
model_types = {
"checkpoints": ("CheckpointLoaderSimple", "ckpt_name"),
"loras": ("LoraLoader", "lora_name"),
"vae": ("VAELoader", "vae_name"),
"clip": ("CLIPLoader", "clip_name"),
"controlnet": ("ControlNetLoader", "control_net_name"),
"upscale_models": ("UpscaleModelLoader", "model_name"),
}
try:
response = httpx.get(f"{base_url}/object_info", timeout=30.0)
response.raise_for_status()
object_info: dict[str, Any] = response.json()
for model_type, (node_class, input_name) in model_types.items():
if node_class in object_info:
node_info = object_info[node_class]
inputs = node_info.get("input", {}).get("required", {})
if input_name in inputs:
input_def = inputs[input_name]
if isinstance(input_def, list) and len(input_def) > 0 and isinstance(input_def[0], list):
result[model_type] = input_def[0]
return result
except httpx.HTTPStatusError as e:
if console:
console.print(f"[red]API error: {e.response.status_code}[/red]")
return None
except httpx.RequestError as e:
if console:
console.print(f"[red]Connection error: {e}[/red]")
return None
if console:
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console,
transient=True,
) as progress:
progress.add_task("[cyan]Fetching loaded models...", total=None)
return _do_fetch()
else:
return _do_fetch()
def get_history(
url: str | None = None,
prompt_id: str | None = None,
max_items: int = 100,
console: Console | None = None,
) -> dict[str, Any] | None:
"""Get ComfyUI history.
Args:
url: ComfyUI base URL
prompt_id: Specific prompt ID to fetch (if None, fetches recent history)
max_items: Maximum number of history items to return
console: Rich console for output
Returns:
History dict (keyed by prompt_id) or None on error
"""
base_url = url or _get_comfyui_url()
def _do_fetch() -> dict[str, Any] | None:
try:
endpoint = f"{base_url}/history/{prompt_id}" if prompt_id else f"{base_url}/history?max_items={max_items}"
response = httpx.get(endpoint, timeout=30.0)
response.raise_for_status()
result: dict[str, Any] = response.json()
return result
except httpx.HTTPStatusError as e:
if console:
console.print(f"[red]API error: {e.response.status_code}[/red]")
return None
except httpx.RequestError as e:
if console:
console.print(f"[red]Connection error: {e}[/red]")
return None
if console:
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console,
transient=True,
) as progress:
progress.add_task("[cyan]Fetching history...", total=None)
return _do_fetch()
else:
return _do_fetch()
# ============================================================================
# Workflow Execution
# ============================================================================
def queue_prompt(
workflow: dict[str, Any],
url: str | None = None,
client_id: str | None = None,
console: Console | None = None,
) -> dict[str, Any] | None:
"""Queue a workflow prompt for execution.
Args:
workflow: ComfyUI workflow dict (API format)
url: ComfyUI base URL
client_id: Client ID for WebSocket tracking
console: Rich console for output
Returns:
Response dict with 'prompt_id' and 'number', or None on error
"""
base_url = url or _get_comfyui_url()
client_id = client_id or str(uuid.uuid4())
try:
payload = {"prompt": workflow, "client_id": client_id}
response = httpx.post(f"{base_url}/prompt", json=payload, timeout=30.0)
response.raise_for_status()
result: dict[str, Any] = response.json()
if "error" in result:
if console:
console.print(f"[red]Workflow error: {result['error']}[/red]")
if "node_errors" in result:
for node_id, errors in result["node_errors"].items():
console.print(f" [yellow]Node {node_id}:[/yellow] {errors}")
return None
return result
except httpx.HTTPStatusError as e:
if console:
console.print(f"[red]API error: {e.response.status_code}[/red]")
try:
error_detail = e.response.json()
if "error" in error_detail:
console.print(f" [yellow]{error_detail['error']}[/yellow]")
except Exception:
pass
return None
except httpx.RequestError as e:
if console:
console.print(f"[red]Connection error: {e}[/red]")
return None
def _poll_for_completion(
prompt_id: str,
url: str,
timeout: float = 600.0,
poll_interval: float = 0.5,
on_progress: ProgressCallback | None = None,
) -> WorkflowResult:
"""Poll history endpoint for workflow completion.
Args:
prompt_id: The prompt ID to track
url: ComfyUI base URL
timeout: Maximum wait time in seconds
poll_interval: Time between polls in seconds
on_progress: Optional callback for progress updates
Returns:
WorkflowResult with outputs or errors
"""
start_time = time.time()
last_progress_time = 0.0
while time.time() - start_time < timeout:
try:
response = httpx.get(f"{url}/history/{prompt_id}", timeout=10.0)
response.raise_for_status()
history = response.json()
if prompt_id in history:
entry = history[prompt_id]
outputs = entry.get("outputs", {})
status_info = entry.get("status", {})
# Check for errors
if status_info.get("status_str") == "error":
return WorkflowResult(
prompt_id=prompt_id,
outputs=outputs,
node_errors=status_info.get("messages", {}),
success=False,
)
# Success - return outputs
return WorkflowResult(
prompt_id=prompt_id,
outputs=outputs,
success=True,
)
# Still running - check queue for progress
if on_progress:
now = time.time()
if now - last_progress_time >= _PROGRESS_UPDATE_INTERVAL:
queue_response = httpx.get(f"{url}/queue", timeout=5.0)
if queue_response.status_code == HTTPStatus.OK:
queue_data = queue_response.json()
running = queue_data.get("queue_running", [])
pending = queue_data.get("queue_pending", [])
total = len(running) + len(pending)
on_progress(0, total, f"Queued ({len(pending)} pending)")
last_progress_time = now
except httpx.RequestError:
pass # Connection error, keep polling
time.sleep(poll_interval)
# Timeout
return WorkflowResult(
prompt_id=prompt_id,
node_errors={"timeout": f"Workflow did not complete within {timeout}s"},
success=False,
)
def run_workflow(
workflow: dict[str, Any] | Path,
url: str | None = None,
console: Console | None = None,
on_progress: ProgressCallback | None = None,
timeout: float = 600.0,
) -> WorkflowResult | None:
"""Run a workflow and wait for completion.
Args:
workflow: ComfyUI workflow dict (API format) or path to JSON file
url: ComfyUI base URL
console: Rich console for progress output
on_progress: Optional callback for progress updates
timeout: Maximum wait time in seconds
Returns:
WorkflowResult with outputs, or None if queuing failed
"""
base_url = url or _get_comfyui_url()
# Load workflow from file if needed
workflow_dict: dict[str, Any]
if isinstance(workflow, Path):
if not workflow.exists():
if console:
console.print(f"[red]Workflow file not found: {workflow}[/red]")
return None
workflow_dict = json.loads(workflow.read_text())
else:
workflow_dict = workflow
# Queue the workflow
if console:
console.print("[cyan]Queueing workflow...[/cyan]")
result = queue_prompt(workflow_dict, url=base_url, console=console)
if not result:
return None
prompt_id = result["prompt_id"]
if console:
console.print(f"[dim]Prompt ID: {prompt_id}[/dim]")
# Poll for completion with progress
if console:
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console,
) as progress:
task = progress.add_task("[cyan]Running workflow...", total=None)
def _console_progress(step: int, total: int, status: str) -> None:
progress.update(task, description=f"[cyan]{status}[/cyan]")
if on_progress:
on_progress(step, total, status)
return _poll_for_completion(prompt_id, base_url, timeout, on_progress=_console_progress)
else:
return _poll_for_completion(prompt_id, base_url, timeout, on_progress=on_progress)
# ============================================================================
# Simple Text-to-Image Generation
# ============================================================================
# Default SDXL/Flux compatible workflow template
# This is a minimal text-to-image workflow that works with most models
DEFAULT_WORKFLOW_TEMPLATE: dict[str, Any] = {
"3": {
"class_type": "KSampler",
"inputs": {
"seed": 0,
"steps": 20,
"cfg": 7.0,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1.0,
"model": ["4", 0],
"positive": ["6", 0],
"negative": ["7", 0],
"latent_image": ["5", 0],
},
},
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {"ckpt_name": ""},
},
"5": {
"class_type": "EmptyLatentImage",
"inputs": {"width": 1024, "height": 1024, "batch_size": 1},
},
"6": {
"class_type": "CLIPTextEncode",
"inputs": {"text": "", "clip": ["4", 1]},
},
"7": {
"class_type": "CLIPTextEncode",
"inputs": {"text": "", "clip": ["4", 1]},
},
"8": {
"class_type": "VAEDecode",
"inputs": {"samples": ["3", 0], "vae": ["4", 2]},
},
"9": {
"class_type": "SaveImage",
"inputs": {"filename_prefix": "comfy", "images": ["8", 0]},
},
}
def _build_workflow(
prompt: str,
negative_prompt: str = "",
model: str | None = None,
width: int = 1024,
height: int = 1024,
steps: int = 20,
cfg: float = 7.0,
seed: int = -1,
sampler: str = "euler",
scheduler: str = "normal",
) -> dict[str, Any]:
"""Build a text-to-image workflow from parameters.
Args:
prompt: Positive prompt text
negative_prompt: Negative prompt text
model: Checkpoint filename (if None, uses first available)
width: Image width
height: Image height
steps: Number of sampling steps
cfg: CFG scale
seed: Random seed (-1 for random)
sampler: Sampler name
scheduler: Scheduler name
Returns:
ComfyUI workflow dict
"""
workflow = copy.deepcopy(DEFAULT_WORKFLOW_TEMPLATE)
# Set seed (random if -1)
actual_seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
# Update KSampler settings
workflow["3"]["inputs"]["seed"] = actual_seed
workflow["3"]["inputs"]["steps"] = steps
workflow["3"]["inputs"]["cfg"] = cfg
workflow["3"]["inputs"]["sampler_name"] = sampler
workflow["3"]["inputs"]["scheduler"] = scheduler
# Set model
if model:
workflow["4"]["inputs"]["ckpt_name"] = model
# Set dimensions
workflow["5"]["inputs"]["width"] = width
workflow["5"]["inputs"]["height"] = height
# Set prompts
workflow["6"]["inputs"]["text"] = prompt
workflow["7"]["inputs"]["text"] = negative_prompt
return workflow
def generate_image(
prompt: str,
url: str | None = None,
negative_prompt: str = "",
model: str | None = None,
width: int = 1024,
height: int = 1024,
steps: int = 20,
cfg: float = 7.0,
seed: int = -1,
sampler: str = "euler",
scheduler: str = "normal",
console: Console | None = None,
on_progress: ProgressCallback | None = None,
timeout: float = 600.0,
) -> GenerationResult | None:
"""Generate an image using a simple text-to-image workflow.
Args:
prompt: Positive prompt text
url: ComfyUI base URL
negative_prompt: Negative prompt text
model: Checkpoint filename (if None, must be pre-loaded in ComfyUI)
width: Image width
height: Image height
steps: Number of sampling steps
cfg: CFG scale
seed: Random seed (-1 for random)
sampler: Sampler name (euler, dpm_2, etc.)
scheduler: Scheduler name (normal, karras, etc.)
console: Rich console for progress output
on_progress: Optional callback for progress updates
timeout: Maximum wait time in seconds
Returns:
GenerationResult with image paths, or None if generation failed
"""
base_url = url or _get_comfyui_url()
# Get available models if none specified
if not model:
models = get_loaded_models(url=base_url)
if models and models.get("checkpoints"):
model = models["checkpoints"][0]
if console:
console.print(f"[dim]Using model: {model}[/dim]")
else:
if console:
console.print("[red]No checkpoints available. Specify a model with --model[/red]")
return None
# Build workflow
workflow = _build_workflow(
prompt=prompt,
negative_prompt=negative_prompt,
model=model,
width=width,
height=height,
steps=steps,
cfg=cfg,
seed=seed,
sampler=sampler,
scheduler=scheduler,
)
# Run workflow
result = run_workflow(
workflow=workflow,
url=base_url,
console=console,
on_progress=on_progress,
timeout=timeout,
)
if not result:
return None
if not result.success:
if console:
console.print("[red]Generation failed[/red]")
for node_id, errors in result.node_errors.items():
console.print(f" [yellow]Node {node_id}:[/yellow] {errors}")
return GenerationResult(
prompt_id=result.prompt_id,
node_errors=result.node_errors,
success=False,
)
# Extract image paths from outputs
images: list[Path] = []
for _node_id, output in result.outputs.items():
if "images" in output:
for img_info in output["images"]:
filename = img_info.get("filename", "")
subfolder = img_info.get("subfolder", "")
img_type = img_info.get("type", "output")
# Construct path (ComfyUI default output structure)
if img_type == "output":
img_path = Path(subfolder) / filename if subfolder else Path(filename)
images.append(img_path)
if console and images:
console.print(f"[green]Generated {len(images)} image(s)[/green]")
for img in images:
console.print(f" [dim]{img}[/dim]")
return GenerationResult(
prompt_id=result.prompt_id,
images=images,
success=True,
)
def get_image(
filename: str,
url: str | None = None,
subfolder: str = "",
folder_type: str = "output",
) -> bytes | None:
"""Download a generated image from ComfyUI.
Args:
filename: Image filename
url: ComfyUI base URL
subfolder: Subfolder within the output directory
folder_type: Folder type (output, input, temp)
Returns:
Image bytes or None on error
"""
base_url = url or _get_comfyui_url()
try:
params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
response = httpx.get(f"{base_url}/view", params=params, timeout=30.0)
response.raise_for_status()
return response.content
except httpx.RequestError:
return None