format: auto-format code [skip ci]
This commit is contained in:
+1
-3
@@ -1456,9 +1456,7 @@ def style_sweep( # noqa: PLR0915
|
||||
# Template is required for generation, but optional when --list is paired
|
||||
# with an explicit --styles source.
|
||||
if template is None and not (list_styles and styles is not None):
|
||||
console.print(
|
||||
"[red]--template is required (or use --list with --styles to inspect a styles file)[/red]"
|
||||
)
|
||||
console.print("[red]--template is required (or use --list with --styles to inspect a styles file)[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# ---- Load template (if provided) ----
|
||||
|
||||
+6
-6
@@ -739,12 +739,12 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||
# Add new patterns here as we encounter them — order doesn't matter, first
|
||||
# match wins.
|
||||
FLUX_UNET_ONLY_PATTERNS: tuple[str, ...] = (
|
||||
"lust_", # lust_v10.safetensors (Flux.2 Klein 9B-base)
|
||||
"lust_", # lust_v10.safetensors (Flux.2 Klein 9B-base)
|
||||
# Note: bare "lust" would falsely match "illustrious" — keep the underscore.
|
||||
"cyberrealisticflux", # cyberrealisticFlux_v25.safetensors
|
||||
"getphatflux", # getphatFLUXReality_v11Softcore.safetensors
|
||||
"moodydesire", # moodyDesireMix_v20PRO.safetensors
|
||||
"fcfluxpony", # fcFluxPonyPerfectBase_fcFluxPerfectBase.safetensors
|
||||
"getphatflux", # getphatFLUXReality_v11Softcore.safetensors
|
||||
"moodydesire", # moodyDesireMix_v20PRO.safetensors
|
||||
"fcfluxpony", # fcFluxPonyPerfectBase_fcFluxPerfectBase.safetensors
|
||||
)
|
||||
|
||||
|
||||
@@ -759,8 +759,8 @@ def _is_flux_unet_only(name_lower: str) -> bool:
|
||||
# filename patterns are a fallback for checkpoints with missing/wrong DB
|
||||
# metadata. Filename match wins over base_model.
|
||||
FLUX2_KLEIN_PATTERNS: tuple[str, ...] = (
|
||||
"lust_", # lust_v10.safetensors
|
||||
"moodydesire", # moodyDesireMix_v20PRO.safetensors
|
||||
"lust_", # lust_v10.safetensors
|
||||
"moodydesire", # moodyDesireMix_v20PRO.safetensors
|
||||
)
|
||||
|
||||
|
||||
|
||||
+21
-38
@@ -130,9 +130,7 @@ def test_limit(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
|
||||
out_dir = tmp_path / "out"
|
||||
styles_file = _write_styles_file(
|
||||
tmp_path,
|
||||
[
|
||||
{"slug": f"{i:02d}-style", "suffix": f"style {i}"} for i in range(1, 6)
|
||||
],
|
||||
[{"slug": f"{i:02d}-style", "suffix": f"style {i}"} for i in range(1, 6)],
|
||||
)
|
||||
tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file))
|
||||
|
||||
@@ -313,9 +311,7 @@ def test_cli_output_dir_overrides_template(tmp_path: Path, calls: list[dict[str,
|
||||
styles_file = _write_styles_file(tmp_path, [{"slug": "x", "suffix": "X"}])
|
||||
tpl = _write_template(tmp_path, output_dir=tpl_out, styles=str(styles_file))
|
||||
|
||||
result = runner.invoke(
|
||||
app, ["style-sweep", "--template", str(tpl), "--output-dir", str(cli_out)]
|
||||
)
|
||||
result = runner.invoke(app, ["style-sweep", "--template", str(tpl), "--output-dir", str(cli_out)])
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
assert calls[0]["output"] == cli_out / "x.png"
|
||||
@@ -329,9 +325,7 @@ def test_remote_override(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
|
||||
styles_file = _write_styles_file(tmp_path, [{"slug": "x", "suffix": "X"}])
|
||||
tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file))
|
||||
|
||||
result = runner.invoke(
|
||||
app, ["style-sweep", "--template", str(tpl), "--remote", "junkpile"]
|
||||
)
|
||||
result = runner.invoke(app, ["style-sweep", "--template", str(tpl), "--remote", "junkpile"])
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
assert calls[0]["remote"] == "junkpile"
|
||||
@@ -346,9 +340,7 @@ def test_list_flag_prints_slugs(tmp_path: Path, calls: list[dict[str, Any]]) ->
|
||||
"""--list prints all slugs and does not call generate."""
|
||||
out_dir = tmp_path / "out"
|
||||
slugs = [f"{i:02d}-style" for i in range(1, 5)]
|
||||
styles_file = _write_styles_file(
|
||||
tmp_path, [{"slug": s, "suffix": f"suffix for {s}"} for s in slugs]
|
||||
)
|
||||
styles_file = _write_styles_file(tmp_path, [{"slug": s, "suffix": f"suffix for {s}"} for s in slugs])
|
||||
tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file))
|
||||
|
||||
result = runner.invoke(app, ["style-sweep", "--template", str(tpl), "--list"])
|
||||
@@ -366,14 +358,10 @@ def test_list_flag_prints_slugs(tmp_path: Path, calls: list[dict[str, Any]]) ->
|
||||
def test_list_with_limit(tmp_path: Path, calls: list[dict[str, Any]]) -> None:
|
||||
"""--list --limit N restricts the table to the first N entries."""
|
||||
out_dir = tmp_path / "out"
|
||||
styles_file = _write_styles_file(
|
||||
tmp_path, [{"slug": f"{i:02d}-x", "suffix": f"s{i}"} for i in range(1, 6)]
|
||||
)
|
||||
styles_file = _write_styles_file(tmp_path, [{"slug": f"{i:02d}-x", "suffix": f"s{i}"} for i in range(1, 6)])
|
||||
tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file))
|
||||
|
||||
result = runner.invoke(
|
||||
app, ["style-sweep", "--template", str(tpl), "--list", "--limit", "2"]
|
||||
)
|
||||
result = runner.invoke(app, ["style-sweep", "--template", str(tpl), "--list", "--limit", "2"])
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
assert "01-x" in result.output
|
||||
@@ -392,9 +380,7 @@ def test_list_without_template(tmp_path: Path, calls: list[dict[str, Any]]) -> N
|
||||
],
|
||||
)
|
||||
|
||||
result = runner.invoke(
|
||||
app, ["style-sweep", "--styles", str(styles_file), "--list"]
|
||||
)
|
||||
result = runner.invoke(app, ["style-sweep", "--styles", str(styles_file), "--list"])
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
assert "alpha" in result.output
|
||||
@@ -405,13 +391,9 @@ def test_list_without_template(tmp_path: Path, calls: list[dict[str, Any]]) -> N
|
||||
def test_list_long_suffix_truncated(tmp_path: Path) -> None:
|
||||
"""Long suffixes are truncated with an ellipsis."""
|
||||
long_suffix = "very long " * 20 # ~200 chars
|
||||
styles_file = _write_styles_file(
|
||||
tmp_path, [{"slug": "long", "suffix": long_suffix}]
|
||||
)
|
||||
styles_file = _write_styles_file(tmp_path, [{"slug": "long", "suffix": long_suffix}])
|
||||
|
||||
result = runner.invoke(
|
||||
app, ["style-sweep", "--styles", str(styles_file), "--list"]
|
||||
)
|
||||
result = runner.invoke(app, ["style-sweep", "--styles", str(styles_file), "--list"])
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
assert "long" in result.output
|
||||
@@ -438,9 +420,7 @@ def test_style_filter_single(tmp_path: Path, calls: list[dict[str, Any]]) -> Non
|
||||
)
|
||||
tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file))
|
||||
|
||||
result = runner.invoke(
|
||||
app, ["style-sweep", "--template", str(tpl), "--style", "02-bar"]
|
||||
)
|
||||
result = runner.invoke(app, ["style-sweep", "--template", str(tpl), "--style", "02-bar"])
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
assert len(calls) == 1
|
||||
@@ -467,9 +447,12 @@ def test_style_filter_multiple(tmp_path: Path, calls: list[dict[str, Any]]) -> N
|
||||
app,
|
||||
[
|
||||
"style-sweep",
|
||||
"--template", str(tpl),
|
||||
"-S", "03-c",
|
||||
"-S", "01-a",
|
||||
"--template",
|
||||
str(tpl),
|
||||
"-S",
|
||||
"03-c",
|
||||
"-S",
|
||||
"01-a",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -491,9 +474,7 @@ def test_style_filter_unknown_slug(tmp_path: Path, calls: list[dict[str, Any]])
|
||||
)
|
||||
tpl = _write_template(tmp_path, output_dir=out_dir, styles=str(styles_file))
|
||||
|
||||
result = runner.invoke(
|
||||
app, ["style-sweep", "--template", str(tpl), "--style", "99-nope"]
|
||||
)
|
||||
result = runner.invoke(app, ["style-sweep", "--template", str(tpl), "--style", "99-nope"])
|
||||
|
||||
assert result.exit_code == 1, result.output
|
||||
assert "99-nope" in result.output
|
||||
@@ -519,9 +500,11 @@ def test_style_filter_with_list(tmp_path: Path, calls: list[dict[str, Any]]) ->
|
||||
app,
|
||||
[
|
||||
"style-sweep",
|
||||
"--styles", str(styles_file),
|
||||
"--styles",
|
||||
str(styles_file),
|
||||
"--list",
|
||||
"--style", "02-bar",
|
||||
"--style",
|
||||
"02-bar",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
+9
-29
@@ -358,24 +358,16 @@ class TestModelFamilyDetection:
|
||||
"""fcFluxPony*.safetensors → flux_unet (intercepts flux + fluxpony)."""
|
||||
from tensors.config import detect_model_family
|
||||
|
||||
assert (
|
||||
detect_model_family("fcFluxPonyPerfectBase_fcFluxPerfectBase.safetensors")
|
||||
== "flux_unet"
|
||||
)
|
||||
assert detect_model_family("fcFluxPonyPerfectBase_fcFluxPerfectBase.safetensors") == "flux_unet"
|
||||
|
||||
def test_detect_flux_unet_overrides_base_model(self) -> None:
|
||||
"""Filename UNet-only pattern wins over a (likely wrong) CivitAI base_model tag."""
|
||||
from tensors.config import detect_model_family
|
||||
|
||||
# cyberrealisticFlux: filename pattern wins over wrong "Pony" tag → flux_unet.
|
||||
assert (
|
||||
detect_model_family("cyberrealisticFlux_v25.safetensors", "Pony") == "flux_unet"
|
||||
)
|
||||
assert detect_model_family("cyberrealisticFlux_v25.safetensors", "Pony") == "flux_unet"
|
||||
# getphat: filename pattern wins over wrong "SDXL 1.0" tag → flux_unet.
|
||||
assert (
|
||||
detect_model_family("getphatFLUXReality_v11.safetensors", "SDXL 1.0")
|
||||
== "flux_unet"
|
||||
)
|
||||
assert detect_model_family("getphatFLUXReality_v11.safetensors", "SDXL 1.0") == "flux_unet"
|
||||
|
||||
def test_flux_unet_family_defaults_has_external_clip(self) -> None:
|
||||
"""flux_unet preset advertises external_clip + clip filenames."""
|
||||
@@ -750,9 +742,7 @@ class TestFluxUnetWorkflowBuilder:
|
||||
"""flux_unet locks KSampler.cfg to 1.0 and exposes the FluxGuidance dial."""
|
||||
from tensors.comfyui import _build_workflow
|
||||
|
||||
wf = _build_workflow(
|
||||
prompt="a cat", model="getphatFLUXReality_v11.safetensors", cfg=7.5
|
||||
)
|
||||
wf = _build_workflow(prompt="a cat", model="getphatFLUXReality_v11.safetensors", cfg=7.5)
|
||||
assert wf["160"]["inputs"]["cfg"] == 1.0
|
||||
# The caller's cfg=7.5 should re-route to FluxGuidance (same precedence as plain flux)
|
||||
assert wf["140"]["inputs"]["guidance"] == 7.5
|
||||
@@ -1178,9 +1168,7 @@ class TestValidateModelAvailable:
|
||||
},
|
||||
)
|
||||
with pytest.raises(typer.Exit) as exc:
|
||||
cli_module._validate_model_available(
|
||||
"fluxRealVision_v99.safetensors", family="flux", lora=None
|
||||
)
|
||||
cli_module._validate_model_available("fluxRealVision_v99.safetensors", family="flux", lora=None)
|
||||
assert exc.value.exit_code == 1
|
||||
|
||||
def test_unknown_model_in_diffusion_models_bucket(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@@ -1198,9 +1186,7 @@ class TestValidateModelAvailable:
|
||||
},
|
||||
)
|
||||
with pytest.raises(typer.Exit):
|
||||
cli_module._validate_model_available(
|
||||
"getphat_v99.safetensors", family="flux_unet", lora=None
|
||||
)
|
||||
cli_module._validate_model_available("getphat_v99.safetensors", family="flux_unet", lora=None)
|
||||
|
||||
def test_flux2_klein_uses_diffusion_models_bucket(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""flux2_klein family also routes to diffusion_models/."""
|
||||
@@ -1215,9 +1201,7 @@ class TestValidateModelAvailable:
|
||||
},
|
||||
)
|
||||
# Should NOT raise — file is present in diffusion_models/.
|
||||
cli_module._validate_model_available(
|
||||
"lust_v10.safetensors", family="flux2_klein", lora=None
|
||||
)
|
||||
cli_module._validate_model_available("lust_v10.safetensors", family="flux2_klein", lora=None)
|
||||
|
||||
def test_present_model_passes_silently(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Happy path — model present, no exception."""
|
||||
@@ -1248,9 +1232,7 @@ class TestValidateModelAvailable:
|
||||
},
|
||||
)
|
||||
with pytest.raises(typer.Exit):
|
||||
cli_module._validate_model_available(
|
||||
"model.safetensors", family="flux", lora="ghost_lora.safetensors"
|
||||
)
|
||||
cli_module._validate_model_available("model.safetensors", family="flux", lora="ghost_lora.safetensors")
|
||||
|
||||
def test_network_failure_is_non_fatal(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""If get_loaded_models() raises, validation falls through silently."""
|
||||
@@ -1278,9 +1260,7 @@ class TestValidateModelAvailable:
|
||||
},
|
||||
)
|
||||
with pytest.raises(typer.Exit):
|
||||
cli_module._validate_model_available(
|
||||
"new_unet_model.safetensors", family="flux_unet", lora=None
|
||||
)
|
||||
cli_module._validate_model_available("new_unet_model.safetensors", family="flux_unet", lora=None)
|
||||
|
||||
def test_get_loaded_models_includes_diffusion_models_bucket(self) -> None:
|
||||
"""The Comfy model-listing helper exposes the UNETLoader bucket.
|
||||
|
||||
Reference in New Issue
Block a user