Merge pull request #2 from saiden-dev/feat/generate-parallel-queue
feat(generate): add --parallel-queue/-P for concurrent submissions
This commit is contained in:
+225
-16
@@ -346,7 +346,10 @@ def search(
|
|||||||
return
|
return
|
||||||
|
|
||||||
key = api_key or load_api_key()
|
key = api_key or load_api_key()
|
||||||
civitai_results: dict[str, Any] | None = None
|
# Reuse the name from the remote-mode branch above (which already returned)
|
||||||
|
# without redeclaring its type — mypy treats class-scope re-annotation as
|
||||||
|
# a no-redef even when control flow guarantees the branches don't overlap.
|
||||||
|
civitai_results = None
|
||||||
hf_results: list[dict[str, Any]] | None = None
|
hf_results: list[dict[str, Any]] | None = None
|
||||||
|
|
||||||
# Search CivitAI
|
# Search CivitAI
|
||||||
@@ -874,6 +877,22 @@ def generate( # noqa: PLR0915
|
|||||||
str | None,
|
str | None,
|
||||||
typer.Option("--input", "-I", help="JSON or YAML params (file path or inline; keys match CLI options)"),
|
typer.Option("--input", "-I", help="JSON or YAML params (file path or inline; keys match CLI options)"),
|
||||||
] = None,
|
] = None,
|
||||||
|
parallel_queue: Annotated[
|
||||||
|
int,
|
||||||
|
typer.Option(
|
||||||
|
"--parallel-queue",
|
||||||
|
"-P",
|
||||||
|
help=(
|
||||||
|
"Concurrent ComfyUI submissions (default 1). When >1 with --count N, "
|
||||||
|
"splits the request into N independent jobs (batch_size=1 each) with "
|
||||||
|
"incrementing seeds, executed P-at-a-time via thread pool. The GPU "
|
||||||
|
"still processes one prompt at a time, but HTTP queue / init / "
|
||||||
|
"download phases pipeline for a ~5-15%% speedup. Per-task output "
|
||||||
|
"interleaves; final summary lists all saved files. Ignored when "
|
||||||
|
"--count is 1."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
] = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Generate an image using text-to-image.
|
"""Generate an image using text-to-image.
|
||||||
|
|
||||||
@@ -887,6 +906,11 @@ def generate( # noqa: PLR0915
|
|||||||
starting with ``{`` are JSON, everything else is YAML. CLI flags override
|
starting with ``{`` are JSON, everything else is YAML. CLI flags override
|
||||||
--input values.
|
--input values.
|
||||||
|
|
||||||
|
With --count > 1, images are generated as a single ComfyUI batch by default
|
||||||
|
(one workflow, sequential on GPU). Use --parallel-queue N to instead split
|
||||||
|
into N independent batch_size=1 jobs queued in parallel, each with its own
|
||||||
|
seed — useful for overlapping the HTTP/download phase across requests.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
tsr generate "a cat on a windowsill"
|
tsr generate "a cat on a windowsill"
|
||||||
tsr generate "portrait photo" -m ponyDiffusionV6XL_v6.safetensors -O portrait
|
tsr generate "portrait photo" -m ponyDiffusionV6XL_v6.safetensors -O portrait
|
||||||
@@ -895,7 +919,20 @@ def generate( # noqa: PLR0915
|
|||||||
tsr generate --input '{"prompt": "a mech", "model": "flux1-dev-fp8.safetensors"}'
|
tsr generate --input '{"prompt": "a mech", "model": "flux1-dev-fp8.safetensors"}'
|
||||||
tsr generate --input scene.yml
|
tsr generate --input scene.yml
|
||||||
tsr generate "raw prompt" --no-quality --no-negative
|
tsr generate "raw prompt" --no-quality --no-negative
|
||||||
|
tsr generate "city" -c 8 -P 4 -o out.png # 8 distinct seeds, 4 in flight
|
||||||
"""
|
"""
|
||||||
|
if parallel_queue < 1:
|
||||||
|
console.print("[red]--parallel-queue must be >= 1[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
if parallel_queue > 1 and json_output:
|
||||||
|
# _run_generation short-circuits the disk-save when json_output=True
|
||||||
|
# (it dumps JSON and returns). For the parallel fanout to actually save
|
||||||
|
# files, each task must take the non-JSON path. We render our own JSON
|
||||||
|
# at the end, so the per-task --json is incompatible.
|
||||||
|
console.print(
|
||||||
|
"[red]--json is not supported with --parallel-queue > 1 (would skip the file-save step). Drop one or the other.[/red]"
|
||||||
|
)
|
||||||
|
raise typer.Exit(1)
|
||||||
# ---- --input merging (JSON or YAML) ----
|
# ---- --input merging (JSON or YAML) ----
|
||||||
if json_input is not None:
|
if json_input is not None:
|
||||||
ji = _parse_generate_input(json_input)
|
ji = _parse_generate_input(json_input)
|
||||||
@@ -912,7 +949,9 @@ def generate( # noqa: PLR0915
|
|||||||
{
|
{
|
||||||
p.name
|
p.name
|
||||||
for p in click_ctx.command.params
|
for p in click_ctx.command.params
|
||||||
if click_ctx.get_parameter_source(p.name) == click.core.ParameterSource.COMMANDLINE
|
# click's Parameter.name is typed `str | None` in stubs but is always
|
||||||
|
# a real string at runtime for any param that's been registered.
|
||||||
|
if p.name is not None and click_ctx.get_parameter_source(p.name) == click.core.ParameterSource.COMMANDLINE
|
||||||
}
|
}
|
||||||
if hasattr(click_ctx, "get_parameter_source")
|
if hasattr(click_ctx, "get_parameter_source")
|
||||||
else set()
|
else set()
|
||||||
@@ -981,12 +1020,22 @@ def generate( # noqa: PLR0915
|
|||||||
scene_prompt = sp_val if isinstance(sp_val, str) else ", ".join(str(x) for x in sp_val if str(x).strip())
|
scene_prompt = sp_val if isinstance(sp_val, str) else ", ".join(str(x) for x in sp_val if str(x).strip())
|
||||||
if "rating" in mapped and "rating" not in explicit:
|
if "rating" in mapped and "rating" not in explicit:
|
||||||
rating = mapped["rating"]
|
rating = mapped["rating"]
|
||||||
|
if "parallel_queue" in mapped and "parallel_queue" not in explicit:
|
||||||
|
parallel_queue = int(mapped["parallel_queue"])
|
||||||
|
|
||||||
has_content = bool(prompt or character or character_prompt or scene or scene_prompt)
|
has_content = bool(prompt or character or character_prompt or scene or scene_prompt)
|
||||||
if not has_content:
|
if not has_content:
|
||||||
console.print("[red]Prompt (or character/scene) is required[/red]")
|
console.print("[red]Prompt (or character/scene) is required[/red]")
|
||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
# Effective parallelism is bounded by count — running 4 threads for 1 image
|
||||||
|
# is silly. count=1 always goes through the sequential path regardless of -P.
|
||||||
|
effective_parallel = min(parallel_queue, count) if count > 1 else 1
|
||||||
|
|
||||||
|
if effective_parallel <= 1:
|
||||||
|
# Sequential path: single _run_generation call with batch_size=count.
|
||||||
|
# Unchanged from pre-parallel behavior — preserves existing output naming,
|
||||||
|
# JSON shape, and log lines exactly.
|
||||||
_run_generation(
|
_run_generation(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=model,
|
model=model,
|
||||||
@@ -1016,6 +1065,164 @@ def generate( # noqa: PLR0915
|
|||||||
remote=remote,
|
remote=remote,
|
||||||
json_output=json_output,
|
json_output=json_output,
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ---- Parallel fanout path ----
|
||||||
|
# Split count into `count` independent jobs (batch_size=1), executed
|
||||||
|
# `effective_parallel` at a time. Each job gets a distinct seed and a
|
||||||
|
# distinct output path so writes don't clobber each other.
|
||||||
|
import random as _rng # noqa: PLC0415
|
||||||
|
import time as _time # noqa: PLC0415
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed # noqa: PLC0415
|
||||||
|
|
||||||
|
# Resolve bare model/lora names ONCE in the parent before fanout. Each
|
||||||
|
# parallel _run_generation call silences its own console (json_output=True)
|
||||||
|
# which also skips the validation/resolution step in that path. Doing it
|
||||||
|
# here means each task receives a canonical filename and ComfyUI's strict
|
||||||
|
# loaders accept the request first try.
|
||||||
|
if model and not remote:
|
||||||
|
# Detect family for the right loader bucket (checkpoints vs diffusion_models).
|
||||||
|
# Mirrors the lookup _run_generation does on entry.
|
||||||
|
from tensors.db import Database # noqa: PLC0415
|
||||||
|
|
||||||
|
_base_model: str | None = None
|
||||||
|
try:
|
||||||
|
with Database() as _db:
|
||||||
|
_db.init_schema()
|
||||||
|
_base_model = _db.get_base_model_by_filename(model)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
_detected = detect_model_family(model, _base_model)
|
||||||
|
_fam = family or _detected
|
||||||
|
try:
|
||||||
|
model, lora = _validate_model_available(model, _fam, lora)
|
||||||
|
except typer.Exit:
|
||||||
|
raise # surface the same error path as sequential
|
||||||
|
|
||||||
|
# Seed strategy:
|
||||||
|
# --seed >= 0 → use as base, increment per job (reproducible series)
|
||||||
|
# --seed == -1 → pick a fresh random seed PER JOB so parallel runs aren't
|
||||||
|
# accidentally correlated (each thread gets variety)
|
||||||
|
seeds = [seed + i for i in range(count)] if seed >= 0 else [_rng.randint(0, 2**32 - 1) for _ in range(count)]
|
||||||
|
|
||||||
|
# Output paths: mirror the existing `count > 1` naming convention from
|
||||||
|
# _run_generation (stem_NNN.ext). When --output is omitted, leave per-task
|
||||||
|
# output as None — _run_generation will skip the disk write and the user
|
||||||
|
# gets only the console listing of generated image refs.
|
||||||
|
out_paths: list[Path | None] = []
|
||||||
|
for i in range(count):
|
||||||
|
if output is None:
|
||||||
|
out_paths.append(None)
|
||||||
|
else:
|
||||||
|
out_paths.append(output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}")
|
||||||
|
|
||||||
|
if not json_output:
|
||||||
|
console.print(
|
||||||
|
f"[dim]Parallel queue: {effective_parallel} concurrent submissions x {count} images (output may interleave)[/dim]"
|
||||||
|
)
|
||||||
|
|
||||||
|
common_kwargs: dict[str, Any] = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"model": model,
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"steps": steps,
|
||||||
|
"cfg": cfg,
|
||||||
|
"guidance": guidance,
|
||||||
|
"sampler": sampler,
|
||||||
|
"scheduler": scheduler,
|
||||||
|
"vae": vae,
|
||||||
|
"orientation": orientation,
|
||||||
|
"lora": lora,
|
||||||
|
"lora_strength": lora_strength,
|
||||||
|
"negative": negative,
|
||||||
|
"count": 1, # each task generates exactly one image
|
||||||
|
"rating": rating,
|
||||||
|
"no_quality": no_quality,
|
||||||
|
"no_negative": no_negative,
|
||||||
|
"character": character,
|
||||||
|
"character_prompt": character_prompt,
|
||||||
|
"scene": scene,
|
||||||
|
"scene_prompt": scene_prompt,
|
||||||
|
"family": family,
|
||||||
|
"remote": remote,
|
||||||
|
# NOTE: json_output stays False so _run_generation's disk-save path runs.
|
||||||
|
# Setting True would short-circuit before saving files. Per-task console
|
||||||
|
# chatter is the trade-off; the final summary still shows clean per-task
|
||||||
|
# status lines.
|
||||||
|
"json_output": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _run_one(idx: int) -> dict[str, Any]:
|
||||||
|
"""Run a single batch_size=1 job. Returns a result dict (success captured)."""
|
||||||
|
start = _time.perf_counter()
|
||||||
|
result: dict[str, Any] = {
|
||||||
|
"index": idx,
|
||||||
|
"seed": seeds[idx],
|
||||||
|
"output": str(out_paths[idx]) if out_paths[idx] is not None else None,
|
||||||
|
"duration_sec": 0.0,
|
||||||
|
"success": False,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
_run_generation(seed=seeds[idx], output=out_paths[idx], **common_kwargs)
|
||||||
|
result["duration_sec"] = round(_time.perf_counter() - start, 2)
|
||||||
|
result["success"] = True
|
||||||
|
except typer.Exit as ex:
|
||||||
|
result["duration_sec"] = round(_time.perf_counter() - start, 2)
|
||||||
|
result["error"] = f"generate exited with code {ex.exit_code}"
|
||||||
|
except Exception as ex:
|
||||||
|
result["duration_sec"] = round(_time.perf_counter() - start, 2)
|
||||||
|
result["error"] = str(ex)
|
||||||
|
return result
|
||||||
|
|
||||||
|
fan_results: list[dict[str, Any]] = []
|
||||||
|
with ThreadPoolExecutor(max_workers=effective_parallel) as pool:
|
||||||
|
futures = {pool.submit(_run_one, i): i for i in range(count)}
|
||||||
|
for completed, fut in enumerate(as_completed(futures), start=1):
|
||||||
|
try:
|
||||||
|
res = fut.result()
|
||||||
|
except Exception as ex:
|
||||||
|
# Defensive — _run_one already swallows, but if the executor itself
|
||||||
|
# raises (e.g. pickling failure) we still want a well-formed result
|
||||||
|
# in the manifest rather than a crash.
|
||||||
|
res = {
|
||||||
|
"index": futures[fut],
|
||||||
|
"seed": seeds[futures[fut]],
|
||||||
|
"output": str(out_paths[futures[fut]]) if out_paths[futures[fut]] is not None else None,
|
||||||
|
"duration_sec": 0.0,
|
||||||
|
"success": False,
|
||||||
|
"error": f"executor exception: {ex}",
|
||||||
|
}
|
||||||
|
fan_results.append(res)
|
||||||
|
if not json_output:
|
||||||
|
if res["success"]:
|
||||||
|
where = res["output"] or "(no --output set)"
|
||||||
|
console.print(
|
||||||
|
f"[green]\\[{completed}/{count}] seed={res['seed']} ok in {res['duration_sec']:.1f}s → {where}[/green]"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
console.print(f"[red]\\[{completed}/{count}] seed={res['seed']} FAIL: {res['error']}[/red]")
|
||||||
|
|
||||||
|
# Reorder by original index so JSON output / final summary list is stable.
|
||||||
|
fan_results.sort(key=lambda r: r["index"])
|
||||||
|
successful = sum(1 for r in fan_results if r["success"])
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
console.print_json(
|
||||||
|
data={
|
||||||
|
"success": successful == count,
|
||||||
|
"count": count,
|
||||||
|
"parallel_queue": effective_parallel,
|
||||||
|
"results": fan_results,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
console.print("[bold green]Generation complete![/bold green]")
|
||||||
|
console.print(f"[dim]Generated {successful}/{count} images at parallelism={effective_parallel}[/dim]")
|
||||||
|
if successful < count:
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
|
||||||
# Map model family → which ComfyUI loader directory the checkpoint must live in.
|
# Map model family → which ComfyUI loader directory the checkpoint must live in.
|
||||||
@@ -1283,7 +1490,7 @@ def _run_generation( # noqa: PLR0915
|
|||||||
|
|
||||||
# ---- Resolve preset defaults for None params (both remote and local need these) ----
|
# ---- Resolve preset defaults for None params (both remote and local need these) ----
|
||||||
from tensors.config import resolve_orientation # noqa: PLC0415
|
from tensors.config import resolve_orientation # noqa: PLC0415
|
||||||
from tensors.config import resolve_remote as do_resolve_remote
|
from tensors.config import resolve_remote as do_resolve_remote # noqa: PLC0415
|
||||||
|
|
||||||
# Use already-detected family_defaults from DB lookup above (not filename guessing)
|
# Use already-detected family_defaults from DB lookup above (not filename guessing)
|
||||||
if family_defaults:
|
if family_defaults:
|
||||||
@@ -1507,7 +1714,7 @@ _STYLE_SWEEP_TEMPLATE_KEYS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _load_json_file_or_inline(value: str | list | dict, *, what: str) -> Any:
|
def _load_json_file_or_inline(value: str | list[Any] | dict[str, Any], *, what: str) -> Any:
|
||||||
"""Load JSON from a file path or accept already-parsed inline data.
|
"""Load JSON from a file path or accept already-parsed inline data.
|
||||||
|
|
||||||
`value` may be a path string, a JSON string, or an already-parsed list/dict
|
`value` may be a path string, a JSON string, or an already-parsed list/dict
|
||||||
@@ -1885,7 +2092,7 @@ def style_sweep( # noqa: PLR0915
|
|||||||
|
|
||||||
def _run_one(task: tuple[int, dict[str, str], dict[str, Any], Path]) -> dict[str, Any]:
|
def _run_one(task: tuple[int, dict[str, str], dict[str, Any], Path]) -> dict[str, Any]:
|
||||||
"""Run a single style. Returns the result dict (success or error captured)."""
|
"""Run a single style. Returns the result dict (success or error captured)."""
|
||||||
idx, entry_in, res, opath = task
|
_idx, _entry_in, res, opath = task
|
||||||
composed = res["prompt"]
|
composed = res["prompt"]
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
try:
|
try:
|
||||||
@@ -1937,11 +2144,9 @@ def style_sweep( # noqa: PLR0915
|
|||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=parallel_queue) as pool:
|
with ThreadPoolExecutor(max_workers=parallel_queue) as pool:
|
||||||
futures = {pool.submit(_run_one, task): task for task in pending_tasks}
|
futures = {pool.submit(_run_one, task): task for task in pending_tasks}
|
||||||
completed = 0
|
for completed, fut in enumerate(as_completed(futures), start=1):
|
||||||
for fut in as_completed(futures):
|
|
||||||
completed += 1
|
|
||||||
task = futures[fut]
|
task = futures[fut]
|
||||||
idx, _entry, _res, _out_path = task
|
idx, _entry, _res, _out_path = task # idx used in log message below
|
||||||
try:
|
try:
|
||||||
res = fut.result()
|
res = fut.result()
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
@@ -1988,14 +2193,16 @@ def style_sweep( # noqa: PLR0915
|
|||||||
|
|
||||||
def _write_sweep_manifest(
|
def _write_sweep_manifest(
|
||||||
out_dir: Path,
|
out_dir: Path,
|
||||||
template_path: Path,
|
template_path: Path | None,
|
||||||
styles_origin: str,
|
styles_origin: str,
|
||||||
results: list[dict[str, Any]],
|
results: list[dict[str, Any]],
|
||||||
) -> Path:
|
) -> Path:
|
||||||
"""Write the per-sweep manifest JSON. Returns the path."""
|
"""Write the per-sweep manifest JSON. Returns the path."""
|
||||||
manifest_path = out_dir / "_sweep.json"
|
manifest_path = out_dir / "_sweep.json"
|
||||||
manifest: dict[str, Any] = {
|
manifest: dict[str, Any] = {
|
||||||
"template": str(template_path),
|
# template_path is None when --list is used with only --styles (no template
|
||||||
|
# required). Serialize as empty string to keep manifest schema stable.
|
||||||
|
"template": str(template_path) if template_path is not None else "",
|
||||||
"styles_source": styles_origin,
|
"styles_source": styles_origin,
|
||||||
"results": results,
|
"results": results,
|
||||||
}
|
}
|
||||||
@@ -2024,7 +2231,7 @@ def _print_styles_list(styles_origin: str, entries: list[dict[str, str]]) -> Non
|
|||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def template(
|
def template( # noqa: PLR0915
|
||||||
model: Annotated[str, typer.Option("-m", "--model", help="Checkpoint model name")],
|
model: Annotated[str, typer.Option("-m", "--model", help="Checkpoint model name")],
|
||||||
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None,
|
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None,
|
||||||
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8,
|
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8,
|
||||||
@@ -2842,7 +3049,7 @@ def scene_extract(
|
|||||||
target_file = None
|
target_file = None
|
||||||
for f in files:
|
for f in files:
|
||||||
file_path = Path(f["file_path"])
|
file_path = Path(f["file_path"])
|
||||||
if file_path.name == model or file_path.stem == model:
|
if model in (file_path.name, file_path.stem):
|
||||||
target_file = f
|
target_file = f
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -2970,7 +3177,7 @@ app.add_typer(templates_app)
|
|||||||
|
|
||||||
|
|
||||||
@templates_app.command("extract")
|
@templates_app.command("extract")
|
||||||
def templates_extract(
|
def templates_extract( # noqa: PLR0915
|
||||||
model: Annotated[str, typer.Argument(help="Local model name (e.g. lust_v10.safetensors)")],
|
model: Annotated[str, typer.Argument(help="Local model name (e.g. lust_v10.safetensors)")],
|
||||||
orientation: Annotated[str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape")] = "portrait",
|
orientation: Annotated[str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape")] = "portrait",
|
||||||
no_overrides: Annotated[
|
no_overrides: Annotated[
|
||||||
@@ -3428,8 +3635,10 @@ def comfy_generate(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""[Deprecated] Use 'tsr generate' instead. All features have been merged into the top-level command."""
|
"""[Deprecated] Use 'tsr generate' instead. All features have been merged into the top-level command."""
|
||||||
console.print("[yellow]Warning: 'tsr comfy generate' is deprecated. Use 'tsr generate' instead.[/yellow]")
|
console.print("[yellow]Warning: 'tsr comfy generate' is deprecated. Use 'tsr generate' instead.[/yellow]")
|
||||||
# Delegate to the unified generate command via context invocation
|
# Delegate to the unified generate command via context invocation.
|
||||||
ctx = typer.Context(generate)
|
# typer.Context expects a click.Command, but passing the typer function directly
|
||||||
|
# works at runtime via duck-typing — keeping it for back-compat with deprecated alias.
|
||||||
|
ctx = typer.Context(generate) # type: ignore[arg-type]
|
||||||
generate(
|
generate(
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
|||||||
+3
-4
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
@@ -388,7 +389,7 @@ def queue_prompt(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _wait_for_completion_ws(
|
def _wait_for_completion_ws( # noqa: PLR0915
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
url: str,
|
url: str,
|
||||||
client_id: str,
|
client_id: str,
|
||||||
@@ -494,10 +495,8 @@ def _wait_for_completion_ws(
|
|||||||
break
|
break
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
ws.close()
|
ws.close()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Fetch final outputs from history to ensure we have everything
|
# Fetch final outputs from history to ensure we have everything
|
||||||
try:
|
try:
|
||||||
|
|||||||
+8
-8
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import tomllib
|
import tomllib
|
||||||
from enum import Enum
|
from enum import StrEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -65,7 +65,7 @@ CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models"
|
|||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
class Provider(str, Enum):
|
class Provider(StrEnum):
|
||||||
"""Model search providers."""
|
"""Model search providers."""
|
||||||
|
|
||||||
civitai = "civitai"
|
civitai = "civitai"
|
||||||
@@ -73,7 +73,7 @@ class Provider(str, Enum):
|
|||||||
all = "all"
|
all = "all"
|
||||||
|
|
||||||
|
|
||||||
class ModelType(str, Enum):
|
class ModelType(StrEnum):
|
||||||
"""CivitAI model types."""
|
"""CivitAI model types."""
|
||||||
|
|
||||||
checkpoint = "checkpoint"
|
checkpoint = "checkpoint"
|
||||||
@@ -110,7 +110,7 @@ class ModelType(str, Enum):
|
|||||||
return mapping[self.value]
|
return mapping[self.value]
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(str, Enum):
|
class BaseModel(StrEnum):
|
||||||
"""Common base models."""
|
"""Common base models."""
|
||||||
|
|
||||||
# Stable Diffusion 1.x
|
# Stable Diffusion 1.x
|
||||||
@@ -166,7 +166,7 @@ class BaseModel(str, Enum):
|
|||||||
return mapping[self.value]
|
return mapping[self.value]
|
||||||
|
|
||||||
|
|
||||||
class SortOrder(str, Enum):
|
class SortOrder(StrEnum):
|
||||||
"""Sort options for search."""
|
"""Sort options for search."""
|
||||||
|
|
||||||
downloads = "downloads"
|
downloads = "downloads"
|
||||||
@@ -183,7 +183,7 @@ class SortOrder(str, Enum):
|
|||||||
return mapping[self.value]
|
return mapping[self.value]
|
||||||
|
|
||||||
|
|
||||||
class Period(str, Enum):
|
class Period(StrEnum):
|
||||||
"""Time period for sorting/filtering."""
|
"""Time period for sorting/filtering."""
|
||||||
|
|
||||||
all = "all"
|
all = "all"
|
||||||
@@ -204,7 +204,7 @@ class Period(str, Enum):
|
|||||||
return mapping[self.value]
|
return mapping[self.value]
|
||||||
|
|
||||||
|
|
||||||
class NsfwLevel(str, Enum):
|
class NsfwLevel(StrEnum):
|
||||||
"""NSFW content filter level."""
|
"""NSFW content filter level."""
|
||||||
|
|
||||||
none = "none"
|
none = "none"
|
||||||
@@ -219,7 +219,7 @@ class NsfwLevel(str, Enum):
|
|||||||
return self.value.capitalize() if self.value != "none" else "None"
|
return self.value.capitalize() if self.value != "none" else "None"
|
||||||
|
|
||||||
|
|
||||||
class CommercialUse(str, Enum):
|
class CommercialUse(StrEnum):
|
||||||
"""Commercial use permissions."""
|
"""Commercial use permissions."""
|
||||||
|
|
||||||
none = "none"
|
none = "none"
|
||||||
|
|||||||
+12
-2
@@ -16,9 +16,16 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from pathlib import Path # noqa: TC003 # used in runtime return annotations exposed to typer
|
from pathlib import Path # noqa: TC003 # used in runtime return annotations exposed to typer
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from tensors.config import DATA_DIR
|
from tensors.config import DATA_DIR
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# Qualified `builtins.list` is referenced in annotations inside FragmentLibrary
|
||||||
|
# because the class defines a method named `list` that shadows the builtin
|
||||||
|
# at class-scope name resolution. Static-only — not needed at runtime.
|
||||||
|
import builtins
|
||||||
|
|
||||||
# Restrict fragment names to a safe subset so they can't escape the storage dir
|
# Restrict fragment names to a safe subset so they can't escape the storage dir
|
||||||
# via path traversal and so file listings stay tidy.
|
# via path traversal and so file listings stay tidy.
|
||||||
_NAME_RE = re.compile(r"^[A-Za-z0-9_.-]+$")
|
_NAME_RE = re.compile(r"^[A-Za-z0-9_.-]+$")
|
||||||
@@ -132,8 +139,11 @@ class FragmentLibrary:
|
|||||||
*,
|
*,
|
||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
inline: str | None = None,
|
inline: str | None = None,
|
||||||
extra: list[str] | None = None,
|
# NOTE: `builtins.list` qualifier needed because this class defines a
|
||||||
) -> list[str]:
|
# `list()` method below, which shadows the builtin in class-scope name
|
||||||
|
# resolution. Affects mypy/pyright even with `from __future__ import annotations`.
|
||||||
|
extra: builtins.list[str] | None = None,
|
||||||
|
) -> builtins.list[str]:
|
||||||
"""Merge a named fragment with an inline CSV string and optional extras.
|
"""Merge a named fragment with an inline CSV string and optional extras.
|
||||||
|
|
||||||
Resolution order (first match wins per duplicate): named → inline → extra.
|
Resolution order (first match wins per duplicate): named → inline → extra.
|
||||||
|
|||||||
+2
-1
@@ -209,7 +209,8 @@ def remote_search(
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result: dict[str, Any] = response.json()
|
result: dict[str, Any] = response.json()
|
||||||
# The remote API wraps CivitAI results under "civitai" key
|
# The remote API wraps CivitAI results under "civitai" key
|
||||||
return result.get("civitai", result)
|
civitai_section: dict[str, Any] = result.get("civitai", result)
|
||||||
|
return civitai_section
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
if console:
|
if console:
|
||||||
console.print(f"[red]Remote API error: {e.response.status_code}[/red]")
|
console.print(f"[red]Remote API error: {e.response.status_code}[/red]")
|
||||||
|
|||||||
@@ -27,6 +27,14 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
router = APIRouter(prefix="/api/comfyui", tags=["ComfyUI API"])
|
router = APIRouter(prefix="/api/comfyui", tags=["ComfyUI API"])
|
||||||
|
|
||||||
|
# Schema default sentinels — see GenerateRequest defaults. These let us detect
|
||||||
|
# "user accepted default" vs "user explicitly chose this value matching default"
|
||||||
|
# is intentionally not distinguished; both paths apply family overrides.
|
||||||
|
_DEFAULT_STEPS = 20
|
||||||
|
_DEFAULT_CFG = 7.0
|
||||||
|
# Logging truncation threshold for long prompts in info-level output.
|
||||||
|
_PROMPT_LOG_TRUNCATE = 100
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Request/Response Models
|
# Request/Response Models
|
||||||
@@ -262,9 +270,9 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
|
|||||||
sampler = family_defaults["sampler"]
|
sampler = family_defaults["sampler"]
|
||||||
if request.scheduler == "normal": # Default value in schema
|
if request.scheduler == "normal": # Default value in schema
|
||||||
scheduler = family_defaults["scheduler"]
|
scheduler = family_defaults["scheduler"]
|
||||||
if request.steps == 20: # Default value in schema
|
if request.steps == _DEFAULT_STEPS: # Default value in schema
|
||||||
steps = family_defaults["steps"]
|
steps = family_defaults["steps"]
|
||||||
if request.cfg == 7.0: # Default value in schema
|
if request.cfg == _DEFAULT_CFG: # Default value in schema
|
||||||
cfg = family_defaults["cfg"]
|
cfg = family_defaults["cfg"]
|
||||||
# Only override VAE if user explicitly specified one;
|
# Only override VAE if user explicitly specified one;
|
||||||
# otherwise use checkpoint's built-in VAE (vae stays None)
|
# otherwise use checkpoint's built-in VAE (vae stays None)
|
||||||
@@ -290,7 +298,7 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
|
|||||||
sampler,
|
sampler,
|
||||||
scheduler,
|
scheduler,
|
||||||
lora_info,
|
lora_info,
|
||||||
request.prompt[:100] + "..." if len(request.prompt) > 100 else request.prompt,
|
request.prompt[:_PROMPT_LOG_TRUNCATE] + "..." if len(request.prompt) > _PROMPT_LOG_TRUNCATE else request.prompt,
|
||||||
)
|
)
|
||||||
if request.negative_prompt:
|
if request.negative_prompt:
|
||||||
logger.debug("Negative prompt: %r", request.negative_prompt[:100])
|
logger.debug("Negative prompt: %r", request.negative_prompt[:100])
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import StrEnum
|
||||||
from typing import Annotated, Any
|
from typing import Annotated, Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Query
|
from fastapi import APIRouter, Query
|
||||||
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
|
|||||||
router = APIRouter(prefix="/api/search", tags=["Search"])
|
router = APIRouter(prefix="/api/search", tags=["Search"])
|
||||||
|
|
||||||
|
|
||||||
class Provider(str, Enum):
|
class Provider(StrEnum):
|
||||||
"""Search provider options."""
|
"""Search provider options."""
|
||||||
|
|
||||||
civitai = "civitai"
|
civitai = "civitai"
|
||||||
@@ -45,7 +45,7 @@ class Provider(str, Enum):
|
|||||||
all = "all"
|
all = "all"
|
||||||
|
|
||||||
|
|
||||||
class SortOrder(str, Enum):
|
class SortOrder(StrEnum):
|
||||||
"""Sort order options."""
|
"""Sort order options."""
|
||||||
|
|
||||||
downloads = "downloads"
|
downloads = "downloads"
|
||||||
|
|||||||
@@ -0,0 +1,320 @@
|
|||||||
|
"""Tests for the `tsr generate --parallel-queue` flag (parallel fanout path)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
from tensors import cli as cli_module
|
||||||
|
from tensors.cli import app
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def calls(monkeypatch: pytest.MonkeyPatch) -> list[dict[str, Any]]:
|
||||||
|
"""Record every _run_generation call and stub the disk-write side effect.
|
||||||
|
|
||||||
|
The parallel fanout path invokes _run_generation N times (one per task);
|
||||||
|
the sequential path invokes it once. By recording kwargs we can assert
|
||||||
|
fanout behavior (per-task seeds, per-task output paths, count=1 per task)
|
||||||
|
without round-tripping ComfyUI.
|
||||||
|
"""
|
||||||
|
recorded: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
def fake_run_generation(**kwargs: Any) -> None:
|
||||||
|
recorded.append(kwargs)
|
||||||
|
out: Path | None = kwargs.get("output")
|
||||||
|
if out is not None:
|
||||||
|
out.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
out.write_bytes(b"fake-png")
|
||||||
|
|
||||||
|
monkeypatch.setattr(cli_module, "_run_generation", fake_run_generation)
|
||||||
|
return recorded
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _stub_model_validation(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
"""Bypass ComfyUI's live model lookup so tests don't need a backend."""
|
||||||
|
monkeypatch.setattr(
|
||||||
|
cli_module,
|
||||||
|
"_validate_model_available",
|
||||||
|
lambda model, family, lora: (model, lora),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Validation / sanity
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_parallel_queue_invalid_value_rejected(calls: list[dict[str, Any]]) -> None:
|
||||||
|
"""--parallel-queue 0 (or negative) exits non-zero before any work."""
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
["generate", "test prompt", "-m", "x.safetensors", "--parallel-queue", "0"],
|
||||||
|
)
|
||||||
|
assert result.exit_code != 0
|
||||||
|
assert "--parallel-queue must be >= 1" in result.output
|
||||||
|
assert calls == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_parallel_queue_one_is_sequential_path(calls: list[dict[str, Any]]) -> None:
|
||||||
|
"""-P 1 collapses to the legacy single _run_generation call with count=N.
|
||||||
|
|
||||||
|
This is the key compatibility contract: existing scripts that don't pass
|
||||||
|
-P must see identical behavior (one call, count forwarded as batch_size).
|
||||||
|
"""
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
["generate", "test prompt", "-m", "x.safetensors", "-c", "4", "-P", "1"],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0, result.output
|
||||||
|
assert len(calls) == 1
|
||||||
|
assert calls[0]["count"] == 4
|
||||||
|
assert calls[0]["prompt"] == "test prompt"
|
||||||
|
|
||||||
|
|
||||||
|
def test_count_one_ignores_parallel_queue(calls: list[dict[str, Any]]) -> None:
|
||||||
|
"""count=1 always takes sequential path regardless of -P (no fanout point)."""
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
["generate", "test prompt", "-m", "x.safetensors", "-c", "1", "-P", "8"],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0, result.output
|
||||||
|
assert len(calls) == 1
|
||||||
|
assert calls[0]["count"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_output_incompatible_with_parallel(calls: list[dict[str, Any]]) -> None:
|
||||||
|
"""--json + -P>1 errors out cleanly (would skip disk-save inside tasks)."""
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
["generate", "test prompt", "-m", "x.safetensors", "-c", "2", "-P", "2", "--json"],
|
||||||
|
)
|
||||||
|
assert result.exit_code != 0
|
||||||
|
assert "--json is not supported with --parallel-queue > 1" in result.output
|
||||||
|
assert calls == []
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Fanout behavior
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_parallel_fanout_creates_n_tasks(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
|
||||||
|
"""-c N -P M (M>1, N>1) → N independent _run_generation calls, each count=1."""
|
||||||
|
out = tmp_path / "img.png"
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"generate",
|
||||||
|
"test prompt",
|
||||||
|
"-m",
|
||||||
|
"x.safetensors",
|
||||||
|
"-c",
|
||||||
|
"4",
|
||||||
|
"-P",
|
||||||
|
"2",
|
||||||
|
"--seed",
|
||||||
|
"100",
|
||||||
|
"-o",
|
||||||
|
str(out),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0, result.output
|
||||||
|
assert len(calls) == 4
|
||||||
|
# Each task generates exactly one image
|
||||||
|
for c in calls:
|
||||||
|
assert c["count"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_parallel_seeds_increment_from_base(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
|
||||||
|
"""Explicit --seed → each task receives base+i (reproducible series)."""
|
||||||
|
out = tmp_path / "img.png"
|
||||||
|
runner.invoke(
|
||||||
|
app,
|
||||||
|
["generate", "test prompt", "-m", "x.safetensors", "-c", "3", "-P", "3", "--seed", "500", "-o", str(out)],
|
||||||
|
)
|
||||||
|
seeds_seen = sorted(c["seed"] for c in calls)
|
||||||
|
assert seeds_seen == [500, 501, 502]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parallel_seeds_random_when_unset(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
|
||||||
|
"""seed=-1 → each task gets a freshly-rolled random seed (not all the same).
|
||||||
|
|
||||||
|
Vanishingly small chance of collision across 4 random ints; treat as flake
|
||||||
|
threshold of "all distinct" rather than exact equality to any value.
|
||||||
|
"""
|
||||||
|
out = tmp_path / "img.png"
|
||||||
|
runner.invoke(
|
||||||
|
app,
|
||||||
|
["generate", "test prompt", "-m", "x.safetensors", "-c", "4", "-P", "2", "-o", str(out)],
|
||||||
|
)
|
||||||
|
seeds = [c["seed"] for c in calls]
|
||||||
|
# All non-negative (i.e. resolved from -1 to actual int) and distinct.
|
||||||
|
assert all(s >= 0 for s in seeds)
|
||||||
|
assert len(set(seeds)) == len(seeds)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parallel_output_paths_indexed(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
|
||||||
|
"""Per-task output paths use stem_NNN.suffix naming (matches sequential count>1)."""
|
||||||
|
out = tmp_path / "scene.png"
|
||||||
|
runner.invoke(
|
||||||
|
app,
|
||||||
|
["generate", "test prompt", "-m", "x.safetensors", "-c", "3", "-P", "3", "--seed", "1", "-o", str(out)],
|
||||||
|
)
|
||||||
|
paths = sorted(str(c["output"]) for c in calls)
|
||||||
|
assert paths == [
|
||||||
|
str(tmp_path / "scene_001.png"),
|
||||||
|
str(tmp_path / "scene_002.png"),
|
||||||
|
str(tmp_path / "scene_003.png"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parallel_without_output_passes_none(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
|
||||||
|
"""When --output is omitted, each task gets output=None (no disk write planned)."""
|
||||||
|
runner.invoke(
|
||||||
|
app,
|
||||||
|
["generate", "test prompt", "-m", "x.safetensors", "-c", "2", "-P", "2", "--seed", "1"],
|
||||||
|
)
|
||||||
|
assert len(calls) == 2
|
||||||
|
assert all(c["output"] is None for c in calls)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parallel_files_actually_written(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
|
||||||
|
"""End-to-end: per-task stub writes its file → all N appear on disk.
|
||||||
|
|
||||||
|
Guards against the bug where json_output=True short-circuits the save block
|
||||||
|
inside _run_generation. Each task must use the non-JSON code path.
|
||||||
|
"""
|
||||||
|
out = tmp_path / "shot.png"
|
||||||
|
runner.invoke(
|
||||||
|
app,
|
||||||
|
["generate", "test prompt", "-m", "x.safetensors", "-c", "3", "-P", "3", "--seed", "1", "-o", str(out)],
|
||||||
|
)
|
||||||
|
written = sorted(p.name for p in tmp_path.iterdir())
|
||||||
|
assert written == ["shot_001.png", "shot_002.png", "shot_003.png"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parallel_summary_reports_success_count(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
|
||||||
|
"""Final summary line reports N/N success when all tasks complete."""
|
||||||
|
out = tmp_path / "img.png"
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
["generate", "test prompt", "-m", "x.safetensors", "-c", "3", "-P", "2", "--seed", "1", "-o", str(out)],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Generated 3/3 images" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_parallel_partial_failure_exits_nonzero(
|
||||||
|
tmp_path: Path,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""If one task raises, summary shows partial count and command exits non-zero."""
|
||||||
|
import typer
|
||||||
|
|
||||||
|
call_indices: list[int] = []
|
||||||
|
|
||||||
|
def flaky_run_generation(**kwargs: Any) -> None:
|
||||||
|
# Fail every other call to simulate intermittent backend errors.
|
||||||
|
idx = len(call_indices)
|
||||||
|
call_indices.append(idx)
|
||||||
|
if idx % 2 == 0:
|
||||||
|
raise typer.Exit(1)
|
||||||
|
out: Path | None = kwargs.get("output")
|
||||||
|
if out is not None:
|
||||||
|
out.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
out.write_bytes(b"ok")
|
||||||
|
|
||||||
|
monkeypatch.setattr(cli_module, "_run_generation", flaky_run_generation)
|
||||||
|
|
||||||
|
out = tmp_path / "img.png"
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
["generate", "test prompt", "-m", "x.safetensors", "-c", "4", "-P", "2", "--seed", "1", "-o", str(out)],
|
||||||
|
)
|
||||||
|
assert result.exit_code != 0
|
||||||
|
# Two tasks failed; final summary should show 2/4.
|
||||||
|
assert "Generated 2/4 images" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# --input integration
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_parallel_queue_from_yaml_input(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
|
||||||
|
"""parallel_queue can be set via --input YAML (mirrors other generate params)."""
|
||||||
|
out = tmp_path / "img.png"
|
||||||
|
yml = tmp_path / "spec.yml"
|
||||||
|
yml.write_text(f'prompt: from-yaml\nmodel: x.safetensors\ncount: 3\nparallel_queue: 3\nseed: 7\noutput: "{out}"\n')
|
||||||
|
result = runner.invoke(app, ["generate", "--input", str(yml)])
|
||||||
|
assert result.exit_code == 0, result.output
|
||||||
|
assert len(calls) == 3
|
||||||
|
assert sorted(c["seed"] for c in calls) == [7, 8, 9]
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_parallel_queue_overrides_yaml(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
|
||||||
|
"""CLI --parallel-queue wins over YAML's parallel_queue (standard precedence)."""
|
||||||
|
out = tmp_path / "img.png"
|
||||||
|
yml = tmp_path / "spec.yml"
|
||||||
|
yml.write_text(f'prompt: from-yaml\nmodel: x.safetensors\ncount: 2\nparallel_queue: 1\nseed: 10\noutput: "{out}"\n')
|
||||||
|
# YAML says P=1 (sequential), CLI overrides to P=2 (fanout)
|
||||||
|
result = runner.invoke(app, ["generate", "--input", str(yml), "-P", "2"])
|
||||||
|
assert result.exit_code == 0, result.output
|
||||||
|
# Fanout path → 2 separate calls, each count=1
|
||||||
|
assert len(calls) == 2
|
||||||
|
assert all(c["count"] == 1 for c in calls)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Concurrency assertion
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_parallel_actually_runs_concurrently(
|
||||||
|
tmp_path: Path,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""Sanity: P concurrent tasks really overlap in time (vs all-serial)."""
|
||||||
|
import threading
|
||||||
|
import time as _t
|
||||||
|
|
||||||
|
in_flight = 0
|
||||||
|
peak_in_flight = 0
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
def slow_run_generation(**kwargs: Any) -> None:
|
||||||
|
nonlocal in_flight, peak_in_flight
|
||||||
|
with lock:
|
||||||
|
in_flight += 1
|
||||||
|
peak_in_flight = max(peak_in_flight, in_flight)
|
||||||
|
_t.sleep(0.1) # 100ms — long enough to overlap, short enough for fast tests
|
||||||
|
with lock:
|
||||||
|
in_flight -= 1
|
||||||
|
out: Path | None = kwargs.get("output")
|
||||||
|
if out is not None:
|
||||||
|
out.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
out.write_bytes(b"ok")
|
||||||
|
|
||||||
|
monkeypatch.setattr(cli_module, "_run_generation", slow_run_generation)
|
||||||
|
|
||||||
|
out = tmp_path / "img.png"
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
["generate", "test prompt", "-m", "x.safetensors", "-c", "4", "-P", "4", "--seed", "1", "-o", str(out)],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0, result.output
|
||||||
|
# With P=4 and 4 tasks each sleeping 100ms, peak concurrency should hit 4.
|
||||||
|
# Even allowing for thread-pool warmup quirks, ≥2 means parallelism is real.
|
||||||
|
assert peak_in_flight >= 2, f"peak_in_flight={peak_in_flight} (expected ≥2 for parallel)"
|
||||||
@@ -1237,7 +1237,7 @@ class TestDownloadBackgroundTasks:
|
|||||||
captured["api_key"] = api_key
|
captured["api_key"] = api_key
|
||||||
return {"file_id": 42, "sha256": "deadbeef", "linked": True, "cached": True, "error": None}
|
return {"file_id": 42, "sha256": "deadbeef", "linked": True, "cached": True, "error": None}
|
||||||
|
|
||||||
monkeypatch.setattr(download_routes_module, "Database", lambda: StubDB())
|
monkeypatch.setattr(download_routes_module, "Database", StubDB)
|
||||||
return captured
|
return captured
|
||||||
|
|
||||||
def test_do_download_success(self, monkeypatch, tmp_path) -> None:
|
def test_do_download_success(self, monkeypatch, tmp_path) -> None:
|
||||||
@@ -1397,7 +1397,7 @@ class TestDownloadBackgroundTasks:
|
|||||||
def register_downloaded_file(self, *args, **kwargs):
|
def register_downloaded_file(self, *args, **kwargs):
|
||||||
return {"file_id": None, "sha256": None, "linked": False, "cached": False, "error": "boom"}
|
return {"file_id": None, "sha256": None, "linked": False, "cached": False, "error": "boom"}
|
||||||
|
|
||||||
monkeypatch.setattr(download_routes, "Database", lambda: FailingDB())
|
monkeypatch.setattr(download_routes, "Database", FailingDB)
|
||||||
|
|
||||||
dest_path = tmp_path / "model.safetensors"
|
dest_path = tmp_path / "model.safetensors"
|
||||||
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
|
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
|
||||||
|
|||||||
Reference in New Issue
Block a user