Merge pull request #2 from saiden-dev/feat/generate-parallel-queue

feat(generate): add --parallel-queue/-P for concurrent submissions
This commit is contained in:
Adam Ladachowski
2026-05-18 23:44:21 +02:00
committed by GitHub
9 changed files with 615 additions and 68 deletions
+254 -45
View File
@@ -346,7 +346,10 @@ def search(
return return
key = api_key or load_api_key() key = api_key or load_api_key()
civitai_results: dict[str, Any] | None = None # Reuse the name from the remote-mode branch above (which already returned)
# without redeclaring its type — mypy treats class-scope re-annotation as
# a no-redef even when control flow guarantees the branches don't overlap.
civitai_results = None
hf_results: list[dict[str, Any]] | None = None hf_results: list[dict[str, Any]] | None = None
# Search CivitAI # Search CivitAI
@@ -874,6 +877,22 @@ def generate( # noqa: PLR0915
str | None, str | None,
typer.Option("--input", "-I", help="JSON or YAML params (file path or inline; keys match CLI options)"), typer.Option("--input", "-I", help="JSON or YAML params (file path or inline; keys match CLI options)"),
] = None, ] = None,
parallel_queue: Annotated[
int,
typer.Option(
"--parallel-queue",
"-P",
help=(
"Concurrent ComfyUI submissions (default 1). When >1 with --count N, "
"splits the request into N independent jobs (batch_size=1 each) with "
"incrementing seeds, executed P-at-a-time via thread pool. The GPU "
"still processes one prompt at a time, but HTTP queue / init / "
"download phases pipeline for a ~5-15%% speedup. Per-task output "
"interleaves; final summary lists all saved files. Ignored when "
"--count is 1."
),
),
] = 1,
) -> None: ) -> None:
"""Generate an image using text-to-image. """Generate an image using text-to-image.
@@ -887,6 +906,11 @@ def generate( # noqa: PLR0915
starting with ``{`` are JSON, everything else is YAML. CLI flags override starting with ``{`` are JSON, everything else is YAML. CLI flags override
--input values. --input values.
With --count > 1, images are generated as a single ComfyUI batch by default
(one workflow, sequential on GPU). Use --parallel-queue N to instead split
into N independent batch_size=1 jobs queued in parallel, each with its own
seed — useful for overlapping the HTTP/download phase across requests.
Examples: Examples:
tsr generate "a cat on a windowsill" tsr generate "a cat on a windowsill"
tsr generate "portrait photo" -m ponyDiffusionV6XL_v6.safetensors -O portrait tsr generate "portrait photo" -m ponyDiffusionV6XL_v6.safetensors -O portrait
@@ -895,7 +919,20 @@ def generate( # noqa: PLR0915
tsr generate --input '{"prompt": "a mech", "model": "flux1-dev-fp8.safetensors"}' tsr generate --input '{"prompt": "a mech", "model": "flux1-dev-fp8.safetensors"}'
tsr generate --input scene.yml tsr generate --input scene.yml
tsr generate "raw prompt" --no-quality --no-negative tsr generate "raw prompt" --no-quality --no-negative
tsr generate "city" -c 8 -P 4 -o out.png # 8 distinct seeds, 4 in flight
""" """
if parallel_queue < 1:
console.print("[red]--parallel-queue must be >= 1[/red]")
raise typer.Exit(1)
if parallel_queue > 1 and json_output:
# _run_generation short-circuits the disk-save when json_output=True
# (it dumps JSON and returns). For the parallel fanout to actually save
# files, each task must take the non-JSON path. We render our own JSON
# at the end, so the per-task --json is incompatible.
console.print(
"[red]--json is not supported with --parallel-queue > 1 (would skip the file-save step). Drop one or the other.[/red]"
)
raise typer.Exit(1)
# ---- --input merging (JSON or YAML) ---- # ---- --input merging (JSON or YAML) ----
if json_input is not None: if json_input is not None:
ji = _parse_generate_input(json_input) ji = _parse_generate_input(json_input)
@@ -912,7 +949,9 @@ def generate( # noqa: PLR0915
{ {
p.name p.name
for p in click_ctx.command.params for p in click_ctx.command.params
if click_ctx.get_parameter_source(p.name) == click.core.ParameterSource.COMMANDLINE # click's Parameter.name is typed `str | None` in stubs but is always
# a real string at runtime for any param that's been registered.
if p.name is not None and click_ctx.get_parameter_source(p.name) == click.core.ParameterSource.COMMANDLINE
} }
if hasattr(click_ctx, "get_parameter_source") if hasattr(click_ctx, "get_parameter_source")
else set() else set()
@@ -981,41 +1020,209 @@ def generate( # noqa: PLR0915
scene_prompt = sp_val if isinstance(sp_val, str) else ", ".join(str(x) for x in sp_val if str(x).strip()) scene_prompt = sp_val if isinstance(sp_val, str) else ", ".join(str(x) for x in sp_val if str(x).strip())
if "rating" in mapped and "rating" not in explicit: if "rating" in mapped and "rating" not in explicit:
rating = mapped["rating"] rating = mapped["rating"]
if "parallel_queue" in mapped and "parallel_queue" not in explicit:
parallel_queue = int(mapped["parallel_queue"])
has_content = bool(prompt or character or character_prompt or scene or scene_prompt) has_content = bool(prompt or character or character_prompt or scene or scene_prompt)
if not has_content: if not has_content:
console.print("[red]Prompt (or character/scene) is required[/red]") console.print("[red]Prompt (or character/scene) is required[/red]")
raise typer.Exit(1) raise typer.Exit(1)
_run_generation( # Effective parallelism is bounded by count — running 4 threads for 1 image
prompt=prompt, # is silly. count=1 always goes through the sequential path regardless of -P.
model=model, effective_parallel = min(parallel_queue, count) if count > 1 else 1
width=width,
height=height, if effective_parallel <= 1:
steps=steps, # Sequential path: single _run_generation call with batch_size=count.
cfg=cfg, # Unchanged from pre-parallel behavior — preserves existing output naming,
guidance=guidance, # JSON shape, and log lines exactly.
seed=seed, _run_generation(
sampler=sampler, prompt=prompt,
scheduler=scheduler, model=model,
vae=vae, width=width,
orientation=orientation, height=height,
lora=lora, steps=steps,
lora_strength=lora_strength, cfg=cfg,
negative=negative, guidance=guidance,
count=count, seed=seed,
rating=rating, sampler=sampler,
no_quality=no_quality, scheduler=scheduler,
no_negative=no_negative, vae=vae,
character=character, orientation=orientation,
character_prompt=character_prompt, lora=lora,
scene=scene, lora_strength=lora_strength,
scene_prompt=scene_prompt, negative=negative,
family=family, count=count,
output=output, rating=rating,
remote=remote, no_quality=no_quality,
json_output=json_output, no_negative=no_negative,
) character=character,
character_prompt=character_prompt,
scene=scene,
scene_prompt=scene_prompt,
family=family,
output=output,
remote=remote,
json_output=json_output,
)
return
# ---- Parallel fanout path ----
# Split count into `count` independent jobs (batch_size=1), executed
# `effective_parallel` at a time. Each job gets a distinct seed and a
# distinct output path so writes don't clobber each other.
import random as _rng # noqa: PLC0415
import time as _time # noqa: PLC0415
from concurrent.futures import ThreadPoolExecutor, as_completed # noqa: PLC0415
# Resolve bare model/lora names ONCE in the parent before fanout. Each
# parallel _run_generation call silences its own console (json_output=True)
# which also skips the validation/resolution step in that path. Doing it
# here means each task receives a canonical filename and ComfyUI's strict
# loaders accept the request first try.
if model and not remote:
# Detect family for the right loader bucket (checkpoints vs diffusion_models).
# Mirrors the lookup _run_generation does on entry.
from tensors.db import Database # noqa: PLC0415
_base_model: str | None = None
try:
with Database() as _db:
_db.init_schema()
_base_model = _db.get_base_model_by_filename(model)
except Exception:
pass
_detected = detect_model_family(model, _base_model)
_fam = family or _detected
try:
model, lora = _validate_model_available(model, _fam, lora)
except typer.Exit:
raise # surface the same error path as sequential
# Seed strategy:
# --seed >= 0 → use as base, increment per job (reproducible series)
# --seed == -1 → pick a fresh random seed PER JOB so parallel runs aren't
# accidentally correlated (each thread gets variety)
seeds = [seed + i for i in range(count)] if seed >= 0 else [_rng.randint(0, 2**32 - 1) for _ in range(count)]
# Output paths: mirror the existing `count > 1` naming convention from
# _run_generation (stem_NNN.ext). When --output is omitted, leave per-task
# output as None — _run_generation will skip the disk write and the user
# gets only the console listing of generated image refs.
out_paths: list[Path | None] = []
for i in range(count):
if output is None:
out_paths.append(None)
else:
out_paths.append(output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}")
if not json_output:
console.print(
f"[dim]Parallel queue: {effective_parallel} concurrent submissions x {count} images (output may interleave)[/dim]"
)
common_kwargs: dict[str, Any] = {
"prompt": prompt,
"model": model,
"width": width,
"height": height,
"steps": steps,
"cfg": cfg,
"guidance": guidance,
"sampler": sampler,
"scheduler": scheduler,
"vae": vae,
"orientation": orientation,
"lora": lora,
"lora_strength": lora_strength,
"negative": negative,
"count": 1, # each task generates exactly one image
"rating": rating,
"no_quality": no_quality,
"no_negative": no_negative,
"character": character,
"character_prompt": character_prompt,
"scene": scene,
"scene_prompt": scene_prompt,
"family": family,
"remote": remote,
# NOTE: json_output stays False so _run_generation's disk-save path runs.
# Setting True would short-circuit before saving files. Per-task console
# chatter is the trade-off; the final summary still shows clean per-task
# status lines.
"json_output": False,
}
def _run_one(idx: int) -> dict[str, Any]:
"""Run a single batch_size=1 job. Returns a result dict (success captured)."""
start = _time.perf_counter()
result: dict[str, Any] = {
"index": idx,
"seed": seeds[idx],
"output": str(out_paths[idx]) if out_paths[idx] is not None else None,
"duration_sec": 0.0,
"success": False,
"error": None,
}
try:
_run_generation(seed=seeds[idx], output=out_paths[idx], **common_kwargs)
result["duration_sec"] = round(_time.perf_counter() - start, 2)
result["success"] = True
except typer.Exit as ex:
result["duration_sec"] = round(_time.perf_counter() - start, 2)
result["error"] = f"generate exited with code {ex.exit_code}"
except Exception as ex:
result["duration_sec"] = round(_time.perf_counter() - start, 2)
result["error"] = str(ex)
return result
fan_results: list[dict[str, Any]] = []
with ThreadPoolExecutor(max_workers=effective_parallel) as pool:
futures = {pool.submit(_run_one, i): i for i in range(count)}
for completed, fut in enumerate(as_completed(futures), start=1):
try:
res = fut.result()
except Exception as ex:
# Defensive — _run_one already swallows, but if the executor itself
# raises (e.g. pickling failure) we still want a well-formed result
# in the manifest rather than a crash.
res = {
"index": futures[fut],
"seed": seeds[futures[fut]],
"output": str(out_paths[futures[fut]]) if out_paths[futures[fut]] is not None else None,
"duration_sec": 0.0,
"success": False,
"error": f"executor exception: {ex}",
}
fan_results.append(res)
if not json_output:
if res["success"]:
where = res["output"] or "(no --output set)"
console.print(
f"[green]\\[{completed}/{count}] seed={res['seed']} ok in {res['duration_sec']:.1f}s → {where}[/green]"
)
else:
console.print(f"[red]\\[{completed}/{count}] seed={res['seed']} FAIL: {res['error']}[/red]")
# Reorder by original index so JSON output / final summary list is stable.
fan_results.sort(key=lambda r: r["index"])
successful = sum(1 for r in fan_results if r["success"])
if json_output:
console.print_json(
data={
"success": successful == count,
"count": count,
"parallel_queue": effective_parallel,
"results": fan_results,
}
)
return
console.print("[bold green]Generation complete![/bold green]")
console.print(f"[dim]Generated {successful}/{count} images at parallelism={effective_parallel}[/dim]")
if successful < count:
raise typer.Exit(1)
# Map model family → which ComfyUI loader directory the checkpoint must live in. # Map model family → which ComfyUI loader directory the checkpoint must live in.
@@ -1283,7 +1490,7 @@ def _run_generation( # noqa: PLR0915
# ---- Resolve preset defaults for None params (both remote and local need these) ---- # ---- Resolve preset defaults for None params (both remote and local need these) ----
from tensors.config import resolve_orientation # noqa: PLC0415 from tensors.config import resolve_orientation # noqa: PLC0415
from tensors.config import resolve_remote as do_resolve_remote from tensors.config import resolve_remote as do_resolve_remote # noqa: PLC0415
# Use already-detected family_defaults from DB lookup above (not filename guessing) # Use already-detected family_defaults from DB lookup above (not filename guessing)
if family_defaults: if family_defaults:
@@ -1507,7 +1714,7 @@ _STYLE_SWEEP_TEMPLATE_KEYS = {
} }
def _load_json_file_or_inline(value: str | list | dict, *, what: str) -> Any: def _load_json_file_or_inline(value: str | list[Any] | dict[str, Any], *, what: str) -> Any:
"""Load JSON from a file path or accept already-parsed inline data. """Load JSON from a file path or accept already-parsed inline data.
`value` may be a path string, a JSON string, or an already-parsed list/dict `value` may be a path string, a JSON string, or an already-parsed list/dict
@@ -1885,7 +2092,7 @@ def style_sweep( # noqa: PLR0915
def _run_one(task: tuple[int, dict[str, str], dict[str, Any], Path]) -> dict[str, Any]: def _run_one(task: tuple[int, dict[str, str], dict[str, Any], Path]) -> dict[str, Any]:
"""Run a single style. Returns the result dict (success or error captured).""" """Run a single style. Returns the result dict (success or error captured)."""
idx, entry_in, res, opath = task _idx, _entry_in, res, opath = task
composed = res["prompt"] composed = res["prompt"]
start = time.perf_counter() start = time.perf_counter()
try: try:
@@ -1937,11 +2144,9 @@ def style_sweep( # noqa: PLR0915
with ThreadPoolExecutor(max_workers=parallel_queue) as pool: with ThreadPoolExecutor(max_workers=parallel_queue) as pool:
futures = {pool.submit(_run_one, task): task for task in pending_tasks} futures = {pool.submit(_run_one, task): task for task in pending_tasks}
completed = 0 for completed, fut in enumerate(as_completed(futures), start=1):
for fut in as_completed(futures):
completed += 1
task = futures[fut] task = futures[fut]
idx, _entry, _res, _out_path = task idx, _entry, _res, _out_path = task # idx used in log message below
try: try:
res = fut.result() res = fut.result()
except Exception as ex: except Exception as ex:
@@ -1988,14 +2193,16 @@ def style_sweep( # noqa: PLR0915
def _write_sweep_manifest( def _write_sweep_manifest(
out_dir: Path, out_dir: Path,
template_path: Path, template_path: Path | None,
styles_origin: str, styles_origin: str,
results: list[dict[str, Any]], results: list[dict[str, Any]],
) -> Path: ) -> Path:
"""Write the per-sweep manifest JSON. Returns the path.""" """Write the per-sweep manifest JSON. Returns the path."""
manifest_path = out_dir / "_sweep.json" manifest_path = out_dir / "_sweep.json"
manifest: dict[str, Any] = { manifest: dict[str, Any] = {
"template": str(template_path), # template_path is None when --list is used with only --styles (no template
# required). Serialize as empty string to keep manifest schema stable.
"template": str(template_path) if template_path is not None else "",
"styles_source": styles_origin, "styles_source": styles_origin,
"results": results, "results": results,
} }
@@ -2024,7 +2231,7 @@ def _print_styles_list(styles_origin: str, entries: list[dict[str, str]]) -> Non
@app.command() @app.command()
def template( def template( # noqa: PLR0915
model: Annotated[str, typer.Option("-m", "--model", help="Checkpoint model name")], model: Annotated[str, typer.Option("-m", "--model", help="Checkpoint model name")],
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None, lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None,
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8, lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8,
@@ -2842,7 +3049,7 @@ def scene_extract(
target_file = None target_file = None
for f in files: for f in files:
file_path = Path(f["file_path"]) file_path = Path(f["file_path"])
if file_path.name == model or file_path.stem == model: if model in (file_path.name, file_path.stem):
target_file = f target_file = f
break break
@@ -2970,7 +3177,7 @@ app.add_typer(templates_app)
@templates_app.command("extract") @templates_app.command("extract")
def templates_extract( def templates_extract( # noqa: PLR0915
model: Annotated[str, typer.Argument(help="Local model name (e.g. lust_v10.safetensors)")], model: Annotated[str, typer.Argument(help="Local model name (e.g. lust_v10.safetensors)")],
orientation: Annotated[str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape")] = "portrait", orientation: Annotated[str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape")] = "portrait",
no_overrides: Annotated[ no_overrides: Annotated[
@@ -3428,8 +3635,10 @@ def comfy_generate(
) -> None: ) -> None:
"""[Deprecated] Use 'tsr generate' instead. All features have been merged into the top-level command.""" """[Deprecated] Use 'tsr generate' instead. All features have been merged into the top-level command."""
console.print("[yellow]Warning: 'tsr comfy generate' is deprecated. Use 'tsr generate' instead.[/yellow]") console.print("[yellow]Warning: 'tsr comfy generate' is deprecated. Use 'tsr generate' instead.[/yellow]")
# Delegate to the unified generate command via context invocation # Delegate to the unified generate command via context invocation.
ctx = typer.Context(generate) # typer.Context expects a click.Command, but passing the typer function directly
# works at runtime via duck-typing — keeping it for back-compat with deprecated alias.
ctx = typer.Context(generate) # type: ignore[arg-type]
generate( generate(
ctx=ctx, ctx=ctx,
prompt=prompt, prompt=prompt,
+3 -4
View File
@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import contextlib
import copy import copy
import json import json
import random import random
@@ -388,7 +389,7 @@ def queue_prompt(
return None return None
def _wait_for_completion_ws( def _wait_for_completion_ws( # noqa: PLR0915
prompt_id: str, prompt_id: str,
url: str, url: str,
client_id: str, client_id: str,
@@ -494,10 +495,8 @@ def _wait_for_completion_ws(
break break
finally: finally:
try: with contextlib.suppress(Exception):
ws.close() ws.close()
except Exception:
pass
# Fetch final outputs from history to ensure we have everything # Fetch final outputs from history to ensure we have everything
try: try:
+8 -8
View File
@@ -4,7 +4,7 @@ from __future__ import annotations
import os import os
import tomllib import tomllib
from enum import Enum from enum import StrEnum
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -65,7 +65,7 @@ CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models"
# ============================================================================ # ============================================================================
class Provider(str, Enum): class Provider(StrEnum):
"""Model search providers.""" """Model search providers."""
civitai = "civitai" civitai = "civitai"
@@ -73,7 +73,7 @@ class Provider(str, Enum):
all = "all" all = "all"
class ModelType(str, Enum): class ModelType(StrEnum):
"""CivitAI model types.""" """CivitAI model types."""
checkpoint = "checkpoint" checkpoint = "checkpoint"
@@ -110,7 +110,7 @@ class ModelType(str, Enum):
return mapping[self.value] return mapping[self.value]
class BaseModel(str, Enum): class BaseModel(StrEnum):
"""Common base models.""" """Common base models."""
# Stable Diffusion 1.x # Stable Diffusion 1.x
@@ -166,7 +166,7 @@ class BaseModel(str, Enum):
return mapping[self.value] return mapping[self.value]
class SortOrder(str, Enum): class SortOrder(StrEnum):
"""Sort options for search.""" """Sort options for search."""
downloads = "downloads" downloads = "downloads"
@@ -183,7 +183,7 @@ class SortOrder(str, Enum):
return mapping[self.value] return mapping[self.value]
class Period(str, Enum): class Period(StrEnum):
"""Time period for sorting/filtering.""" """Time period for sorting/filtering."""
all = "all" all = "all"
@@ -204,7 +204,7 @@ class Period(str, Enum):
return mapping[self.value] return mapping[self.value]
class NsfwLevel(str, Enum): class NsfwLevel(StrEnum):
"""NSFW content filter level.""" """NSFW content filter level."""
none = "none" none = "none"
@@ -219,7 +219,7 @@ class NsfwLevel(str, Enum):
return self.value.capitalize() if self.value != "none" else "None" return self.value.capitalize() if self.value != "none" else "None"
class CommercialUse(str, Enum): class CommercialUse(StrEnum):
"""Commercial use permissions.""" """Commercial use permissions."""
none = "none" none = "none"
+12 -2
View File
@@ -16,9 +16,16 @@ from __future__ import annotations
import json import json
import re import re
from pathlib import Path # noqa: TC003 # used in runtime return annotations exposed to typer from pathlib import Path # noqa: TC003 # used in runtime return annotations exposed to typer
from typing import TYPE_CHECKING
from tensors.config import DATA_DIR from tensors.config import DATA_DIR
if TYPE_CHECKING:
# Qualified `builtins.list` is referenced in annotations inside FragmentLibrary
# because the class defines a method named `list` that shadows the builtin
# at class-scope name resolution. Static-only — not needed at runtime.
import builtins
# Restrict fragment names to a safe subset so they can't escape the storage dir # Restrict fragment names to a safe subset so they can't escape the storage dir
# via path traversal and so file listings stay tidy. # via path traversal and so file listings stay tidy.
_NAME_RE = re.compile(r"^[A-Za-z0-9_.-]+$") _NAME_RE = re.compile(r"^[A-Za-z0-9_.-]+$")
@@ -132,8 +139,11 @@ class FragmentLibrary:
*, *,
name: str | None = None, name: str | None = None,
inline: str | None = None, inline: str | None = None,
extra: list[str] | None = None, # NOTE: `builtins.list` qualifier needed because this class defines a
) -> list[str]: # `list()` method below, which shadows the builtin in class-scope name
# resolution. Affects mypy/pyright even with `from __future__ import annotations`.
extra: builtins.list[str] | None = None,
) -> builtins.list[str]:
"""Merge a named fragment with an inline CSV string and optional extras. """Merge a named fragment with an inline CSV string and optional extras.
Resolution order (first match wins per duplicate): named → inline → extra. Resolution order (first match wins per duplicate): named → inline → extra.
+2 -1
View File
@@ -209,7 +209,8 @@ def remote_search(
response.raise_for_status() response.raise_for_status()
result: dict[str, Any] = response.json() result: dict[str, Any] = response.json()
# The remote API wraps CivitAI results under "civitai" key # The remote API wraps CivitAI results under "civitai" key
return result.get("civitai", result) civitai_section: dict[str, Any] = result.get("civitai", result)
return civitai_section
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
if console: if console:
console.print(f"[red]Remote API error: {e.response.status_code}[/red]") console.print(f"[red]Remote API error: {e.response.status_code}[/red]")
+11 -3
View File
@@ -27,6 +27,14 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/comfyui", tags=["ComfyUI API"]) router = APIRouter(prefix="/api/comfyui", tags=["ComfyUI API"])
# Schema default sentinels — see GenerateRequest defaults. These let us detect
# "user accepted default" vs "user explicitly chose this value matching default"
# is intentionally not distinguished; both paths apply family overrides.
_DEFAULT_STEPS = 20
_DEFAULT_CFG = 7.0
# Logging truncation threshold for long prompts in info-level output.
_PROMPT_LOG_TRUNCATE = 100
# ============================================================================= # =============================================================================
# Request/Response Models # Request/Response Models
@@ -262,9 +270,9 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
sampler = family_defaults["sampler"] sampler = family_defaults["sampler"]
if request.scheduler == "normal": # Default value in schema if request.scheduler == "normal": # Default value in schema
scheduler = family_defaults["scheduler"] scheduler = family_defaults["scheduler"]
if request.steps == 20: # Default value in schema if request.steps == _DEFAULT_STEPS: # Default value in schema
steps = family_defaults["steps"] steps = family_defaults["steps"]
if request.cfg == 7.0: # Default value in schema if request.cfg == _DEFAULT_CFG: # Default value in schema
cfg = family_defaults["cfg"] cfg = family_defaults["cfg"]
# Only override VAE if user explicitly specified one; # Only override VAE if user explicitly specified one;
# otherwise use checkpoint's built-in VAE (vae stays None) # otherwise use checkpoint's built-in VAE (vae stays None)
@@ -290,7 +298,7 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
sampler, sampler,
scheduler, scheduler,
lora_info, lora_info,
request.prompt[:100] + "..." if len(request.prompt) > 100 else request.prompt, request.prompt[:_PROMPT_LOG_TRUNCATE] + "..." if len(request.prompt) > _PROMPT_LOG_TRUNCATE else request.prompt,
) )
if request.negative_prompt: if request.negative_prompt:
logger.debug("Negative prompt: %r", request.negative_prompt[:100]) logger.debug("Negative prompt: %r", request.negative_prompt[:100])
+3 -3
View File
@@ -4,7 +4,7 @@ from __future__ import annotations
import contextlib import contextlib
import logging import logging
from enum import Enum from enum import StrEnum
from typing import Annotated, Any from typing import Annotated, Any
from fastapi import APIRouter, Query from fastapi import APIRouter, Query
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/search", tags=["Search"]) router = APIRouter(prefix="/api/search", tags=["Search"])
class Provider(str, Enum): class Provider(StrEnum):
"""Search provider options.""" """Search provider options."""
civitai = "civitai" civitai = "civitai"
@@ -45,7 +45,7 @@ class Provider(str, Enum):
all = "all" all = "all"
class SortOrder(str, Enum): class SortOrder(StrEnum):
"""Sort order options.""" """Sort order options."""
downloads = "downloads" downloads = "downloads"
+320
View File
@@ -0,0 +1,320 @@
"""Tests for the `tsr generate --parallel-queue` flag (parallel fanout path)."""
from __future__ import annotations
from pathlib import Path
from typing import Any
import pytest
from typer.testing import CliRunner
from tensors import cli as cli_module
from tensors.cli import app
runner = CliRunner()
# -----------------------------------------------------------------------------
# Fixtures
# -----------------------------------------------------------------------------
@pytest.fixture
def calls(monkeypatch: pytest.MonkeyPatch) -> list[dict[str, Any]]:
"""Record every _run_generation call and stub the disk-write side effect.
The parallel fanout path invokes _run_generation N times (one per task);
the sequential path invokes it once. By recording kwargs we can assert
fanout behavior (per-task seeds, per-task output paths, count=1 per task)
without round-tripping ComfyUI.
"""
recorded: list[dict[str, Any]] = []
def fake_run_generation(**kwargs: Any) -> None:
recorded.append(kwargs)
out: Path | None = kwargs.get("output")
if out is not None:
out.parent.mkdir(parents=True, exist_ok=True)
out.write_bytes(b"fake-png")
monkeypatch.setattr(cli_module, "_run_generation", fake_run_generation)
return recorded
@pytest.fixture(autouse=True)
def _stub_model_validation(monkeypatch: pytest.MonkeyPatch) -> None:
"""Bypass ComfyUI's live model lookup so tests don't need a backend."""
monkeypatch.setattr(
cli_module,
"_validate_model_available",
lambda model, family, lora: (model, lora),
)
# -----------------------------------------------------------------------------
# Validation / sanity
# -----------------------------------------------------------------------------
def test_parallel_queue_invalid_value_rejected(calls: list[dict[str, Any]]) -> None:
"""--parallel-queue 0 (or negative) exits non-zero before any work."""
result = runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "--parallel-queue", "0"],
)
assert result.exit_code != 0
assert "--parallel-queue must be >= 1" in result.output
assert calls == []
def test_parallel_queue_one_is_sequential_path(calls: list[dict[str, Any]]) -> None:
"""-P 1 collapses to the legacy single _run_generation call with count=N.
This is the key compatibility contract: existing scripts that don't pass
-P must see identical behavior (one call, count forwarded as batch_size).
"""
result = runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "-c", "4", "-P", "1"],
)
assert result.exit_code == 0, result.output
assert len(calls) == 1
assert calls[0]["count"] == 4
assert calls[0]["prompt"] == "test prompt"
def test_count_one_ignores_parallel_queue(calls: list[dict[str, Any]]) -> None:
"""count=1 always takes sequential path regardless of -P (no fanout point)."""
result = runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "-c", "1", "-P", "8"],
)
assert result.exit_code == 0, result.output
assert len(calls) == 1
assert calls[0]["count"] == 1
def test_json_output_incompatible_with_parallel(calls: list[dict[str, Any]]) -> None:
"""--json + -P>1 errors out cleanly (would skip disk-save inside tasks)."""
result = runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "-c", "2", "-P", "2", "--json"],
)
assert result.exit_code != 0
assert "--json is not supported with --parallel-queue > 1" in result.output
assert calls == []
# -----------------------------------------------------------------------------
# Fanout behavior
# -----------------------------------------------------------------------------
def test_parallel_fanout_creates_n_tasks(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
"""-c N -P M (M>1, N>1) → N independent _run_generation calls, each count=1."""
out = tmp_path / "img.png"
result = runner.invoke(
app,
[
"generate",
"test prompt",
"-m",
"x.safetensors",
"-c",
"4",
"-P",
"2",
"--seed",
"100",
"-o",
str(out),
],
)
assert result.exit_code == 0, result.output
assert len(calls) == 4
# Each task generates exactly one image
for c in calls:
assert c["count"] == 1
def test_parallel_seeds_increment_from_base(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
"""Explicit --seed → each task receives base+i (reproducible series)."""
out = tmp_path / "img.png"
runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "-c", "3", "-P", "3", "--seed", "500", "-o", str(out)],
)
seeds_seen = sorted(c["seed"] for c in calls)
assert seeds_seen == [500, 501, 502]
def test_parallel_seeds_random_when_unset(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
"""seed=-1 → each task gets a freshly-rolled random seed (not all the same).
Vanishingly small chance of collision across 4 random ints; treat as flake
threshold of "all distinct" rather than exact equality to any value.
"""
out = tmp_path / "img.png"
runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "-c", "4", "-P", "2", "-o", str(out)],
)
seeds = [c["seed"] for c in calls]
# All non-negative (i.e. resolved from -1 to actual int) and distinct.
assert all(s >= 0 for s in seeds)
assert len(set(seeds)) == len(seeds)
def test_parallel_output_paths_indexed(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
"""Per-task output paths use stem_NNN.suffix naming (matches sequential count>1)."""
out = tmp_path / "scene.png"
runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "-c", "3", "-P", "3", "--seed", "1", "-o", str(out)],
)
paths = sorted(str(c["output"]) for c in calls)
assert paths == [
str(tmp_path / "scene_001.png"),
str(tmp_path / "scene_002.png"),
str(tmp_path / "scene_003.png"),
]
def test_parallel_without_output_passes_none(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
"""When --output is omitted, each task gets output=None (no disk write planned)."""
runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "-c", "2", "-P", "2", "--seed", "1"],
)
assert len(calls) == 2
assert all(c["output"] is None for c in calls)
def test_parallel_files_actually_written(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
"""End-to-end: per-task stub writes its file → all N appear on disk.
Guards against the bug where json_output=True short-circuits the save block
inside _run_generation. Each task must use the non-JSON code path.
"""
out = tmp_path / "shot.png"
runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "-c", "3", "-P", "3", "--seed", "1", "-o", str(out)],
)
written = sorted(p.name for p in tmp_path.iterdir())
assert written == ["shot_001.png", "shot_002.png", "shot_003.png"]
def test_parallel_summary_reports_success_count(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
"""Final summary line reports N/N success when all tasks complete."""
out = tmp_path / "img.png"
result = runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "-c", "3", "-P", "2", "--seed", "1", "-o", str(out)],
)
assert result.exit_code == 0
assert "Generated 3/3 images" in result.output
def test_parallel_partial_failure_exits_nonzero(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""If one task raises, summary shows partial count and command exits non-zero."""
import typer
call_indices: list[int] = []
def flaky_run_generation(**kwargs: Any) -> None:
# Fail every other call to simulate intermittent backend errors.
idx = len(call_indices)
call_indices.append(idx)
if idx % 2 == 0:
raise typer.Exit(1)
out: Path | None = kwargs.get("output")
if out is not None:
out.parent.mkdir(parents=True, exist_ok=True)
out.write_bytes(b"ok")
monkeypatch.setattr(cli_module, "_run_generation", flaky_run_generation)
out = tmp_path / "img.png"
result = runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "-c", "4", "-P", "2", "--seed", "1", "-o", str(out)],
)
assert result.exit_code != 0
# Two tasks failed; final summary should show 2/4.
assert "Generated 2/4 images" in result.output
# -----------------------------------------------------------------------------
# --input integration
# -----------------------------------------------------------------------------
def test_parallel_queue_from_yaml_input(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
"""parallel_queue can be set via --input YAML (mirrors other generate params)."""
out = tmp_path / "img.png"
yml = tmp_path / "spec.yml"
yml.write_text(f'prompt: from-yaml\nmodel: x.safetensors\ncount: 3\nparallel_queue: 3\nseed: 7\noutput: "{out}"\n')
result = runner.invoke(app, ["generate", "--input", str(yml)])
assert result.exit_code == 0, result.output
assert len(calls) == 3
assert sorted(c["seed"] for c in calls) == [7, 8, 9]
def test_cli_parallel_queue_overrides_yaml(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
"""CLI --parallel-queue wins over YAML's parallel_queue (standard precedence)."""
out = tmp_path / "img.png"
yml = tmp_path / "spec.yml"
yml.write_text(f'prompt: from-yaml\nmodel: x.safetensors\ncount: 2\nparallel_queue: 1\nseed: 10\noutput: "{out}"\n')
# YAML says P=1 (sequential), CLI overrides to P=2 (fanout)
result = runner.invoke(app, ["generate", "--input", str(yml), "-P", "2"])
assert result.exit_code == 0, result.output
# Fanout path → 2 separate calls, each count=1
assert len(calls) == 2
assert all(c["count"] == 1 for c in calls)
# -----------------------------------------------------------------------------
# Concurrency assertion
# -----------------------------------------------------------------------------
def test_parallel_actually_runs_concurrently(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Sanity: P concurrent tasks really overlap in time (vs all-serial)."""
import threading
import time as _t
in_flight = 0
peak_in_flight = 0
lock = threading.Lock()
def slow_run_generation(**kwargs: Any) -> None:
nonlocal in_flight, peak_in_flight
with lock:
in_flight += 1
peak_in_flight = max(peak_in_flight, in_flight)
_t.sleep(0.1) # 100ms — long enough to overlap, short enough for fast tests
with lock:
in_flight -= 1
out: Path | None = kwargs.get("output")
if out is not None:
out.parent.mkdir(parents=True, exist_ok=True)
out.write_bytes(b"ok")
monkeypatch.setattr(cli_module, "_run_generation", slow_run_generation)
out = tmp_path / "img.png"
result = runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "-c", "4", "-P", "4", "--seed", "1", "-o", str(out)],
)
assert result.exit_code == 0, result.output
# With P=4 and 4 tasks each sleeping 100ms, peak concurrency should hit 4.
# Even allowing for thread-pool warmup quirks, ≥2 means parallelism is real.
assert peak_in_flight >= 2, f"peak_in_flight={peak_in_flight} (expected ≥2 for parallel)"
+2 -2
View File
@@ -1237,7 +1237,7 @@ class TestDownloadBackgroundTasks:
captured["api_key"] = api_key captured["api_key"] = api_key
return {"file_id": 42, "sha256": "deadbeef", "linked": True, "cached": True, "error": None} return {"file_id": 42, "sha256": "deadbeef", "linked": True, "cached": True, "error": None}
monkeypatch.setattr(download_routes_module, "Database", lambda: StubDB()) monkeypatch.setattr(download_routes_module, "Database", StubDB)
return captured return captured
def test_do_download_success(self, monkeypatch, tmp_path) -> None: def test_do_download_success(self, monkeypatch, tmp_path) -> None:
@@ -1397,7 +1397,7 @@ class TestDownloadBackgroundTasks:
def register_downloaded_file(self, *args, **kwargs): def register_downloaded_file(self, *args, **kwargs):
return {"file_id": None, "sha256": None, "linked": False, "cached": False, "error": "boom"} return {"file_id": None, "sha256": None, "linked": False, "cached": False, "error": "boom"}
monkeypatch.setattr(download_routes, "Database", lambda: FailingDB()) monkeypatch.setattr(download_routes, "Database", FailingDB)
dest_path = tmp_path / "model.safetensors" dest_path = tmp_path / "model.safetensors"
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1}) _do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})