💬 Commit message: Update 2026-02-15 06:14:53, 5 files, 260 lines
📁 Files changed: 5 📝 Lines changed: 260 • output.png • pyproject.toml • cli.py • comfy.py • uv.lock
This commit is contained in:
+46
-2
@@ -453,6 +453,7 @@ def config(
|
||||
def generate(
|
||||
prompt: Annotated[str, typer.Argument(help="Text prompt for image generation.")],
|
||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||
model: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model (remote mode only).")] = None,
|
||||
host: Annotated[str, typer.Option(help="sd-server address (local mode).")] = "127.0.0.1",
|
||||
port: Annotated[int, typer.Option(help="sd-server port (local mode).")] = 8080,
|
||||
output: Annotated[str, typer.Option("-o", help="Output directory (local mode).")] = ".",
|
||||
@@ -475,6 +476,11 @@ def generate(
|
||||
# Remote mode: use TsrClient API
|
||||
try:
|
||||
with TsrClient(remote_url) as client:
|
||||
# Switch model if specified
|
||||
if model:
|
||||
console.print(f"[cyan]Switching to model: {model}[/cyan]")
|
||||
client.switch_model(model)
|
||||
|
||||
console.print(f"[cyan]Generating {batch_size} image(s) on {remote_url}...[/cyan]")
|
||||
result = client.generate(
|
||||
prompt=prompt,
|
||||
@@ -501,6 +507,9 @@ def generate(
|
||||
console.print(f"[green]Generated:[/green] {img.get('id', 'unknown')}")
|
||||
else:
|
||||
# Local mode: direct sd-server connection
|
||||
if model:
|
||||
console.print("[yellow]Warning: --model ignored in local mode (sd-server loads model at startup)[/yellow]")
|
||||
|
||||
from tensors.generate import SDClient, Txt2ImgParams, save_images # noqa: PLC0415
|
||||
|
||||
params = Txt2ImgParams(
|
||||
@@ -1322,6 +1331,8 @@ def comfy_generate(
|
||||
prompt: Annotated[str, typer.Argument(help="Text prompt for generation")],
|
||||
url: Annotated[str, typer.Option("--url", "-u", help="ComfyUI server URL")] = COMFY_DEFAULT_URL,
|
||||
checkpoint: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model name")] = None,
|
||||
lora: Annotated[str | None, typer.Option("--lora", "-l", help="LoRA name")] = None,
|
||||
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8,
|
||||
negative: Annotated[str, typer.Option("-n", "--negative", help="Negative prompt")] = "",
|
||||
width: Annotated[int, typer.Option("-W", "--width", help="Image width")] = 512,
|
||||
height: Annotated[int, typer.Option("-H", "--height", help="Image height")] = 512,
|
||||
@@ -1330,35 +1341,68 @@ def comfy_generate(
|
||||
seed: Annotated[int, typer.Option("-s", "--seed", help="RNG seed (-1 for random)")] = -1,
|
||||
sampler: Annotated[str, typer.Option("--sampler", help="Sampler name")] = "euler_ancestral",
|
||||
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file path")] = None,
|
||||
no_restart: Annotated[bool, typer.Option("--no-restart", help="Don't auto-restart on model change")] = False,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
) -> None:
|
||||
"""Generate an image using ComfyUI."""
|
||||
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn # noqa: PLC0415
|
||||
|
||||
from tensors.comfy import ComfyClient # noqa: PLC0415
|
||||
|
||||
progress_task = None
|
||||
progress_ctx = None
|
||||
|
||||
def on_status(msg: str) -> None:
|
||||
console.print(f"[cyan]{msg}[/cyan]")
|
||||
|
||||
def on_progress(current: int, total: int, stage: str) -> None: # noqa: ARG001
|
||||
nonlocal progress_task, progress_ctx
|
||||
if progress_ctx is None:
|
||||
progress_ctx = Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
console=console,
|
||||
)
|
||||
progress_ctx.start()
|
||||
progress_task = progress_ctx.add_task("Sampling", total=total)
|
||||
if progress_task is not None:
|
||||
progress_ctx.update(progress_task, completed=current)
|
||||
|
||||
try:
|
||||
client = ComfyClient(url)
|
||||
|
||||
console.print(f"[cyan]Generating with ComfyUI at {url}...[/cyan]")
|
||||
console.print(f"[dim]ComfyUI: {url}[/dim]")
|
||||
result = client.generate(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative,
|
||||
checkpoint=checkpoint,
|
||||
lora=lora,
|
||||
lora_strength=lora_strength,
|
||||
width=width,
|
||||
height=height,
|
||||
steps=steps,
|
||||
cfg=cfg,
|
||||
seed=seed,
|
||||
sampler=sampler,
|
||||
on_status=on_status,
|
||||
on_progress=on_progress,
|
||||
auto_restart=not no_restart,
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1) from e
|
||||
finally:
|
||||
if progress_ctx:
|
||||
progress_ctx.stop()
|
||||
|
||||
if json_output:
|
||||
console.print_json(data=result)
|
||||
return
|
||||
|
||||
console.print(f"[green]Generated![/green] Seed: {result['seed']}, Checkpoint: {result['checkpoint']}")
|
||||
lora_info = f", LoRA: {result['lora']}" if result.get("lora") else ""
|
||||
console.print(f"[green]Generated![/green] Seed: {result['seed']}, Checkpoint: {result['checkpoint']}{lora_info}")
|
||||
|
||||
if result["images"]:
|
||||
img_info = result["images"][0]
|
||||
|
||||
+185
-15
@@ -3,13 +3,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from tensors.config import load_config, save_config
|
||||
|
||||
DEFAULT_WORKFLOW = {
|
||||
"3": {
|
||||
"class_type": "KSampler",
|
||||
@@ -52,6 +58,55 @@ DEFAULT_WORKFLOW = {
|
||||
},
|
||||
}
|
||||
|
||||
COMFY_CONTAINER = "comfyui"
|
||||
COMFY_HOST = "junkpile"
|
||||
|
||||
|
||||
def get_last_checkpoint() -> str | None:
|
||||
"""Get last used checkpoint from config."""
|
||||
cfg = load_config()
|
||||
value = cfg.get("comfy", {}).get("last_checkpoint")
|
||||
return str(value) if value else None
|
||||
|
||||
|
||||
def save_last_checkpoint(checkpoint: str) -> None:
|
||||
"""Save last used checkpoint to config."""
|
||||
cfg = load_config()
|
||||
if "comfy" not in cfg:
|
||||
cfg["comfy"] = {}
|
||||
cfg["comfy"]["last_checkpoint"] = checkpoint
|
||||
save_config(cfg)
|
||||
|
||||
|
||||
def restart_comfy_container(on_status: Callable[[str], None] | None = None) -> None:
|
||||
"""Restart the ComfyUI container on junkpile and wait for it to be ready."""
|
||||
def status(msg: str) -> None:
|
||||
if on_status:
|
||||
on_status(msg)
|
||||
|
||||
status("Restarting ComfyUI container...")
|
||||
subprocess.run(
|
||||
["ssh", COMFY_HOST, f"docker restart {COMFY_CONTAINER}"],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
# Wait for ComfyUI to be ready
|
||||
status("Waiting for ComfyUI to start...")
|
||||
max_wait = 120
|
||||
start = time.time()
|
||||
while time.time() - start < max_wait:
|
||||
try:
|
||||
resp = httpx.get(f"http://{COMFY_HOST}:8188/system_stats", timeout=5)
|
||||
if resp.is_success:
|
||||
status("ComfyUI is ready!")
|
||||
return
|
||||
except httpx.HTTPError:
|
||||
pass
|
||||
time.sleep(2)
|
||||
|
||||
raise TimeoutError(f"ComfyUI did not start within {max_wait}s")
|
||||
|
||||
|
||||
class ComfyClient:
|
||||
"""Simple ComfyUI API client."""
|
||||
@@ -97,15 +152,85 @@ class ComfyClient:
|
||||
return dict(data.get(prompt_id, {})) if prompt_id in data else None
|
||||
return None
|
||||
|
||||
def wait_for_completion(self, prompt_id: str, poll_interval: float = 0.5) -> dict[str, Any]:
|
||||
"""Poll until the prompt completes."""
|
||||
def wait_for_completion(
|
||||
self,
|
||||
prompt_id: str,
|
||||
on_progress: Callable[[int, int, str], None] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Wait for prompt completion with progress updates via websocket."""
|
||||
import websocket # noqa: PLC0415
|
||||
|
||||
ws_url = self.base_url.replace("http://", "ws://").replace("https://", "wss://")
|
||||
ws_url = f"{ws_url}/ws?clientId={self.client_id}"
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
completed = False
|
||||
|
||||
def on_message(ws: Any, message: str) -> None: # noqa: ARG001
|
||||
nonlocal completed, result
|
||||
try:
|
||||
data = json.loads(message)
|
||||
msg_type = data.get("type")
|
||||
|
||||
if msg_type == "progress":
|
||||
progress_data = data.get("data", {})
|
||||
current = progress_data.get("value", 0)
|
||||
total = progress_data.get("max", 1)
|
||||
if on_progress:
|
||||
on_progress(current, total, "sampling")
|
||||
|
||||
elif msg_type == "executing":
|
||||
exec_data = data.get("data", {})
|
||||
if exec_data.get("node") is None and exec_data.get("prompt_id") == prompt_id:
|
||||
# Execution finished
|
||||
completed = True
|
||||
|
||||
elif msg_type == "executed":
|
||||
exec_data = data.get("data", {})
|
||||
if exec_data.get("prompt_id") == prompt_id:
|
||||
result = exec_data
|
||||
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
def on_error(ws: Any, error: Exception) -> None: # noqa: ARG001
|
||||
nonlocal completed
|
||||
completed = True
|
||||
|
||||
def on_close(ws: Any, close_status_code: int, close_msg: str) -> None: # noqa: ARG001
|
||||
nonlocal completed
|
||||
completed = True
|
||||
|
||||
ws = websocket.WebSocketApp(
|
||||
ws_url,
|
||||
on_message=on_message,
|
||||
on_error=on_error,
|
||||
on_close=on_close,
|
||||
)
|
||||
|
||||
# Run websocket in a thread
|
||||
import threading # noqa: PLC0415
|
||||
|
||||
ws_thread = threading.Thread(target=ws.run_forever)
|
||||
ws_thread.daemon = True
|
||||
ws_thread.start()
|
||||
|
||||
# Wait for completion
|
||||
start = time.time()
|
||||
while time.time() - start < self.timeout:
|
||||
history = self.get_history(prompt_id)
|
||||
if history and history.get("outputs"):
|
||||
return history
|
||||
time.sleep(poll_interval)
|
||||
raise TimeoutError(f"Prompt {prompt_id} did not complete within {self.timeout}s")
|
||||
while not completed and time.time() - start < self.timeout:
|
||||
time.sleep(0.1)
|
||||
|
||||
ws.close()
|
||||
|
||||
if not completed:
|
||||
raise TimeoutError(f"Prompt {prompt_id} did not complete within {self.timeout}s")
|
||||
|
||||
# Get final history
|
||||
history = self.get_history(prompt_id)
|
||||
if not history:
|
||||
raise RuntimeError(f"Could not get history for prompt {prompt_id}")
|
||||
|
||||
return history
|
||||
|
||||
def get_image(self, filename: str, subfolder: str = "", folder_type: str = "output") -> bytes:
|
||||
"""Download an image from ComfyUI."""
|
||||
@@ -119,21 +244,40 @@ class ComfyClient:
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
checkpoint: str | None = None,
|
||||
lora: str | None = None,
|
||||
lora_strength: float = 0.8,
|
||||
width: int = 512,
|
||||
height: int = 512,
|
||||
steps: int = 20,
|
||||
cfg: float = 7.0,
|
||||
seed: int = -1,
|
||||
sampler: str = "euler_a",
|
||||
sampler: str = "euler_ancestral",
|
||||
scheduler: str = "normal",
|
||||
on_progress: Callable[[int, int, str], None] | None = None,
|
||||
on_status: Callable[[str], None] | None = None,
|
||||
auto_restart: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Generate an image with a simple txt2img workflow."""
|
||||
# Use first checkpoint if not specified
|
||||
if not checkpoint:
|
||||
checkpoints = self.get_checkpoints()
|
||||
if not checkpoints:
|
||||
raise ValueError("No checkpoints available")
|
||||
checkpoint = checkpoints[0]
|
||||
# Try last used checkpoint first
|
||||
checkpoint = get_last_checkpoint()
|
||||
if not checkpoint:
|
||||
checkpoints = self.get_checkpoints()
|
||||
if not checkpoints:
|
||||
raise ValueError("No checkpoints available")
|
||||
checkpoint = checkpoints[0]
|
||||
|
||||
# Check if we need to restart container for model change
|
||||
if auto_restart:
|
||||
last_checkpoint = get_last_checkpoint()
|
||||
if last_checkpoint and last_checkpoint != checkpoint:
|
||||
if on_status:
|
||||
on_status(f"Model changed: {last_checkpoint} -> {checkpoint}")
|
||||
restart_comfy_container(on_status)
|
||||
|
||||
# Save checkpoint as last used
|
||||
save_last_checkpoint(checkpoint)
|
||||
|
||||
# Build workflow
|
||||
workflow = json.loads(json.dumps(DEFAULT_WORKFLOW))
|
||||
@@ -148,9 +292,34 @@ class ComfyClient:
|
||||
workflow["3"]["inputs"]["sampler_name"] = sampler
|
||||
workflow["3"]["inputs"]["scheduler"] = scheduler
|
||||
|
||||
# Add LoRA if specified
|
||||
if lora:
|
||||
workflow["10"] = {
|
||||
"class_type": "LoraLoader",
|
||||
"inputs": {
|
||||
"lora_name": lora,
|
||||
"strength_model": lora_strength,
|
||||
"strength_clip": lora_strength,
|
||||
"model": ["4", 0],
|
||||
"clip": ["4", 1],
|
||||
},
|
||||
}
|
||||
# Rewire: KSampler uses LoRA output instead of checkpoint
|
||||
workflow["3"]["inputs"]["model"] = ["10", 0]
|
||||
# Rewire: CLIP encoders use LoRA output
|
||||
workflow["6"]["inputs"]["clip"] = ["10", 1]
|
||||
workflow["7"]["inputs"]["clip"] = ["10", 1]
|
||||
|
||||
if on_status:
|
||||
on_status("Queueing prompt...")
|
||||
|
||||
# Queue and wait
|
||||
prompt_id = self.queue_prompt(workflow)
|
||||
history = self.wait_for_completion(prompt_id)
|
||||
|
||||
if on_status:
|
||||
on_status("Generating...")
|
||||
|
||||
history = self.wait_for_completion(prompt_id, on_progress)
|
||||
|
||||
# Extract output images
|
||||
outputs = history.get("outputs", {})
|
||||
@@ -168,6 +337,7 @@ class ComfyClient:
|
||||
"prompt_id": prompt_id,
|
||||
"images": images,
|
||||
"checkpoint": checkpoint,
|
||||
"lora": lora,
|
||||
"seed": workflow["3"]["inputs"]["seed"],
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user