💬 Commit message: Update 2026-02-15 06:14:53, 5 files, 260 lines

📁 Files changed: 5
📝 Lines changed: 260

  • output.png
  • pyproject.toml
  • cli.py
  • comfy.py
  • uv.lock
This commit is contained in:
Adam Ladachowski
2026-02-15 06:14:53 +01:00
parent dfca42ac72
commit e016c01370
5 changed files with 243 additions and 17 deletions
BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 324 KiB

+1
View File
@@ -9,6 +9,7 @@ dependencies = [
"httpx>=0.27.0", "httpx>=0.27.0",
"rich>=13.0.0", "rich>=13.0.0",
"typer>=0.15.0", "typer>=0.15.0",
"websocket-client>=1.9.0",
] ]
[project.optional-dependencies] [project.optional-dependencies]
+46 -2
View File
@@ -453,6 +453,7 @@ def config(
def generate( def generate(
prompt: Annotated[str, typer.Argument(help="Text prompt for image generation.")], prompt: Annotated[str, typer.Argument(help="Text prompt for image generation.")],
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
model: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model (remote mode only).")] = None,
host: Annotated[str, typer.Option(help="sd-server address (local mode).")] = "127.0.0.1", host: Annotated[str, typer.Option(help="sd-server address (local mode).")] = "127.0.0.1",
port: Annotated[int, typer.Option(help="sd-server port (local mode).")] = 8080, port: Annotated[int, typer.Option(help="sd-server port (local mode).")] = 8080,
output: Annotated[str, typer.Option("-o", help="Output directory (local mode).")] = ".", output: Annotated[str, typer.Option("-o", help="Output directory (local mode).")] = ".",
@@ -475,6 +476,11 @@ def generate(
# Remote mode: use TsrClient API # Remote mode: use TsrClient API
try: try:
with TsrClient(remote_url) as client: with TsrClient(remote_url) as client:
# Switch model if specified
if model:
console.print(f"[cyan]Switching to model: {model}[/cyan]")
client.switch_model(model)
console.print(f"[cyan]Generating {batch_size} image(s) on {remote_url}...[/cyan]") console.print(f"[cyan]Generating {batch_size} image(s) on {remote_url}...[/cyan]")
result = client.generate( result = client.generate(
prompt=prompt, prompt=prompt,
@@ -501,6 +507,9 @@ def generate(
console.print(f"[green]Generated:[/green] {img.get('id', 'unknown')}") console.print(f"[green]Generated:[/green] {img.get('id', 'unknown')}")
else: else:
# Local mode: direct sd-server connection # Local mode: direct sd-server connection
if model:
console.print("[yellow]Warning: --model ignored in local mode (sd-server loads model at startup)[/yellow]")
from tensors.generate import SDClient, Txt2ImgParams, save_images # noqa: PLC0415 from tensors.generate import SDClient, Txt2ImgParams, save_images # noqa: PLC0415
params = Txt2ImgParams( params = Txt2ImgParams(
@@ -1322,6 +1331,8 @@ def comfy_generate(
prompt: Annotated[str, typer.Argument(help="Text prompt for generation")], prompt: Annotated[str, typer.Argument(help="Text prompt for generation")],
url: Annotated[str, typer.Option("--url", "-u", help="ComfyUI server URL")] = COMFY_DEFAULT_URL, url: Annotated[str, typer.Option("--url", "-u", help="ComfyUI server URL")] = COMFY_DEFAULT_URL,
checkpoint: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model name")] = None, checkpoint: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model name")] = None,
lora: Annotated[str | None, typer.Option("--lora", "-l", help="LoRA name")] = None,
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8,
negative: Annotated[str, typer.Option("-n", "--negative", help="Negative prompt")] = "", negative: Annotated[str, typer.Option("-n", "--negative", help="Negative prompt")] = "",
width: Annotated[int, typer.Option("-W", "--width", help="Image width")] = 512, width: Annotated[int, typer.Option("-W", "--width", help="Image width")] = 512,
height: Annotated[int, typer.Option("-H", "--height", help="Image height")] = 512, height: Annotated[int, typer.Option("-H", "--height", help="Image height")] = 512,
@@ -1330,35 +1341,68 @@ def comfy_generate(
seed: Annotated[int, typer.Option("-s", "--seed", help="RNG seed (-1 for random)")] = -1, seed: Annotated[int, typer.Option("-s", "--seed", help="RNG seed (-1 for random)")] = -1,
sampler: Annotated[str, typer.Option("--sampler", help="Sampler name")] = "euler_ancestral", sampler: Annotated[str, typer.Option("--sampler", help="Sampler name")] = "euler_ancestral",
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file path")] = None, output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file path")] = None,
no_restart: Annotated[bool, typer.Option("--no-restart", help="Don't auto-restart on model change")] = False,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None: ) -> None:
"""Generate an image using ComfyUI.""" """Generate an image using ComfyUI."""
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn # noqa: PLC0415
from tensors.comfy import ComfyClient # noqa: PLC0415 from tensors.comfy import ComfyClient # noqa: PLC0415
progress_task = None
progress_ctx = None
def on_status(msg: str) -> None:
console.print(f"[cyan]{msg}[/cyan]")
def on_progress(current: int, total: int, stage: str) -> None: # noqa: ARG001
nonlocal progress_task, progress_ctx
if progress_ctx is None:
progress_ctx = Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
console=console,
)
progress_ctx.start()
progress_task = progress_ctx.add_task("Sampling", total=total)
if progress_task is not None:
progress_ctx.update(progress_task, completed=current)
try: try:
client = ComfyClient(url) client = ComfyClient(url)
console.print(f"[cyan]Generating with ComfyUI at {url}...[/cyan]") console.print(f"[dim]ComfyUI: {url}[/dim]")
result = client.generate( result = client.generate(
prompt=prompt, prompt=prompt,
negative_prompt=negative, negative_prompt=negative,
checkpoint=checkpoint, checkpoint=checkpoint,
lora=lora,
lora_strength=lora_strength,
width=width, width=width,
height=height, height=height,
steps=steps, steps=steps,
cfg=cfg, cfg=cfg,
seed=seed, seed=seed,
sampler=sampler, sampler=sampler,
on_status=on_status,
on_progress=on_progress,
auto_restart=not no_restart,
) )
except Exception as e: except Exception as e:
console.print(f"[red]Error: {e}[/red]") console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e raise typer.Exit(1) from e
finally:
if progress_ctx:
progress_ctx.stop()
if json_output: if json_output:
console.print_json(data=result) console.print_json(data=result)
return return
console.print(f"[green]Generated![/green] Seed: {result['seed']}, Checkpoint: {result['checkpoint']}") lora_info = f", LoRA: {result['lora']}" if result.get("lora") else ""
console.print(f"[green]Generated![/green] Seed: {result['seed']}, Checkpoint: {result['checkpoint']}{lora_info}")
if result["images"]: if result["images"]:
img_info = result["images"][0] img_info = result["images"][0]
+180 -10
View File
@@ -3,13 +3,19 @@
from __future__ import annotations from __future__ import annotations
import json import json
import subprocess
import time import time
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Any from typing import TYPE_CHECKING, Any
import httpx import httpx
if TYPE_CHECKING:
from collections.abc import Callable
from tensors.config import load_config, save_config
DEFAULT_WORKFLOW = { DEFAULT_WORKFLOW = {
"3": { "3": {
"class_type": "KSampler", "class_type": "KSampler",
@@ -52,6 +58,55 @@ DEFAULT_WORKFLOW = {
}, },
} }
COMFY_CONTAINER = "comfyui"
COMFY_HOST = "junkpile"
def get_last_checkpoint() -> str | None:
"""Get last used checkpoint from config."""
cfg = load_config()
value = cfg.get("comfy", {}).get("last_checkpoint")
return str(value) if value else None
def save_last_checkpoint(checkpoint: str) -> None:
"""Save last used checkpoint to config."""
cfg = load_config()
if "comfy" not in cfg:
cfg["comfy"] = {}
cfg["comfy"]["last_checkpoint"] = checkpoint
save_config(cfg)
def restart_comfy_container(on_status: Callable[[str], None] | None = None) -> None:
"""Restart the ComfyUI container on junkpile and wait for it to be ready."""
def status(msg: str) -> None:
if on_status:
on_status(msg)
status("Restarting ComfyUI container...")
subprocess.run(
["ssh", COMFY_HOST, f"docker restart {COMFY_CONTAINER}"],
check=True,
capture_output=True,
)
# Wait for ComfyUI to be ready
status("Waiting for ComfyUI to start...")
max_wait = 120
start = time.time()
while time.time() - start < max_wait:
try:
resp = httpx.get(f"http://{COMFY_HOST}:8188/system_stats", timeout=5)
if resp.is_success:
status("ComfyUI is ready!")
return
except httpx.HTTPError:
pass
time.sleep(2)
raise TimeoutError(f"ComfyUI did not start within {max_wait}s")
class ComfyClient: class ComfyClient:
"""Simple ComfyUI API client.""" """Simple ComfyUI API client."""
@@ -97,16 +152,86 @@ class ComfyClient:
return dict(data.get(prompt_id, {})) if prompt_id in data else None return dict(data.get(prompt_id, {})) if prompt_id in data else None
return None return None
def wait_for_completion(self, prompt_id: str, poll_interval: float = 0.5) -> dict[str, Any]: def wait_for_completion(
"""Poll until the prompt completes.""" self,
prompt_id: str,
on_progress: Callable[[int, int, str], None] | None = None,
) -> dict[str, Any]:
"""Wait for prompt completion with progress updates via websocket."""
import websocket # noqa: PLC0415
ws_url = self.base_url.replace("http://", "ws://").replace("https://", "wss://")
ws_url = f"{ws_url}/ws?clientId={self.client_id}"
result: dict[str, Any] = {}
completed = False
def on_message(ws: Any, message: str) -> None: # noqa: ARG001
nonlocal completed, result
try:
data = json.loads(message)
msg_type = data.get("type")
if msg_type == "progress":
progress_data = data.get("data", {})
current = progress_data.get("value", 0)
total = progress_data.get("max", 1)
if on_progress:
on_progress(current, total, "sampling")
elif msg_type == "executing":
exec_data = data.get("data", {})
if exec_data.get("node") is None and exec_data.get("prompt_id") == prompt_id:
# Execution finished
completed = True
elif msg_type == "executed":
exec_data = data.get("data", {})
if exec_data.get("prompt_id") == prompt_id:
result = exec_data
except json.JSONDecodeError:
pass
def on_error(ws: Any, error: Exception) -> None: # noqa: ARG001
nonlocal completed
completed = True
def on_close(ws: Any, close_status_code: int, close_msg: str) -> None: # noqa: ARG001
nonlocal completed
completed = True
ws = websocket.WebSocketApp(
ws_url,
on_message=on_message,
on_error=on_error,
on_close=on_close,
)
# Run websocket in a thread
import threading # noqa: PLC0415
ws_thread = threading.Thread(target=ws.run_forever)
ws_thread.daemon = True
ws_thread.start()
# Wait for completion
start = time.time() start = time.time()
while time.time() - start < self.timeout: while not completed and time.time() - start < self.timeout:
history = self.get_history(prompt_id) time.sleep(0.1)
if history and history.get("outputs"):
return history ws.close()
time.sleep(poll_interval)
if not completed:
raise TimeoutError(f"Prompt {prompt_id} did not complete within {self.timeout}s") raise TimeoutError(f"Prompt {prompt_id} did not complete within {self.timeout}s")
# Get final history
history = self.get_history(prompt_id)
if not history:
raise RuntimeError(f"Could not get history for prompt {prompt_id}")
return history
def get_image(self, filename: str, subfolder: str = "", folder_type: str = "output") -> bytes: def get_image(self, filename: str, subfolder: str = "", folder_type: str = "output") -> bytes:
"""Download an image from ComfyUI.""" """Download an image from ComfyUI."""
params = {"filename": filename, "subfolder": subfolder, "type": folder_type} params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
@@ -119,22 +244,41 @@ class ComfyClient:
prompt: str, prompt: str,
negative_prompt: str = "", negative_prompt: str = "",
checkpoint: str | None = None, checkpoint: str | None = None,
lora: str | None = None,
lora_strength: float = 0.8,
width: int = 512, width: int = 512,
height: int = 512, height: int = 512,
steps: int = 20, steps: int = 20,
cfg: float = 7.0, cfg: float = 7.0,
seed: int = -1, seed: int = -1,
sampler: str = "euler_a", sampler: str = "euler_ancestral",
scheduler: str = "normal", scheduler: str = "normal",
on_progress: Callable[[int, int, str], None] | None = None,
on_status: Callable[[str], None] | None = None,
auto_restart: bool = True,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Generate an image with a simple txt2img workflow.""" """Generate an image with a simple txt2img workflow."""
# Use first checkpoint if not specified # Use first checkpoint if not specified
if not checkpoint:
# Try last used checkpoint first
checkpoint = get_last_checkpoint()
if not checkpoint: if not checkpoint:
checkpoints = self.get_checkpoints() checkpoints = self.get_checkpoints()
if not checkpoints: if not checkpoints:
raise ValueError("No checkpoints available") raise ValueError("No checkpoints available")
checkpoint = checkpoints[0] checkpoint = checkpoints[0]
# Check if we need to restart container for model change
if auto_restart:
last_checkpoint = get_last_checkpoint()
if last_checkpoint and last_checkpoint != checkpoint:
if on_status:
on_status(f"Model changed: {last_checkpoint} -> {checkpoint}")
restart_comfy_container(on_status)
# Save checkpoint as last used
save_last_checkpoint(checkpoint)
# Build workflow # Build workflow
workflow = json.loads(json.dumps(DEFAULT_WORKFLOW)) workflow = json.loads(json.dumps(DEFAULT_WORKFLOW))
workflow["4"]["inputs"]["ckpt_name"] = checkpoint workflow["4"]["inputs"]["ckpt_name"] = checkpoint
@@ -148,9 +292,34 @@ class ComfyClient:
workflow["3"]["inputs"]["sampler_name"] = sampler workflow["3"]["inputs"]["sampler_name"] = sampler
workflow["3"]["inputs"]["scheduler"] = scheduler workflow["3"]["inputs"]["scheduler"] = scheduler
# Add LoRA if specified
if lora:
workflow["10"] = {
"class_type": "LoraLoader",
"inputs": {
"lora_name": lora,
"strength_model": lora_strength,
"strength_clip": lora_strength,
"model": ["4", 0],
"clip": ["4", 1],
},
}
# Rewire: KSampler uses LoRA output instead of checkpoint
workflow["3"]["inputs"]["model"] = ["10", 0]
# Rewire: CLIP encoders use LoRA output
workflow["6"]["inputs"]["clip"] = ["10", 1]
workflow["7"]["inputs"]["clip"] = ["10", 1]
if on_status:
on_status("Queueing prompt...")
# Queue and wait # Queue and wait
prompt_id = self.queue_prompt(workflow) prompt_id = self.queue_prompt(workflow)
history = self.wait_for_completion(prompt_id)
if on_status:
on_status("Generating...")
history = self.wait_for_completion(prompt_id, on_progress)
# Extract output images # Extract output images
outputs = history.get("outputs", {}) outputs = history.get("outputs", {})
@@ -168,6 +337,7 @@ class ComfyClient:
"prompt_id": prompt_id, "prompt_id": prompt_id,
"images": images, "images": images,
"checkpoint": checkpoint, "checkpoint": checkpoint,
"lora": lora,
"seed": workflow["3"]["inputs"]["seed"], "seed": workflow["3"]["inputs"]["seed"],
} }
Generated
+11
View File
@@ -714,6 +714,7 @@ dependencies = [
{ name = "rich" }, { name = "rich" },
{ name = "safetensors" }, { name = "safetensors" },
{ name = "typer" }, { name = "typer" },
{ name = "websocket-client" },
] ]
[package.optional-dependencies] [package.optional-dependencies]
@@ -743,6 +744,7 @@ requires-dist = [
{ name = "safetensors", specifier = ">=0.4.0" }, { name = "safetensors", specifier = ">=0.4.0" },
{ name = "typer", specifier = ">=0.15.0" }, { name = "typer", specifier = ">=0.15.0" },
{ name = "uvicorn", marker = "extra == 'server'", specifier = ">=0.30" }, { name = "uvicorn", marker = "extra == 'server'", specifier = ">=0.30" },
{ name = "websocket-client", specifier = ">=1.9.0" },
] ]
provides-extras = ["server"] provides-extras = ["server"]
@@ -822,6 +824,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/6a/2a/dc2228b2888f51192c7dc766106cd475f1b768c10caaf9727659726f7391/virtualenv-20.36.1-py3-none-any.whl", hash = "sha256:575a8d6b124ef88f6f51d56d656132389f961062a9177016a50e4f507bbcc19f", size = 6008258, upload-time = "2026-01-09T18:20:59.425Z" }, { url = "https://files.pythonhosted.org/packages/6a/2a/dc2228b2888f51192c7dc766106cd475f1b768c10caaf9727659726f7391/virtualenv-20.36.1-py3-none-any.whl", hash = "sha256:575a8d6b124ef88f6f51d56d656132389f961062a9177016a50e4f507bbcc19f", size = 6008258, upload-time = "2026-01-09T18:20:59.425Z" },
] ]
[[package]]
name = "websocket-client"
version = "1.9.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/2c/41/aa4bf9664e4cda14c3b39865b12251e8e7d239f4cd0e3cc1b6c2ccde25c1/websocket_client-1.9.0.tar.gz", hash = "sha256:9e813624b6eb619999a97dc7958469217c3176312b3a16a4bd1bc7e08a46ec98", size = 70576, upload-time = "2025-10-07T21:16:36.495Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/34/db/b10e48aa8fff7407e67470363eac595018441cf32d5e1001567a7aeba5d2/websocket_client-1.9.0-py3-none-any.whl", hash = "sha256:af248a825037ef591efbf6ed20cc5faa03d3b47b9e5a2230a529eeee1c1fc3ef", size = 82616, upload-time = "2025-10-07T21:16:34.951Z" },
]
[[package]] [[package]]
name = "zstandard" name = "zstandard"
version = "0.25.0" version = "0.25.0"