diff --git a/tensors/cli.py b/tensors/cli.py index 040b13c..ba0fb93 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -1431,6 +1431,20 @@ def style_sweep( # noqa: PLR0915 str | None, typer.Option("-r", "--remote", help="Remote server name or URL (overrides template)"), ] = None, + parallel_queue: Annotated[ + int, + typer.Option( + "--parallel-queue", + "-P", + help=( + "Concurrent ComfyUI submissions (default 1). Values >1 submit N " + "prompts to ComfyUI's HTTP queue in parallel; the GPU still " + "processes one at a time, but HTTP/init/download overhead is " + "pipelined for a ~5-15%% speedup. Per-task console output will " + "interleave; use the manifest for accurate per-slug timing." + ), + ), + ] = 1, ) -> None: """Sweep a base prompt across a list of style suffixes, one image per style. @@ -1451,6 +1465,7 @@ def style_sweep( # noqa: PLR0915 tsr style-sweep -t template.json --list tsr style-sweep --styles styles.json --list tsr style-sweep -t template.json -S 38-manara -S 40-elder-kurtzman + tsr style-sweep -t template.json -P 4 # 4 concurrent submissions """ # ---- Validate required inputs ---- # Template is required for generation, but optional when --list is paired @@ -1459,6 +1474,10 @@ def style_sweep( # noqa: PLR0915 console.print("[red]--template is required (or use --list with --styles to inspect a styles file)[/red]") raise typer.Exit(1) + if parallel_queue < 1: + console.print("[red]--parallel-queue must be >= 1[/red]") + raise typer.Exit(1) + # ---- Load template (if provided) ---- tpl_data: dict[str, Any] = {} if template is not None: @@ -1586,6 +1605,11 @@ def style_sweep( # noqa: PLR0915 results: list[dict[str, Any]] = [] failed_slugs: list[str] = [] + # Pre-compute per-style work items and short-circuit skip/dry-run cases + # synchronously (no point pipelining no-ops). Only real generation tasks + # go through the executor path. + pending_tasks: list[tuple[int, dict[str, str], dict[str, Any], Path]] = [] + for i, entry in enumerate(style_entries, start=1): slug = entry["slug"] suffix = entry["suffix"] @@ -1602,7 +1626,6 @@ def style_sweep( # noqa: PLR0915 "error": None, } - # Skip if exists if skip_existing and out_path.exists(): console.print(f"[dim]\\[{i}/{total}] {slug} skip (exists)[/dim]") result["success"] = True @@ -1619,60 +1642,130 @@ def style_sweep( # noqa: PLR0915 results.append(result) continue + pending_tasks.append((i, entry, result, out_path)) + + # Common kwargs for every _run_generation call — extracted from the + # template once, reused across sequential and parallel paths. + base_gen_kwargs: dict[str, Any] = { + "model": _t("model"), + "width": _t("width", cast=int), + "height": _t("height", cast=int), + "steps": _t("steps", cast=int), + "cfg": _t("cfg", cast=float), + "guidance": _t("guidance", cast=float), + "seed": _t("seed", cast=int, default=-1), + "sampler": _t("sampler"), + "scheduler": _t("scheduler"), + "vae": _t("vae"), + "orientation": _t("orientation", default="square"), + "lora": _t("lora"), + "lora_strength": _t("lora_strength", cast=float, default=0.8), + "negative": negative_val, + "count": 1, + "rating": _t("rating"), + "no_quality": bool(_t("no_quality", default=False)), + "no_negative": bool(_t("no_negative", default=False)), + "family": _t("family"), + "remote": gen_remote, + "json_output": False, + } + + def _run_one(task: tuple[int, dict[str, str], dict[str, Any], Path]) -> dict[str, Any]: + """Run a single style. Returns the result dict (success or error captured).""" + idx, entry_in, res, opath = task + composed = res["prompt"] start = time.perf_counter() try: - _run_generation( - prompt=composed_prompt, - model=_t("model"), - width=_t("width", cast=int), - height=_t("height", cast=int), - steps=_t("steps", cast=int), - cfg=_t("cfg", cast=float), - guidance=_t("guidance", cast=float), - seed=_t("seed", cast=int, default=-1), - sampler=_t("sampler"), - scheduler=_t("scheduler"), - vae=_t("vae"), - orientation=_t("orientation", default="square"), - lora=_t("lora"), - lora_strength=_t("lora_strength", cast=float, default=0.8), - negative=negative_val, - count=1, - rating=_t("rating"), - no_quality=bool(_t("no_quality", default=False)), - no_negative=bool(_t("no_negative", default=False)), - family=_t("family"), - output=out_path, - remote=gen_remote, - json_output=False, - ) - duration = time.perf_counter() - start - result["duration_sec"] = round(duration, 2) - result["success"] = True - console.print(f"[green]\\[{i}/{total}] {slug} ok in {duration:.1f}s[/green]") - except typer.Exit as e: - duration = time.perf_counter() - start - result["duration_sec"] = round(duration, 2) - err_msg = f"generate exited with code {e.exit_code}" - result["error"] = err_msg - failed_slugs.append(slug) - console.print(f"[red]\\[{i}/{total}] {slug} FAIL: {err_msg}[/red]") - if not continue_on_error: - results.append(result) - _write_sweep_manifest(out_dir, template, styles_origin, results) - raise - except Exception as e: - duration = time.perf_counter() - start - result["duration_sec"] = round(duration, 2) - result["error"] = str(e) - failed_slugs.append(slug) - console.print(f"[red]\\[{i}/{total}] {slug} FAIL: {e}[/red]") - if not continue_on_error: - results.append(result) - _write_sweep_manifest(out_dir, template, styles_origin, results) - raise typer.Exit(1) from e + _run_generation(prompt=composed, output=opath, **base_gen_kwargs) + res["duration_sec"] = round(time.perf_counter() - start, 2) + res["success"] = True + except typer.Exit as ex: + res["duration_sec"] = round(time.perf_counter() - start, 2) + res["error"] = f"generate exited with code {ex.exit_code}" + except Exception as ex: # noqa: BLE001 + res["duration_sec"] = round(time.perf_counter() - start, 2) + res["error"] = str(ex) + return res - results.append(result) + if parallel_queue == 1: + # Sequential path — preserves the original "ok in Xs" / "FAIL" lines + # exactly so existing log-scraping stays valid. + for task in pending_tasks: + idx, _entry, result, _out_path = task + slug = result["slug"] + res = _run_one(task) + if res["success"]: + console.print(f"[green]\\[{idx}/{total}] {slug} ok in {res['duration_sec']:.1f}s[/green]") + else: + failed_slugs.append(slug) + console.print(f"[red]\\[{idx}/{total}] {slug} FAIL: {res['error']}[/red]") + if not continue_on_error: + results.append(res) + _write_sweep_manifest(out_dir, template, styles_origin, results) + raise typer.Exit(1) + results.append(res) + else: + # Parallel path — N concurrent ComfyUI submissions. The GPU still + # processes one prompt at a time, but the HTTP queueing, websocket + # polling, image download, and disk write phases overlap with the + # next prompt's submission. Net effect: 5-15%% speedup vs sequential. + # Per-task console output WILL interleave (each _run_generation + # prints its own progress); use the manifest for clean per-slug + # timing data. + from concurrent.futures import ThreadPoolExecutor, as_completed # noqa: PLC0415 + + console.print( + f"[dim]Parallel queue: {parallel_queue} concurrent submissions " + f"(output may interleave)[/dim]" + ) + # abort-on-error is incompatible with parallelism — we can't reliably + # stop in-flight workers without losing their state. Warn and continue. + if not continue_on_error: + console.print( + "[yellow]Note: --abort-on-error is ignored when --parallel-queue > 1; " + "in-flight tasks always complete[/yellow]" + ) + + with ThreadPoolExecutor(max_workers=parallel_queue) as pool: + futures = {pool.submit(_run_one, task): task for task in pending_tasks} + completed = 0 + for fut in as_completed(futures): + completed += 1 + task = futures[fut] + idx, _entry, _res, _out_path = task + try: + res = fut.result() + except Exception as ex: # noqa: BLE001 + # Pathological — _run_one is supposed to catch everything. + # Re-build a result dict so the manifest is still well-formed. + res = { + "slug": task[2]["slug"], + "prompt": task[2]["prompt"], + "output": task[2]["output"], + "seed": task[2]["seed"], + "duration_sec": 0.0, + "success": False, + "error": f"executor exception: {ex}", + } + if res["success"]: + console.print( + f"[green]\\[{completed}/{len(pending_tasks)}] " + f"{res['slug']} ok in {res['duration_sec']:.1f}s " + f"(submit #{idx})[/green]" + ) + else: + failed_slugs.append(res["slug"]) + console.print( + f"[red]\\[{completed}/{len(pending_tasks)}] " + f"{res['slug']} FAIL: {res['error']}[/red]" + ) + results.append(res) + + # Reorder results to match the original styles list order so the manifest + # is human-readable. Skipped/dry-run entries already in `results` keep + # their position from the pre-loop walk. + slug_order = {e["slug"]: i for i, e in enumerate(style_entries)} + results.sort(key=lambda r: slug_order.get(r["slug"], 1_000_000)) # ---- Manifest ---- if not dry_run: diff --git a/tests/test_style_sweep.py b/tests/test_style_sweep.py index b519a15..6781adc 100644 --- a/tests/test_style_sweep.py +++ b/tests/test_style_sweep.py @@ -514,3 +514,225 @@ def test_style_filter_with_list(tmp_path: Path, calls: list[dict[str, Any]]) -> assert "01-foo" not in result.output assert "03-baz" not in result.output assert "1 entries" in result.output + + +# ----------------------------------------------------------------------------- +# --parallel-queue tests +# ----------------------------------------------------------------------------- + + +def test_parallel_queue_invalid_value(tmp_path: Path, calls: list[dict[str, Any]]) -> None: + """--parallel-queue 0 (or negative) is rejected.""" + styles_file = _write_styles_file(tmp_path, [{"slug": "01-foo", "suffix": "Foo"}]) + tpl = _write_template(tmp_path, output_dir=tmp_path / "out", styles=str(styles_file)) + + result = runner.invoke( + app, + ["style-sweep", "--template", str(tpl), "--parallel-queue", "0"], + ) + assert result.exit_code == 1 + assert ">= 1" in result.output + + +def test_parallel_queue_one_is_equivalent_to_sequential( + tmp_path: Path, calls: list[dict[str, Any]] +) -> None: + """-P 1 is identical to omitting the flag (the default sequential path).""" + styles_file = _write_styles_file( + tmp_path, + [ + {"slug": "01-foo", "suffix": "Foo"}, + {"slug": "02-bar", "suffix": "Bar"}, + ], + ) + tpl = _write_template(tmp_path, output_dir=tmp_path / "out", styles=str(styles_file)) + + result = runner.invoke( + app, + ["style-sweep", "--template", str(tpl), "--parallel-queue", "1"], + ) + assert result.exit_code == 0, result.output + assert len(calls) == 2 + # Sequential path emits "ok in X.Xs" without the "(submit #N)" suffix + assert "(submit #" not in result.output + # Order is deterministic: 01-foo then 02-bar + assert calls[0]["prompt"].endswith("Foo") + assert calls[1]["prompt"].endswith("Bar") + + +def test_parallel_queue_runs_all_styles( + tmp_path: Path, calls: list[dict[str, Any]] +) -> None: + """-P 3 still produces N outputs and N _run_generation calls (no skipped work).""" + styles_file = _write_styles_file( + tmp_path, + [ + {"slug": "01-foo", "suffix": "Foo"}, + {"slug": "02-bar", "suffix": "Bar"}, + {"slug": "03-baz", "suffix": "Baz"}, + {"slug": "04-qux", "suffix": "Qux"}, + ], + ) + out_dir = tmp_path / "out" + tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file)) + + result = runner.invoke( + app, + ["style-sweep", "--template", str(tpl), "--parallel-queue", "3"], + ) + assert result.exit_code == 0, result.output + assert len(calls) == 4 + # Output files exist regardless of submission order + assert (out_dir / "01-foo.png").is_file() + assert (out_dir / "02-bar.png").is_file() + assert (out_dir / "03-baz.png").is_file() + assert (out_dir / "04-qux.png").is_file() + # Parallel mode announces itself + assert "Parallel queue: 3 concurrent submissions" in result.output + + +def test_parallel_queue_manifest_preserves_source_order( + tmp_path: Path, calls: list[dict[str, Any]] +) -> None: + """Manifest results are sorted by original styles-list order, not completion order.""" + import time as time_mod + + styles_file = _write_styles_file( + tmp_path, + [ + {"slug": f"{n:02d}-s{n}", "suffix": f"S{n}"} for n in range(1, 7) + ], + ) + out_dir = tmp_path / "out" + tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file)) + + # Make completion order intentionally chaotic: later slugs finish faster. + def staggered_fake(**kwargs: Any) -> None: + prompt = kwargs.get("prompt", "") + # Heavier work for earlier slugs so they finish last. + # S1 sleeps 30ms, S6 sleeps 5ms. + sleep_ms = max(35 - int(prompt[-1]) * 5, 1) + time_mod.sleep(sleep_ms / 1000) + out: Path | None = kwargs.get("output") + if out is not None: + out.parent.mkdir(parents=True, exist_ok=True) + out.write_bytes(b"png") + calls.append(kwargs) + + import tensors.cli as cli_module # noqa: PLC0415 + + cli_module._run_generation = staggered_fake # type: ignore[assignment] + + result = runner.invoke( + app, + ["style-sweep", "--template", str(tpl), "--parallel-queue", "4"], + ) + assert result.exit_code == 0, result.output + + manifest = json.loads((out_dir / "_sweep.json").read_text()) + slugs_in_manifest = [r["slug"] for r in manifest["results"]] + assert slugs_in_manifest == [f"{n:02d}-s{n}" for n in range(1, 7)], ( + "Manifest results must be in source-list order, " + f"not completion order. Got: {slugs_in_manifest}" + ) + + +def test_parallel_queue_skip_existing_runs_synchronously( + tmp_path: Path, calls: list[dict[str, Any]] +) -> None: + """Pre-existing outputs are skipped before the executor even starts.""" + styles_file = _write_styles_file( + tmp_path, + [ + {"slug": "01-foo", "suffix": "Foo"}, + {"slug": "02-bar", "suffix": "Bar"}, + {"slug": "03-baz", "suffix": "Baz"}, + ], + ) + out_dir = tmp_path / "out" + out_dir.mkdir() + # Pre-create 02-bar + (out_dir / "02-bar.png").write_bytes(b"existing") + tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file)) + + result = runner.invoke( + app, + ["style-sweep", "--template", str(tpl), "--parallel-queue", "2"], + ) + assert result.exit_code == 0, result.output + # Only 01-foo and 03-baz generated; 02-bar skipped + assert len(calls) == 2 + generated_slugs = sorted(c["output"].name for c in calls) + assert generated_slugs == ["01-foo.png", "03-baz.png"] + assert "02-bar skip (exists)" in result.output + + +def test_parallel_queue_continues_after_individual_failure( + tmp_path: Path, calls: list[dict[str, Any]] +) -> None: + """One task raising doesn't kill the others; manifest records the failure.""" + import typer as typer_mod + + styles_file = _write_styles_file( + tmp_path, + [ + {"slug": "01-foo", "suffix": "Foo"}, + {"slug": "02-bar", "suffix": "Bar"}, + {"slug": "03-baz", "suffix": "Baz"}, + ], + ) + out_dir = tmp_path / "out" + tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file)) + + def selective_fail(**kwargs: Any) -> None: + prompt = kwargs.get("prompt", "") + out: Path | None = kwargs.get("output") + if "Bar" in prompt: + raise typer_mod.Exit(1) + if out is not None: + out.parent.mkdir(parents=True, exist_ok=True) + out.write_bytes(b"png") + calls.append(kwargs) + + import tensors.cli as cli_module # noqa: PLC0415 + + cli_module._run_generation = selective_fail # type: ignore[assignment] + + result = runner.invoke( + app, + ["style-sweep", "--template", str(tpl), "--parallel-queue", "3"], + ) + # Exit 1 because one slug failed (sweep exits non-zero on any failure even + # with --continue-on-error) + assert result.exit_code == 1, result.output + # Other two succeeded + assert (out_dir / "01-foo.png").is_file() + assert (out_dir / "03-baz.png").is_file() + assert not (out_dir / "02-bar.png").is_file() + # Manifest records all three with 02-bar marked failed + manifest = json.loads((out_dir / "_sweep.json").read_text()) + by_slug = {r["slug"]: r for r in manifest["results"]} + assert by_slug["01-foo"]["success"] is True + assert by_slug["02-bar"]["success"] is False + assert "exited with code 1" in by_slug["02-bar"]["error"] + assert by_slug["03-baz"]["success"] is True + + +def test_parallel_queue_warns_about_abort_on_error( + tmp_path: Path, calls: list[dict[str, Any]] +) -> None: + """--abort-on-error + parallel is contradictory; we warn and continue.""" + styles_file = _write_styles_file(tmp_path, [{"slug": "01-foo", "suffix": "Foo"}]) + tpl = _write_template(tmp_path, output_dir=tmp_path / "out", styles=str(styles_file)) + + result = runner.invoke( + app, + [ + "style-sweep", + "--template", str(tpl), + "--parallel-queue", "2", + "--abort-on-error", + ], + ) + assert result.exit_code == 0, result.output + assert "--abort-on-error is ignored when --parallel-queue > 1" in result.output