"""Tests for the `tsr style-sweep` command.""" from __future__ import annotations import json from pathlib import Path from typing import Any import pytest from typer.testing import CliRunner from tensors import cli as cli_module from tensors.cli import app runner = CliRunner() # ----------------------------------------------------------------------------- # Helpers # ----------------------------------------------------------------------------- def _write_template( tmp_path: Path, *, output_dir: Path | str, styles: Any = None, extra: dict[str, Any] | None = None, ) -> Path: """Write a minimal template JSON file and return its path.""" body: dict[str, Any] = { "prompt": "a portrait of a person", "model": "test-model.safetensors", "seed": 12345, "orientation": "portrait", "output_dir": str(output_dir), } if styles is not None: body["styles"] = styles if extra: body.update(extra) path = tmp_path / "template.json" path.write_text(json.dumps(body)) return path def _write_styles_file(tmp_path: Path, entries: list[dict[str, str]]) -> Path: """Write a styles JSON file (object form with 'styles' key).""" path = tmp_path / "styles.json" path.write_text(json.dumps({"name": "test", "description": "", "styles": entries})) return path @pytest.fixture def calls(monkeypatch: pytest.MonkeyPatch) -> list[dict[str, Any]]: """Patch `_run_generation` to record calls and create the output file.""" recorded: list[dict[str, Any]] = [] def fake_run_generation(**kwargs: Any) -> None: recorded.append(kwargs) out: Path | None = kwargs.get("output") if out is not None: out.parent.mkdir(parents=True, exist_ok=True) out.write_bytes(b"fake-png") monkeypatch.setattr(cli_module, "_run_generation", fake_run_generation) return recorded # ----------------------------------------------------------------------------- # Tests # ----------------------------------------------------------------------------- def test_loads_template_and_styles_from_files(tmp_path: Path, calls: list[dict[str, Any]]) -> None: """Template + external styles file → N generate calls with composed prompts.""" out_dir = tmp_path / "out" styles_file = _write_styles_file( tmp_path, [ {"slug": "01-foo", "suffix": "in the style of Foo"}, {"slug": "02-bar", "suffix": "in the style of Bar"}, {"slug": "03-baz", "suffix": "in the style of Baz"}, ], ) tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file)) result = runner.invoke(app, ["style-sweep", "--template", str(tpl)]) assert result.exit_code == 0, result.output assert len(calls) == 3 assert calls[0]["prompt"] == "a portrait of a person, in the style of Foo" assert calls[1]["prompt"] == "a portrait of a person, in the style of Bar" assert calls[2]["prompt"] == "a portrait of a person, in the style of Baz" # Each call writes to {output_dir}/{slug}.png assert calls[0]["output"] == out_dir / "01-foo.png" assert calls[2]["output"] == out_dir / "03-baz.png" # Template values propagated assert calls[0]["model"] == "test-model.safetensors" assert calls[0]["seed"] == 12345 assert calls[0]["orientation"] == "portrait" def test_skip_existing(tmp_path: Path, calls: list[dict[str, Any]]) -> None: """Pre-existing output file → that slug is skipped.""" out_dir = tmp_path / "out" out_dir.mkdir() (out_dir / "01-foo.png").write_bytes(b"already here") styles_file = _write_styles_file( tmp_path, [ {"slug": "01-foo", "suffix": "Foo"}, {"slug": "02-bar", "suffix": "Bar"}, ], ) tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file)) result = runner.invoke(app, ["style-sweep", "--template", str(tpl)]) assert result.exit_code == 0, result.output # Only 02-bar should have been generated assert len(calls) == 1 assert calls[0]["output"] == out_dir / "02-bar.png" assert "skip" in result.output.lower() def test_limit(tmp_path: Path, calls: list[dict[str, Any]]) -> None: """--limit 2 caps the sweep at 2 styles.""" out_dir = tmp_path / "out" styles_file = _write_styles_file( tmp_path, [ {"slug": f"{i:02d}-style", "suffix": f"style {i}"} for i in range(1, 6) ], ) tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file)) result = runner.invoke(app, ["style-sweep", "--template", str(tpl), "--limit", "2"]) assert result.exit_code == 0, result.output assert len(calls) == 2 assert calls[0]["output"].name == "01-style.png" assert calls[1]["output"].name == "02-style.png" def test_dry_run(tmp_path: Path, calls: list[dict[str, Any]]) -> None: """--dry-run prints plan but does not invoke generate.""" out_dir = tmp_path / "out" styles_file = _write_styles_file( tmp_path, [ {"slug": "01-foo", "suffix": "Foo style"}, {"slug": "02-bar", "suffix": "Bar style"}, ], ) tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file)) result = runner.invoke(app, ["style-sweep", "--template", str(tpl), "--dry-run"]) assert result.exit_code == 0, result.output assert len(calls) == 0 assert "DRY RUN" in result.output assert "01-foo" in result.output assert "Foo style" in result.output # No manifest written on dry-run assert not (out_dir / "_sweep.json").exists() def test_inline_styles_list(tmp_path: Path, calls: list[dict[str, Any]]) -> None: """Styles can be passed inline as a list inside the template.""" out_dir = tmp_path / "out" inline_styles = [ {"slug": "alpha", "suffix": "Alpha suffix"}, {"slug": "beta", "suffix": "Beta suffix"}, ] tpl = _write_template(tmp_path, output_dir=out_dir, styles=inline_styles) result = runner.invoke(app, ["style-sweep", "--template", str(tpl)]) assert result.exit_code == 0, result.output assert len(calls) == 2 assert calls[0]["prompt"].endswith("Alpha suffix") assert calls[1]["prompt"].endswith("Beta suffix") def test_manifest_written(tmp_path: Path, calls: list[dict[str, Any]]) -> None: """A successful sweep produces {output_dir}/_sweep.json with expected keys.""" out_dir = tmp_path / "out" styles_file = _write_styles_file( tmp_path, [ {"slug": "01-foo", "suffix": "Foo"}, {"slug": "02-bar", "suffix": "Bar"}, ], ) tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file)) result = runner.invoke(app, ["style-sweep", "--template", str(tpl)]) assert result.exit_code == 0, result.output manifest_path = out_dir / "_sweep.json" assert manifest_path.exists() manifest = json.loads(manifest_path.read_text()) assert manifest["template"] == str(tpl) assert manifest["styles_source"] == str(styles_file) assert len(manifest["results"]) == 2 first = manifest["results"][0] for key in ("slug", "prompt", "output", "seed", "duration_sec", "success", "error"): assert key in first, f"missing manifest key {key}" assert first["slug"] == "01-foo" assert first["success"] is True assert first["error"] is None assert first["seed"] == 12345 def test_continue_on_error(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: """One failed style does not abort the sweep; manifest records the error.""" out_dir = tmp_path / "out" styles_file = _write_styles_file( tmp_path, [ {"slug": "01-ok", "suffix": "ok one"}, {"slug": "02-bad", "suffix": "bad one"}, {"slug": "03-ok", "suffix": "ok two"}, ], ) tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file)) def fake_run_generation(**kwargs: Any) -> None: out: Path = kwargs["output"] if "02-bad" in out.name: raise RuntimeError("simulated failure") out.parent.mkdir(parents=True, exist_ok=True) out.write_bytes(b"fake") monkeypatch.setattr(cli_module, "_run_generation", fake_run_generation) result = runner.invoke(app, ["style-sweep", "--template", str(tpl)]) # Sweep finished but exit code non-zero because one slug failed assert result.exit_code == 1, result.output assert "02-bad" in result.output assert "FAIL" in result.output manifest = json.loads((out_dir / "_sweep.json").read_text()) assert len(manifest["results"]) == 3 statuses = {r["slug"]: r for r in manifest["results"]} assert statuses["01-ok"]["success"] is True assert statuses["02-bad"]["success"] is False assert "simulated failure" in statuses["02-bad"]["error"] assert statuses["03-ok"]["success"] is True def test_abort_on_error_stops_immediately(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: """--abort-on-error aborts at the first failure.""" out_dir = tmp_path / "out" styles_file = _write_styles_file( tmp_path, [ {"slug": "01-bad", "suffix": "bad"}, {"slug": "02-skipped", "suffix": "never reached"}, ], ) tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file)) seen: list[str] = [] def fake_run_generation(**kwargs: Any) -> None: seen.append(Path(kwargs["output"]).name) raise RuntimeError("boom") monkeypatch.setattr(cli_module, "_run_generation", fake_run_generation) result = runner.invoke(app, ["style-sweep", "--template", str(tpl), "--abort-on-error"]) assert result.exit_code != 0 assert seen == ["01-bad.png"] # Manifest was still written (with the one failed entry) manifest = json.loads((out_dir / "_sweep.json").read_text()) assert len(manifest["results"]) == 1 assert manifest["results"][0]["slug"] == "01-bad" def test_missing_template_file_errors(tmp_path: Path) -> None: """A non-existent template path yields a clean error exit.""" result = runner.invoke(app, ["style-sweep", "--template", str(tmp_path / "nope.json")]) assert result.exit_code != 0 assert "not found" in result.output.lower() def test_missing_styles_errors(tmp_path: Path) -> None: """A template without styles (and no --styles) errors out.""" out_dir = tmp_path / "out" tpl_body = { "prompt": "a portrait", "output_dir": str(out_dir), } tpl = tmp_path / "template.json" tpl.write_text(json.dumps(tpl_body)) result = runner.invoke(app, ["style-sweep", "--template", str(tpl)]) assert result.exit_code != 0 assert "styles" in result.output.lower() def test_cli_output_dir_overrides_template(tmp_path: Path, calls: list[dict[str, Any]]) -> None: """--output-dir on the CLI overrides the template's output_dir.""" tpl_out = tmp_path / "from-template" cli_out = tmp_path / "from-cli" styles_file = _write_styles_file(tmp_path, [{"slug": "x", "suffix": "X"}]) tpl = _write_template(tmp_path, output_dir=tpl_out, styles=str(styles_file)) result = runner.invoke( app, ["style-sweep", "--template", str(tpl), "--output-dir", str(cli_out)] ) assert result.exit_code == 0, result.output assert calls[0]["output"] == cli_out / "x.png" assert (cli_out / "_sweep.json").exists() assert not tpl_out.exists() def test_remote_override(tmp_path: Path, calls: list[dict[str, Any]]) -> None: """--remote propagates through to _run_generation.""" out_dir = tmp_path / "out" styles_file = _write_styles_file(tmp_path, [{"slug": "x", "suffix": "X"}]) tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file)) result = runner.invoke( app, ["style-sweep", "--template", str(tpl), "--remote", "junkpile"] ) assert result.exit_code == 0, result.output assert calls[0]["remote"] == "junkpile"