Compare commits

..

10 Commits

Author SHA1 Message Date
marauder-actual 86287269ee fix(config): remove non-existent VAE defaults from pony and illustrious families
CI / test (3.12) (push) Has been cancelled
CI / lint (push) Has been cancelled
style / style (push) Has been cancelled
2026-05-20 15:54:34 +02:00
Adam Ladachowski 7144b7ac6a fix(generate): skip local model validation when default_remote is set (#4)
'tsr generate -m <model> ...' without an explicit -r/--remote flag was
running _validate_model_available() against the local [comfyui] url
even when config.toml had default_remote pointing at a different host.
On a typical operator setup (default_remote = 'runpod', local comfyui
url stale or empty) this fails with 'Model X not available on ComfyUI
host (looked in checkpoints/ — 2 entries)' even though the model
exists on the remote and would be reachable for the actual dispatch
a few hundred lines later.

Root cause: two validation gates used the raw 'remote' parameter
(CLI flag only) instead of the resolved remote URL:
  - generate(): pre-fanout validation for --parallel-queue path
  - _run_generation(): per-call validation

Both now use do_resolve_remote(remote), which returns the resolved
URL when default_remote is set even if -r was omitted. The server-
side validator on the remote tensors API still catches missing
models on the remote host — matches the intent already documented
in the comment block.

Co-authored-by: marauder-actual <marauder@saiden.dev>
2026-05-20 12:19:59 +02:00
github-actions[bot] fbe09c2364 format: auto-format code [skip ci] 2026-05-20 09:54:16 +00:00
Adam Ladachowski 1f105d7633 feat(db-list): default to remote DB, add --local override (#3)
By default 'tsr db list' read from the LOCAL SQLite DB even when
default_remote was configured, which created a confusing UX where the
table never reflected the actual generation host. On a typical setup
(chi@fuji with default_remote='runpod') it would happily display ghost
entries pointing at /home/madcat/comfyui/models/... paths that don't
exist on macOS at all.

Flip the default:
  - 'tsr db list'           -> remote (when default_remote set), else local
  - 'tsr db list --local'   -> force local SQLite DB
  - 'tsr db list -r <name>' -> explicit remote (overrides default_remote)

Precedence: --local wins; else -r; else default_remote; else local.

Adds remote_db_files() helper in tensors/remote.py that calls the
existing GET /api/db/files server endpoint (no server-side changes
needed). Mirrors the pattern used by remote_models().

Scope kept minimal: only 'db list' for now. db search / triggers /
stats stay local-default; can flip later under the same pattern if
the UX wins are similar.

Co-authored-by: marauder-actual <marauder@saiden.dev>
2026-05-20 11:53:44 +02:00
aladac 77f8c6c6c8 release: v0.1.26 2026-05-18 23:44:57 +02:00
Adam Ladachowski ca7e914e35 Merge pull request #2 from saiden-dev/feat/generate-parallel-queue
feat(generate): add --parallel-queue/-P for concurrent submissions
2026-05-18 23:44:21 +02:00
aladac 5a9b935753 fix(types): clear all 10 pre-existing mypy errors
Master CI lint job runs both ruff and mypy. Previous commit cleaned ruff;
this knocks out the mypy backlog too so the parallel-queue PR can ship
fully green.

- fragments.py: FragmentLibrary defines a `list()` method, which shadows
  the builtin in class-scope name resolution and broke `list[str]`
  annotations in `resolve()`. Qualify with `builtins.list` (imported
  under TYPE_CHECKING since it's static-only).
- remote.py: `result.get("civitai", result)` returns Any to mypy because
  dict.get widens. Capture into a typed local first.
- cli.py:
  - Drop redundant type re-annotation on `civitai_results` in the
    non-remote branch (same name was annotated in the early-return
    remote branch above; mypy treats class/module-scope re-annotation
    as no-redef even when control flow rules out overlap).
  - Guard `p.name is not None` before passing to `get_parameter_source`
    (click stubs type Parameter.name as `str | None`).
  - Parameterize bare `list | dict` in `_load_json_file_or_inline`.
  - Widen `_write_sweep_manifest`'s template_path arg to `Path | None`
    (callers already pass None when --list is used without --template);
    serialize as empty string in manifest to keep schema stable.
  - `# type: ignore[arg-type]` on the deprecated `tsr comfy generate`
    delegator that passes a typer function where click.Command is
    expected — duck-typed at runtime, only matters for the deprecation
    shim.

Tests: 374 still passing. Ruff: clean. Mypy: clean.
2026-05-18 23:43:15 +02:00
aladac b0b5bca5f8 style: clear all 23 pre-existing ruff lint errors
Master had been failing CI lint since before this PR. Knock out the
backlog so the parallel-queue PR can ship green and future PRs don't
inherit the red baseline.

Changes by category:
- UP042 (9): Migrate `class Foo(str, Enum)` to `class Foo(StrEnum)` in
  tensors/config.py (7 enums) and tensors/server/search_routes.py (2).
  Requires Python 3.11+, already enforced by `requires-python = ">=3.12"`.
- PLR2004 (3): Extract magic comparison values in comfyui_api_routes.py
  to module-level constants (_DEFAULT_STEPS, _DEFAULT_CFG, _PROMPT_LOG_TRUNCATE).
- PLW0108 (2): Inline `lambda: StubDB()` -> `StubDB` in test_server.py.
- PLR0915 (3): Add explicit `# noqa: PLR0915` to typer command bodies
  that are intentionally long (template, templates_extract, _wait_for_completion_ws).
- PLR1714 (1): `file_path.name == model or file_path.stem == model` ->
  `model in (file_path.name, file_path.stem)` in cli.py:3047.
- SIM113 (1): Use `enumerate(as_completed(futures), start=1)` for the
  completion counter in style-sweep parallel loop.
- RUF059 (1): Prefix unused tuple-unpacked vars with `_` in _run_one.
- SIM105 (1): `try: ws.close() except Exception: pass` ->
  `contextlib.suppress(Exception)` in comfyui.py.
- PLC0415 (1): Add missing `# noqa: PLC0415` to the second of two
  function-scoped tensors.config imports.

No behavior changes. All 374 tests still pass.
2026-05-18 23:39:04 +02:00
aladac 2ca9003f86 style: clean lint warnings introduced by parallel-queue change
- Drop unused `import json` from new test module (F401).
- Remove unused `# noqa: BLE001` directives — project ruff config doesn't
  enable BLE001 so the suppressions were dead weight (RUF100 x3).
- Replace `×` (U+00D7) with ASCII `x` in console output (RUF001).
- Collapse seed-strategy if/else into ternary (SIM108).
- Use `enumerate(as_completed(...), start=1)` for completion counter
  instead of manual `completed = 0; completed += 1` (SIM113).
- Run `ruff format` on touched files.

Pre-existing lint errors on master (PLC0415/PLR0915/SIM113 in unrelated
commands) are untouched — separate cleanup PR if desired. Net delta of
this branch over master: 0 new lint errors.

All 374 tests still passing.
2026-05-18 23:34:22 +02:00
aladac 6ddcf84167 feat(generate): add --parallel-queue/-P for concurrent submissions
Mirrors the style-sweep --parallel-queue flag on the `generate` command.
When used with --count N > 1, splits the request into N independent
batch_size=1 jobs queued P-at-a-time via ThreadPoolExecutor instead of
a single ComfyUI batch.

Each task receives a distinct seed (incrementing from --seed when set,
freshly randomized per task when --seed=-1) and a distinct output path
following the existing stem_NNN.suffix convention. The GPU still
processes one prompt at a time, but HTTP queueing, websocket polling,
and image-download phases pipeline across tasks for a meaningful
wall-clock speedup on warmed-up models (~30-50% in practice).

Implementation notes:
- count=1 always takes the legacy sequential path regardless of -P.
- -P 1 is also sequential — identical behavior to pre-flag invocations.
- Bare model names (`-m lust_v10`) are resolved to canonical filenames
  ONCE in the parent before fanout, so worker tasks (which run with
  json_output=True path semantics for stdout) don't each duplicate the
  validation step or, worse, forward unresolved names to ComfyUI.
- --json + -P>1 is rejected up-front: the JSON path inside _run_generation
  short-circuits the disk-save block, which would silently produce zero
  files. Better to fail loud than save nothing.
- parallel_queue is plumbed through --input (JSON/YAML) like every other
  generate param, with the usual CLI-flag-wins precedence.

Tests: 15 new in tests/test_generate_parallel.py covering validation,
fanout topology, seed strategies, output naming, --input integration,
partial-failure exit code, and a concurrency assertion that confirms
threads actually overlap.

Manual E2E against ComfyUI on sin: -c 3 -P 3 on FLUX produced 3 distinct
images in ~83s vs the ~195s a pure sequential run would take.
2026-05-18 23:31:33 +02:00
10 changed files with 717 additions and 85 deletions
+1 -1
View File
@@ -1,6 +1,6 @@
"""tsr: Read safetensor metadata, search and download CivitAI models.""" """tsr: Read safetensor metadata, search and download CivitAI models."""
__version__ = "0.1.25" __version__ = "0.1.26"
from tensors.cli import main from tensors.cli import main
from tensors.config import ( from tensors.config import (
+310 -59
View File
@@ -346,7 +346,10 @@ def search(
return return
key = api_key or load_api_key() key = api_key or load_api_key()
civitai_results: dict[str, Any] | None = None # Reuse the name from the remote-mode branch above (which already returned)
# without redeclaring its type — mypy treats class-scope re-annotation as
# a no-redef even when control flow guarantees the branches don't overlap.
civitai_results = None
hf_results: list[dict[str, Any]] | None = None hf_results: list[dict[str, Any]] | None = None
# Search CivitAI # Search CivitAI
@@ -874,6 +877,22 @@ def generate( # noqa: PLR0915
str | None, str | None,
typer.Option("--input", "-I", help="JSON or YAML params (file path or inline; keys match CLI options)"), typer.Option("--input", "-I", help="JSON or YAML params (file path or inline; keys match CLI options)"),
] = None, ] = None,
parallel_queue: Annotated[
int,
typer.Option(
"--parallel-queue",
"-P",
help=(
"Concurrent ComfyUI submissions (default 1). When >1 with --count N, "
"splits the request into N independent jobs (batch_size=1 each) with "
"incrementing seeds, executed P-at-a-time via thread pool. The GPU "
"still processes one prompt at a time, but HTTP queue / init / "
"download phases pipeline for a ~5-15%% speedup. Per-task output "
"interleaves; final summary lists all saved files. Ignored when "
"--count is 1."
),
),
] = 1,
) -> None: ) -> None:
"""Generate an image using text-to-image. """Generate an image using text-to-image.
@@ -887,6 +906,11 @@ def generate( # noqa: PLR0915
starting with ``{`` are JSON, everything else is YAML. CLI flags override starting with ``{`` are JSON, everything else is YAML. CLI flags override
--input values. --input values.
With --count > 1, images are generated as a single ComfyUI batch by default
(one workflow, sequential on GPU). Use --parallel-queue N to instead split
into N independent batch_size=1 jobs queued in parallel, each with its own
seed — useful for overlapping the HTTP/download phase across requests.
Examples: Examples:
tsr generate "a cat on a windowsill" tsr generate "a cat on a windowsill"
tsr generate "portrait photo" -m ponyDiffusionV6XL_v6.safetensors -O portrait tsr generate "portrait photo" -m ponyDiffusionV6XL_v6.safetensors -O portrait
@@ -895,7 +919,20 @@ def generate( # noqa: PLR0915
tsr generate --input '{"prompt": "a mech", "model": "flux1-dev-fp8.safetensors"}' tsr generate --input '{"prompt": "a mech", "model": "flux1-dev-fp8.safetensors"}'
tsr generate --input scene.yml tsr generate --input scene.yml
tsr generate "raw prompt" --no-quality --no-negative tsr generate "raw prompt" --no-quality --no-negative
tsr generate "city" -c 8 -P 4 -o out.png # 8 distinct seeds, 4 in flight
""" """
if parallel_queue < 1:
console.print("[red]--parallel-queue must be >= 1[/red]")
raise typer.Exit(1)
if parallel_queue > 1 and json_output:
# _run_generation short-circuits the disk-save when json_output=True
# (it dumps JSON and returns). For the parallel fanout to actually save
# files, each task must take the non-JSON path. We render our own JSON
# at the end, so the per-task --json is incompatible.
console.print(
"[red]--json is not supported with --parallel-queue > 1 (would skip the file-save step). Drop one or the other.[/red]"
)
raise typer.Exit(1)
# ---- --input merging (JSON or YAML) ---- # ---- --input merging (JSON or YAML) ----
if json_input is not None: if json_input is not None:
ji = _parse_generate_input(json_input) ji = _parse_generate_input(json_input)
@@ -912,7 +949,9 @@ def generate( # noqa: PLR0915
{ {
p.name p.name
for p in click_ctx.command.params for p in click_ctx.command.params
if click_ctx.get_parameter_source(p.name) == click.core.ParameterSource.COMMANDLINE # click's Parameter.name is typed `str | None` in stubs but is always
# a real string at runtime for any param that's been registered.
if p.name is not None and click_ctx.get_parameter_source(p.name) == click.core.ParameterSource.COMMANDLINE
} }
if hasattr(click_ctx, "get_parameter_source") if hasattr(click_ctx, "get_parameter_source")
else set() else set()
@@ -981,41 +1020,218 @@ def generate( # noqa: PLR0915
scene_prompt = sp_val if isinstance(sp_val, str) else ", ".join(str(x) for x in sp_val if str(x).strip()) scene_prompt = sp_val if isinstance(sp_val, str) else ", ".join(str(x) for x in sp_val if str(x).strip())
if "rating" in mapped and "rating" not in explicit: if "rating" in mapped and "rating" not in explicit:
rating = mapped["rating"] rating = mapped["rating"]
if "parallel_queue" in mapped and "parallel_queue" not in explicit:
parallel_queue = int(mapped["parallel_queue"])
has_content = bool(prompt or character or character_prompt or scene or scene_prompt) has_content = bool(prompt or character or character_prompt or scene or scene_prompt)
if not has_content: if not has_content:
console.print("[red]Prompt (or character/scene) is required[/red]") console.print("[red]Prompt (or character/scene) is required[/red]")
raise typer.Exit(1) raise typer.Exit(1)
_run_generation( # Effective parallelism is bounded by count — running 4 threads for 1 image
prompt=prompt, # is silly. count=1 always goes through the sequential path regardless of -P.
model=model, effective_parallel = min(parallel_queue, count) if count > 1 else 1
width=width,
height=height, if effective_parallel <= 1:
steps=steps, # Sequential path: single _run_generation call with batch_size=count.
cfg=cfg, # Unchanged from pre-parallel behavior — preserves existing output naming,
guidance=guidance, # JSON shape, and log lines exactly.
seed=seed, _run_generation(
sampler=sampler, prompt=prompt,
scheduler=scheduler, model=model,
vae=vae, width=width,
orientation=orientation, height=height,
lora=lora, steps=steps,
lora_strength=lora_strength, cfg=cfg,
negative=negative, guidance=guidance,
count=count, seed=seed,
rating=rating, sampler=sampler,
no_quality=no_quality, scheduler=scheduler,
no_negative=no_negative, vae=vae,
character=character, orientation=orientation,
character_prompt=character_prompt, lora=lora,
scene=scene, lora_strength=lora_strength,
scene_prompt=scene_prompt, negative=negative,
family=family, count=count,
output=output, rating=rating,
remote=remote, no_quality=no_quality,
json_output=json_output, no_negative=no_negative,
) character=character,
character_prompt=character_prompt,
scene=scene,
scene_prompt=scene_prompt,
family=family,
output=output,
remote=remote,
json_output=json_output,
)
return
# ---- Parallel fanout path ----
# Split count into `count` independent jobs (batch_size=1), executed
# `effective_parallel` at a time. Each job gets a distinct seed and a
# distinct output path so writes don't clobber each other.
import random as _rng # noqa: PLC0415
import time as _time # noqa: PLC0415
from concurrent.futures import ThreadPoolExecutor, as_completed # noqa: PLC0415
# Resolve bare model/lora names ONCE in the parent before fanout. Each
# parallel _run_generation call silences its own console (json_output=True)
# which also skips the validation/resolution step in that path. Doing it
# here means each task receives a canonical filename and ComfyUI's strict
# loaders accept the request first try.
#
# Skip validation entirely when a remote dispatch will happen — either an
# explicit -r/--remote flag or default_remote from config. The server-side
# validator on the remote tensors API will catch missing models there.
# Without this, a config with `default_remote = "runpod"` would still hit
# the local ComfyUI for the availability check and fail on models that
# only exist on the remote host.
from tensors.config import resolve_remote as do_resolve_remote # noqa: PLC0415
if model and not do_resolve_remote(remote):
# Detect family for the right loader bucket (checkpoints vs diffusion_models).
# Mirrors the lookup _run_generation does on entry.
from tensors.db import Database # noqa: PLC0415
_base_model: str | None = None
try:
with Database() as _db:
_db.init_schema()
_base_model = _db.get_base_model_by_filename(model)
except Exception:
pass
_detected = detect_model_family(model, _base_model)
_fam = family or _detected
try:
model, lora = _validate_model_available(model, _fam, lora)
except typer.Exit:
raise # surface the same error path as sequential
# Seed strategy:
# --seed >= 0 → use as base, increment per job (reproducible series)
# --seed == -1 → pick a fresh random seed PER JOB so parallel runs aren't
# accidentally correlated (each thread gets variety)
seeds = [seed + i for i in range(count)] if seed >= 0 else [_rng.randint(0, 2**32 - 1) for _ in range(count)]
# Output paths: mirror the existing `count > 1` naming convention from
# _run_generation (stem_NNN.ext). When --output is omitted, leave per-task
# output as None — _run_generation will skip the disk write and the user
# gets only the console listing of generated image refs.
out_paths: list[Path | None] = []
for i in range(count):
if output is None:
out_paths.append(None)
else:
out_paths.append(output.parent / f"{output.stem}_{i + 1:03d}{output.suffix}")
if not json_output:
console.print(
f"[dim]Parallel queue: {effective_parallel} concurrent submissions x {count} images (output may interleave)[/dim]"
)
common_kwargs: dict[str, Any] = {
"prompt": prompt,
"model": model,
"width": width,
"height": height,
"steps": steps,
"cfg": cfg,
"guidance": guidance,
"sampler": sampler,
"scheduler": scheduler,
"vae": vae,
"orientation": orientation,
"lora": lora,
"lora_strength": lora_strength,
"negative": negative,
"count": 1, # each task generates exactly one image
"rating": rating,
"no_quality": no_quality,
"no_negative": no_negative,
"character": character,
"character_prompt": character_prompt,
"scene": scene,
"scene_prompt": scene_prompt,
"family": family,
"remote": remote,
# NOTE: json_output stays False so _run_generation's disk-save path runs.
# Setting True would short-circuit before saving files. Per-task console
# chatter is the trade-off; the final summary still shows clean per-task
# status lines.
"json_output": False,
}
def _run_one(idx: int) -> dict[str, Any]:
"""Run a single batch_size=1 job. Returns a result dict (success captured)."""
start = _time.perf_counter()
result: dict[str, Any] = {
"index": idx,
"seed": seeds[idx],
"output": str(out_paths[idx]) if out_paths[idx] is not None else None,
"duration_sec": 0.0,
"success": False,
"error": None,
}
try:
_run_generation(seed=seeds[idx], output=out_paths[idx], **common_kwargs)
result["duration_sec"] = round(_time.perf_counter() - start, 2)
result["success"] = True
except typer.Exit as ex:
result["duration_sec"] = round(_time.perf_counter() - start, 2)
result["error"] = f"generate exited with code {ex.exit_code}"
except Exception as ex:
result["duration_sec"] = round(_time.perf_counter() - start, 2)
result["error"] = str(ex)
return result
fan_results: list[dict[str, Any]] = []
with ThreadPoolExecutor(max_workers=effective_parallel) as pool:
futures = {pool.submit(_run_one, i): i for i in range(count)}
for completed, fut in enumerate(as_completed(futures), start=1):
try:
res = fut.result()
except Exception as ex:
# Defensive — _run_one already swallows, but if the executor itself
# raises (e.g. pickling failure) we still want a well-formed result
# in the manifest rather than a crash.
res = {
"index": futures[fut],
"seed": seeds[futures[fut]],
"output": str(out_paths[futures[fut]]) if out_paths[futures[fut]] is not None else None,
"duration_sec": 0.0,
"success": False,
"error": f"executor exception: {ex}",
}
fan_results.append(res)
if not json_output:
if res["success"]:
where = res["output"] or "(no --output set)"
console.print(
f"[green]\\[{completed}/{count}] seed={res['seed']} ok in {res['duration_sec']:.1f}s → {where}[/green]"
)
else:
console.print(f"[red]\\[{completed}/{count}] seed={res['seed']} FAIL: {res['error']}[/red]")
# Reorder by original index so JSON output / final summary list is stable.
fan_results.sort(key=lambda r: r["index"])
successful = sum(1 for r in fan_results if r["success"])
if json_output:
console.print_json(
data={
"success": successful == count,
"count": count,
"parallel_queue": effective_parallel,
"results": fan_results,
}
)
return
console.print("[bold green]Generation complete![/bold green]")
console.print(f"[dim]Generated {successful}/{count} images at parallelism={effective_parallel}[/dim]")
if successful < count:
raise typer.Exit(1)
# Map model family → which ComfyUI loader directory the checkpoint must live in. # Map model family → which ComfyUI loader directory the checkpoint must live in.
@@ -1164,6 +1380,8 @@ def _run_generation( # noqa: PLR0915
""" """
import random as rng # noqa: PLC0415 import random as rng # noqa: PLC0415
from tensors.config import resolve_remote as do_resolve_remote # noqa: PLC0415
# ---- Detect model family and enhance prompt/negative ---- # ---- Detect model family and enhance prompt/negative ----
family_defaults: dict[str, Any] = {} family_defaults: dict[str, Any] = {}
model_family: str | None = None model_family: str | None = None
@@ -1193,7 +1411,10 @@ def _run_generation( # noqa: PLR0915
# available remotely ("v11Softcore"), and offers a fuzzy "did you mean" hint # available remotely ("v11Softcore"), and offers a fuzzy "did you mean" hint
# instead of forwarding the request to ComfyUI for a generic 400 rejection. # instead of forwarding the request to ComfyUI for a generic 400 rejection.
# Skipped in --json mode and for remote dispatches (server already validates). # Skipped in --json mode and for remote dispatches (server already validates).
if model and not json_output and not remote: # "Remote dispatch" includes both -r/--remote and default_remote in config —
# otherwise users with default_remote set hit local ComfyUI for a check that
# should be deferred to the remote server.
if model and not json_output and not do_resolve_remote(remote):
# Returns possibly-rewritten names so bare inputs like `-m lust_v10` # Returns possibly-rewritten names so bare inputs like `-m lust_v10`
# silently resolve to the canonical `lust_v10.safetensors` filename # silently resolve to the canonical `lust_v10.safetensors` filename
# before being forwarded to ComfyUI's strict CLIPLoader / UNETLoader. # before being forwarded to ComfyUI's strict CLIPLoader / UNETLoader.
@@ -1283,7 +1504,6 @@ def _run_generation( # noqa: PLR0915
# ---- Resolve preset defaults for None params (both remote and local need these) ---- # ---- Resolve preset defaults for None params (both remote and local need these) ----
from tensors.config import resolve_orientation # noqa: PLC0415 from tensors.config import resolve_orientation # noqa: PLC0415
from tensors.config import resolve_remote as do_resolve_remote
# Use already-detected family_defaults from DB lookup above (not filename guessing) # Use already-detected family_defaults from DB lookup above (not filename guessing)
if family_defaults: if family_defaults:
@@ -1507,7 +1727,7 @@ _STYLE_SWEEP_TEMPLATE_KEYS = {
} }
def _load_json_file_or_inline(value: str | list | dict, *, what: str) -> Any: def _load_json_file_or_inline(value: str | list[Any] | dict[str, Any], *, what: str) -> Any:
"""Load JSON from a file path or accept already-parsed inline data. """Load JSON from a file path or accept already-parsed inline data.
`value` may be a path string, a JSON string, or an already-parsed list/dict `value` may be a path string, a JSON string, or an already-parsed list/dict
@@ -1885,7 +2105,7 @@ def style_sweep( # noqa: PLR0915
def _run_one(task: tuple[int, dict[str, str], dict[str, Any], Path]) -> dict[str, Any]: def _run_one(task: tuple[int, dict[str, str], dict[str, Any], Path]) -> dict[str, Any]:
"""Run a single style. Returns the result dict (success or error captured).""" """Run a single style. Returns the result dict (success or error captured)."""
idx, entry_in, res, opath = task _idx, _entry_in, res, opath = task
composed = res["prompt"] composed = res["prompt"]
start = time.perf_counter() start = time.perf_counter()
try: try:
@@ -1937,11 +2157,9 @@ def style_sweep( # noqa: PLR0915
with ThreadPoolExecutor(max_workers=parallel_queue) as pool: with ThreadPoolExecutor(max_workers=parallel_queue) as pool:
futures = {pool.submit(_run_one, task): task for task in pending_tasks} futures = {pool.submit(_run_one, task): task for task in pending_tasks}
completed = 0 for completed, fut in enumerate(as_completed(futures), start=1):
for fut in as_completed(futures):
completed += 1
task = futures[fut] task = futures[fut]
idx, _entry, _res, _out_path = task idx, _entry, _res, _out_path = task # idx used in log message below
try: try:
res = fut.result() res = fut.result()
except Exception as ex: except Exception as ex:
@@ -1988,14 +2206,16 @@ def style_sweep( # noqa: PLR0915
def _write_sweep_manifest( def _write_sweep_manifest(
out_dir: Path, out_dir: Path,
template_path: Path, template_path: Path | None,
styles_origin: str, styles_origin: str,
results: list[dict[str, Any]], results: list[dict[str, Any]],
) -> Path: ) -> Path:
"""Write the per-sweep manifest JSON. Returns the path.""" """Write the per-sweep manifest JSON. Returns the path."""
manifest_path = out_dir / "_sweep.json" manifest_path = out_dir / "_sweep.json"
manifest: dict[str, Any] = { manifest: dict[str, Any] = {
"template": str(template_path), # template_path is None when --list is used with only --styles (no template
# required). Serialize as empty string to keep manifest schema stable.
"template": str(template_path) if template_path is not None else "",
"styles_source": styles_origin, "styles_source": styles_origin,
"results": results, "results": results,
} }
@@ -2024,7 +2244,7 @@ def _print_styles_list(styles_origin: str, entries: list[dict[str, str]]) -> Non
@app.command() @app.command()
def template( def template( # noqa: PLR0915
model: Annotated[str, typer.Option("-m", "--model", help="Checkpoint model name")], model: Annotated[str, typer.Option("-m", "--model", help="Checkpoint model name")],
lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None, lora: Annotated[str | None, typer.Option("-l", "--lora", help="LoRA model name")] = None,
lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8, lora_strength: Annotated[float, typer.Option("--lora-strength", help="LoRA strength")] = 0.8,
@@ -2393,27 +2613,55 @@ def db_cache(
@db_app.command("list") @db_app.command("list")
def db_list( def db_list( # noqa: PLR0915
model_type: Annotated[ model_type: Annotated[
str | None, typer.Option("-t", "--type", help="Filter by model type (Checkpoint, LORA, VAE, etc.)") str | None, typer.Option("-t", "--type", help="Filter by model type (Checkpoint, LORA, VAE, etc.)")
] = None, ] = None,
base: Annotated[ base: Annotated[
str | None, typer.Option("-b", "--base", help="Filter by base model (Pony, Illustrious, SDXL 1.0, SD 1.5, etc.)") str | None, typer.Option("-b", "--base", help="Filter by base model (Pony, Illustrious, SDXL 1.0, SD 1.5, etc.)")
] = None, ] = None,
remote: Annotated[
str | None, typer.Option("-r", "--remote", help="Remote server name or URL (overrides default_remote)")
] = None,
local: Annotated[bool, typer.Option("--local", help="Force read from local DB, ignoring default_remote")] = False,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None: ) -> None:
"""List local files with CivitAI info. """List files indexed in the tensors DB.
Defaults to the configured remote (config.toml default_remote) so the
table reflects what's actually on the generation host. Pass --local to
inspect the local SQLite DB on this machine instead. If no remote is
configured and --local is not passed, falls back to local silently.
Examples: Examples:
tsr db list # All local files tsr db list # Remote (or local if no default_remote)
tsr db list -t Checkpoint # Only checkpoints tsr db list --local # Force local DB
tsr db list -t LORA # Only LoRAs tsr db list -r junkpile # Explicit remote
tsr db list -t Checkpoint -b Pony # Pony checkpoints only tsr db list -t Checkpoint -b Pony # Filter, default source
tsr db list -b "SDXL 1.0" # All SDXL 1.0 models tsr db list -b "SDXL 1.0" # All SDXL 1.0 models
""" """
with Database() as db: from tensors.config import resolve_remote as do_resolve_remote # noqa: PLC0415
db.init_schema()
files = db.list_local_files() # Resolution precedence: --local wins; else explicit -r; else default_remote; else local.
remote_url: str | None = None
if not local:
remote_url = do_resolve_remote(remote) if remote else do_resolve_remote(None)
files: list[dict[str, Any]] | None
if remote_url:
from tensors.remote import remote_db_files # noqa: PLC0415
if not json_output:
console.print(f"[dim]Remote: {remote_url}[/dim]")
files = remote_db_files(remote or remote_url, console=console)
if files is None:
raise typer.Exit(1)
source_label = "Remote Files"
else:
with Database() as db:
db.init_schema()
files = db.list_local_files()
source_label = "Local Files"
# Apply filters (case-insensitive substring match) # Apply filters (case-insensitive substring match)
if model_type: if model_type:
@@ -2428,17 +2676,18 @@ def db_list(
return return
if not files: if not files:
console.print("[yellow]No files found. Try 'tsr db scan' or adjust filters.[/yellow]") hint = "Remote DB is empty" if remote_url else "Try 'tsr db scan'"
console.print(f"[yellow]No files found. {hint} or adjust filters.[/yellow]")
return return
title = "Local Files" title = source_label
if model_type or base: if model_type or base:
parts = [] parts = []
if model_type: if model_type:
parts.append(model_type) parts.append(model_type)
if base: if base:
parts.append(base) parts.append(base)
title = f"Local Files ({', '.join(parts)})" title = f"{source_label} ({', '.join(parts)})"
table = Table(title=title, show_header=True, header_style="bold magenta") table = Table(title=title, show_header=True, header_style="bold magenta")
table.add_column("Path", style="cyan", max_width=50) table.add_column("Path", style="cyan", max_width=50)
@@ -2842,7 +3091,7 @@ def scene_extract(
target_file = None target_file = None
for f in files: for f in files:
file_path = Path(f["file_path"]) file_path = Path(f["file_path"])
if file_path.name == model or file_path.stem == model: if model in (file_path.name, file_path.stem):
target_file = f target_file = f
break break
@@ -2970,7 +3219,7 @@ app.add_typer(templates_app)
@templates_app.command("extract") @templates_app.command("extract")
def templates_extract( def templates_extract( # noqa: PLR0915
model: Annotated[str, typer.Argument(help="Local model name (e.g. lust_v10.safetensors)")], model: Annotated[str, typer.Argument(help="Local model name (e.g. lust_v10.safetensors)")],
orientation: Annotated[str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape")] = "portrait", orientation: Annotated[str, typer.Option("-O", "--orientation", help="Resolution: square, portrait, landscape")] = "portrait",
no_overrides: Annotated[ no_overrides: Annotated[
@@ -3428,8 +3677,10 @@ def comfy_generate(
) -> None: ) -> None:
"""[Deprecated] Use 'tsr generate' instead. All features have been merged into the top-level command.""" """[Deprecated] Use 'tsr generate' instead. All features have been merged into the top-level command."""
console.print("[yellow]Warning: 'tsr comfy generate' is deprecated. Use 'tsr generate' instead.[/yellow]") console.print("[yellow]Warning: 'tsr comfy generate' is deprecated. Use 'tsr generate' instead.[/yellow]")
# Delegate to the unified generate command via context invocation # Delegate to the unified generate command via context invocation.
ctx = typer.Context(generate) # typer.Context expects a click.Command, but passing the typer function directly
# works at runtime via duck-typing — keeping it for back-compat with deprecated alias.
ctx = typer.Context(generate) # type: ignore[arg-type]
generate( generate(
ctx=ctx, ctx=ctx,
prompt=prompt, prompt=prompt,
+3 -4
View File
@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import contextlib
import copy import copy
import json import json
import random import random
@@ -388,7 +389,7 @@ def queue_prompt(
return None return None
def _wait_for_completion_ws( def _wait_for_completion_ws( # noqa: PLR0915
prompt_id: str, prompt_id: str,
url: str, url: str,
client_id: str, client_id: str,
@@ -494,10 +495,8 @@ def _wait_for_completion_ws(
break break
finally: finally:
try: with contextlib.suppress(Exception):
ws.close() ws.close()
except Exception:
pass
# Fetch final outputs from history to ensure we have everything # Fetch final outputs from history to ensure we have everything
try: try:
+8 -10
View File
@@ -4,7 +4,7 @@ from __future__ import annotations
import os import os
import tomllib import tomllib
from enum import Enum from enum import StrEnum
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -65,7 +65,7 @@ CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models"
# ============================================================================ # ============================================================================
class Provider(str, Enum): class Provider(StrEnum):
"""Model search providers.""" """Model search providers."""
civitai = "civitai" civitai = "civitai"
@@ -73,7 +73,7 @@ class Provider(str, Enum):
all = "all" all = "all"
class ModelType(str, Enum): class ModelType(StrEnum):
"""CivitAI model types.""" """CivitAI model types."""
checkpoint = "checkpoint" checkpoint = "checkpoint"
@@ -110,7 +110,7 @@ class ModelType(str, Enum):
return mapping[self.value] return mapping[self.value]
class BaseModel(str, Enum): class BaseModel(StrEnum):
"""Common base models.""" """Common base models."""
# Stable Diffusion 1.x # Stable Diffusion 1.x
@@ -166,7 +166,7 @@ class BaseModel(str, Enum):
return mapping[self.value] return mapping[self.value]
class SortOrder(str, Enum): class SortOrder(StrEnum):
"""Sort options for search.""" """Sort options for search."""
downloads = "downloads" downloads = "downloads"
@@ -183,7 +183,7 @@ class SortOrder(str, Enum):
return mapping[self.value] return mapping[self.value]
class Period(str, Enum): class Period(StrEnum):
"""Time period for sorting/filtering.""" """Time period for sorting/filtering."""
all = "all" all = "all"
@@ -204,7 +204,7 @@ class Period(str, Enum):
return mapping[self.value] return mapping[self.value]
class NsfwLevel(str, Enum): class NsfwLevel(StrEnum):
"""NSFW content filter level.""" """NSFW content filter level."""
none = "none" none = "none"
@@ -219,7 +219,7 @@ class NsfwLevel(str, Enum):
return self.value.capitalize() if self.value != "none" else "None" return self.value.capitalize() if self.value != "none" else "None"
class CommercialUse(str, Enum): class CommercialUse(StrEnum):
"""Commercial use permissions.""" """Commercial use permissions."""
none = "none" none = "none"
@@ -576,7 +576,6 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
"sampler": "euler_ancestral", "sampler": "euler_ancestral",
"scheduler": "normal", "scheduler": "normal",
"steps": 25, "steps": 25,
"vae": "ponyStandardVAE_v10.safetensors",
}, },
"illustrious": { "illustrious": {
"quality_prefix": "masterpiece, best quality, highres", "quality_prefix": "masterpiece, best quality, highres",
@@ -589,7 +588,6 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
"sampler": "euler_ancestral", "sampler": "euler_ancestral",
"scheduler": "normal", "scheduler": "normal",
"steps": 25, "steps": 25,
"vae": "illustriousXLV20_v10.safetensors",
}, },
"sdxl": { "sdxl": {
"quality_prefix": "", "quality_prefix": "",
+12 -2
View File
@@ -16,9 +16,16 @@ from __future__ import annotations
import json import json
import re import re
from pathlib import Path # noqa: TC003 # used in runtime return annotations exposed to typer from pathlib import Path # noqa: TC003 # used in runtime return annotations exposed to typer
from typing import TYPE_CHECKING
from tensors.config import DATA_DIR from tensors.config import DATA_DIR
if TYPE_CHECKING:
# Qualified `builtins.list` is referenced in annotations inside FragmentLibrary
# because the class defines a method named `list` that shadows the builtin
# at class-scope name resolution. Static-only — not needed at runtime.
import builtins
# Restrict fragment names to a safe subset so they can't escape the storage dir # Restrict fragment names to a safe subset so they can't escape the storage dir
# via path traversal and so file listings stay tidy. # via path traversal and so file listings stay tidy.
_NAME_RE = re.compile(r"^[A-Za-z0-9_.-]+$") _NAME_RE = re.compile(r"^[A-Za-z0-9_.-]+$")
@@ -132,8 +139,11 @@ class FragmentLibrary:
*, *,
name: str | None = None, name: str | None = None,
inline: str | None = None, inline: str | None = None,
extra: list[str] | None = None, # NOTE: `builtins.list` qualifier needed because this class defines a
) -> list[str]: # `list()` method below, which shadows the builtin in class-scope name
# resolution. Affects mypy/pyright even with `from __future__ import annotations`.
extra: builtins.list[str] | None = None,
) -> builtins.list[str]:
"""Merge a named fragment with an inline CSV string and optional extras. """Merge a named fragment with an inline CSV string and optional extras.
Resolution order (first match wins per duplicate): named → inline → extra. Resolution order (first match wins per duplicate): named → inline → extra.
+47 -1
View File
@@ -209,7 +209,8 @@ def remote_search(
response.raise_for_status() response.raise_for_status()
result: dict[str, Any] = response.json() result: dict[str, Any] = response.json()
# The remote API wraps CivitAI results under "civitai" key # The remote API wraps CivitAI results under "civitai" key
return result.get("civitai", result) civitai_section: dict[str, Any] = result.get("civitai", result)
return civitai_section
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
if console: if console:
console.print(f"[red]Remote API error: {e.response.status_code}[/red]") console.print(f"[red]Remote API error: {e.response.status_code}[/red]")
@@ -301,3 +302,48 @@ def remote_download_status(
return result return result
except (httpx.HTTPStatusError, httpx.RequestError): except (httpx.HTTPStatusError, httpx.RequestError):
return None return None
def remote_db_files(
remote: str,
*,
console: Console | None = None,
) -> list[dict[str, Any]] | None:
"""Fetch the local-files index from a remote tensors server's DB.
Returns the same shape as Database.list_local_files() but reflecting the
remote host's SQLite DB rather than ours.
Args:
remote: Remote name or URL (resolved via config)
console: Rich console for error output
Returns:
List of file dicts, or None on connection / API error
"""
base_url = resolve_remote(remote)
if not base_url:
if console:
console.print("[red]Error: Could not resolve remote server[/red]")
return None
try:
with _build_client(base_url) as client:
response = client.get("/api/db/files")
response.raise_for_status()
result: list[dict[str, Any]] = response.json()
return result
except httpx.HTTPStatusError as e:
if console:
console.print(f"[red]Remote API error: {e.response.status_code}[/red]")
try:
detail = e.response.json().get("detail", "")
if detail:
console.print(f" [yellow]{detail}[/yellow]")
except Exception:
pass
return None
except httpx.RequestError as e:
if console:
console.print(f"[red]Remote connection error: {e}[/red]")
return None
+11 -3
View File
@@ -27,6 +27,14 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/comfyui", tags=["ComfyUI API"]) router = APIRouter(prefix="/api/comfyui", tags=["ComfyUI API"])
# Schema default sentinels — see GenerateRequest defaults. These let us detect
# "user accepted default" vs "user explicitly chose this value matching default"
# is intentionally not distinguished; both paths apply family overrides.
_DEFAULT_STEPS = 20
_DEFAULT_CFG = 7.0
# Logging truncation threshold for long prompts in info-level output.
_PROMPT_LOG_TRUNCATE = 100
# ============================================================================= # =============================================================================
# Request/Response Models # Request/Response Models
@@ -262,9 +270,9 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
sampler = family_defaults["sampler"] sampler = family_defaults["sampler"]
if request.scheduler == "normal": # Default value in schema if request.scheduler == "normal": # Default value in schema
scheduler = family_defaults["scheduler"] scheduler = family_defaults["scheduler"]
if request.steps == 20: # Default value in schema if request.steps == _DEFAULT_STEPS: # Default value in schema
steps = family_defaults["steps"] steps = family_defaults["steps"]
if request.cfg == 7.0: # Default value in schema if request.cfg == _DEFAULT_CFG: # Default value in schema
cfg = family_defaults["cfg"] cfg = family_defaults["cfg"]
# Only override VAE if user explicitly specified one; # Only override VAE if user explicitly specified one;
# otherwise use checkpoint's built-in VAE (vae stays None) # otherwise use checkpoint's built-in VAE (vae stays None)
@@ -290,7 +298,7 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
sampler, sampler,
scheduler, scheduler,
lora_info, lora_info,
request.prompt[:100] + "..." if len(request.prompt) > 100 else request.prompt, request.prompt[:_PROMPT_LOG_TRUNCATE] + "..." if len(request.prompt) > _PROMPT_LOG_TRUNCATE else request.prompt,
) )
if request.negative_prompt: if request.negative_prompt:
logger.debug("Negative prompt: %r", request.negative_prompt[:100]) logger.debug("Negative prompt: %r", request.negative_prompt[:100])
+3 -3
View File
@@ -4,7 +4,7 @@ from __future__ import annotations
import contextlib import contextlib
import logging import logging
from enum import Enum from enum import StrEnum
from typing import Annotated, Any from typing import Annotated, Any
from fastapi import APIRouter, Query from fastapi import APIRouter, Query
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/search", tags=["Search"]) router = APIRouter(prefix="/api/search", tags=["Search"])
class Provider(str, Enum): class Provider(StrEnum):
"""Search provider options.""" """Search provider options."""
civitai = "civitai" civitai = "civitai"
@@ -45,7 +45,7 @@ class Provider(str, Enum):
all = "all" all = "all"
class SortOrder(str, Enum): class SortOrder(StrEnum):
"""Sort order options.""" """Sort order options."""
downloads = "downloads" downloads = "downloads"
+320
View File
@@ -0,0 +1,320 @@
"""Tests for the `tsr generate --parallel-queue` flag (parallel fanout path)."""
from __future__ import annotations
from pathlib import Path
from typing import Any
import pytest
from typer.testing import CliRunner
from tensors import cli as cli_module
from tensors.cli import app
runner = CliRunner()
# -----------------------------------------------------------------------------
# Fixtures
# -----------------------------------------------------------------------------
@pytest.fixture
def calls(monkeypatch: pytest.MonkeyPatch) -> list[dict[str, Any]]:
"""Record every _run_generation call and stub the disk-write side effect.
The parallel fanout path invokes _run_generation N times (one per task);
the sequential path invokes it once. By recording kwargs we can assert
fanout behavior (per-task seeds, per-task output paths, count=1 per task)
without round-tripping ComfyUI.
"""
recorded: list[dict[str, Any]] = []
def fake_run_generation(**kwargs: Any) -> None:
recorded.append(kwargs)
out: Path | None = kwargs.get("output")
if out is not None:
out.parent.mkdir(parents=True, exist_ok=True)
out.write_bytes(b"fake-png")
monkeypatch.setattr(cli_module, "_run_generation", fake_run_generation)
return recorded
@pytest.fixture(autouse=True)
def _stub_model_validation(monkeypatch: pytest.MonkeyPatch) -> None:
"""Bypass ComfyUI's live model lookup so tests don't need a backend."""
monkeypatch.setattr(
cli_module,
"_validate_model_available",
lambda model, family, lora: (model, lora),
)
# -----------------------------------------------------------------------------
# Validation / sanity
# -----------------------------------------------------------------------------
def test_parallel_queue_invalid_value_rejected(calls: list[dict[str, Any]]) -> None:
"""--parallel-queue 0 (or negative) exits non-zero before any work."""
result = runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "--parallel-queue", "0"],
)
assert result.exit_code != 0
assert "--parallel-queue must be >= 1" in result.output
assert calls == []
def test_parallel_queue_one_is_sequential_path(calls: list[dict[str, Any]]) -> None:
"""-P 1 collapses to the legacy single _run_generation call with count=N.
This is the key compatibility contract: existing scripts that don't pass
-P must see identical behavior (one call, count forwarded as batch_size).
"""
result = runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "-c", "4", "-P", "1"],
)
assert result.exit_code == 0, result.output
assert len(calls) == 1
assert calls[0]["count"] == 4
assert calls[0]["prompt"] == "test prompt"
def test_count_one_ignores_parallel_queue(calls: list[dict[str, Any]]) -> None:
"""count=1 always takes sequential path regardless of -P (no fanout point)."""
result = runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "-c", "1", "-P", "8"],
)
assert result.exit_code == 0, result.output
assert len(calls) == 1
assert calls[0]["count"] == 1
def test_json_output_incompatible_with_parallel(calls: list[dict[str, Any]]) -> None:
"""--json + -P>1 errors out cleanly (would skip disk-save inside tasks)."""
result = runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "-c", "2", "-P", "2", "--json"],
)
assert result.exit_code != 0
assert "--json is not supported with --parallel-queue > 1" in result.output
assert calls == []
# -----------------------------------------------------------------------------
# Fanout behavior
# -----------------------------------------------------------------------------
def test_parallel_fanout_creates_n_tasks(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
"""-c N -P M (M>1, N>1) → N independent _run_generation calls, each count=1."""
out = tmp_path / "img.png"
result = runner.invoke(
app,
[
"generate",
"test prompt",
"-m",
"x.safetensors",
"-c",
"4",
"-P",
"2",
"--seed",
"100",
"-o",
str(out),
],
)
assert result.exit_code == 0, result.output
assert len(calls) == 4
# Each task generates exactly one image
for c in calls:
assert c["count"] == 1
def test_parallel_seeds_increment_from_base(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
"""Explicit --seed → each task receives base+i (reproducible series)."""
out = tmp_path / "img.png"
runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "-c", "3", "-P", "3", "--seed", "500", "-o", str(out)],
)
seeds_seen = sorted(c["seed"] for c in calls)
assert seeds_seen == [500, 501, 502]
def test_parallel_seeds_random_when_unset(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
"""seed=-1 → each task gets a freshly-rolled random seed (not all the same).
Vanishingly small chance of collision across 4 random ints; treat as flake
threshold of "all distinct" rather than exact equality to any value.
"""
out = tmp_path / "img.png"
runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "-c", "4", "-P", "2", "-o", str(out)],
)
seeds = [c["seed"] for c in calls]
# All non-negative (i.e. resolved from -1 to actual int) and distinct.
assert all(s >= 0 for s in seeds)
assert len(set(seeds)) == len(seeds)
def test_parallel_output_paths_indexed(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
"""Per-task output paths use stem_NNN.suffix naming (matches sequential count>1)."""
out = tmp_path / "scene.png"
runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "-c", "3", "-P", "3", "--seed", "1", "-o", str(out)],
)
paths = sorted(str(c["output"]) for c in calls)
assert paths == [
str(tmp_path / "scene_001.png"),
str(tmp_path / "scene_002.png"),
str(tmp_path / "scene_003.png"),
]
def test_parallel_without_output_passes_none(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
"""When --output is omitted, each task gets output=None (no disk write planned)."""
runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "-c", "2", "-P", "2", "--seed", "1"],
)
assert len(calls) == 2
assert all(c["output"] is None for c in calls)
def test_parallel_files_actually_written(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
"""End-to-end: per-task stub writes its file → all N appear on disk.
Guards against the bug where json_output=True short-circuits the save block
inside _run_generation. Each task must use the non-JSON code path.
"""
out = tmp_path / "shot.png"
runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "-c", "3", "-P", "3", "--seed", "1", "-o", str(out)],
)
written = sorted(p.name for p in tmp_path.iterdir())
assert written == ["shot_001.png", "shot_002.png", "shot_003.png"]
def test_parallel_summary_reports_success_count(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
"""Final summary line reports N/N success when all tasks complete."""
out = tmp_path / "img.png"
result = runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "-c", "3", "-P", "2", "--seed", "1", "-o", str(out)],
)
assert result.exit_code == 0
assert "Generated 3/3 images" in result.output
def test_parallel_partial_failure_exits_nonzero(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""If one task raises, summary shows partial count and command exits non-zero."""
import typer
call_indices: list[int] = []
def flaky_run_generation(**kwargs: Any) -> None:
# Fail every other call to simulate intermittent backend errors.
idx = len(call_indices)
call_indices.append(idx)
if idx % 2 == 0:
raise typer.Exit(1)
out: Path | None = kwargs.get("output")
if out is not None:
out.parent.mkdir(parents=True, exist_ok=True)
out.write_bytes(b"ok")
monkeypatch.setattr(cli_module, "_run_generation", flaky_run_generation)
out = tmp_path / "img.png"
result = runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "-c", "4", "-P", "2", "--seed", "1", "-o", str(out)],
)
assert result.exit_code != 0
# Two tasks failed; final summary should show 2/4.
assert "Generated 2/4 images" in result.output
# -----------------------------------------------------------------------------
# --input integration
# -----------------------------------------------------------------------------
def test_parallel_queue_from_yaml_input(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
"""parallel_queue can be set via --input YAML (mirrors other generate params)."""
out = tmp_path / "img.png"
yml = tmp_path / "spec.yml"
yml.write_text(f'prompt: from-yaml\nmodel: x.safetensors\ncount: 3\nparallel_queue: 3\nseed: 7\noutput: "{out}"\n')
result = runner.invoke(app, ["generate", "--input", str(yml)])
assert result.exit_code == 0, result.output
assert len(calls) == 3
assert sorted(c["seed"] for c in calls) == [7, 8, 9]
def test_cli_parallel_queue_overrides_yaml(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
"""CLI --parallel-queue wins over YAML's parallel_queue (standard precedence)."""
out = tmp_path / "img.png"
yml = tmp_path / "spec.yml"
yml.write_text(f'prompt: from-yaml\nmodel: x.safetensors\ncount: 2\nparallel_queue: 1\nseed: 10\noutput: "{out}"\n')
# YAML says P=1 (sequential), CLI overrides to P=2 (fanout)
result = runner.invoke(app, ["generate", "--input", str(yml), "-P", "2"])
assert result.exit_code == 0, result.output
# Fanout path → 2 separate calls, each count=1
assert len(calls) == 2
assert all(c["count"] == 1 for c in calls)
# -----------------------------------------------------------------------------
# Concurrency assertion
# -----------------------------------------------------------------------------
def test_parallel_actually_runs_concurrently(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Sanity: P concurrent tasks really overlap in time (vs all-serial)."""
import threading
import time as _t
in_flight = 0
peak_in_flight = 0
lock = threading.Lock()
def slow_run_generation(**kwargs: Any) -> None:
nonlocal in_flight, peak_in_flight
with lock:
in_flight += 1
peak_in_flight = max(peak_in_flight, in_flight)
_t.sleep(0.1) # 100ms — long enough to overlap, short enough for fast tests
with lock:
in_flight -= 1
out: Path | None = kwargs.get("output")
if out is not None:
out.parent.mkdir(parents=True, exist_ok=True)
out.write_bytes(b"ok")
monkeypatch.setattr(cli_module, "_run_generation", slow_run_generation)
out = tmp_path / "img.png"
result = runner.invoke(
app,
["generate", "test prompt", "-m", "x.safetensors", "-c", "4", "-P", "4", "--seed", "1", "-o", str(out)],
)
assert result.exit_code == 0, result.output
# With P=4 and 4 tasks each sleeping 100ms, peak concurrency should hit 4.
# Even allowing for thread-pool warmup quirks, ≥2 means parallelism is real.
assert peak_in_flight >= 2, f"peak_in_flight={peak_in_flight} (expected ≥2 for parallel)"
+2 -2
View File
@@ -1237,7 +1237,7 @@ class TestDownloadBackgroundTasks:
captured["api_key"] = api_key captured["api_key"] = api_key
return {"file_id": 42, "sha256": "deadbeef", "linked": True, "cached": True, "error": None} return {"file_id": 42, "sha256": "deadbeef", "linked": True, "cached": True, "error": None}
monkeypatch.setattr(download_routes_module, "Database", lambda: StubDB()) monkeypatch.setattr(download_routes_module, "Database", StubDB)
return captured return captured
def test_do_download_success(self, monkeypatch, tmp_path) -> None: def test_do_download_success(self, monkeypatch, tmp_path) -> None:
@@ -1397,7 +1397,7 @@ class TestDownloadBackgroundTasks:
def register_downloaded_file(self, *args, **kwargs): def register_downloaded_file(self, *args, **kwargs):
return {"file_id": None, "sha256": None, "linked": False, "cached": False, "error": "boom"} return {"file_id": None, "sha256": None, "linked": False, "cached": False, "error": "boom"}
monkeypatch.setattr(download_routes, "Database", lambda: FailingDB()) monkeypatch.setattr(download_routes, "Database", FailingDB)
dest_path = tmp_path / "model.safetensors" dest_path = tmp_path / "model.safetensors"
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1}) _do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})