💬 Commit message: Update 2026-02-15 06:14:53, 5 files, 260 lines

📁 Files changed: 5
📝 Lines changed: 260

  • output.png
  • pyproject.toml
  • cli.py
  • comfy.py
  • uv.lock
This commit is contained in:
Adam Ladachowski
2026-02-15 06:14:53 +01:00
parent dfca42ac72
commit e016c01370
5 changed files with 243 additions and 17 deletions
BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 324 KiB

+1
View File
@@ -9,6 +9,7 @@ dependencies = [
"httpx>=0.27.0",
"rich>=13.0.0",
"typer>=0.15.0",
"websocket-client>=1.9.0",
]
[project.optional-dependencies]
+46 -2
View File
@@ -453,6 +453,7 @@ def config(
def generate(
prompt: Annotated[str, typer.Argument(help="Text prompt for image generation.")],
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
model: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model (remote mode only).")] = None,
host: Annotated[str, typer.Option(help="sd-server address (local mode).")] = "127.0.0.1",
port: Annotated[int, typer.Option(help="sd-server port (local mode).")] = 8080,
output: Annotated[str, typer.Option("-o", help="Output directory (local mode).")] = ".",
@@ -475,6 +476,11 @@ def generate(
# Remote mode: use TsrClient API
try:
with TsrClient(remote_url) as client:
# Switch model if specified
if model:
console.print(f"[cyan]Switching to model: {model}[/cyan]")
client.switch_model(model)
console.print(f"[cyan]Generating {batch_size} image(s) on {remote_url}...[/cyan]")
result = client.generate(
prompt=prompt,
@@ -501,6 +507,9 @@ def generate(
console.print(f"[green]Generated:[/green] {img.get('id', 'unknown')}")
else:
# Local mode: direct sd-server connection
if model:
console.print("[yellow]Warning: --model ignored in local mode (sd-server loads model at startup)[/yellow]")
from tensors.generate import SDClient, Txt2ImgParams, save_images # noqa: PLC0415
params = Txt2ImgParams(
@@ -1322,6 +1331,8 @@ def comfy_generate(
prompt: Annotated[str, typer.Argument(help="Text prompt for generation")],
url: Annotated[str, typer.Option("--url", "-u", help="ComfyUI server URL")] = COMFY_DEFAULT_URL,
checkpoint: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model name")] = None,
lora: Annotated[str | None, typer.Option("--lora", "-l", help="LoRA name")] = None,
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8,
negative: Annotated[str, typer.Option("-n", "--negative", help="Negative prompt")] = "",
width: Annotated[int, typer.Option("-W", "--width", help="Image width")] = 512,
height: Annotated[int, typer.Option("-H", "--height", help="Image height")] = 512,
@@ -1330,35 +1341,68 @@ def comfy_generate(
seed: Annotated[int, typer.Option("-s", "--seed", help="RNG seed (-1 for random)")] = -1,
sampler: Annotated[str, typer.Option("--sampler", help="Sampler name")] = "euler_ancestral",
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file path")] = None,
no_restart: Annotated[bool, typer.Option("--no-restart", help="Don't auto-restart on model change")] = False,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Generate an image using ComfyUI."""
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn # noqa: PLC0415
from tensors.comfy import ComfyClient # noqa: PLC0415
progress_task = None
progress_ctx = None
def on_status(msg: str) -> None:
console.print(f"[cyan]{msg}[/cyan]")
def on_progress(current: int, total: int, stage: str) -> None: # noqa: ARG001
nonlocal progress_task, progress_ctx
if progress_ctx is None:
progress_ctx = Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
console=console,
)
progress_ctx.start()
progress_task = progress_ctx.add_task("Sampling", total=total)
if progress_task is not None:
progress_ctx.update(progress_task, completed=current)
try:
client = ComfyClient(url)
console.print(f"[cyan]Generating with ComfyUI at {url}...[/cyan]")
console.print(f"[dim]ComfyUI: {url}[/dim]")
result = client.generate(
prompt=prompt,
negative_prompt=negative,
checkpoint=checkpoint,
lora=lora,
lora_strength=lora_strength,
width=width,
height=height,
steps=steps,
cfg=cfg,
seed=seed,
sampler=sampler,
on_status=on_status,
on_progress=on_progress,
auto_restart=not no_restart,
)
except Exception as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
finally:
if progress_ctx:
progress_ctx.stop()
if json_output:
console.print_json(data=result)
return
console.print(f"[green]Generated![/green] Seed: {result['seed']}, Checkpoint: {result['checkpoint']}")
lora_info = f", LoRA: {result['lora']}" if result.get("lora") else ""
console.print(f"[green]Generated![/green] Seed: {result['seed']}, Checkpoint: {result['checkpoint']}{lora_info}")
if result["images"]:
img_info = result["images"][0]
+185 -15
View File
@@ -3,13 +3,19 @@
from __future__ import annotations
import json
import subprocess
import time
import uuid
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any
import httpx
if TYPE_CHECKING:
from collections.abc import Callable
from tensors.config import load_config, save_config
DEFAULT_WORKFLOW = {
"3": {
"class_type": "KSampler",
@@ -52,6 +58,55 @@ DEFAULT_WORKFLOW = {
},
}
COMFY_CONTAINER = "comfyui"
COMFY_HOST = "junkpile"
def get_last_checkpoint() -> str | None:
"""Get last used checkpoint from config."""
cfg = load_config()
value = cfg.get("comfy", {}).get("last_checkpoint")
return str(value) if value else None
def save_last_checkpoint(checkpoint: str) -> None:
"""Save last used checkpoint to config."""
cfg = load_config()
if "comfy" not in cfg:
cfg["comfy"] = {}
cfg["comfy"]["last_checkpoint"] = checkpoint
save_config(cfg)
def restart_comfy_container(on_status: Callable[[str], None] | None = None) -> None:
"""Restart the ComfyUI container on junkpile and wait for it to be ready."""
def status(msg: str) -> None:
if on_status:
on_status(msg)
status("Restarting ComfyUI container...")
subprocess.run(
["ssh", COMFY_HOST, f"docker restart {COMFY_CONTAINER}"],
check=True,
capture_output=True,
)
# Wait for ComfyUI to be ready
status("Waiting for ComfyUI to start...")
max_wait = 120
start = time.time()
while time.time() - start < max_wait:
try:
resp = httpx.get(f"http://{COMFY_HOST}:8188/system_stats", timeout=5)
if resp.is_success:
status("ComfyUI is ready!")
return
except httpx.HTTPError:
pass
time.sleep(2)
raise TimeoutError(f"ComfyUI did not start within {max_wait}s")
class ComfyClient:
"""Simple ComfyUI API client."""
@@ -97,15 +152,85 @@ class ComfyClient:
return dict(data.get(prompt_id, {})) if prompt_id in data else None
return None
def wait_for_completion(self, prompt_id: str, poll_interval: float = 0.5) -> dict[str, Any]:
"""Poll until the prompt completes."""
def wait_for_completion(
self,
prompt_id: str,
on_progress: Callable[[int, int, str], None] | None = None,
) -> dict[str, Any]:
"""Wait for prompt completion with progress updates via websocket."""
import websocket # noqa: PLC0415
ws_url = self.base_url.replace("http://", "ws://").replace("https://", "wss://")
ws_url = f"{ws_url}/ws?clientId={self.client_id}"
result: dict[str, Any] = {}
completed = False
def on_message(ws: Any, message: str) -> None: # noqa: ARG001
nonlocal completed, result
try:
data = json.loads(message)
msg_type = data.get("type")
if msg_type == "progress":
progress_data = data.get("data", {})
current = progress_data.get("value", 0)
total = progress_data.get("max", 1)
if on_progress:
on_progress(current, total, "sampling")
elif msg_type == "executing":
exec_data = data.get("data", {})
if exec_data.get("node") is None and exec_data.get("prompt_id") == prompt_id:
# Execution finished
completed = True
elif msg_type == "executed":
exec_data = data.get("data", {})
if exec_data.get("prompt_id") == prompt_id:
result = exec_data
except json.JSONDecodeError:
pass
def on_error(ws: Any, error: Exception) -> None: # noqa: ARG001
nonlocal completed
completed = True
def on_close(ws: Any, close_status_code: int, close_msg: str) -> None: # noqa: ARG001
nonlocal completed
completed = True
ws = websocket.WebSocketApp(
ws_url,
on_message=on_message,
on_error=on_error,
on_close=on_close,
)
# Run websocket in a thread
import threading # noqa: PLC0415
ws_thread = threading.Thread(target=ws.run_forever)
ws_thread.daemon = True
ws_thread.start()
# Wait for completion
start = time.time()
while time.time() - start < self.timeout:
history = self.get_history(prompt_id)
if history and history.get("outputs"):
return history
time.sleep(poll_interval)
raise TimeoutError(f"Prompt {prompt_id} did not complete within {self.timeout}s")
while not completed and time.time() - start < self.timeout:
time.sleep(0.1)
ws.close()
if not completed:
raise TimeoutError(f"Prompt {prompt_id} did not complete within {self.timeout}s")
# Get final history
history = self.get_history(prompt_id)
if not history:
raise RuntimeError(f"Could not get history for prompt {prompt_id}")
return history
def get_image(self, filename: str, subfolder: str = "", folder_type: str = "output") -> bytes:
"""Download an image from ComfyUI."""
@@ -119,21 +244,40 @@ class ComfyClient:
prompt: str,
negative_prompt: str = "",
checkpoint: str | None = None,
lora: str | None = None,
lora_strength: float = 0.8,
width: int = 512,
height: int = 512,
steps: int = 20,
cfg: float = 7.0,
seed: int = -1,
sampler: str = "euler_a",
sampler: str = "euler_ancestral",
scheduler: str = "normal",
on_progress: Callable[[int, int, str], None] | None = None,
on_status: Callable[[str], None] | None = None,
auto_restart: bool = True,
) -> dict[str, Any]:
"""Generate an image with a simple txt2img workflow."""
# Use first checkpoint if not specified
if not checkpoint:
checkpoints = self.get_checkpoints()
if not checkpoints:
raise ValueError("No checkpoints available")
checkpoint = checkpoints[0]
# Try last used checkpoint first
checkpoint = get_last_checkpoint()
if not checkpoint:
checkpoints = self.get_checkpoints()
if not checkpoints:
raise ValueError("No checkpoints available")
checkpoint = checkpoints[0]
# Check if we need to restart container for model change
if auto_restart:
last_checkpoint = get_last_checkpoint()
if last_checkpoint and last_checkpoint != checkpoint:
if on_status:
on_status(f"Model changed: {last_checkpoint} -> {checkpoint}")
restart_comfy_container(on_status)
# Save checkpoint as last used
save_last_checkpoint(checkpoint)
# Build workflow
workflow = json.loads(json.dumps(DEFAULT_WORKFLOW))
@@ -148,9 +292,34 @@ class ComfyClient:
workflow["3"]["inputs"]["sampler_name"] = sampler
workflow["3"]["inputs"]["scheduler"] = scheduler
# Add LoRA if specified
if lora:
workflow["10"] = {
"class_type": "LoraLoader",
"inputs": {
"lora_name": lora,
"strength_model": lora_strength,
"strength_clip": lora_strength,
"model": ["4", 0],
"clip": ["4", 1],
},
}
# Rewire: KSampler uses LoRA output instead of checkpoint
workflow["3"]["inputs"]["model"] = ["10", 0]
# Rewire: CLIP encoders use LoRA output
workflow["6"]["inputs"]["clip"] = ["10", 1]
workflow["7"]["inputs"]["clip"] = ["10", 1]
if on_status:
on_status("Queueing prompt...")
# Queue and wait
prompt_id = self.queue_prompt(workflow)
history = self.wait_for_completion(prompt_id)
if on_status:
on_status("Generating...")
history = self.wait_for_completion(prompt_id, on_progress)
# Extract output images
outputs = history.get("outputs", {})
@@ -168,6 +337,7 @@ class ComfyClient:
"prompt_id": prompt_id,
"images": images,
"checkpoint": checkpoint,
"lora": lora,
"seed": workflow["3"]["inputs"]["seed"],
}
Generated
+11
View File
@@ -714,6 +714,7 @@ dependencies = [
{ name = "rich" },
{ name = "safetensors" },
{ name = "typer" },
{ name = "websocket-client" },
]
[package.optional-dependencies]
@@ -743,6 +744,7 @@ requires-dist = [
{ name = "safetensors", specifier = ">=0.4.0" },
{ name = "typer", specifier = ">=0.15.0" },
{ name = "uvicorn", marker = "extra == 'server'", specifier = ">=0.30" },
{ name = "websocket-client", specifier = ">=1.9.0" },
]
provides-extras = ["server"]
@@ -822,6 +824,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/6a/2a/dc2228b2888f51192c7dc766106cd475f1b768c10caaf9727659726f7391/virtualenv-20.36.1-py3-none-any.whl", hash = "sha256:575a8d6b124ef88f6f51d56d656132389f961062a9177016a50e4f507bbcc19f", size = 6008258, upload-time = "2026-01-09T18:20:59.425Z" },
]
[[package]]
name = "websocket-client"
version = "1.9.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/2c/41/aa4bf9664e4cda14c3b39865b12251e8e7d239f4cd0e3cc1b6c2ccde25c1/websocket_client-1.9.0.tar.gz", hash = "sha256:9e813624b6eb619999a97dc7958469217c3176312b3a16a4bd1bc7e08a46ec98", size = 70576, upload-time = "2025-10-07T21:16:36.495Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/34/db/b10e48aa8fff7407e67470363eac595018441cf32d5e1001567a7aeba5d2/websocket_client-1.9.0-py3-none-any.whl", hash = "sha256:af248a825037ef591efbf6ed20cc5faa03d3b47b9e5a2230a529eeee1c1fc3ef", size = 82616, upload-time = "2025-10-07T21:16:34.951Z" },
]
[[package]]
name = "zstandard"
version = "0.25.0"