diff --git a/output.png b/output.png new file mode 100644 index 0000000..87b8f1b Binary files /dev/null and b/output.png differ diff --git a/tensors/cli.py b/tensors/cli.py index 5b12c12..8937f73 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -8,6 +8,7 @@ from importlib.metadata import version from pathlib import Path from typing import Annotated, Any +import httpx import typer from rich.console import Console from rich.table import Table @@ -1198,6 +1199,177 @@ def remote_default( console.print("[green]Default remote cleared.[/green]") +# ============================================================================= +# ComfyUI Commands +# ============================================================================= + +comfy_app = typer.Typer( + name="comfy", + help="ComfyUI client commands.", + no_args_is_help=True, +) +app.add_typer(comfy_app, name="comfy") + +COMFY_DEFAULT_URL = "http://junkpile:8188" + + +@comfy_app.command("status") +def comfy_status( + url: Annotated[str, typer.Option("--url", "-u", help="ComfyUI server URL")] = COMFY_DEFAULT_URL, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """Show ComfyUI server status.""" + try: + resp = httpx.get(f"{url}/system_stats", timeout=10) + resp.raise_for_status() + data = resp.json() + except httpx.HTTPError as e: + console.print(f"[red]Error: Cannot connect to ComfyUI at {url}: {e}[/red]") + raise typer.Exit(1) from e + + if json_output: + console.print_json(data=data) + return + + system = data.get("system", {}) + devices = data.get("devices", []) + + table = Table(title="ComfyUI Status", show_header=True, header_style="bold magenta") + table.add_column("Property", style="cyan") + table.add_column("Value", style="green") + table.add_row("URL", url) + table.add_row("Version", system.get("comfyui_version", "N/A")) + table.add_row("Python", system.get("python_version", "N/A").split()[0]) + table.add_row("PyTorch", system.get("pytorch_version", "N/A")) + table.add_row("RAM Free", f"{system.get('ram_free', 0) / 1024**3:.1f} GB") + + for dev in devices: + vram_free = dev.get("vram_free", 0) / 1024**3 + vram_total = dev.get("vram_total", 0) / 1024**3 + table.add_row(f"GPU {dev.get('index', 0)}", f"{dev.get('name', 'N/A')} ({vram_free:.1f}/{vram_total:.1f} GB free)") + + console.print(table) + + +@comfy_app.command("models") +def comfy_models( + url: Annotated[str, typer.Option("--url", "-u", help="ComfyUI server URL")] = COMFY_DEFAULT_URL, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """List available checkpoints in ComfyUI.""" + from tensors.comfy import ComfyClient # noqa: PLC0415 + + try: + client = ComfyClient(url) + checkpoints = client.get_checkpoints() + except httpx.HTTPError as e: + console.print(f"[red]Error: {e}[/red]") + raise typer.Exit(1) from e + + if json_output: + console.print_json(data={"checkpoints": checkpoints}) + return + + if not checkpoints: + console.print("[yellow]No checkpoints found.[/yellow]") + return + + table = Table(title="ComfyUI Checkpoints", show_header=True, header_style="bold magenta") + table.add_column("#", style="dim", width=3) + table.add_column("Name", style="cyan") + + for i, ckpt in enumerate(checkpoints, 1): + table.add_row(str(i), ckpt) + + console.print(table) + + +@comfy_app.command("loras") +def comfy_loras( + url: Annotated[str, typer.Option("--url", "-u", help="ComfyUI server URL")] = COMFY_DEFAULT_URL, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """List available LoRAs in ComfyUI.""" + from tensors.comfy import ComfyClient # noqa: PLC0415 + + try: + client = ComfyClient(url) + loras = client.get_loras() + except httpx.HTTPError as e: + console.print(f"[red]Error: {e}[/red]") + raise typer.Exit(1) from e + + if json_output: + console.print_json(data={"loras": loras}) + return + + if not loras: + console.print("[yellow]No LoRAs found.[/yellow]") + return + + table = Table(title="ComfyUI LoRAs", show_header=True, header_style="bold magenta") + table.add_column("#", style="dim", width=3) + table.add_column("Name", style="cyan") + + for i, lora in enumerate(loras, 1): + table.add_row(str(i), lora) + + console.print(table) + + +@comfy_app.command("generate") +def comfy_generate( + prompt: Annotated[str, typer.Argument(help="Text prompt for generation")], + url: Annotated[str, typer.Option("--url", "-u", help="ComfyUI server URL")] = COMFY_DEFAULT_URL, + checkpoint: Annotated[str | None, typer.Option("-m", "--model", help="Checkpoint model name")] = None, + negative: Annotated[str, typer.Option("-n", "--negative", help="Negative prompt")] = "", + width: Annotated[int, typer.Option("-W", "--width", help="Image width")] = 512, + height: Annotated[int, typer.Option("-H", "--height", help="Image height")] = 512, + steps: Annotated[int, typer.Option("--steps", help="Sampling steps")] = 20, + cfg: Annotated[float, typer.Option("--cfg", help="CFG scale")] = 7.0, + seed: Annotated[int, typer.Option("-s", "--seed", help="RNG seed (-1 for random)")] = -1, + sampler: Annotated[str, typer.Option("--sampler", help="Sampler name")] = "euler_ancestral", + output: Annotated[Path | None, typer.Option("-o", "--output", help="Output file path")] = None, + json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, +) -> None: + """Generate an image using ComfyUI.""" + from tensors.comfy import ComfyClient # noqa: PLC0415 + + try: + client = ComfyClient(url) + + console.print(f"[cyan]Generating with ComfyUI at {url}...[/cyan]") + result = client.generate( + prompt=prompt, + negative_prompt=negative, + checkpoint=checkpoint, + width=width, + height=height, + steps=steps, + cfg=cfg, + seed=seed, + sampler=sampler, + ) + except Exception as e: + console.print(f"[red]Error: {e}[/red]") + raise typer.Exit(1) from e + + if json_output: + console.print_json(data=result) + return + + console.print(f"[green]Generated![/green] Seed: {result['seed']}, Checkpoint: {result['checkpoint']}") + + if result["images"]: + img_info = result["images"][0] + console.print(f"[dim]Image: {img_info['filename']}[/dim]") + + if output: + img_data = client.get_image(img_info["filename"], img_info["subfolder"], img_info["type"]) + output.write_bytes(img_data) + console.print(f"[green]Saved to:[/green] {output}") + + def main() -> int: """Main entry point.""" # Handle legacy invocation: tsr -> tsr info @@ -1216,6 +1388,7 @@ def main() -> int: "images", "models", "remote", + "comfy", ) if len(sys.argv) > 1 and not sys.argv[1].startswith("-"): arg = sys.argv[1] diff --git a/tensors/comfy.py b/tensors/comfy.py new file mode 100644 index 0000000..33c88f6 --- /dev/null +++ b/tensors/comfy.py @@ -0,0 +1,190 @@ +"""Simple ComfyUI client for basic txt2img generation.""" + +from __future__ import annotations + +import json +import time +import uuid +from pathlib import Path +from typing import Any + +import httpx + +DEFAULT_WORKFLOW = { + "3": { + "class_type": "KSampler", + "inputs": { + "cfg": 7, + "denoise": 1, + "latent_image": ["5", 0], + "model": ["4", 0], + "negative": ["7", 0], + "positive": ["6", 0], + "sampler_name": "euler_ancestral", + "scheduler": "normal", + "seed": -1, + "steps": 20, + }, + }, + "4": { + "class_type": "CheckpointLoaderSimple", + "inputs": {"ckpt_name": ""}, + }, + "5": { + "class_type": "EmptyLatentImage", + "inputs": {"batch_size": 1, "height": 512, "width": 512}, + }, + "6": { + "class_type": "CLIPTextEncode", + "inputs": {"clip": ["4", 1], "text": ""}, + }, + "7": { + "class_type": "CLIPTextEncode", + "inputs": {"clip": ["4", 1], "text": ""}, + }, + "8": { + "class_type": "VAEDecode", + "inputs": {"samples": ["3", 0], "vae": ["4", 2]}, + }, + "9": { + "class_type": "SaveImage", + "inputs": {"filename_prefix": "comfy", "images": ["8", 0]}, + }, +} + + +class ComfyClient: + """Simple ComfyUI API client.""" + + def __init__(self, base_url: str = "http://127.0.0.1:8188", timeout: float = 300.0) -> None: + self.base_url = base_url.rstrip("/") + self.timeout = timeout + self.client_id = str(uuid.uuid4()) + + def get_checkpoints(self) -> list[str]: + """List available checkpoint models.""" + resp = httpx.get(f"{self.base_url}/object_info/CheckpointLoaderSimple", timeout=10) + resp.raise_for_status() + data = resp.json() + return list(data.get("CheckpointLoaderSimple", {}).get("input", {}).get("required", {}).get("ckpt_name", [[]])[0]) + + def get_loras(self) -> list[str]: + """List available LoRAs.""" + resp = httpx.get(f"{self.base_url}/object_info/LoraLoader", timeout=10) + resp.raise_for_status() + data = resp.json() + return list(data.get("LoraLoader", {}).get("input", {}).get("required", {}).get("lora_name", [[]])[0]) + + def get_samplers(self) -> list[str]: + """List available samplers.""" + resp = httpx.get(f"{self.base_url}/object_info/KSampler", timeout=10) + resp.raise_for_status() + data = resp.json() + return list(data.get("KSampler", {}).get("input", {}).get("required", {}).get("sampler_name", [[]])[0]) + + def queue_prompt(self, workflow: dict[str, Any]) -> str: + """Queue a prompt and return the prompt_id.""" + payload = {"prompt": workflow, "client_id": self.client_id} + resp = httpx.post(f"{self.base_url}/prompt", json=payload, timeout=30) + resp.raise_for_status() + return str(resp.json()["prompt_id"]) + + def get_history(self, prompt_id: str) -> dict[str, Any] | None: + """Get history for a prompt_id.""" + resp = httpx.get(f"{self.base_url}/history/{prompt_id}", timeout=10) + if resp.is_success: + data = resp.json() + return dict(data.get(prompt_id, {})) if prompt_id in data else None + return None + + def wait_for_completion(self, prompt_id: str, poll_interval: float = 0.5) -> dict[str, Any]: + """Poll until the prompt completes.""" + start = time.time() + while time.time() - start < self.timeout: + history = self.get_history(prompt_id) + if history and history.get("outputs"): + return history + time.sleep(poll_interval) + raise TimeoutError(f"Prompt {prompt_id} did not complete within {self.timeout}s") + + def get_image(self, filename: str, subfolder: str = "", folder_type: str = "output") -> bytes: + """Download an image from ComfyUI.""" + params = {"filename": filename, "subfolder": subfolder, "type": folder_type} + resp = httpx.get(f"{self.base_url}/view", params=params, timeout=30) + resp.raise_for_status() + return resp.content + + def generate( + self, + prompt: str, + negative_prompt: str = "", + checkpoint: str | None = None, + width: int = 512, + height: int = 512, + steps: int = 20, + cfg: float = 7.0, + seed: int = -1, + sampler: str = "euler_a", + scheduler: str = "normal", + ) -> dict[str, Any]: + """Generate an image with a simple txt2img workflow.""" + # Use first checkpoint if not specified + if not checkpoint: + checkpoints = self.get_checkpoints() + if not checkpoints: + raise ValueError("No checkpoints available") + checkpoint = checkpoints[0] + + # Build workflow + workflow = json.loads(json.dumps(DEFAULT_WORKFLOW)) + workflow["4"]["inputs"]["ckpt_name"] = checkpoint + workflow["5"]["inputs"]["width"] = width + workflow["5"]["inputs"]["height"] = height + workflow["6"]["inputs"]["text"] = prompt + workflow["7"]["inputs"]["text"] = negative_prompt + workflow["3"]["inputs"]["steps"] = steps + workflow["3"]["inputs"]["cfg"] = cfg + workflow["3"]["inputs"]["seed"] = seed if seed >= 0 else int(time.time() * 1000) % (2**32) + workflow["3"]["inputs"]["sampler_name"] = sampler + workflow["3"]["inputs"]["scheduler"] = scheduler + + # Queue and wait + prompt_id = self.queue_prompt(workflow) + history = self.wait_for_completion(prompt_id) + + # Extract output images + outputs = history.get("outputs", {}) + images = [] + for _node_id, node_output in outputs.items(): + if "images" in node_output: + for img in node_output["images"]: + images.append({ + "filename": img["filename"], + "subfolder": img.get("subfolder", ""), + "type": img.get("type", "output"), + }) + + return { + "prompt_id": prompt_id, + "images": images, + "checkpoint": checkpoint, + "seed": workflow["3"]["inputs"]["seed"], + } + + def generate_and_save( + self, + prompt: str, + output_path: str | Path, + **kwargs: Any, + ) -> Path: + """Generate an image and save it locally.""" + result = self.generate(prompt, **kwargs) + if not result["images"]: + raise RuntimeError("No images generated") + + img_info = result["images"][0] + img_data = self.get_image(img_info["filename"], img_info["subfolder"], img_info["type"]) + + output = Path(output_path) + output.write_bytes(img_data) + return output