Merge pull request #1 from saiden-dev/feat/generate-yaml-input
feat(generate): accept YAML in --input alongside JSON
This commit is contained in:
@@ -12,6 +12,7 @@ dependencies = [
|
|||||||
"websocket-client>=1.9.0",
|
"websocket-client>=1.9.0",
|
||||||
"huggingface_hub>=0.25.0",
|
"huggingface_hub>=0.25.0",
|
||||||
"sqlmodel>=0.0.33",
|
"sqlmodel>=0.0.33",
|
||||||
|
"pyyaml>=6.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
@@ -113,6 +114,10 @@ ignore_missing_imports = true
|
|||||||
module = ["websockets.*"]
|
module = ["websockets.*"]
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = ["yaml.*"]
|
||||||
|
ignore_missing_imports = true
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
addopts = "-v --cov=tensors --cov-report=term-missing"
|
addopts = "-v --cov=tensors --cov-report=term-missing"
|
||||||
|
|||||||
+73
-23
@@ -80,6 +80,66 @@ MAX_QUEUE_DISPLAY = 10
|
|||||||
MAX_MODEL_LIST_DISPLAY = 20
|
MAX_MODEL_LIST_DISPLAY = 20
|
||||||
MAX_PROMPT_ID_DISPLAY = 36
|
MAX_PROMPT_ID_DISPLAY = 36
|
||||||
|
|
||||||
|
# File extensions that force YAML parsing for `tsr generate --input <file>`.
|
||||||
|
_YAML_INPUT_EXTENSIONS = frozenset({".yml", ".yaml"})
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_generate_input(value: str) -> dict[str, Any]:
|
||||||
|
"""Parse a ``--input`` argument into a dict of generation params.
|
||||||
|
|
||||||
|
Accepts either:
|
||||||
|
* a path to a ``.json`` / ``.yml`` / ``.yaml`` file,
|
||||||
|
* a raw JSON object string (``{"prompt": ...}``),
|
||||||
|
* a raw YAML document string (anything else that doesn't start with ``{``).
|
||||||
|
|
||||||
|
File extension wins when reading from disk. For inline strings we try JSON
|
||||||
|
first (current behaviour) and fall back to YAML so existing callers keep
|
||||||
|
working without surprises.
|
||||||
|
|
||||||
|
Raises ``typer.Exit(1)`` with a rich error on every failure path so callers
|
||||||
|
don't need to repeat the diagnostics.
|
||||||
|
"""
|
||||||
|
import yaml # noqa: PLC0415 — keep yaml a soft import path
|
||||||
|
|
||||||
|
# ---- locate source text + decide format ----
|
||||||
|
path = Path(value)
|
||||||
|
if path.is_file():
|
||||||
|
text = path.read_text()
|
||||||
|
suffix = path.suffix.lower()
|
||||||
|
if suffix in _YAML_INPUT_EXTENSIONS:
|
||||||
|
fmt = "yaml"
|
||||||
|
elif suffix == ".json":
|
||||||
|
fmt = "json"
|
||||||
|
else:
|
||||||
|
# Unknown extension: sniff the content. Leading '{' or '[' → JSON.
|
||||||
|
fmt = "json" if text.lstrip().startswith(("{", "[")) else "yaml"
|
||||||
|
elif value.lstrip().startswith("{"):
|
||||||
|
text = value
|
||||||
|
fmt = "json"
|
||||||
|
else:
|
||||||
|
# Last-resort: treat as inline YAML. This is unusual but lets the user
|
||||||
|
# pass ``--input 'prompt: foo\nmodel: bar.safetensors'`` without quoting
|
||||||
|
# a JSON object on the shell.
|
||||||
|
text = value
|
||||||
|
fmt = "yaml"
|
||||||
|
|
||||||
|
# ---- parse ----
|
||||||
|
parsed: Any
|
||||||
|
try:
|
||||||
|
parsed = json.loads(text) if fmt == "json" else yaml.safe_load(text)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
console.print(f"[red]Invalid JSON input:[/red] {e}")
|
||||||
|
raise typer.Exit(1) from e
|
||||||
|
except yaml.YAMLError as e:
|
||||||
|
console.print(f"[red]Invalid YAML input:[/red] {e}")
|
||||||
|
raise typer.Exit(1) from e
|
||||||
|
|
||||||
|
if not isinstance(parsed, dict):
|
||||||
|
console.print(f"[red]{fmt.upper()} input must be a mapping/object[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
def _cache_model_quietly(model_data: dict[str, Any]) -> None:
|
def _cache_model_quietly(model_data: dict[str, Any]) -> None:
|
||||||
"""Cache model data to database without output."""
|
"""Cache model data to database without output."""
|
||||||
@@ -810,7 +870,10 @@ def generate( # noqa: PLR0915
|
|||||||
output: Annotated[Path | None, typer.Option("-o", "--output", help="Save path (default: current dir)")] = None,
|
output: Annotated[Path | None, typer.Option("-o", "--output", help="Save path (default: current dir)")] = None,
|
||||||
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
remote: Annotated[str | None, typer.Option("-r", "--remote", help="Remote server name or URL")] = None,
|
||||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
json_input: Annotated[str | None, typer.Option("--input", "-I", help="JSON params (keys match CLI options)")] = None,
|
json_input: Annotated[
|
||||||
|
str | None,
|
||||||
|
typer.Option("--input", "-I", help="JSON or YAML params (file path or inline; keys match CLI options)"),
|
||||||
|
] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Generate an image using text-to-image.
|
"""Generate an image using text-to-image.
|
||||||
|
|
||||||
@@ -818,7 +881,11 @@ def generate( # noqa: PLR0915
|
|||||||
model family. All auto-detected values can be overridden with explicit flags.
|
model family. All auto-detected values can be overridden with explicit flags.
|
||||||
|
|
||||||
Calls ComfyUI directly when local, or the remote tensors API when --remote is given.
|
Calls ComfyUI directly when local, or the remote tensors API when --remote is given.
|
||||||
Accepts --input with a JSON object whose keys match CLI option names. CLI flags override JSON values.
|
Accepts --input with a JSON or YAML object whose keys match CLI option names.
|
||||||
|
Files ending in ``.yml`` / ``.yaml`` are parsed as YAML; ``.json`` (or any
|
||||||
|
other extension whose contents start with ``{``/``[``) as JSON. Inline strings
|
||||||
|
starting with ``{`` are JSON, everything else is YAML. CLI flags override
|
||||||
|
--input values.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
tsr generate "a cat on a windowsill"
|
tsr generate "a cat on a windowsill"
|
||||||
@@ -826,31 +893,14 @@ def generate( # noqa: PLR0915
|
|||||||
tsr generate "cyberpunk city" -o output.png --count 4
|
tsr generate "cyberpunk city" -o output.png --count 4
|
||||||
tsr generate "landscape" --remote junkpile
|
tsr generate "landscape" --remote junkpile
|
||||||
tsr generate --input '{"prompt": "a mech", "model": "flux1-dev-fp8.safetensors"}'
|
tsr generate --input '{"prompt": "a mech", "model": "flux1-dev-fp8.safetensors"}'
|
||||||
|
tsr generate --input scene.yml
|
||||||
tsr generate "raw prompt" --no-quality --no-negative
|
tsr generate "raw prompt" --no-quality --no-negative
|
||||||
"""
|
"""
|
||||||
# ---- JSON input merging ----
|
# ---- --input merging (JSON or YAML) ----
|
||||||
if json_input is not None:
|
if json_input is not None:
|
||||||
# Support file paths and raw JSON strings
|
ji = _parse_generate_input(json_input)
|
||||||
json_path = Path(json_input)
|
|
||||||
if json_path.is_file():
|
|
||||||
json_text = json_path.read_text()
|
|
||||||
elif json_input.lstrip().startswith("{"):
|
|
||||||
json_text = json_input
|
|
||||||
else:
|
|
||||||
console.print(f"[red]Not a JSON string or file:[/red] {json_input}")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
try:
|
# Map source keys to parameter names (handle aliases)
|
||||||
ji = json.loads(json_text)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
console.print(f"[red]Invalid JSON input:[/red] {e}")
|
|
||||||
raise typer.Exit(1) from e
|
|
||||||
|
|
||||||
if not isinstance(ji, dict):
|
|
||||||
console.print("[red]JSON input must be an object[/red]")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
# Map JSON keys to parameter names (handle aliases)
|
|
||||||
key_map = {"negative_prompt": "negative", "lora_name": "lora"}
|
key_map = {"negative_prompt": "negative", "lora_name": "lora"}
|
||||||
mapped: dict[str, Any] = {}
|
mapped: dict[str, Any] = {}
|
||||||
for k, v in ji.items():
|
for k, v in ji.items():
|
||||||
|
|||||||
@@ -0,0 +1,218 @@
|
|||||||
|
"""Tests for the ``tsr generate --input`` JSON/YAML parser.
|
||||||
|
|
||||||
|
Covers the :func:`tensors.cli._parse_generate_input` helper directly (unit
|
||||||
|
level) and the end-to-end integration through the ``generate`` Typer command
|
||||||
|
(with ``_run_generation`` patched so nothing hits ComfyUI).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import typer
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
from tensors import cli as cli_module
|
||||||
|
from tensors.cli import _parse_generate_input, app
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Unit tests: _parse_generate_input
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseGenerateInputInline:
|
||||||
|
"""Inline string arguments (not file paths)."""
|
||||||
|
|
||||||
|
def test_inline_json_object(self) -> None:
|
||||||
|
out = _parse_generate_input('{"prompt": "hi", "steps": 30}')
|
||||||
|
assert out == {"prompt": "hi", "steps": 30}
|
||||||
|
|
||||||
|
def test_inline_yaml_mapping(self) -> None:
|
||||||
|
out = _parse_generate_input("prompt: hi\nsteps: 30\n")
|
||||||
|
assert out == {"prompt": "hi", "steps": 30}
|
||||||
|
|
||||||
|
def test_inline_yaml_with_list(self) -> None:
|
||||||
|
out = _parse_generate_input("prompt: x\nscene:\n - foo\n - bar\n")
|
||||||
|
assert out == {"prompt": "x", "scene": ["foo", "bar"]}
|
||||||
|
|
||||||
|
def test_inline_json_with_leading_whitespace(self) -> None:
|
||||||
|
out = _parse_generate_input(' {"prompt": "hi"}')
|
||||||
|
assert out == {"prompt": "hi"}
|
||||||
|
|
||||||
|
def test_inline_non_mapping_yaml_rejected(self) -> None:
|
||||||
|
with pytest.raises(typer.Exit):
|
||||||
|
_parse_generate_input("- just\n- a list\n")
|
||||||
|
|
||||||
|
def test_inline_non_mapping_json_rejected(self) -> None:
|
||||||
|
with pytest.raises(typer.Exit):
|
||||||
|
_parse_generate_input("[1, 2, 3]")
|
||||||
|
|
||||||
|
def test_inline_invalid_yaml_rejected(self) -> None:
|
||||||
|
with pytest.raises(typer.Exit):
|
||||||
|
_parse_generate_input("prompt: [unterminated\n")
|
||||||
|
|
||||||
|
def test_inline_invalid_json_falls_to_yaml_and_fails(self) -> None:
|
||||||
|
# Starts with '{' so JSON path is taken; malformed → Exit.
|
||||||
|
with pytest.raises(typer.Exit):
|
||||||
|
_parse_generate_input('{"prompt": "missing-close"')
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseGenerateInputFiles:
|
||||||
|
"""File path arguments resolved by extension."""
|
||||||
|
|
||||||
|
def test_json_file_by_extension(self, tmp_path: Path) -> None:
|
||||||
|
p = tmp_path / "scene.json"
|
||||||
|
p.write_text(json.dumps({"prompt": "from-json", "steps": 20}))
|
||||||
|
assert _parse_generate_input(str(p)) == {"prompt": "from-json", "steps": 20}
|
||||||
|
|
||||||
|
def test_yaml_file_dot_yml(self, tmp_path: Path) -> None:
|
||||||
|
p = tmp_path / "scene.yml"
|
||||||
|
p.write_text("prompt: from-yml\nsteps: 25\n")
|
||||||
|
assert _parse_generate_input(str(p)) == {"prompt": "from-yml", "steps": 25}
|
||||||
|
|
||||||
|
def test_yaml_file_dot_yaml(self, tmp_path: Path) -> None:
|
||||||
|
p = tmp_path / "scene.yaml"
|
||||||
|
p.write_text("prompt: from-yaml\n")
|
||||||
|
assert _parse_generate_input(str(p)) == {"prompt": "from-yaml"}
|
||||||
|
|
||||||
|
def test_unknown_extension_sniffs_json(self, tmp_path: Path) -> None:
|
||||||
|
p = tmp_path / "scene.txt"
|
||||||
|
p.write_text('{"prompt": "sniffed"}')
|
||||||
|
assert _parse_generate_input(str(p)) == {"prompt": "sniffed"}
|
||||||
|
|
||||||
|
def test_unknown_extension_sniffs_yaml(self, tmp_path: Path) -> None:
|
||||||
|
p = tmp_path / "scene.txt"
|
||||||
|
p.write_text("prompt: sniffed-yaml\n")
|
||||||
|
assert _parse_generate_input(str(p)) == {"prompt": "sniffed-yaml"}
|
||||||
|
|
||||||
|
def test_yaml_file_with_full_draw_template(self, tmp_path: Path) -> None:
|
||||||
|
"""Smoke test against the exact shape used by ~/Projects/draw/templates/."""
|
||||||
|
p = tmp_path / "scene.yml"
|
||||||
|
p.write_text(
|
||||||
|
'prompt: ""\n'
|
||||||
|
'negative_prompt: ""\n'
|
||||||
|
'model: "getphatFLUXReality_v5Hardcore.safetensors"\n'
|
||||||
|
"width: 832\n"
|
||||||
|
"height: 1216\n"
|
||||||
|
"steps: 35\n"
|
||||||
|
"cfg: 1.0\n"
|
||||||
|
"guidance: 4.0\n"
|
||||||
|
'sampler: "dpmpp_2m"\n'
|
||||||
|
'scheduler: "sgm_uniform"\n'
|
||||||
|
'vae: "ae.safetensors"\n'
|
||||||
|
'orientation: "portrait"\n'
|
||||||
|
"seed: -1\n"
|
||||||
|
"count: 1\n"
|
||||||
|
"scene:\n"
|
||||||
|
' - "first element with embedded \\nnewline"\n'
|
||||||
|
' - "second element"\n'
|
||||||
|
'_scene_name: "demo_01"\n'
|
||||||
|
'_family: "flux_unet"\n'
|
||||||
|
'_base_model: "Flux.1 D"\n'
|
||||||
|
)
|
||||||
|
out = _parse_generate_input(str(p))
|
||||||
|
assert out["model"] == "getphatFLUXReality_v5Hardcore.safetensors"
|
||||||
|
assert out["width"] == 832
|
||||||
|
assert out["height"] == 1216
|
||||||
|
assert out["steps"] == 35
|
||||||
|
assert out["cfg"] == 1.0
|
||||||
|
assert out["guidance"] == 4.0
|
||||||
|
assert out["sampler"] == "dpmpp_2m"
|
||||||
|
assert out["scheduler"] == "sgm_uniform"
|
||||||
|
assert out["vae"] == "ae.safetensors"
|
||||||
|
assert out["orientation"] == "portrait"
|
||||||
|
assert out["seed"] == -1
|
||||||
|
assert out["count"] == 1
|
||||||
|
assert isinstance(out["scene"], list)
|
||||||
|
assert len(out["scene"]) == 2
|
||||||
|
assert "embedded" in out["scene"][0]
|
||||||
|
|
||||||
|
def test_malformed_yaml_file_rejected(self, tmp_path: Path) -> None:
|
||||||
|
p = tmp_path / "bad.yml"
|
||||||
|
p.write_text("prompt: [unterminated\n")
|
||||||
|
with pytest.raises(typer.Exit):
|
||||||
|
_parse_generate_input(str(p))
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Integration: generate --input through Typer
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def captured(monkeypatch: pytest.MonkeyPatch) -> dict[str, Any]:
|
||||||
|
"""Capture _run_generation kwargs without dispatching to ComfyUI."""
|
||||||
|
sink: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def fake_run_generation(**kwargs: Any) -> None:
|
||||||
|
sink.update(kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(cli_module, "_run_generation", fake_run_generation)
|
||||||
|
return sink
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_consumes_yaml_file(tmp_path: Path, captured: dict[str, Any]) -> None:
|
||||||
|
"""``tsr generate --input scene.yml`` plumbs YAML values through."""
|
||||||
|
yml = tmp_path / "scene.yml"
|
||||||
|
yml.write_text(
|
||||||
|
"prompt: a sunset\n"
|
||||||
|
'model: "fluxmodel.safetensors"\n'
|
||||||
|
"steps: 28\n"
|
||||||
|
"scene:\n"
|
||||||
|
' - "golden hour"\n'
|
||||||
|
' - "wide angle"\n'
|
||||||
|
)
|
||||||
|
result = runner.invoke(app, ["generate", "--input", str(yml)])
|
||||||
|
assert result.exit_code == 0, result.output
|
||||||
|
assert captured["prompt"] == "a sunset"
|
||||||
|
assert captured["model"] == "fluxmodel.safetensors"
|
||||||
|
assert captured["steps"] == 28
|
||||||
|
# YAML list under `scene` is joined into scene_prompt by existing logic.
|
||||||
|
assert captured["scene_prompt"] == "golden hour, wide angle"
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_yaml_then_cli_flag_wins(tmp_path: Path, captured: dict[str, Any]) -> None:
|
||||||
|
"""Explicit CLI flags must override --input values (same contract as JSON)."""
|
||||||
|
yml = tmp_path / "scene.yml"
|
||||||
|
yml.write_text('prompt: from-yaml\nmodel: "yamlmodel.safetensors"\nsteps: 10\n')
|
||||||
|
result = runner.invoke(app, ["generate", "--input", str(yml), "--steps", "99"])
|
||||||
|
assert result.exit_code == 0, result.output
|
||||||
|
assert captured["prompt"] == "from-yaml"
|
||||||
|
assert captured["model"] == "yamlmodel.safetensors"
|
||||||
|
assert captured["steps"] == 99 # CLI override wins
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_inline_yaml_string(captured: dict[str, Any]) -> None:
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
["generate", "--input", "prompt: inline-yaml\nmodel: m.safetensors\n"],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0, result.output
|
||||||
|
assert captured["prompt"] == "inline-yaml"
|
||||||
|
assert captured["model"] == "m.safetensors"
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_inline_json_still_works(captured: dict[str, Any]) -> None:
|
||||||
|
"""Regression guard for the original JSON contract."""
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
["generate", "--input", '{"prompt": "inline-json", "model": "j.safetensors"}'],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0, result.output
|
||||||
|
assert captured["prompt"] == "inline-json"
|
||||||
|
assert captured["model"] == "j.safetensors"
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_invalid_yaml_file_exits_nonzero(tmp_path: Path, captured: dict[str, Any]) -> None:
|
||||||
|
yml = tmp_path / "bad.yml"
|
||||||
|
yml.write_text("prompt: [oops\n")
|
||||||
|
result = runner.invoke(app, ["generate", "--input", str(yml)])
|
||||||
|
assert result.exit_code != 0
|
||||||
|
assert "Invalid YAML input" in result.output
|
||||||
|
assert captured == {}
|
||||||
@@ -886,6 +886,7 @@ source = { editable = "." }
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "httpx" },
|
{ name = "httpx" },
|
||||||
{ name = "huggingface-hub" },
|
{ name = "huggingface-hub" },
|
||||||
|
{ name = "pyyaml" },
|
||||||
{ name = "rich" },
|
{ name = "rich" },
|
||||||
{ name = "safetensors" },
|
{ name = "safetensors" },
|
||||||
{ name = "sqlmodel" },
|
{ name = "sqlmodel" },
|
||||||
@@ -924,6 +925,7 @@ requires-dist = [
|
|||||||
{ name = "httpx", specifier = ">=0.27.0" },
|
{ name = "httpx", specifier = ">=0.27.0" },
|
||||||
{ name = "huggingface-hub", specifier = ">=0.25.0" },
|
{ name = "huggingface-hub", specifier = ">=0.25.0" },
|
||||||
{ name = "python-multipart", marker = "extra == 'server'", specifier = ">=0.0.9" },
|
{ name = "python-multipart", marker = "extra == 'server'", specifier = ">=0.0.9" },
|
||||||
|
{ name = "pyyaml", specifier = ">=6.0" },
|
||||||
{ name = "rich", specifier = ">=13.0.0" },
|
{ name = "rich", specifier = ">=13.0.0" },
|
||||||
{ name = "safetensors", specifier = ">=0.4.0" },
|
{ name = "safetensors", specifier = ">=0.4.0" },
|
||||||
{ name = "scalar-fastapi", marker = "extra == 'server'", specifier = ">=1.6" },
|
{ name = "scalar-fastapi", marker = "extra == 'server'", specifier = ">=1.6" },
|
||||||
|
|||||||
Reference in New Issue
Block a user