feat(generate): validate model availability against live ComfyUI before queueing
Catches mismatches between local intent and what's actually loaded on the ComfyUI host. Replaces ComfyUI's generic 400 'prompt_outputs_failed_validation' with a clear "model X not available on host — did you mean Y?" suggestion. Why: when a user types `tsr generate -m getphatFLUXReality_v5Hardcore` but only v11Softcore is installed, they got a 30-line raw API error buried in node validation output. Now they get one red line plus three fuzzy-matched candidates from the actual loader bucket. Implementation: - Extends get_loaded_models() in comfyui.py to include the diffusion_models bucket (UNETLoader -> unet_name). Previously only checkpoints, loras, vae, clip, controlnet, upscale_models were exposed. - New _validate_model_available() helper in cli.py runs after family detection, before prompt enhancement. Maps family -> loader bucket: flux_unet / flux2_klein -> diffusion_models/, else checkpoints/. Uses difflib.get_close_matches for the "did you mean" hint. - Validates LoRA presence too when -l is passed. - Special hint: if the requested file IS in checkpoints/ but the family requires diffusion_models/, suggests the symlink command the user needs to run on the host. Common case for newly-uploaded UNet-only checkpoints. - Network failures are non-fatal — falls through to let ComfyUI surface the error itself rather than blocking on a stale endpoint. - Skipped in --json mode (machine callers) and --remote dispatches (the server validates remotely). 8 new tests covering: unknown model in checkpoints bucket, unknown in diffusion_models, flux2_klein routing, happy path, missing LoRA, network failure, symlink hint, and a source-level check that the diffusion_models bucket is wired into get_loaded_models. 259 -> 267 tests.
This commit is contained in:
@@ -926,6 +926,66 @@ def generate( # noqa: PLR0915
|
||||
)
|
||||
|
||||
|
||||
# Map model family → which ComfyUI loader directory the checkpoint must live in.
|
||||
# Used by _validate_model_available() to query the right slot from get_loaded_models().
|
||||
_FAMILY_TO_LOADER_BUCKET: dict[str, str] = {
|
||||
"flux_unet": "diffusion_models",
|
||||
"flux2_klein": "diffusion_models",
|
||||
}
|
||||
|
||||
|
||||
def _validate_model_available(model: str, family: str | None, lora: str | None) -> None:
|
||||
"""Verify model + LoRA exist on the live ComfyUI host before queueing.
|
||||
|
||||
Fails fast with typer.Exit(1) and a "did you mean" suggestion when the
|
||||
requested file isn't loaded. Bucket lookup respects family:
|
||||
- flux_unet / flux2_klein → diffusion_models/ (UNETLoader)
|
||||
- everything else → checkpoints/ (CheckpointLoaderSimple)
|
||||
|
||||
Network failures are non-fatal — we'd rather forward to ComfyUI and let its
|
||||
400 surface than block on a stale comfyui endpoint.
|
||||
"""
|
||||
from difflib import get_close_matches # noqa: PLC0415
|
||||
|
||||
from tensors.comfyui import get_loaded_models # noqa: PLC0415
|
||||
|
||||
try:
|
||||
loaded = get_loaded_models(console=None)
|
||||
except Exception:
|
||||
return # network down — let ComfyUI itself handle it
|
||||
if not loaded:
|
||||
return
|
||||
|
||||
bucket = _FAMILY_TO_LOADER_BUCKET.get(family or "", "checkpoints")
|
||||
available = loaded.get(bucket, [])
|
||||
if model not in available:
|
||||
console.print(f"[red]Model '{model}' not available on ComfyUI host[/red]")
|
||||
console.print(f"[dim](looked in {bucket}/ — {len(available)} entries)[/dim]")
|
||||
matches = get_close_matches(model, available, n=3, cutoff=0.5)
|
||||
if matches:
|
||||
console.print("[yellow]Did you mean:[/yellow]")
|
||||
for m in matches:
|
||||
console.print(f" [cyan]{m}[/cyan]")
|
||||
else:
|
||||
console.print(f"[dim]Run `tsr models` to see what's installed in {bucket}/.[/dim]")
|
||||
# Suggest symlink fix if the file exists in checkpoints/ but family wants diffusion_models/
|
||||
if bucket == "diffusion_models" and model in loaded.get("checkpoints", []):
|
||||
console.print(
|
||||
f"[yellow]Hint:[/yellow] '{model}' is in checkpoints/ but UNet-only checkpoints need to be in diffusion_models/. "
|
||||
f"On the ComfyUI host: [cyan]ln -s ../checkpoints/{model} <comfyui>/models/diffusion_models/{model}[/cyan]"
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
if lora and lora not in loaded.get("loras", []):
|
||||
console.print(f"[red]LoRA '{lora}' not available on ComfyUI host[/red]")
|
||||
matches = get_close_matches(lora, loaded.get("loras", []), n=3, cutoff=0.5)
|
||||
if matches:
|
||||
console.print("[yellow]Did you mean:[/yellow]")
|
||||
for m in matches:
|
||||
console.print(f" [cyan]{m}[/cyan]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def _run_generation( # noqa: PLR0915
|
||||
*,
|
||||
prompt: str,
|
||||
@@ -983,6 +1043,14 @@ def _run_generation( # noqa: PLR0915
|
||||
else:
|
||||
console.print(f"[dim]Detected model family: {model_family}[/dim]")
|
||||
|
||||
# ---- Validate the requested model exists on the target host ----
|
||||
# Catches mismatches between local intent ("v5Hardcore") and what's actually
|
||||
# available remotely ("v11Softcore"), and offers a fuzzy "did you mean" hint
|
||||
# instead of forwarding the request to ComfyUI for a generic 400 rejection.
|
||||
# Skipped in --json mode and for remote dispatches (server already validates).
|
||||
if model and not json_output and not remote:
|
||||
_validate_model_available(model, model_family, lora)
|
||||
|
||||
# Build enhanced prompt with quality prefix and LoRA trigger words
|
||||
prompt_parts: list[str] = []
|
||||
|
||||
|
||||
@@ -235,6 +235,7 @@ def get_loaded_models(url: str | None = None, console: Console | None = None) ->
|
||||
# Model type to node class and input name mapping
|
||||
model_types = {
|
||||
"checkpoints": ("CheckpointLoaderSimple", "ckpt_name"),
|
||||
"diffusion_models": ("UNETLoader", "unet_name"),
|
||||
"loras": ("LoraLoader", "lora_name"),
|
||||
"vae": ("VAELoader", "vae_name"),
|
||||
"clip": ("CLIPLoader", "clip_name"),
|
||||
|
||||
@@ -1158,3 +1158,142 @@ class TestCLI:
|
||||
result = runner.invoke(app, ["dl"])
|
||||
assert result.exit_code == 1
|
||||
assert "must specify" in result.stdout.lower()
|
||||
|
||||
|
||||
class TestValidateModelAvailable:
|
||||
"""Tests for the pre-flight model-availability check before queueing."""
|
||||
|
||||
def test_unknown_model_in_checkpoints_bucket(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Unknown model + fuzzy-match candidates — exits 1 with did-you-mean."""
|
||||
import typer # noqa: PLC0415
|
||||
|
||||
from tensors import cli as cli_module # noqa: PLC0415
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tensors.comfyui.get_loaded_models",
|
||||
lambda console=None: {
|
||||
"checkpoints": ["fluxRealVision_v25.safetensors", "ponyDiffusionV6XL.safetensors"],
|
||||
"loras": [],
|
||||
"diffusion_models": [],
|
||||
},
|
||||
)
|
||||
with pytest.raises(typer.Exit) as exc:
|
||||
cli_module._validate_model_available(
|
||||
"fluxRealVision_v99.safetensors", family="flux", lora=None
|
||||
)
|
||||
assert exc.value.exit_code == 1
|
||||
|
||||
def test_unknown_model_in_diffusion_models_bucket(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""flux_unet family looks in diffusion_models/, not checkpoints/."""
|
||||
import typer # noqa: PLC0415
|
||||
|
||||
from tensors import cli as cli_module # noqa: PLC0415
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tensors.comfyui.get_loaded_models",
|
||||
lambda console=None: {
|
||||
"checkpoints": [],
|
||||
"loras": [],
|
||||
"diffusion_models": ["getphat_v11.safetensors"],
|
||||
},
|
||||
)
|
||||
with pytest.raises(typer.Exit):
|
||||
cli_module._validate_model_available(
|
||||
"getphat_v99.safetensors", family="flux_unet", lora=None
|
||||
)
|
||||
|
||||
def test_flux2_klein_uses_diffusion_models_bucket(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""flux2_klein family also routes to diffusion_models/."""
|
||||
from tensors import cli as cli_module # noqa: PLC0415
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tensors.comfyui.get_loaded_models",
|
||||
lambda console=None: {
|
||||
"checkpoints": [],
|
||||
"loras": [],
|
||||
"diffusion_models": ["lust_v10.safetensors"],
|
||||
},
|
||||
)
|
||||
# Should NOT raise — file is present in diffusion_models/.
|
||||
cli_module._validate_model_available(
|
||||
"lust_v10.safetensors", family="flux2_klein", lora=None
|
||||
)
|
||||
|
||||
def test_present_model_passes_silently(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Happy path — model present, no exception."""
|
||||
from tensors import cli as cli_module # noqa: PLC0415
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tensors.comfyui.get_loaded_models",
|
||||
lambda console=None: {
|
||||
"checkpoints": ["model.safetensors"],
|
||||
"loras": [],
|
||||
"diffusion_models": [],
|
||||
},
|
||||
)
|
||||
cli_module._validate_model_available("model.safetensors", family="flux", lora=None)
|
||||
|
||||
def test_missing_lora_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Model present but LoRA missing — exit 1."""
|
||||
import typer # noqa: PLC0415
|
||||
|
||||
from tensors import cli as cli_module # noqa: PLC0415
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tensors.comfyui.get_loaded_models",
|
||||
lambda console=None: {
|
||||
"checkpoints": ["model.safetensors"],
|
||||
"loras": ["real_lora.safetensors"],
|
||||
"diffusion_models": [],
|
||||
},
|
||||
)
|
||||
with pytest.raises(typer.Exit):
|
||||
cli_module._validate_model_available(
|
||||
"model.safetensors", family="flux", lora="ghost_lora.safetensors"
|
||||
)
|
||||
|
||||
def test_network_failure_is_non_fatal(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""If get_loaded_models() raises, validation falls through silently."""
|
||||
from tensors import cli as cli_module # noqa: PLC0415
|
||||
|
||||
def _boom(console=None):
|
||||
raise ConnectionError("comfyui down")
|
||||
|
||||
monkeypatch.setattr("tensors.comfyui.get_loaded_models", _boom)
|
||||
# Must not raise — we fall through to let ComfyUI surface the failure itself.
|
||||
cli_module._validate_model_available("anything.safetensors", family=None, lora=None)
|
||||
|
||||
def test_symlink_hint_when_file_in_wrong_bucket(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""flux_unet checkpoint sitting in checkpoints/ → suggest symlinking."""
|
||||
import typer # noqa: PLC0415
|
||||
|
||||
from tensors import cli as cli_module # noqa: PLC0415
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tensors.comfyui.get_loaded_models",
|
||||
lambda console=None: {
|
||||
"checkpoints": ["new_unet_model.safetensors"],
|
||||
"loras": [],
|
||||
"diffusion_models": [],
|
||||
},
|
||||
)
|
||||
with pytest.raises(typer.Exit):
|
||||
cli_module._validate_model_available(
|
||||
"new_unet_model.safetensors", family="flux_unet", lora=None
|
||||
)
|
||||
|
||||
def test_get_loaded_models_includes_diffusion_models_bucket(self) -> None:
|
||||
"""The Comfy model-listing helper exposes the UNETLoader bucket.
|
||||
|
||||
Source-level check (no network): the model_types map inside get_loaded_models
|
||||
must contain a diffusion_models entry wired to UNETLoader, otherwise the
|
||||
validator's flux_unet / flux2_klein bucket lookup would silently return
|
||||
an empty list.
|
||||
"""
|
||||
import inspect # noqa: PLC0415
|
||||
|
||||
import tensors.comfyui as comfyui_module # noqa: PLC0415
|
||||
|
||||
src = inspect.getsource(comfyui_module.get_loaded_models)
|
||||
assert '"diffusion_models"' in src
|
||||
assert '"UNETLoader"' in src
|
||||
|
||||
Reference in New Issue
Block a user