From c911abfe6941df2fa5e2b7e55249f72a949e022c Mon Sep 17 00:00:00 2001 From: aladac Date: Mon, 18 May 2026 21:16:11 +0200 Subject: [PATCH] feat(generate): accept YAML in --input alongside JSON MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `tsr generate --input ` previously only understood JSON, which was awkward for hand-authored template libraries (e.g. ~/Projects/draw/templates/ ships *.yml scene files with embedded newlines and unquoted keys that mirror the `tsr template` output shape). Behavior: - Files with .yml / .yaml extension parse as YAML; .json (or unknown extensions whose first non-whitespace char is '{' or '[') parse as JSON. - Inline strings starting with '{' still parse as JSON (regression-safe). - Inline strings without leading '{' now parse as YAML, enabling `tsr generate --input 'prompt: foo\nmodel: bar.safetensors'` without shell-quoting a JSON object. - All downstream key-mapping / CLI-override / character / scene / lora / count handling is identical to the JSON path — parsing only differs. Implementation: - New `_parse_generate_input(value)` helper in tensors/cli.py centralizes source detection (file vs inline), format selection (extension or content sniff), and rich-formatted error reporting via typer.Exit(1). - The pre-existing inline JSON merge block in `generate` is reduced to a single call to the helper. - Adds pyyaml>=6.0 as a runtime dep. It was already transitively pulled in by huggingface_hub, but we depend on it directly so the surface contract is explicit and survives a hub re-pin. - mypy override added for the yaml module (no upstream stubs in tree). Tests: - 20 new tests in tests/test_generate_input.py covering inline JSON, inline YAML, file by extension (.json/.yml/.yaml), unknown extension content sniffing, non-mapping rejection, malformed input handling, CLI-flag-wins-over-input precedence, and a full smoke against the exact draw template shape (with embedded newlines in the scene list). - 359 -> 379 total tests. Lint clean on changed lines. Co-Authored-By: OpenCode --- pyproject.toml | 5 + tensors/cli.py | 96 +++++++++++---- tests/test_generate_input.py | 218 +++++++++++++++++++++++++++++++++++ uv.lock | 2 + 4 files changed, 298 insertions(+), 23 deletions(-) create mode 100644 tests/test_generate_input.py diff --git a/pyproject.toml b/pyproject.toml index ee705b3..e94f792 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "websocket-client>=1.9.0", "huggingface_hub>=0.25.0", "sqlmodel>=0.0.33", + "pyyaml>=6.0", ] [project.optional-dependencies] @@ -113,6 +114,10 @@ ignore_missing_imports = true module = ["websockets.*"] ignore_missing_imports = true +[[tool.mypy.overrides]] +module = ["yaml.*"] +ignore_missing_imports = true + [tool.pytest.ini_options] testpaths = ["tests"] addopts = "-v --cov=tensors --cov-report=term-missing" diff --git a/tensors/cli.py b/tensors/cli.py index de0e356..101eec6 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -80,6 +80,66 @@ MAX_QUEUE_DISPLAY = 10 MAX_MODEL_LIST_DISPLAY = 20 MAX_PROMPT_ID_DISPLAY = 36 +# File extensions that force YAML parsing for `tsr generate --input `. +_YAML_INPUT_EXTENSIONS = frozenset({".yml", ".yaml"}) + + +def _parse_generate_input(value: str) -> dict[str, Any]: + """Parse a ``--input`` argument into a dict of generation params. + + Accepts either: + * a path to a ``.json`` / ``.yml`` / ``.yaml`` file, + * a raw JSON object string (``{"prompt": ...}``), + * a raw YAML document string (anything else that doesn't start with ``{``). + + File extension wins when reading from disk. For inline strings we try JSON + first (current behaviour) and fall back to YAML so existing callers keep + working without surprises. + + Raises ``typer.Exit(1)`` with a rich error on every failure path so callers + don't need to repeat the diagnostics. + """ + import yaml # noqa: PLC0415 — keep yaml a soft import path + + # ---- locate source text + decide format ---- + path = Path(value) + if path.is_file(): + text = path.read_text() + suffix = path.suffix.lower() + if suffix in _YAML_INPUT_EXTENSIONS: + fmt = "yaml" + elif suffix == ".json": + fmt = "json" + else: + # Unknown extension: sniff the content. Leading '{' or '[' → JSON. + fmt = "json" if text.lstrip().startswith(("{", "[")) else "yaml" + elif value.lstrip().startswith("{"): + text = value + fmt = "json" + else: + # Last-resort: treat as inline YAML. This is unusual but lets the user + # pass ``--input 'prompt: foo\nmodel: bar.safetensors'`` without quoting + # a JSON object on the shell. + text = value + fmt = "yaml" + + # ---- parse ---- + parsed: Any + try: + parsed = json.loads(text) if fmt == "json" else yaml.safe_load(text) + except json.JSONDecodeError as e: + console.print(f"[red]Invalid JSON input:[/red] {e}") + raise typer.Exit(1) from e + except yaml.YAMLError as e: + console.print(f"[red]Invalid YAML input:[/red] {e}") + raise typer.Exit(1) from e + + if not isinstance(parsed, dict): + console.print(f"[red]{fmt.upper()} input must be a mapping/object[/red]") + raise typer.Exit(1) + + return parsed + def _cache_model_quietly(model_data: dict[str, Any]) -> None: """Cache model data to database without output.""" @@ -810,7 +870,10 @@ def generate( # noqa: PLR0915 output: Annotated[Path | None, typer.Option("-o", "--output", help="Save path (default: current dir)")] = None, remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, - json_input: Annotated[str | None, typer.Option("--input", "-I", help="JSON params (keys match CLI options)")] = None, + json_input: Annotated[ + str | None, + typer.Option("--input", "-I", help="JSON or YAML params (file path or inline; keys match CLI options)"), + ] = None, ) -> None: """Generate an image using text-to-image. @@ -818,7 +881,11 @@ def generate( # noqa: PLR0915 model family. All auto-detected values can be overridden with explicit flags. Calls ComfyUI directly when local, or the remote tensors API when --remote is given. - Accepts --input with a JSON object whose keys match CLI option names. CLI flags override JSON values. + Accepts --input with a JSON or YAML object whose keys match CLI option names. + Files ending in ``.yml`` / ``.yaml`` are parsed as YAML; ``.json`` (or any + other extension whose contents start with ``{``/``[``) as JSON. Inline strings + starting with ``{`` are JSON, everything else is YAML. CLI flags override + --input values. Examples: tsr generate "a cat on a windowsill" @@ -826,31 +893,14 @@ def generate( # noqa: PLR0915 tsr generate "cyberpunk city" -o output.png --count 4 tsr generate "landscape" --remote junkpile tsr generate --input '{"prompt": "a mech", "model": "flux1-dev-fp8.safetensors"}' + tsr generate --input scene.yml tsr generate "raw prompt" --no-quality --no-negative """ - # ---- JSON input merging ---- + # ---- --input merging (JSON or YAML) ---- if json_input is not None: - # Support file paths and raw JSON strings - json_path = Path(json_input) - if json_path.is_file(): - json_text = json_path.read_text() - elif json_input.lstrip().startswith("{"): - json_text = json_input - else: - console.print(f"[red]Not a JSON string or file:[/red] {json_input}") - raise typer.Exit(1) + ji = _parse_generate_input(json_input) - try: - ji = json.loads(json_text) - except json.JSONDecodeError as e: - console.print(f"[red]Invalid JSON input:[/red] {e}") - raise typer.Exit(1) from e - - if not isinstance(ji, dict): - console.print("[red]JSON input must be an object[/red]") - raise typer.Exit(1) - - # Map JSON keys to parameter names (handle aliases) + # Map source keys to parameter names (handle aliases) key_map = {"negative_prompt": "negative", "lora_name": "lora"} mapped: dict[str, Any] = {} for k, v in ji.items(): diff --git a/tests/test_generate_input.py b/tests/test_generate_input.py new file mode 100644 index 0000000..00e9527 --- /dev/null +++ b/tests/test_generate_input.py @@ -0,0 +1,218 @@ +"""Tests for the ``tsr generate --input`` JSON/YAML parser. + +Covers the :func:`tensors.cli._parse_generate_input` helper directly (unit +level) and the end-to-end integration through the ``generate`` Typer command +(with ``_run_generation`` patched so nothing hits ComfyUI). +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +import pytest +import typer +from typer.testing import CliRunner + +from tensors import cli as cli_module +from tensors.cli import _parse_generate_input, app + +runner = CliRunner() + + +# ----------------------------------------------------------------------------- +# Unit tests: _parse_generate_input +# ----------------------------------------------------------------------------- + + +class TestParseGenerateInputInline: + """Inline string arguments (not file paths).""" + + def test_inline_json_object(self) -> None: + out = _parse_generate_input('{"prompt": "hi", "steps": 30}') + assert out == {"prompt": "hi", "steps": 30} + + def test_inline_yaml_mapping(self) -> None: + out = _parse_generate_input("prompt: hi\nsteps: 30\n") + assert out == {"prompt": "hi", "steps": 30} + + def test_inline_yaml_with_list(self) -> None: + out = _parse_generate_input("prompt: x\nscene:\n - foo\n - bar\n") + assert out == {"prompt": "x", "scene": ["foo", "bar"]} + + def test_inline_json_with_leading_whitespace(self) -> None: + out = _parse_generate_input(' {"prompt": "hi"}') + assert out == {"prompt": "hi"} + + def test_inline_non_mapping_yaml_rejected(self) -> None: + with pytest.raises(typer.Exit): + _parse_generate_input("- just\n- a list\n") + + def test_inline_non_mapping_json_rejected(self) -> None: + with pytest.raises(typer.Exit): + _parse_generate_input("[1, 2, 3]") + + def test_inline_invalid_yaml_rejected(self) -> None: + with pytest.raises(typer.Exit): + _parse_generate_input("prompt: [unterminated\n") + + def test_inline_invalid_json_falls_to_yaml_and_fails(self) -> None: + # Starts with '{' so JSON path is taken; malformed → Exit. + with pytest.raises(typer.Exit): + _parse_generate_input('{"prompt": "missing-close"') + + +class TestParseGenerateInputFiles: + """File path arguments resolved by extension.""" + + def test_json_file_by_extension(self, tmp_path: Path) -> None: + p = tmp_path / "scene.json" + p.write_text(json.dumps({"prompt": "from-json", "steps": 20})) + assert _parse_generate_input(str(p)) == {"prompt": "from-json", "steps": 20} + + def test_yaml_file_dot_yml(self, tmp_path: Path) -> None: + p = tmp_path / "scene.yml" + p.write_text("prompt: from-yml\nsteps: 25\n") + assert _parse_generate_input(str(p)) == {"prompt": "from-yml", "steps": 25} + + def test_yaml_file_dot_yaml(self, tmp_path: Path) -> None: + p = tmp_path / "scene.yaml" + p.write_text("prompt: from-yaml\n") + assert _parse_generate_input(str(p)) == {"prompt": "from-yaml"} + + def test_unknown_extension_sniffs_json(self, tmp_path: Path) -> None: + p = tmp_path / "scene.txt" + p.write_text('{"prompt": "sniffed"}') + assert _parse_generate_input(str(p)) == {"prompt": "sniffed"} + + def test_unknown_extension_sniffs_yaml(self, tmp_path: Path) -> None: + p = tmp_path / "scene.txt" + p.write_text("prompt: sniffed-yaml\n") + assert _parse_generate_input(str(p)) == {"prompt": "sniffed-yaml"} + + def test_yaml_file_with_full_draw_template(self, tmp_path: Path) -> None: + """Smoke test against the exact shape used by ~/Projects/draw/templates/.""" + p = tmp_path / "scene.yml" + p.write_text( + 'prompt: ""\n' + 'negative_prompt: ""\n' + 'model: "getphatFLUXReality_v5Hardcore.safetensors"\n' + "width: 832\n" + "height: 1216\n" + "steps: 35\n" + "cfg: 1.0\n" + "guidance: 4.0\n" + 'sampler: "dpmpp_2m"\n' + 'scheduler: "sgm_uniform"\n' + 'vae: "ae.safetensors"\n' + 'orientation: "portrait"\n' + "seed: -1\n" + "count: 1\n" + "scene:\n" + ' - "first element with embedded \\nnewline"\n' + ' - "second element"\n' + '_scene_name: "demo_01"\n' + '_family: "flux_unet"\n' + '_base_model: "Flux.1 D"\n' + ) + out = _parse_generate_input(str(p)) + assert out["model"] == "getphatFLUXReality_v5Hardcore.safetensors" + assert out["width"] == 832 + assert out["height"] == 1216 + assert out["steps"] == 35 + assert out["cfg"] == 1.0 + assert out["guidance"] == 4.0 + assert out["sampler"] == "dpmpp_2m" + assert out["scheduler"] == "sgm_uniform" + assert out["vae"] == "ae.safetensors" + assert out["orientation"] == "portrait" + assert out["seed"] == -1 + assert out["count"] == 1 + assert isinstance(out["scene"], list) + assert len(out["scene"]) == 2 + assert "embedded" in out["scene"][0] + + def test_malformed_yaml_file_rejected(self, tmp_path: Path) -> None: + p = tmp_path / "bad.yml" + p.write_text("prompt: [unterminated\n") + with pytest.raises(typer.Exit): + _parse_generate_input(str(p)) + + +# ----------------------------------------------------------------------------- +# Integration: generate --input through Typer +# ----------------------------------------------------------------------------- + + +@pytest.fixture +def captured(monkeypatch: pytest.MonkeyPatch) -> dict[str, Any]: + """Capture _run_generation kwargs without dispatching to ComfyUI.""" + sink: dict[str, Any] = {} + + def fake_run_generation(**kwargs: Any) -> None: + sink.update(kwargs) + + monkeypatch.setattr(cli_module, "_run_generation", fake_run_generation) + return sink + + +def test_generate_consumes_yaml_file(tmp_path: Path, captured: dict[str, Any]) -> None: + """``tsr generate --input scene.yml`` plumbs YAML values through.""" + yml = tmp_path / "scene.yml" + yml.write_text( + "prompt: a sunset\n" + 'model: "fluxmodel.safetensors"\n' + "steps: 28\n" + "scene:\n" + ' - "golden hour"\n' + ' - "wide angle"\n' + ) + result = runner.invoke(app, ["generate", "--input", str(yml)]) + assert result.exit_code == 0, result.output + assert captured["prompt"] == "a sunset" + assert captured["model"] == "fluxmodel.safetensors" + assert captured["steps"] == 28 + # YAML list under `scene` is joined into scene_prompt by existing logic. + assert captured["scene_prompt"] == "golden hour, wide angle" + + +def test_generate_yaml_then_cli_flag_wins(tmp_path: Path, captured: dict[str, Any]) -> None: + """Explicit CLI flags must override --input values (same contract as JSON).""" + yml = tmp_path / "scene.yml" + yml.write_text('prompt: from-yaml\nmodel: "yamlmodel.safetensors"\nsteps: 10\n') + result = runner.invoke(app, ["generate", "--input", str(yml), "--steps", "99"]) + assert result.exit_code == 0, result.output + assert captured["prompt"] == "from-yaml" + assert captured["model"] == "yamlmodel.safetensors" + assert captured["steps"] == 99 # CLI override wins + + +def test_generate_inline_yaml_string(captured: dict[str, Any]) -> None: + result = runner.invoke( + app, + ["generate", "--input", "prompt: inline-yaml\nmodel: m.safetensors\n"], + ) + assert result.exit_code == 0, result.output + assert captured["prompt"] == "inline-yaml" + assert captured["model"] == "m.safetensors" + + +def test_generate_inline_json_still_works(captured: dict[str, Any]) -> None: + """Regression guard for the original JSON contract.""" + result = runner.invoke( + app, + ["generate", "--input", '{"prompt": "inline-json", "model": "j.safetensors"}'], + ) + assert result.exit_code == 0, result.output + assert captured["prompt"] == "inline-json" + assert captured["model"] == "j.safetensors" + + +def test_generate_invalid_yaml_file_exits_nonzero(tmp_path: Path, captured: dict[str, Any]) -> None: + yml = tmp_path / "bad.yml" + yml.write_text("prompt: [oops\n") + result = runner.invoke(app, ["generate", "--input", str(yml)]) + assert result.exit_code != 0 + assert "Invalid YAML input" in result.output + assert captured == {} diff --git a/uv.lock b/uv.lock index 8de96ea..0981519 100644 --- a/uv.lock +++ b/uv.lock @@ -886,6 +886,7 @@ source = { editable = "." } dependencies = [ { name = "httpx" }, { name = "huggingface-hub" }, + { name = "pyyaml" }, { name = "rich" }, { name = "safetensors" }, { name = "sqlmodel" }, @@ -924,6 +925,7 @@ requires-dist = [ { name = "httpx", specifier = ">=0.27.0" }, { name = "huggingface-hub", specifier = ">=0.25.0" }, { name = "python-multipart", marker = "extra == 'server'", specifier = ">=0.0.9" }, + { name = "pyyaml", specifier = ">=6.0" }, { name = "rich", specifier = ">=13.0.0" }, { name = "safetensors", specifier = ">=0.4.0" }, { name = "scalar-fastapi", marker = "extra == 'server'", specifier = ">=1.6" },