diff --git a/tensors/cli.py b/tensors/cli.py index 101eec6..95906b9 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -346,7 +346,10 @@ def search( return key = api_key or load_api_key() - civitai_results: dict[str, Any] | None = None + # Reuse the name from the remote-mode branch above (which already returned) + # without redeclaring its type — mypy treats class-scope re-annotation as + # a no-redef even when control flow guarantees the branches don't overlap. + civitai_results = None hf_results: list[dict[str, Any]] | None = None # Search CivitAI @@ -874,6 +877,22 @@ def generate( # noqa: PLR0915 str | None, typer.Option("--input", "-I", help="JSON or YAML params (file path or inline; keys match CLI options)"), ] = None, + parallel_queue: Annotated[ + int, + typer.Option( + "--parallel-queue", + "-P", + help=( + "Concurrent ComfyUI submissions (default 1). When >1 with --count N, " + "splits the request into N independent jobs (batch_size=1 each) with " + "incrementing seeds, executed P-at-a-time via thread pool. The GPU " + "still processes one prompt at a time, but HTTP queue / init / " + "download phases pipeline for a ~5-15%% speedup. Per-task output " + "interleaves; final summary lists all saved files. Ignored when " + "--count is 1." + ), + ), + ] = 1, ) -> None: """Generate an image using text-to-image. @@ -887,6 +906,11 @@ def generate( # noqa: PLR0915 starting with ``{`` are JSON, everything else is YAML. CLI flags override --input values. + With --count > 1, images are generated as a single ComfyUI batch by default + (one workflow, sequential on GPU). Use --parallel-queue N to instead split + into N independent batch_size=1 jobs queued in parallel, each with its own + seed — useful for overlapping the HTTP/download phase across requests. + Examples: tsr generate "a cat on a windowsill" tsr generate "portrait photo" -m ponyDiffusionV6XL_v6.safetensors -O portrait @@ -895,7 +919,20 @@ def generate( # noqa: PLR0915 tsr generate --input '{"prompt": "a mech", "model": "flux1-dev-fp8.safetensors"}' tsr generate --input scene.yml tsr generate "raw prompt" --no-quality --no-negative + tsr generate "city" -c 8 -P 4 -o out.png # 8 distinct seeds, 4 in flight """ + if parallel_queue < 1: + console.print("[red]--parallel-queue must be >= 1[/red]") + raise typer.Exit(1) + if parallel_queue > 1 and json_output: + # _run_generation short-circuits the disk-save when json_output=True + # (it dumps JSON and returns). For the parallel fanout to actually save + # files, each task must take the non-JSON path. We render our own JSON + # at the end, so the per-task --json is incompatible. + console.print( + "[red]--json is not supported with --parallel-queue > 1 (would skip the file-save step). Drop one or the other.[/red]" + ) + raise typer.Exit(1) # ---- --input merging (JSON or YAML) ---- if json_input is not None: ji = _parse_generate_input(json_input) @@ -912,7 +949,9 @@ def generate( # noqa: PLR0915 { p.name for p in click_ctx.command.params - if click_ctx.get_parameter_source(p.name) == click.core.ParameterSource.COMMANDLINE + # click's Parameter.name is typed `str | None` in stubs but is always + # a real string at runtime for any param that's been registered. + if p.name is not None and click_ctx.get_parameter_source(p.name) == click.core.ParameterSource.COMMANDLINE } if hasattr(click_ctx, "get_parameter_source") else set() @@ -981,41 +1020,209 @@ def generate( # noqa: PLR0915 scene_prompt = sp_val if isinstance(sp_val, str) else ", ".join(str(x) for x in sp_val if str(x).strip()) if "rating" in mapped and "rating" not in explicit: rating = mapped["rating"] + if "parallel_queue" in mapped and "parallel_queue" not in explicit: + parallel_queue = int(mapped["parallel_queue"]) has_content = bool(prompt or character or character_prompt or scene or scene_prompt) if not has_content: console.print("[red]Prompt (or character/scene) is required[/red]") raise typer.Exit(1) - _run_generation( - prompt=prompt, - model=model, - width=width, - height=height, - steps=steps, - cfg=cfg, - guidance=guidance, - seed=seed, - sampler=sampler, - scheduler=scheduler, - vae=vae, - orientation=orientation, - lora=lora, - lora_strength=lora_strength, - negative=negative, - count=count, - rating=rating, - no_quality=no_quality, - no_negative=no_negative, - character=character, - character_prompt=character_prompt, - scene=scene, - scene_prompt=scene_prompt, - family=family, - output=output, - remote=remote, - json_output=json_output, - ) + # Effective parallelism is bounded by count — running 4 threads for 1 image + # is silly. count=1 always goes through the sequential path regardless of -P. + effective_parallel = min(parallel_queue, count) if count > 1 else 1 + + if effective_parallel <= 1: + # Sequential path: single _run_generation call with batch_size=count. + # Unchanged from pre-parallel behavior — preserves existing output naming, + # JSON shape, and log lines exactly. + _run_generation( + prompt=prompt, + model=model, + width=width, + height=height, + steps=steps, + cfg=cfg, + guidance=guidance, + seed=seed, + sampler=sampler, + scheduler=scheduler, + vae=vae, + orientation=orientation, + lora=lora, + lora_strength=lora_strength, + negative=negative, + count=count, + rating=rating, + no_quality=no_quality, + no_negative=no_negative, + character=character, + character_prompt=character_prompt, + scene=scene, + scene_prompt=scene_prompt, + family=family, + output=output, + remote=remote, + json_output=json_output, + ) + return + + # ---- Parallel fanout path ---- + # Split count into `count` independent jobs (batch_size=1), executed + # `effective_parallel` at a time. Each job gets a distinct seed and a + # distinct output path so writes don't clobber each other. + import random as _rng # noqa: PLC0415 + import time as _time # noqa: PLC0415 + from concurrent.futures import ThreadPoolExecutor, as_completed # noqa: PLC0415 + + # Resolve bare model/lora names ONCE in the parent before fanout. Each + # parallel _run_generation call silences its own console (json_output=True) + # which also skips the validation/resolution step in that path. Doing it + # here means each task receives a canonical filename and ComfyUI's strict + # loaders accept the request first try. + if model and not remote: + # Detect family for the right loader bucket (checkpoints vs diffusion_models). + # Mirrors the lookup _run_generation does on entry. + from tensors.db import Database # noqa: PLC0415 + + _base_model: str | None = None + try: + with Database() as _db: + _db.init_schema() + _base_model = _db.get_base_model_by_filename(model) + except Exception: + pass + _detected = detect_model_family(model, _base_model) + _fam = family or _detected + try: + model, lora = _validate_model_available(model, _fam, lora) + except typer.Exit: + raise # surface the same error path as sequential + + # Seed strategy: + # --seed >= 0 → use as base, increment per job (reproducible series) + # --seed == -1 → pick a fresh random seed PER JOB so parallel runs aren't + # accidentally correlated (each thread gets variety) + seeds = [seed + i for i in range(count)] if seed >= 0 else [_rng.randint(0, 2**32 - 1) for _ in range(count)] + + # Output paths: mirror the existing `count > 1` naming convention from + # _run_generation (stem_NNN.ext). When --output is omitted, leave per-task + # output as None — _run_generation will skip the disk write and the user + # gets only the console listing of generated image refs. + out_paths: list[Path | None] = [] + for i in range(count): + if output is None: + out_paths.append(None) + else: + out_paths.append(output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}") + + if not json_output: + console.print( + f"[dim]Parallel queue: {effective_parallel} concurrent submissions x {count} images (output may interleave)[/dim]" + ) + + common_kwargs: dict[str, Any] = { + "prompt": prompt, + "model": model, + "width": width, + "height": height, + "steps": steps, + "cfg": cfg, + "guidance": guidance, + "sampler": sampler, + "scheduler": scheduler, + "vae": vae, + "orientation": orientation, + "lora": lora, + "lora_strength": lora_strength, + "negative": negative, + "count": 1, # each task generates exactly one image + "rating": rating, + "no_quality": no_quality, + "no_negative": no_negative, + "character": character, + "character_prompt": character_prompt, + "scene": scene, + "scene_prompt": scene_prompt, + "family": family, + "remote": remote, + # NOTE: json_output stays False so _run_generation's disk-save path runs. + # Setting True would short-circuit before saving files. Per-task console + # chatter is the trade-off; the final summary still shows clean per-task + # status lines. + "json_output": False, + } + + def _run_one(idx: int) -> dict[str, Any]: + """Run a single batch_size=1 job. Returns a result dict (success captured).""" + start = _time.perf_counter() + result: dict[str, Any] = { + "index": idx, + "seed": seeds[idx], + "output": str(out_paths[idx]) if out_paths[idx] is not None else None, + "duration_sec": 0.0, + "success": False, + "error": None, + } + try: + _run_generation(seed=seeds[idx], output=out_paths[idx], **common_kwargs) + result["duration_sec"] = round(_time.perf_counter() - start, 2) + result["success"] = True + except typer.Exit as ex: + result["duration_sec"] = round(_time.perf_counter() - start, 2) + result["error"] = f"generate exited with code {ex.exit_code}" + except Exception as ex: + result["duration_sec"] = round(_time.perf_counter() - start, 2) + result["error"] = str(ex) + return result + + fan_results: list[dict[str, Any]] = [] + with ThreadPoolExecutor(max_workers=effective_parallel) as pool: + futures = {pool.submit(_run_one, i): i for i in range(count)} + for completed, fut in enumerate(as_completed(futures), start=1): + try: + res = fut.result() + except Exception as ex: + # Defensive — _run_one already swallows, but if the executor itself + # raises (e.g. pickling failure) we still want a well-formed result + # in the manifest rather than a crash. + res = { + "index": futures[fut], + "seed": seeds[futures[fut]], + "output": str(out_paths[futures[fut]]) if out_paths[futures[fut]] is not None else None, + "duration_sec": 0.0, + "success": False, + "error": f"executor exception: {ex}", + } + fan_results.append(res) + if not json_output: + if res["success"]: + where = res["output"] or "(no --output set)" + console.print( + f"[green]\\[{completed}/{count}] seed={res['seed']} ok in {res['duration_sec']:.1f}s → {where}[/green]" + ) + else: + console.print(f"[red]\\[{completed}/{count}] seed={res['seed']} FAIL: {res['error']}[/red]") + + # Reorder by original index so JSON output / final summary list is stable. + fan_results.sort(key=lambda r: r["index"]) + successful = sum(1 for r in fan_results if r["success"]) + + if json_output: + console.print_json( + data={ + "success": successful == count, + "count": count, + "parallel_queue": effective_parallel, + "results": fan_results, + } + ) + return + + console.print("[bold green]Generation complete![/bold green]") + console.print(f"[dim]Generated {successful}/{count} images at parallelism={effective_parallel}[/dim]") + if successful < count: + raise typer.Exit(1) # Map model family → which ComfyUI loader directory the checkpoint must live in. @@ -1283,7 +1490,7 @@ def _run_generation( # noqa: PLR0915 # ---- Resolve preset defaults for None params (both remote and local need these) ---- from tensors.config import resolve_orientation # noqa: PLC0415 - from tensors.config import resolve_remote as do_resolve_remote + from tensors.config import resolve_remote as do_resolve_remote # noqa: PLC0415 # Use already-detected family_defaults from DB lookup above (not filename guessing) if family_defaults: @@ -1507,7 +1714,7 @@ _STYLE_SWEEP_TEMPLATE_KEYS = { } -def _load_json_file_or_inline(value: str | list | dict, *, what: str) -> Any: +def _load_json_file_or_inline(value: str | list[Any] | dict[str, Any], *, what: str) -> Any: """Load JSON from a file path or accept already-parsed inline data. `value` may be a path string, a JSON string, or an already-parsed list/dict @@ -1885,7 +2092,7 @@ def style_sweep( # noqa: PLR0915 def _run_one(task: tuple[int, dict[str, str], dict[str, Any], Path]) -> dict[str, Any]: """Run a single style. Returns the result dict (success or error captured).""" - idx, entry_in, res, opath = task + _idx, _entry_in, res, opath = task composed = res["prompt"] start = time.perf_counter() try: @@ -1937,11 +2144,9 @@ def style_sweep( # noqa: PLR0915 with ThreadPoolExecutor(max_workers=parallel_queue) as pool: futures = {pool.submit(_run_one, task): task for task in pending_tasks} - completed = 0 - for fut in as_completed(futures): - completed += 1 + for completed, fut in enumerate(as_completed(futures), start=1): task = futures[fut] - idx, _entry, _res, _out_path = task + idx, _entry, _res, _out_path = task # idx used in log message below try: res = fut.result() except Exception as ex: @@ -1988,14 +2193,16 @@ def style_sweep( # noqa: PLR0915 def _write_sweep_manifest( out_dir: Path, - template_path: Path, + template_path: Path | None, styles_origin: str, results: list[dict[str, Any]], ) -> Path: """Write the per-sweep manifest JSON. Returns the path.""" manifest_path = out_dir / "_sweep.json" manifest: dict[str, Any] = { - "template": str(template_path), + # template_path is None when --list is used with only --styles (no template + # required). Serialize as empty string to keep manifest schema stable. + "template": str(template_path) if template_path is not None else "", "styles_source": styles_origin, "results": results, } @@ -2024,7 +2231,7 @@ def _print_styles_list(styles_origin: str, entries: list[dict[str, str]]) -> Non @app.command() -def template( +def template( # noqa: PLR0915 model: Annotated[str, typer.Option("-m", "--model", help="Checkpoint model name")], lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None, lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8, @@ -2842,7 +3049,7 @@ def scene_extract( target_file = None for f in files: file_path = Path(f["file_path"]) - if file_path.name == model or file_path.stem == model: + if model in (file_path.name, file_path.stem): target_file = f break @@ -2970,7 +3177,7 @@ app.add_typer(templates_app) @templates_app.command("extract") -def templates_extract( +def templates_extract( # noqa: PLR0915 model: Annotated[str, typer.Argument(help="Local model name (e.g. lust_v10.safetensors)")], orientation: Annotated[str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape")] = "portrait", no_overrides: Annotated[ @@ -3428,8 +3635,10 @@ def comfy_generate( ) -> None: """[Deprecated] Use 'tsr generate' instead. All features have been merged into the top-level command.""" console.print("[yellow]Warning: 'tsr comfy generate' is deprecated. Use 'tsr generate' instead.[/yellow]") - # Delegate to the unified generate command via context invocation - ctx = typer.Context(generate) + # Delegate to the unified generate command via context invocation. + # typer.Context expects a click.Command, but passing the typer function directly + # works at runtime via duck-typing — keeping it for back-compat with deprecated alias. + ctx = typer.Context(generate) # type: ignore[arg-type] generate( ctx=ctx, prompt=prompt, diff --git a/tensors/comfyui.py b/tensors/comfyui.py index 5758a4c..13cd69c 100644 --- a/tensors/comfyui.py +++ b/tensors/comfyui.py @@ -2,6 +2,7 @@ from __future__ import annotations +import contextlib import copy import json import random @@ -388,7 +389,7 @@ def queue_prompt( return None -def _wait_for_completion_ws( +def _wait_for_completion_ws( # noqa: PLR0915 prompt_id: str, url: str, client_id: str, @@ -494,10 +495,8 @@ def _wait_for_completion_ws( break finally: - try: + with contextlib.suppress(Exception): ws.close() - except Exception: - pass # Fetch final outputs from history to ensure we have everything try: diff --git a/tensors/config.py b/tensors/config.py index 19f46f5..c4fe497 100644 --- a/tensors/config.py +++ b/tensors/config.py @@ -4,7 +4,7 @@ from __future__ import annotations import os import tomllib -from enum import Enum +from enum import StrEnum from pathlib import Path from typing import Any @@ -65,7 +65,7 @@ CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models" # ============================================================================ -class Provider(str, Enum): +class Provider(StrEnum): """Model search providers.""" civitai = "civitai" @@ -73,7 +73,7 @@ class Provider(str, Enum): all = "all" -class ModelType(str, Enum): +class ModelType(StrEnum): """CivitAI model types.""" checkpoint = "checkpoint" @@ -110,7 +110,7 @@ class ModelType(str, Enum): return mapping[self.value] -class BaseModel(str, Enum): +class BaseModel(StrEnum): """Common base models.""" # Stable Diffusion 1.x @@ -166,7 +166,7 @@ class BaseModel(str, Enum): return mapping[self.value] -class SortOrder(str, Enum): +class SortOrder(StrEnum): """Sort options for search.""" downloads = "downloads" @@ -183,7 +183,7 @@ class SortOrder(str, Enum): return mapping[self.value] -class Period(str, Enum): +class Period(StrEnum): """Time period for sorting/filtering.""" all = "all" @@ -204,7 +204,7 @@ class Period(str, Enum): return mapping[self.value] -class NsfwLevel(str, Enum): +class NsfwLevel(StrEnum): """NSFW content filter level.""" none = "none" @@ -219,7 +219,7 @@ class NsfwLevel(str, Enum): return self.value.capitalize() if self.value != "none" else "None" -class CommercialUse(str, Enum): +class CommercialUse(StrEnum): """Commercial use permissions.""" none = "none" diff --git a/tensors/fragments.py b/tensors/fragments.py index 8ad564b..fb4cfba 100644 --- a/tensors/fragments.py +++ b/tensors/fragments.py @@ -16,9 +16,16 @@ from __future__ import annotations import json import re from pathlib import Path # noqa: TC003 # used in runtime return annotations exposed to typer +from typing import TYPE_CHECKING from tensors.config import DATA_DIR +if TYPE_CHECKING: + # Qualified `builtins.list` is referenced in annotations inside FragmentLibrary + # because the class defines a method named `list` that shadows the builtin + # at class-scope name resolution. Static-only — not needed at runtime. + import builtins + # Restrict fragment names to a safe subset so they can't escape the storage dir # via path traversal and so file listings stay tidy. _NAME_RE = re.compile(r"^[A-Za-z0-9_.-]+$") @@ -132,8 +139,11 @@ class FragmentLibrary: *, name: str | None = None, inline: str | None = None, - extra: list[str] | None = None, - ) -> list[str]: + # NOTE: `builtins.list` qualifier needed because this class defines a + # `list()` method below, which shadows the builtin in class-scope name + # resolution. Affects mypy/pyright even with `from __future__ import annotations`. + extra: builtins.list[str] | None = None, + ) -> builtins.list[str]: """Merge a named fragment with an inline CSV string and optional extras. Resolution order (first match wins per duplicate): named → inline → extra. diff --git a/tensors/remote.py b/tensors/remote.py index 77e6824..e0a252e 100644 --- a/tensors/remote.py +++ b/tensors/remote.py @@ -209,7 +209,8 @@ def remote_search( response.raise_for_status() result: dict[str, Any] = response.json() # The remote API wraps CivitAI results under "civitai" key - return result.get("civitai", result) + civitai_section: dict[str, Any] = result.get("civitai", result) + return civitai_section except httpx.HTTPStatusError as e: if console: console.print(f"[red]Remote API error: {e.response.status_code}[/red]") diff --git a/tensors/server/comfyui_api_routes.py b/tensors/server/comfyui_api_routes.py index 1ce44e2..9974330 100644 --- a/tensors/server/comfyui_api_routes.py +++ b/tensors/server/comfyui_api_routes.py @@ -27,6 +27,14 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/comfyui", tags=["ComfyUI API"]) +# Schema default sentinels — see GenerateRequest defaults. These let us detect +# "user accepted default" vs "user explicitly chose this value matching default" +# is intentionally not distinguished; both paths apply family overrides. +_DEFAULT_STEPS = 20 +_DEFAULT_CFG = 7.0 +# Logging truncation threshold for long prompts in info-level output. +_PROMPT_LOG_TRUNCATE = 100 + # ============================================================================= # Request/Response Models @@ -262,9 +270,9 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]: sampler = family_defaults["sampler"] if request.scheduler == "normal": # Default value in schema scheduler = family_defaults["scheduler"] - if request.steps == 20: # Default value in schema + if request.steps == _DEFAULT_STEPS: # Default value in schema steps = family_defaults["steps"] - if request.cfg == 7.0: # Default value in schema + if request.cfg == _DEFAULT_CFG: # Default value in schema cfg = family_defaults["cfg"] # Only override VAE if user explicitly specified one; # otherwise use checkpoint's built-in VAE (vae stays None) @@ -290,7 +298,7 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]: sampler, scheduler, lora_info, - request.prompt[:100] + "..." if len(request.prompt) > 100 else request.prompt, + request.prompt[:_PROMPT_LOG_TRUNCATE] + "..." if len(request.prompt) > _PROMPT_LOG_TRUNCATE else request.prompt, ) if request.negative_prompt: logger.debug("Negative prompt: %r", request.negative_prompt[:100]) diff --git a/tensors/server/search_routes.py b/tensors/server/search_routes.py index 839ca39..3ca58dd 100644 --- a/tensors/server/search_routes.py +++ b/tensors/server/search_routes.py @@ -4,7 +4,7 @@ from __future__ import annotations import contextlib import logging -from enum import Enum +from enum import StrEnum from typing import Annotated, Any from fastapi import APIRouter, Query @@ -37,7 +37,7 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/search", tags=["Search"]) -class Provider(str, Enum): +class Provider(StrEnum): """Search provider options.""" civitai = "civitai" @@ -45,7 +45,7 @@ class Provider(str, Enum): all = "all" -class SortOrder(str, Enum): +class SortOrder(StrEnum): """Sort order options.""" downloads = "downloads" diff --git a/tests/test_generate_parallel.py b/tests/test_generate_parallel.py new file mode 100644 index 0000000..c5efbc7 --- /dev/null +++ b/tests/test_generate_parallel.py @@ -0,0 +1,320 @@ +"""Tests for the `tsr generate --parallel-queue` flag (parallel fanout path).""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import pytest +from typer.testing import CliRunner + +from tensors import cli as cli_module +from tensors.cli import app + +runner = CliRunner() + + +# ----------------------------------------------------------------------------- +# Fixtures +# ----------------------------------------------------------------------------- + + +@pytest.fixture +def calls(monkeypatch: pytest.MonkeyPatch) -> list[dict[str, Any]]: + """Record every _run_generation call and stub the disk-write side effect. + + The parallel fanout path invokes _run_generation N times (one per task); + the sequential path invokes it once. By recording kwargs we can assert + fanout behavior (per-task seeds, per-task output paths, count=1 per task) + without round-tripping ComfyUI. + """ + recorded: list[dict[str, Any]] = [] + + def fake_run_generation(**kwargs: Any) -> None: + recorded.append(kwargs) + out: Path | None = kwargs.get("output") + if out is not None: + out.parent.mkdir(parents=True, exist_ok=True) + out.write_bytes(b"fake-png") + + monkeypatch.setattr(cli_module, "_run_generation", fake_run_generation) + return recorded + + +@pytest.fixture(autouse=True) +def _stub_model_validation(monkeypatch: pytest.MonkeyPatch) -> None: + """Bypass ComfyUI's live model lookup so tests don't need a backend.""" + monkeypatch.setattr( + cli_module, + "_validate_model_available", + lambda model, family, lora: (model, lora), + ) + + +# ----------------------------------------------------------------------------- +# Validation / sanity +# ----------------------------------------------------------------------------- + + +def test_parallel_queue_invalid_value_rejected(calls: list[dict[str, Any]]) -> None: + """--parallel-queue 0 (or negative) exits non-zero before any work.""" + result = runner.invoke( + app, + ["generate", "test prompt", "-m", "x.safetensors", "--parallel-queue", "0"], + ) + assert result.exit_code != 0 + assert "--parallel-queue must be >= 1" in result.output + assert calls == [] + + +def test_parallel_queue_one_is_sequential_path(calls: list[dict[str, Any]]) -> None: + """-P 1 collapses to the legacy single _run_generation call with count=N. + + This is the key compatibility contract: existing scripts that don't pass + -P must see identical behavior (one call, count forwarded as batch_size). + """ + result = runner.invoke( + app, + ["generate", "test prompt", "-m", "x.safetensors", "-c", "4", "-P", "1"], + ) + assert result.exit_code == 0, result.output + assert len(calls) == 1 + assert calls[0]["count"] == 4 + assert calls[0]["prompt"] == "test prompt" + + +def test_count_one_ignores_parallel_queue(calls: list[dict[str, Any]]) -> None: + """count=1 always takes sequential path regardless of -P (no fanout point).""" + result = runner.invoke( + app, + ["generate", "test prompt", "-m", "x.safetensors", "-c", "1", "-P", "8"], + ) + assert result.exit_code == 0, result.output + assert len(calls) == 1 + assert calls[0]["count"] == 1 + + +def test_json_output_incompatible_with_parallel(calls: list[dict[str, Any]]) -> None: + """--json + -P>1 errors out cleanly (would skip disk-save inside tasks).""" + result = runner.invoke( + app, + ["generate", "test prompt", "-m", "x.safetensors", "-c", "2", "-P", "2", "--json"], + ) + assert result.exit_code != 0 + assert "--json is not supported with --parallel-queue > 1" in result.output + assert calls == [] + + +# ----------------------------------------------------------------------------- +# Fanout behavior +# ----------------------------------------------------------------------------- + + +def test_parallel_fanout_creates_n_tasks(tmp_path: Path, calls: list[dict[str, Any]]) -> None: + """-c N -P M (M>1, N>1) → N independent _run_generation calls, each count=1.""" + out = tmp_path / "img.png" + result = runner.invoke( + app, + [ + "generate", + "test prompt", + "-m", + "x.safetensors", + "-c", + "4", + "-P", + "2", + "--seed", + "100", + "-o", + str(out), + ], + ) + assert result.exit_code == 0, result.output + assert len(calls) == 4 + # Each task generates exactly one image + for c in calls: + assert c["count"] == 1 + + +def test_parallel_seeds_increment_from_base(tmp_path: Path, calls: list[dict[str, Any]]) -> None: + """Explicit --seed → each task receives base+i (reproducible series).""" + out = tmp_path / "img.png" + runner.invoke( + app, + ["generate", "test prompt", "-m", "x.safetensors", "-c", "3", "-P", "3", "--seed", "500", "-o", str(out)], + ) + seeds_seen = sorted(c["seed"] for c in calls) + assert seeds_seen == [500, 501, 502] + + +def test_parallel_seeds_random_when_unset(tmp_path: Path, calls: list[dict[str, Any]]) -> None: + """seed=-1 → each task gets a freshly-rolled random seed (not all the same). + + Vanishingly small chance of collision across 4 random ints; treat as flake + threshold of "all distinct" rather than exact equality to any value. + """ + out = tmp_path / "img.png" + runner.invoke( + app, + ["generate", "test prompt", "-m", "x.safetensors", "-c", "4", "-P", "2", "-o", str(out)], + ) + seeds = [c["seed"] for c in calls] + # All non-negative (i.e. resolved from -1 to actual int) and distinct. + assert all(s >= 0 for s in seeds) + assert len(set(seeds)) == len(seeds) + + +def test_parallel_output_paths_indexed(tmp_path: Path, calls: list[dict[str, Any]]) -> None: + """Per-task output paths use stem_NNN.suffix naming (matches sequential count>1).""" + out = tmp_path / "scene.png" + runner.invoke( + app, + ["generate", "test prompt", "-m", "x.safetensors", "-c", "3", "-P", "3", "--seed", "1", "-o", str(out)], + ) + paths = sorted(str(c["output"]) for c in calls) + assert paths == [ + str(tmp_path / "scene_001.png"), + str(tmp_path / "scene_002.png"), + str(tmp_path / "scene_003.png"), + ] + + +def test_parallel_without_output_passes_none(tmp_path: Path, calls: list[dict[str, Any]]) -> None: + """When --output is omitted, each task gets output=None (no disk write planned).""" + runner.invoke( + app, + ["generate", "test prompt", "-m", "x.safetensors", "-c", "2", "-P", "2", "--seed", "1"], + ) + assert len(calls) == 2 + assert all(c["output"] is None for c in calls) + + +def test_parallel_files_actually_written(tmp_path: Path, calls: list[dict[str, Any]]) -> None: + """End-to-end: per-task stub writes its file → all N appear on disk. + + Guards against the bug where json_output=True short-circuits the save block + inside _run_generation. Each task must use the non-JSON code path. + """ + out = tmp_path / "shot.png" + runner.invoke( + app, + ["generate", "test prompt", "-m", "x.safetensors", "-c", "3", "-P", "3", "--seed", "1", "-o", str(out)], + ) + written = sorted(p.name for p in tmp_path.iterdir()) + assert written == ["shot_001.png", "shot_002.png", "shot_003.png"] + + +def test_parallel_summary_reports_success_count(tmp_path: Path, calls: list[dict[str, Any]]) -> None: + """Final summary line reports N/N success when all tasks complete.""" + out = tmp_path / "img.png" + result = runner.invoke( + app, + ["generate", "test prompt", "-m", "x.safetensors", "-c", "3", "-P", "2", "--seed", "1", "-o", str(out)], + ) + assert result.exit_code == 0 + assert "Generated 3/3 images" in result.output + + +def test_parallel_partial_failure_exits_nonzero( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """If one task raises, summary shows partial count and command exits non-zero.""" + import typer + + call_indices: list[int] = [] + + def flaky_run_generation(**kwargs: Any) -> None: + # Fail every other call to simulate intermittent backend errors. + idx = len(call_indices) + call_indices.append(idx) + if idx % 2 == 0: + raise typer.Exit(1) + out: Path | None = kwargs.get("output") + if out is not None: + out.parent.mkdir(parents=True, exist_ok=True) + out.write_bytes(b"ok") + + monkeypatch.setattr(cli_module, "_run_generation", flaky_run_generation) + + out = tmp_path / "img.png" + result = runner.invoke( + app, + ["generate", "test prompt", "-m", "x.safetensors", "-c", "4", "-P", "2", "--seed", "1", "-o", str(out)], + ) + assert result.exit_code != 0 + # Two tasks failed; final summary should show 2/4. + assert "Generated 2/4 images" in result.output + + +# ----------------------------------------------------------------------------- +# --input integration +# ----------------------------------------------------------------------------- + + +def test_parallel_queue_from_yaml_input(tmp_path: Path, calls: list[dict[str, Any]]) -> None: + """parallel_queue can be set via --input YAML (mirrors other generate params).""" + out = tmp_path / "img.png" + yml = tmp_path / "spec.yml" + yml.write_text(f'prompt: from-yaml\nmodel: x.safetensors\ncount: 3\nparallel_queue: 3\nseed: 7\noutput: "{out}"\n') + result = runner.invoke(app, ["generate", "--input", str(yml)]) + assert result.exit_code == 0, result.output + assert len(calls) == 3 + assert sorted(c["seed"] for c in calls) == [7, 8, 9] + + +def test_cli_parallel_queue_overrides_yaml(tmp_path: Path, calls: list[dict[str, Any]]) -> None: + """CLI --parallel-queue wins over YAML's parallel_queue (standard precedence).""" + out = tmp_path / "img.png" + yml = tmp_path / "spec.yml" + yml.write_text(f'prompt: from-yaml\nmodel: x.safetensors\ncount: 2\nparallel_queue: 1\nseed: 10\noutput: "{out}"\n') + # YAML says P=1 (sequential), CLI overrides to P=2 (fanout) + result = runner.invoke(app, ["generate", "--input", str(yml), "-P", "2"]) + assert result.exit_code == 0, result.output + # Fanout path → 2 separate calls, each count=1 + assert len(calls) == 2 + assert all(c["count"] == 1 for c in calls) + + +# ----------------------------------------------------------------------------- +# Concurrency assertion +# ----------------------------------------------------------------------------- + + +def test_parallel_actually_runs_concurrently( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Sanity: P concurrent tasks really overlap in time (vs all-serial).""" + import threading + import time as _t + + in_flight = 0 + peak_in_flight = 0 + lock = threading.Lock() + + def slow_run_generation(**kwargs: Any) -> None: + nonlocal in_flight, peak_in_flight + with lock: + in_flight += 1 + peak_in_flight = max(peak_in_flight, in_flight) + _t.sleep(0.1) # 100ms — long enough to overlap, short enough for fast tests + with lock: + in_flight -= 1 + out: Path | None = kwargs.get("output") + if out is not None: + out.parent.mkdir(parents=True, exist_ok=True) + out.write_bytes(b"ok") + + monkeypatch.setattr(cli_module, "_run_generation", slow_run_generation) + + out = tmp_path / "img.png" + result = runner.invoke( + app, + ["generate", "test prompt", "-m", "x.safetensors", "-c", "4", "-P", "4", "--seed", "1", "-o", str(out)], + ) + assert result.exit_code == 0, result.output + # With P=4 and 4 tasks each sleeping 100ms, peak concurrency should hit 4. + # Even allowing for thread-pool warmup quirks, ≥2 means parallelism is real. + assert peak_in_flight >= 2, f"peak_in_flight={peak_in_flight} (expected ≥2 for parallel)" diff --git a/tests/test_server.py b/tests/test_server.py index 339f7ed..babed70 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1237,7 +1237,7 @@ class TestDownloadBackgroundTasks: captured["api_key"] = api_key return {"file_id": 42, "sha256": "deadbeef", "linked": True, "cached": True, "error": None} - monkeypatch.setattr(download_routes_module, "Database", lambda: StubDB()) + monkeypatch.setattr(download_routes_module, "Database", StubDB) return captured def test_do_download_success(self, monkeypatch, tmp_path) -> None: @@ -1397,7 +1397,7 @@ class TestDownloadBackgroundTasks: def register_downloaded_file(self, *args, **kwargs): return {"file_id": None, "sha256": None, "linked": False, "cached": False, "error": "boom"} - monkeypatch.setattr(download_routes, "Database", lambda: FailingDB()) + monkeypatch.setattr(download_routes, "Database", FailingDB) dest_path = tmp_path / "model.safetensors" _do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})