feat(cli): add style-sweep command for batched style variation
New `tsr style-sweep` command renders one image per style suffix from a
template JSON, composing prompt = template.prompt + ', ' + style.suffix
and writing to {output_dir}/{slug}.png.
- Template JSON mirrors `generate --input` keys plus output_dir + styles.
- Styles source can be a path or inline list/object on either CLI or
template. Relative styles paths in the template resolve against the
template's directory (so templates can ship with their styles file).
- Skips existing outputs by default (--no-skip-existing to force).
- --dry-run prints planned prompts/paths without invoking generate.
- --limit N caps the sweep for fast iteration.
- --continue-on-error keeps going on individual failures; final exit code
is non-zero if any style failed and failed slugs are reported.
- --remote propagates to the underlying generation, same as `generate`.
- Writes a manifest {output_dir}/_sweep.json with per-style results
(slug, prompt, output, seed, duration_sec, success, error).
Delegates to the `_run_generation` helper extracted from `generate`.
This commit is contained in:
+370
@@ -1220,6 +1220,376 @@ def _run_generation( # noqa: PLR0915
|
||||
console.print(f"[dim]Prompt ID: {all_results[0]['prompt_id']}[/dim]")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Style Sweep
|
||||
# =============================================================================
|
||||
|
||||
|
||||
# Keys that style-sweep templates accept (mirror of `generate --input` keys, plus
|
||||
# two sweep-specific keys: output_dir and styles).
|
||||
_STYLE_SWEEP_TEMPLATE_KEYS = {
|
||||
"prompt",
|
||||
"model",
|
||||
"width",
|
||||
"height",
|
||||
"steps",
|
||||
"cfg",
|
||||
"guidance",
|
||||
"seed",
|
||||
"sampler",
|
||||
"scheduler",
|
||||
"vae",
|
||||
"lora",
|
||||
"lora_strength",
|
||||
"negative",
|
||||
"negative_prompt",
|
||||
"orientation",
|
||||
"no_quality",
|
||||
"no_negative",
|
||||
"rating",
|
||||
"family",
|
||||
"remote",
|
||||
# sweep-specific
|
||||
"output_dir",
|
||||
"styles",
|
||||
}
|
||||
|
||||
|
||||
def _load_json_file_or_inline(value: str | list | dict, *, what: str) -> Any:
|
||||
"""Load JSON from a file path or accept already-parsed inline data.
|
||||
|
||||
`value` may be a path string, a JSON string, or an already-parsed list/dict
|
||||
(e.g. when read out of a template). Raises typer.Exit on failure.
|
||||
"""
|
||||
if isinstance(value, (list, dict)):
|
||||
return value
|
||||
if not isinstance(value, str):
|
||||
console.print(f"[red]Invalid {what} value (expected path, JSON string, or inline data)[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
path = Path(value)
|
||||
if path.is_file():
|
||||
try:
|
||||
return json.loads(path.read_text())
|
||||
except json.JSONDecodeError as e:
|
||||
console.print(f"[red]Invalid JSON in {what} file {path}:[/red] {e}")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
stripped = value.lstrip()
|
||||
if stripped.startswith(("{", "[")):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
console.print(f"[red]Invalid inline JSON for {what}:[/red] {e}")
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
console.print(f"[red]{what.capitalize()} is neither a readable file nor inline JSON:[/red] {value}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def _normalize_styles(styles_data: Any) -> list[dict[str, str]]:
|
||||
"""Coerce styles data into a flat list of {slug, suffix} dicts."""
|
||||
if isinstance(styles_data, dict):
|
||||
entries = styles_data.get("styles")
|
||||
if entries is None:
|
||||
console.print("[red]Styles object missing 'styles' key[/red]")
|
||||
raise typer.Exit(1)
|
||||
elif isinstance(styles_data, list):
|
||||
entries = styles_data
|
||||
else:
|
||||
console.print("[red]Styles data must be an object with 'styles' key or a list[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if not isinstance(entries, list) or not entries:
|
||||
console.print("[red]Styles list is empty or not a list[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
normalized: list[dict[str, str]] = []
|
||||
for i, entry in enumerate(entries):
|
||||
if not isinstance(entry, dict):
|
||||
console.print(f"[red]Style entry #{i} is not an object[/red]")
|
||||
raise typer.Exit(1)
|
||||
slug = entry.get("slug")
|
||||
suffix = entry.get("suffix")
|
||||
if not slug or not isinstance(slug, str):
|
||||
console.print(f"[red]Style entry #{i} missing/invalid 'slug'[/red]")
|
||||
raise typer.Exit(1)
|
||||
if suffix is None or not isinstance(suffix, str):
|
||||
console.print(f"[red]Style entry #{i} ({slug}) missing/invalid 'suffix'[/red]")
|
||||
raise typer.Exit(1)
|
||||
normalized.append({"slug": slug, "suffix": suffix})
|
||||
return normalized
|
||||
|
||||
|
||||
@app.command(name="style-sweep")
|
||||
def style_sweep( # noqa: PLR0915
|
||||
template: Annotated[
|
||||
Path,
|
||||
typer.Option("--template", "-t", help="Path to template JSON (mirrors `generate --input` keys + output_dir/styles)"),
|
||||
],
|
||||
styles: Annotated[
|
||||
str | None,
|
||||
typer.Option("--styles", help="Override styles source: path to JSON or inline JSON list/object"),
|
||||
] = None,
|
||||
output_dir: Annotated[
|
||||
Path | None,
|
||||
typer.Option("--output-dir", help="Override output directory from template"),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int | None,
|
||||
typer.Option("--limit", help="Stop after N styles (useful for testing)"),
|
||||
] = None,
|
||||
skip_existing: Annotated[
|
||||
bool,
|
||||
typer.Option("--skip-existing/--no-skip-existing", help="Skip styles whose output file already exists"),
|
||||
] = True,
|
||||
dry_run: Annotated[
|
||||
bool,
|
||||
typer.Option("--dry-run", help="Print planned prompts/paths without invoking generate"),
|
||||
] = False,
|
||||
continue_on_error: Annotated[
|
||||
bool,
|
||||
typer.Option("--continue-on-error/--abort-on-error", help="Keep going after individual style failures"),
|
||||
] = True,
|
||||
remote: Annotated[
|
||||
str | None,
|
||||
typer.Option("-r", "--remote", help="Remote server name or URL (overrides template)"),
|
||||
] = None,
|
||||
) -> None:
|
||||
"""Sweep a base prompt across a list of style suffixes, one image per style.
|
||||
|
||||
Loads a template JSON with the base prompt + generation params, plus a styles
|
||||
JSON listing {slug, suffix} entries. For each style, composes
|
||||
"{prompt}, {suffix}" and renders to {output_dir}/{slug}.png.
|
||||
|
||||
Writes a manifest at {output_dir}/_sweep.json with per-style results.
|
||||
|
||||
Examples:
|
||||
tsr style-sweep --template woman-black-dress.json
|
||||
tsr style-sweep -t template.json --styles styles.json --limit 3
|
||||
tsr style-sweep -t template.json --dry-run
|
||||
tsr style-sweep -t template.json --remote junkpile
|
||||
"""
|
||||
# ---- Load template ----
|
||||
if not template.is_file():
|
||||
console.print(f"[red]Template file not found:[/red] {template}")
|
||||
raise typer.Exit(1)
|
||||
try:
|
||||
tpl_data = json.loads(template.read_text())
|
||||
except json.JSONDecodeError as e:
|
||||
console.print(f"[red]Invalid JSON in template {template}:[/red] {e}")
|
||||
raise typer.Exit(1) from e
|
||||
if not isinstance(tpl_data, dict):
|
||||
console.print("[red]Template JSON must be an object[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Warn on unknown keys (don't error — forward-compat)
|
||||
unknown = {k for k in tpl_data if not k.startswith("_") and k not in _STYLE_SWEEP_TEMPLATE_KEYS}
|
||||
if unknown:
|
||||
console.print(f"[yellow]Unknown template keys ignored:[/yellow] {sorted(unknown)}")
|
||||
|
||||
base_prompt = tpl_data.get("prompt")
|
||||
if not base_prompt or not isinstance(base_prompt, str):
|
||||
console.print("[red]Template missing required 'prompt' string[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# ---- Resolve styles source ----
|
||||
# Relative paths inside the template are resolved against the template's
|
||||
# directory (so templates can ship next to their styles files).
|
||||
tpl_dir = template.resolve().parent
|
||||
|
||||
def _resolve_relative_to_template(val: str) -> str:
|
||||
p = Path(val)
|
||||
if not p.is_absolute() and not p.exists():
|
||||
alt = tpl_dir / p
|
||||
if alt.exists():
|
||||
return str(alt)
|
||||
return val
|
||||
|
||||
styles_source: Any
|
||||
styles_origin: str
|
||||
if styles is not None:
|
||||
styles_origin = styles
|
||||
styles_source = _load_json_file_or_inline(styles, what="styles")
|
||||
elif "styles" in tpl_data:
|
||||
tpl_styles = tpl_data["styles"]
|
||||
if isinstance(tpl_styles, list):
|
||||
styles_origin = "<inline in template>"
|
||||
styles_source = tpl_styles
|
||||
else:
|
||||
resolved = _resolve_relative_to_template(tpl_styles)
|
||||
styles_origin = resolved
|
||||
styles_source = _load_json_file_or_inline(resolved, what="styles")
|
||||
else:
|
||||
console.print("[red]No styles specified (use --styles or set 'styles' in template)[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
style_entries = _normalize_styles(styles_source)
|
||||
if limit is not None:
|
||||
if limit < 0:
|
||||
console.print("[red]--limit must be >= 0[/red]")
|
||||
raise typer.Exit(1)
|
||||
style_entries = style_entries[:limit]
|
||||
|
||||
# ---- Resolve output directory ----
|
||||
out_dir: Path
|
||||
if output_dir is not None:
|
||||
out_dir = output_dir
|
||||
elif "output_dir" in tpl_data:
|
||||
out_dir = Path(tpl_data["output_dir"])
|
||||
else:
|
||||
console.print("[red]No output_dir specified (use --output-dir or set 'output_dir' in template)[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if not dry_run:
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ---- Resolve generate params from template ----
|
||||
def _t(key: str, *, cast: Any = None, default: Any = None) -> Any:
|
||||
val = tpl_data.get(key, default)
|
||||
if val is None or cast is None:
|
||||
return val
|
||||
try:
|
||||
return cast(val)
|
||||
except (TypeError, ValueError):
|
||||
return val
|
||||
|
||||
# Accept both "negative" and "negative_prompt" keys
|
||||
negative_val = tpl_data.get("negative", tpl_data.get("negative_prompt", "")) or ""
|
||||
|
||||
gen_remote = remote if remote is not None else tpl_data.get("remote")
|
||||
|
||||
# ---- Execute sweep ----
|
||||
import time # noqa: PLC0415
|
||||
|
||||
total = len(style_entries)
|
||||
console.print(f"[bold]Style sweep:[/bold] {total} styles → {out_dir}")
|
||||
console.print(f"[dim]Template: {template}[/dim]")
|
||||
console.print(f"[dim]Styles: {styles_origin}[/dim]")
|
||||
if dry_run:
|
||||
console.print("[yellow]DRY RUN — no generation calls will be made[/yellow]")
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
failed_slugs: list[str] = []
|
||||
|
||||
for i, entry in enumerate(style_entries, start=1):
|
||||
slug = entry["slug"]
|
||||
suffix = entry["suffix"]
|
||||
composed_prompt = f"{base_prompt}, {suffix}"
|
||||
out_path = out_dir / f"{slug}.png"
|
||||
|
||||
result: dict[str, Any] = {
|
||||
"slug": slug,
|
||||
"prompt": composed_prompt,
|
||||
"output": str(out_path),
|
||||
"seed": _t("seed", cast=int, default=-1),
|
||||
"duration_sec": 0.0,
|
||||
"success": False,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
# Skip if exists
|
||||
if skip_existing and out_path.exists():
|
||||
console.print(f"[dim]\\[{i}/{total}] {slug} skip (exists)[/dim]")
|
||||
result["success"] = True
|
||||
result["skipped"] = True
|
||||
results.append(result)
|
||||
continue
|
||||
|
||||
if dry_run:
|
||||
console.print(f"\\[{i}/{total}] {slug}")
|
||||
console.print(f" [dim]prompt:[/dim] {composed_prompt}")
|
||||
console.print(f" [dim]output:[/dim] {out_path}")
|
||||
result["success"] = True
|
||||
result["dry_run"] = True
|
||||
results.append(result)
|
||||
continue
|
||||
|
||||
start = time.perf_counter()
|
||||
try:
|
||||
_run_generation(
|
||||
prompt=composed_prompt,
|
||||
model=_t("model"),
|
||||
width=_t("width", cast=int),
|
||||
height=_t("height", cast=int),
|
||||
steps=_t("steps", cast=int),
|
||||
cfg=_t("cfg", cast=float),
|
||||
guidance=_t("guidance", cast=float),
|
||||
seed=_t("seed", cast=int, default=-1),
|
||||
sampler=_t("sampler"),
|
||||
scheduler=_t("scheduler"),
|
||||
vae=_t("vae"),
|
||||
orientation=_t("orientation", default="square"),
|
||||
lora=_t("lora"),
|
||||
lora_strength=_t("lora_strength", cast=float, default=0.8),
|
||||
negative=negative_val,
|
||||
count=1,
|
||||
rating=_t("rating"),
|
||||
no_quality=bool(_t("no_quality", default=False)),
|
||||
no_negative=bool(_t("no_negative", default=False)),
|
||||
family=_t("family"),
|
||||
output=out_path,
|
||||
remote=gen_remote,
|
||||
json_output=False,
|
||||
)
|
||||
duration = time.perf_counter() - start
|
||||
result["duration_sec"] = round(duration, 2)
|
||||
result["success"] = True
|
||||
console.print(f"[green]\\[{i}/{total}] {slug} ok in {duration:.1f}s[/green]")
|
||||
except typer.Exit as e:
|
||||
duration = time.perf_counter() - start
|
||||
result["duration_sec"] = round(duration, 2)
|
||||
err_msg = f"generate exited with code {e.exit_code}"
|
||||
result["error"] = err_msg
|
||||
failed_slugs.append(slug)
|
||||
console.print(f"[red]\\[{i}/{total}] {slug} FAIL: {err_msg}[/red]")
|
||||
if not continue_on_error:
|
||||
results.append(result)
|
||||
_write_sweep_manifest(out_dir, template, styles_origin, results)
|
||||
raise
|
||||
except Exception as e:
|
||||
duration = time.perf_counter() - start
|
||||
result["duration_sec"] = round(duration, 2)
|
||||
result["error"] = str(e)
|
||||
failed_slugs.append(slug)
|
||||
console.print(f"[red]\\[{i}/{total}] {slug} FAIL: {e}[/red]")
|
||||
if not continue_on_error:
|
||||
results.append(result)
|
||||
_write_sweep_manifest(out_dir, template, styles_origin, results)
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
results.append(result)
|
||||
|
||||
# ---- Manifest ----
|
||||
if not dry_run:
|
||||
manifest_path = _write_sweep_manifest(out_dir, template, styles_origin, results)
|
||||
console.print(f"[dim]Manifest: {manifest_path}[/dim]")
|
||||
|
||||
# ---- Summary ----
|
||||
successful = sum(1 for r in results if r.get("success"))
|
||||
console.print(f"[bold]Sweep complete:[/bold] {successful}/{len(results)} ok")
|
||||
if failed_slugs:
|
||||
console.print(f"[red]Failed slugs ({len(failed_slugs)}):[/red] {', '.join(failed_slugs)}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def _write_sweep_manifest(
|
||||
out_dir: Path,
|
||||
template_path: Path,
|
||||
styles_origin: str,
|
||||
results: list[dict[str, Any]],
|
||||
) -> Path:
|
||||
"""Write the per-sweep manifest JSON. Returns the path."""
|
||||
manifest_path = out_dir / "_sweep.json"
|
||||
manifest: dict[str, Any] = {
|
||||
"template": str(template_path),
|
||||
"styles_source": styles_origin,
|
||||
"results": results,
|
||||
}
|
||||
manifest_path.write_text(json.dumps(manifest, indent=2) + "\n")
|
||||
return manifest_path
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Template Dump
|
||||
# =============================================================================
|
||||
|
||||
@@ -0,0 +1,337 @@
|
||||
"""Tests for the `tsr style-sweep` command."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from tensors import cli as cli_module
|
||||
from tensors.cli import app
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _write_template(
|
||||
tmp_path: Path,
|
||||
*,
|
||||
output_dir: Path | str,
|
||||
styles: Any = None,
|
||||
extra: dict[str, Any] | None = None,
|
||||
) -> Path:
|
||||
"""Write a minimal template JSON file and return its path."""
|
||||
body: dict[str, Any] = {
|
||||
"prompt": "a portrait of a person",
|
||||
"model": "test-model.safetensors",
|
||||
"seed": 12345,
|
||||
"orientation": "portrait",
|
||||
"output_dir": str(output_dir),
|
||||
}
|
||||
if styles is not None:
|
||||
body["styles"] = styles
|
||||
if extra:
|
||||
body.update(extra)
|
||||
path = tmp_path / "template.json"
|
||||
path.write_text(json.dumps(body))
|
||||
return path
|
||||
|
||||
|
||||
def _write_styles_file(tmp_path: Path, entries: list[dict[str, str]]) -> Path:
|
||||
"""Write a styles JSON file (object form with 'styles' key)."""
|
||||
path = tmp_path / "styles.json"
|
||||
path.write_text(json.dumps({"name": "test", "description": "", "styles": entries}))
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def calls(monkeypatch: pytest.MonkeyPatch) -> list[dict[str, Any]]:
|
||||
"""Patch `_run_generation` to record calls and create the output file."""
|
||||
recorded: list[dict[str, Any]] = []
|
||||
|
||||
def fake_run_generation(**kwargs: Any) -> None:
|
||||
recorded.append(kwargs)
|
||||
out: Path | None = kwargs.get("output")
|
||||
if out is not None:
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
out.write_bytes(b"fake-png")
|
||||
|
||||
monkeypatch.setattr(cli_module, "_run_generation", fake_run_generation)
|
||||
return recorded
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_loads_template_and_styles_from_files(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
|
||||
"""Template + external styles file → N generate calls with composed prompts."""
|
||||
out_dir = tmp_path / "out"
|
||||
styles_file = _write_styles_file(
|
||||
tmp_path,
|
||||
[
|
||||
{"slug": "01-foo", "suffix": "in the style of Foo"},
|
||||
{"slug": "02-bar", "suffix": "in the style of Bar"},
|
||||
{"slug": "03-baz", "suffix": "in the style of Baz"},
|
||||
],
|
||||
)
|
||||
tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file))
|
||||
|
||||
result = runner.invoke(app, ["style-sweep", "--template", str(tpl)])
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
assert len(calls) == 3
|
||||
assert calls[0]["prompt"] == "a portrait of a person, in the style of Foo"
|
||||
assert calls[1]["prompt"] == "a portrait of a person, in the style of Bar"
|
||||
assert calls[2]["prompt"] == "a portrait of a person, in the style of Baz"
|
||||
# Each call writes to {output_dir}/{slug}.png
|
||||
assert calls[0]["output"] == out_dir / "01-foo.png"
|
||||
assert calls[2]["output"] == out_dir / "03-baz.png"
|
||||
# Template values propagated
|
||||
assert calls[0]["model"] == "test-model.safetensors"
|
||||
assert calls[0]["seed"] == 12345
|
||||
assert calls[0]["orientation"] == "portrait"
|
||||
|
||||
|
||||
def test_skip_existing(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
|
||||
"""Pre-existing output file → that slug is skipped."""
|
||||
out_dir = tmp_path / "out"
|
||||
out_dir.mkdir()
|
||||
(out_dir / "01-foo.png").write_bytes(b"already here")
|
||||
|
||||
styles_file = _write_styles_file(
|
||||
tmp_path,
|
||||
[
|
||||
{"slug": "01-foo", "suffix": "Foo"},
|
||||
{"slug": "02-bar", "suffix": "Bar"},
|
||||
],
|
||||
)
|
||||
tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file))
|
||||
|
||||
result = runner.invoke(app, ["style-sweep", "--template", str(tpl)])
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
# Only 02-bar should have been generated
|
||||
assert len(calls) == 1
|
||||
assert calls[0]["output"] == out_dir / "02-bar.png"
|
||||
assert "skip" in result.output.lower()
|
||||
|
||||
|
||||
def test_limit(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
|
||||
"""--limit 2 caps the sweep at 2 styles."""
|
||||
out_dir = tmp_path / "out"
|
||||
styles_file = _write_styles_file(
|
||||
tmp_path,
|
||||
[
|
||||
{"slug": f"{i:02d}-style", "suffix": f"style {i}"} for i in range(1, 6)
|
||||
],
|
||||
)
|
||||
tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file))
|
||||
|
||||
result = runner.invoke(app, ["style-sweep", "--template", str(tpl), "--limit", "2"])
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
assert len(calls) == 2
|
||||
assert calls[0]["output"].name == "01-style.png"
|
||||
assert calls[1]["output"].name == "02-style.png"
|
||||
|
||||
|
||||
def test_dry_run(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
|
||||
"""--dry-run prints plan but does not invoke generate."""
|
||||
out_dir = tmp_path / "out"
|
||||
styles_file = _write_styles_file(
|
||||
tmp_path,
|
||||
[
|
||||
{"slug": "01-foo", "suffix": "Foo style"},
|
||||
{"slug": "02-bar", "suffix": "Bar style"},
|
||||
],
|
||||
)
|
||||
tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file))
|
||||
|
||||
result = runner.invoke(app, ["style-sweep", "--template", str(tpl), "--dry-run"])
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
assert len(calls) == 0
|
||||
assert "DRY RUN" in result.output
|
||||
assert "01-foo" in result.output
|
||||
assert "Foo style" in result.output
|
||||
# No manifest written on dry-run
|
||||
assert not (out_dir / "_sweep.json").exists()
|
||||
|
||||
|
||||
def test_inline_styles_list(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
|
||||
"""Styles can be passed inline as a list inside the template."""
|
||||
out_dir = tmp_path / "out"
|
||||
inline_styles = [
|
||||
{"slug": "alpha", "suffix": "Alpha suffix"},
|
||||
{"slug": "beta", "suffix": "Beta suffix"},
|
||||
]
|
||||
tpl = _write_template(tmp_path, output_dir=out_dir, styles=inline_styles)
|
||||
|
||||
result = runner.invoke(app, ["style-sweep", "--template", str(tpl)])
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
assert len(calls) == 2
|
||||
assert calls[0]["prompt"].endswith("Alpha suffix")
|
||||
assert calls[1]["prompt"].endswith("Beta suffix")
|
||||
|
||||
|
||||
def test_manifest_written(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
|
||||
"""A successful sweep produces {output_dir}/_sweep.json with expected keys."""
|
||||
out_dir = tmp_path / "out"
|
||||
styles_file = _write_styles_file(
|
||||
tmp_path,
|
||||
[
|
||||
{"slug": "01-foo", "suffix": "Foo"},
|
||||
{"slug": "02-bar", "suffix": "Bar"},
|
||||
],
|
||||
)
|
||||
tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file))
|
||||
|
||||
result = runner.invoke(app, ["style-sweep", "--template", str(tpl)])
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
manifest_path = out_dir / "_sweep.json"
|
||||
assert manifest_path.exists()
|
||||
|
||||
manifest = json.loads(manifest_path.read_text())
|
||||
assert manifest["template"] == str(tpl)
|
||||
assert manifest["styles_source"] == str(styles_file)
|
||||
assert len(manifest["results"]) == 2
|
||||
|
||||
first = manifest["results"][0]
|
||||
for key in ("slug", "prompt", "output", "seed", "duration_sec", "success", "error"):
|
||||
assert key in first, f"missing manifest key {key}"
|
||||
assert first["slug"] == "01-foo"
|
||||
assert first["success"] is True
|
||||
assert first["error"] is None
|
||||
assert first["seed"] == 12345
|
||||
|
||||
|
||||
def test_continue_on_error(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""One failed style does not abort the sweep; manifest records the error."""
|
||||
out_dir = tmp_path / "out"
|
||||
styles_file = _write_styles_file(
|
||||
tmp_path,
|
||||
[
|
||||
{"slug": "01-ok", "suffix": "ok one"},
|
||||
{"slug": "02-bad", "suffix": "bad one"},
|
||||
{"slug": "03-ok", "suffix": "ok two"},
|
||||
],
|
||||
)
|
||||
tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file))
|
||||
|
||||
def fake_run_generation(**kwargs: Any) -> None:
|
||||
out: Path = kwargs["output"]
|
||||
if "02-bad" in out.name:
|
||||
raise RuntimeError("simulated failure")
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
out.write_bytes(b"fake")
|
||||
|
||||
monkeypatch.setattr(cli_module, "_run_generation", fake_run_generation)
|
||||
|
||||
result = runner.invoke(app, ["style-sweep", "--template", str(tpl)])
|
||||
|
||||
# Sweep finished but exit code non-zero because one slug failed
|
||||
assert result.exit_code == 1, result.output
|
||||
assert "02-bad" in result.output
|
||||
assert "FAIL" in result.output
|
||||
|
||||
manifest = json.loads((out_dir / "_sweep.json").read_text())
|
||||
assert len(manifest["results"]) == 3
|
||||
statuses = {r["slug"]: r for r in manifest["results"]}
|
||||
assert statuses["01-ok"]["success"] is True
|
||||
assert statuses["02-bad"]["success"] is False
|
||||
assert "simulated failure" in statuses["02-bad"]["error"]
|
||||
assert statuses["03-ok"]["success"] is True
|
||||
|
||||
|
||||
def test_abort_on_error_stops_immediately(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""--abort-on-error aborts at the first failure."""
|
||||
out_dir = tmp_path / "out"
|
||||
styles_file = _write_styles_file(
|
||||
tmp_path,
|
||||
[
|
||||
{"slug": "01-bad", "suffix": "bad"},
|
||||
{"slug": "02-skipped", "suffix": "never reached"},
|
||||
],
|
||||
)
|
||||
tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file))
|
||||
|
||||
seen: list[str] = []
|
||||
|
||||
def fake_run_generation(**kwargs: Any) -> None:
|
||||
seen.append(Path(kwargs["output"]).name)
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(cli_module, "_run_generation", fake_run_generation)
|
||||
|
||||
result = runner.invoke(app, ["style-sweep", "--template", str(tpl), "--abort-on-error"])
|
||||
|
||||
assert result.exit_code != 0
|
||||
assert seen == ["01-bad.png"]
|
||||
# Manifest was still written (with the one failed entry)
|
||||
manifest = json.loads((out_dir / "_sweep.json").read_text())
|
||||
assert len(manifest["results"]) == 1
|
||||
assert manifest["results"][0]["slug"] == "01-bad"
|
||||
|
||||
|
||||
def test_missing_template_file_errors(tmp_path: Path) -> None:
|
||||
"""A non-existent template path yields a clean error exit."""
|
||||
result = runner.invoke(app, ["style-sweep", "--template", str(tmp_path / "nope.json")])
|
||||
assert result.exit_code != 0
|
||||
assert "not found" in result.output.lower()
|
||||
|
||||
|
||||
def test_missing_styles_errors(tmp_path: Path) -> None:
|
||||
"""A template without styles (and no --styles) errors out."""
|
||||
out_dir = tmp_path / "out"
|
||||
tpl_body = {
|
||||
"prompt": "a portrait",
|
||||
"output_dir": str(out_dir),
|
||||
}
|
||||
tpl = tmp_path / "template.json"
|
||||
tpl.write_text(json.dumps(tpl_body))
|
||||
|
||||
result = runner.invoke(app, ["style-sweep", "--template", str(tpl)])
|
||||
assert result.exit_code != 0
|
||||
assert "styles" in result.output.lower()
|
||||
|
||||
|
||||
def test_cli_output_dir_overrides_template(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
|
||||
"""--output-dir on the CLI overrides the template's output_dir."""
|
||||
tpl_out = tmp_path / "from-template"
|
||||
cli_out = tmp_path / "from-cli"
|
||||
styles_file = _write_styles_file(tmp_path, [{"slug": "x", "suffix": "X"}])
|
||||
tpl = _write_template(tmp_path, output_dir=tpl_out, styles=str(styles_file))
|
||||
|
||||
result = runner.invoke(
|
||||
app, ["style-sweep", "--template", str(tpl), "--output-dir", str(cli_out)]
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
assert calls[0]["output"] == cli_out / "x.png"
|
||||
assert (cli_out / "_sweep.json").exists()
|
||||
assert not tpl_out.exists()
|
||||
|
||||
|
||||
def test_remote_override(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
|
||||
"""--remote propagates through to _run_generation."""
|
||||
out_dir = tmp_path / "out"
|
||||
styles_file = _write_styles_file(tmp_path, [{"slug": "x", "suffix": "X"}])
|
||||
tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file))
|
||||
|
||||
result = runner.invoke(
|
||||
app, ["style-sweep", "--template", str(tpl), "--remote", "junkpile"]
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
assert calls[0]["remote"] == "junkpile"
|
||||
Reference in New Issue
Block a user