diff --git a/output.png b/output.png deleted file mode 100644 index 87b8f1b..0000000 Binary files a/output.png and /dev/null differ diff --git a/pyproject.toml b/pyproject.toml index f0b65e2..3afb800 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "httpx>=0.27.0", "rich>=13.0.0", "typer>=0.15.0", + "websocket-client>=1.9.0", ] [project.optional-dependencies] diff --git a/tensors/cli.py b/tensors/cli.py index 8937f73..acead72 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -453,6 +453,7 @@ def config( def generate( prompt: Annotated[str, typer.Argument(help="Text prompt for image generation.")], remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, + model: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model (remote mode only).")] = None, host: Annotated[str, typer.Option(help="sd-server address (local mode).")] = "127.0.0.1", port: Annotated[int, typer.Option(help="sd-server port (local mode).")] = 8080, output: Annotated[str, typer.Option("-o", help="Output directory (local mode).")] = ".", @@ -475,6 +476,11 @@ def generate( # Remote mode: use TsrClient API try: with TsrClient(remote_url) as client: + # Switch model if specified + if model: + console.print(f"[cyan]Switching to model: {model}[/cyan]") + client.switch_model(model) + console.print(f"[cyan]Generating {batch_size} image(s) on {remote_url}...[/cyan]") result = client.generate( prompt=prompt, @@ -501,6 +507,9 @@ def generate( console.print(f"[green]Generated:[/green] {img.get('id', 'unknown')}") else: # Local mode: direct sd-server connection + if model: + console.print("[yellow]Warning: --model ignored in local mode (sd-server loads model at startup)[/yellow]") + from tensors.generate import SDClient, Txt2ImgParams, save_images # noqa: PLC0415 params = Txt2ImgParams( @@ -1322,6 +1331,8 @@ def comfy_generate( prompt: Annotated[str, typer.Argument(help="Text prompt for generation")], url: Annotated[str, typer.Option("--url", "-u", help="ComfyUI server URL")] = COMFY_DEFAULT_URL, checkpoint: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model name")] = None, + lora: Annotated[str | None, typer.Option("--lora", "-l", help="LoRA name")] = None, + lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8, negative: Annotated[str, typer.Option("-n", "--negative", help="Negative prompt")] = "", width: Annotated[int, typer.Option("-W", "--width", help="Image width")] = 512, height: Annotated[int, typer.Option("-H", "--height", help="Image height")] = 512, @@ -1330,35 +1341,68 @@ def comfy_generate( seed: Annotated[int, typer.Option("-s", "--seed", help="RNG seed (-1 for random)")] = -1, sampler: Annotated[str, typer.Option("--sampler", help="Sampler name")] = "euler_ancestral", output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file path")] = None, + no_restart: Annotated[bool, typer.Option("--no-restart", help="Don't auto-restart on model change")] = False, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, ) -> None: """Generate an image using ComfyUI.""" + from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn # noqa: PLC0415 + from tensors.comfy import ComfyClient # noqa: PLC0415 + progress_task = None + progress_ctx = None + + def on_status(msg: str) -> None: + console.print(f"[cyan]{msg}[/cyan]") + + def on_progress(current: int, total: int, stage: str) -> None: # noqa: ARG001 + nonlocal progress_task, progress_ctx + if progress_ctx is None: + progress_ctx = Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + console=console, + ) + progress_ctx.start() + progress_task = progress_ctx.add_task("Sampling", total=total) + if progress_task is not None: + progress_ctx.update(progress_task, completed=current) + try: client = ComfyClient(url) - console.print(f"[cyan]Generating with ComfyUI at {url}...[/cyan]") + console.print(f"[dim]ComfyUI: {url}[/dim]") result = client.generate( prompt=prompt, negative_prompt=negative, checkpoint=checkpoint, + lora=lora, + lora_strength=lora_strength, width=width, height=height, steps=steps, cfg=cfg, seed=seed, sampler=sampler, + on_status=on_status, + on_progress=on_progress, + auto_restart=not no_restart, ) except Exception as e: console.print(f"[red]Error: {e}[/red]") raise typer.Exit(1) from e + finally: + if progress_ctx: + progress_ctx.stop() if json_output: console.print_json(data=result) return - console.print(f"[green]Generated![/green] Seed: {result['seed']}, Checkpoint: {result['checkpoint']}") + lora_info = f", LoRA: {result['lora']}" if result.get("lora") else "" + console.print(f"[green]Generated![/green] Seed: {result['seed']}, Checkpoint: {result['checkpoint']}{lora_info}") if result["images"]: img_info = result["images"][0] diff --git a/tensors/comfy.py b/tensors/comfy.py index 33c88f6..43412ce 100644 --- a/tensors/comfy.py +++ b/tensors/comfy.py @@ -3,13 +3,19 @@ from __future__ import annotations import json +import subprocess import time import uuid from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any import httpx +if TYPE_CHECKING: + from collections.abc import Callable + +from tensors.config import load_config, save_config + DEFAULT_WORKFLOW = { "3": { "class_type": "KSampler", @@ -52,6 +58,55 @@ DEFAULT_WORKFLOW = { }, } +COMFY_CONTAINER = "comfyui" +COMFY_HOST = "junkpile" + + +def get_last_checkpoint() -> str | None: + """Get last used checkpoint from config.""" + cfg = load_config() + value = cfg.get("comfy", {}).get("last_checkpoint") + return str(value) if value else None + + +def save_last_checkpoint(checkpoint: str) -> None: + """Save last used checkpoint to config.""" + cfg = load_config() + if "comfy" not in cfg: + cfg["comfy"] = {} + cfg["comfy"]["last_checkpoint"] = checkpoint + save_config(cfg) + + +def restart_comfy_container(on_status: Callable[[str], None] | None = None) -> None: + """Restart the ComfyUI container on junkpile and wait for it to be ready.""" + def status(msg: str) -> None: + if on_status: + on_status(msg) + + status("Restarting ComfyUI container...") + subprocess.run( + ["ssh", COMFY_HOST, f"docker restart {COMFY_CONTAINER}"], + check=True, + capture_output=True, + ) + + # Wait for ComfyUI to be ready + status("Waiting for ComfyUI to start...") + max_wait = 120 + start = time.time() + while time.time() - start < max_wait: + try: + resp = httpx.get(f"http://{COMFY_HOST}:8188/system_stats", timeout=5) + if resp.is_success: + status("ComfyUI is ready!") + return + except httpx.HTTPError: + pass + time.sleep(2) + + raise TimeoutError(f"ComfyUI did not start within {max_wait}s") + class ComfyClient: """Simple ComfyUI API client.""" @@ -97,15 +152,85 @@ class ComfyClient: return dict(data.get(prompt_id, {})) if prompt_id in data else None return None - def wait_for_completion(self, prompt_id: str, poll_interval: float = 0.5) -> dict[str, Any]: - """Poll until the prompt completes.""" + def wait_for_completion( + self, + prompt_id: str, + on_progress: Callable[[int, int, str], None] | None = None, + ) -> dict[str, Any]: + """Wait for prompt completion with progress updates via websocket.""" + import websocket # noqa: PLC0415 + + ws_url = self.base_url.replace("http://", "ws://").replace("https://", "wss://") + ws_url = f"{ws_url}/ws?clientId={self.client_id}" + + result: dict[str, Any] = {} + completed = False + + def on_message(ws: Any, message: str) -> None: # noqa: ARG001 + nonlocal completed, result + try: + data = json.loads(message) + msg_type = data.get("type") + + if msg_type == "progress": + progress_data = data.get("data", {}) + current = progress_data.get("value", 0) + total = progress_data.get("max", 1) + if on_progress: + on_progress(current, total, "sampling") + + elif msg_type == "executing": + exec_data = data.get("data", {}) + if exec_data.get("node") is None and exec_data.get("prompt_id") == prompt_id: + # Execution finished + completed = True + + elif msg_type == "executed": + exec_data = data.get("data", {}) + if exec_data.get("prompt_id") == prompt_id: + result = exec_data + + except json.JSONDecodeError: + pass + + def on_error(ws: Any, error: Exception) -> None: # noqa: ARG001 + nonlocal completed + completed = True + + def on_close(ws: Any, close_status_code: int, close_msg: str) -> None: # noqa: ARG001 + nonlocal completed + completed = True + + ws = websocket.WebSocketApp( + ws_url, + on_message=on_message, + on_error=on_error, + on_close=on_close, + ) + + # Run websocket in a thread + import threading # noqa: PLC0415 + + ws_thread = threading.Thread(target=ws.run_forever) + ws_thread.daemon = True + ws_thread.start() + + # Wait for completion start = time.time() - while time.time() - start < self.timeout: - history = self.get_history(prompt_id) - if history and history.get("outputs"): - return history - time.sleep(poll_interval) - raise TimeoutError(f"Prompt {prompt_id} did not complete within {self.timeout}s") + while not completed and time.time() - start < self.timeout: + time.sleep(0.1) + + ws.close() + + if not completed: + raise TimeoutError(f"Prompt {prompt_id} did not complete within {self.timeout}s") + + # Get final history + history = self.get_history(prompt_id) + if not history: + raise RuntimeError(f"Could not get history for prompt {prompt_id}") + + return history def get_image(self, filename: str, subfolder: str = "", folder_type: str = "output") -> bytes: """Download an image from ComfyUI.""" @@ -119,21 +244,40 @@ class ComfyClient: prompt: str, negative_prompt: str = "", checkpoint: str | None = None, + lora: str | None = None, + lora_strength: float = 0.8, width: int = 512, height: int = 512, steps: int = 20, cfg: float = 7.0, seed: int = -1, - sampler: str = "euler_a", + sampler: str = "euler_ancestral", scheduler: str = "normal", + on_progress: Callable[[int, int, str], None] | None = None, + on_status: Callable[[str], None] | None = None, + auto_restart: bool = True, ) -> dict[str, Any]: """Generate an image with a simple txt2img workflow.""" # Use first checkpoint if not specified if not checkpoint: - checkpoints = self.get_checkpoints() - if not checkpoints: - raise ValueError("No checkpoints available") - checkpoint = checkpoints[0] + # Try last used checkpoint first + checkpoint = get_last_checkpoint() + if not checkpoint: + checkpoints = self.get_checkpoints() + if not checkpoints: + raise ValueError("No checkpoints available") + checkpoint = checkpoints[0] + + # Check if we need to restart container for model change + if auto_restart: + last_checkpoint = get_last_checkpoint() + if last_checkpoint and last_checkpoint != checkpoint: + if on_status: + on_status(f"Model changed: {last_checkpoint} -> {checkpoint}") + restart_comfy_container(on_status) + + # Save checkpoint as last used + save_last_checkpoint(checkpoint) # Build workflow workflow = json.loads(json.dumps(DEFAULT_WORKFLOW)) @@ -148,9 +292,34 @@ class ComfyClient: workflow["3"]["inputs"]["sampler_name"] = sampler workflow["3"]["inputs"]["scheduler"] = scheduler + # Add LoRA if specified + if lora: + workflow["10"] = { + "class_type": "LoraLoader", + "inputs": { + "lora_name": lora, + "strength_model": lora_strength, + "strength_clip": lora_strength, + "model": ["4", 0], + "clip": ["4", 1], + }, + } + # Rewire: KSampler uses LoRA output instead of checkpoint + workflow["3"]["inputs"]["model"] = ["10", 0] + # Rewire: CLIP encoders use LoRA output + workflow["6"]["inputs"]["clip"] = ["10", 1] + workflow["7"]["inputs"]["clip"] = ["10", 1] + + if on_status: + on_status("Queueing prompt...") + # Queue and wait prompt_id = self.queue_prompt(workflow) - history = self.wait_for_completion(prompt_id) + + if on_status: + on_status("Generating...") + + history = self.wait_for_completion(prompt_id, on_progress) # Extract output images outputs = history.get("outputs", {}) @@ -168,6 +337,7 @@ class ComfyClient: "prompt_id": prompt_id, "images": images, "checkpoint": checkpoint, + "lora": lora, "seed": workflow["3"]["inputs"]["seed"], } diff --git a/uv.lock b/uv.lock index 2506473..12f4353 100644 --- a/uv.lock +++ b/uv.lock @@ -714,6 +714,7 @@ dependencies = [ { name = "rich" }, { name = "safetensors" }, { name = "typer" }, + { name = "websocket-client" }, ] [package.optional-dependencies] @@ -743,6 +744,7 @@ requires-dist = [ { name = "safetensors", specifier = ">=0.4.0" }, { name = "typer", specifier = ">=0.15.0" }, { name = "uvicorn", marker = "extra == 'server'", specifier = ">=0.30" }, + { name = "websocket-client", specifier = ">=1.9.0" }, ] provides-extras = ["server"] @@ -822,6 +824,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/2a/dc2228b2888f51192c7dc766106cd475f1b768c10caaf9727659726f7391/virtualenv-20.36.1-py3-none-any.whl", hash = "sha256:575a8d6b124ef88f6f51d56d656132389f961062a9177016a50e4f507bbcc19f", size = 6008258, upload-time = "2026-01-09T18:20:59.425Z" }, ] +[[package]] +name = "websocket-client" +version = "1.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/41/aa4bf9664e4cda14c3b39865b12251e8e7d239f4cd0e3cc1b6c2ccde25c1/websocket_client-1.9.0.tar.gz", hash = "sha256:9e813624b6eb619999a97dc7958469217c3176312b3a16a4bd1bc7e08a46ec98", size = 70576, upload-time = "2025-10-07T21:16:36.495Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/34/db/b10e48aa8fff7407e67470363eac595018441cf32d5e1001567a7aeba5d2/websocket_client-1.9.0-py3-none-any.whl", hash = "sha256:af248a825037ef591efbf6ed20cc5faa03d3b47b9e5a2230a529eeee1c1fc3ef", size = 82616, upload-time = "2025-10-07T21:16:34.951Z" }, +] + [[package]] name = "zstandard" version = "0.25.0"