feat(cli): resolve bare model/lora names to .safetensors silently

Before: `tsr generate -m lust_v10` failed validation with a 'did you mean'
hint pointing at 'lust_v10.safetensors'. Annoying because the bare name is
what the user typed and what `tsr models` prints back to them.

After: when the exact name misses and there is exactly one suffixed variant
in the target bucket (diffusion_models/ or checkpoints/), the validator
prints a dim '[Resolved 'X' \u2192 'X.safetensors']' line and substitutes the
canonical filename before forwarding to ComfyUI. Extensions tried in order:
.safetensors, .ckpt, .gguf, .pt, .bin. Same logic applies to LoRA names.

Ambiguous matches (multiple extensions present) still fail \u2014 the user must
disambiguate explicitly. Names that already contain a dot are left alone.

Function signature changes from -> None to -> tuple[str, str | None] so the
caller can pick up the resolved names.
This commit is contained in:
chi
2026-05-18 15:07:41 +02:00
parent fc5ea1a44b
commit 60066f3ec2
+61 -9
View File
@@ -978,9 +978,36 @@ _FAMILY_TO_LOADER_BUCKET: dict[str, str] = {
}
def _validate_model_available(model: str, family: str | None, lora: str | None) -> None:
# Extensions tried in order when the user passes a bare name (no suffix) and
# the exact lookup misses. Safetensors first since that's the modern default.
_MODEL_EXTENSIONS: tuple[str, ...] = (".safetensors", ".ckpt", ".gguf", ".pt", ".bin")
def _resolve_with_extension(name: str, available: list[str]) -> str | None:
"""If `name` is missing an extension but exactly one suffixed variant exists in
`available`, return that variant. Otherwise return None.
Examples:
_resolve_with_extension("lust_v10", ["lust_v10.safetensors", ...]) -> "lust_v10.safetensors"
_resolve_with_extension("lust_v10", ["lust_v10.safetensors", "lust_v10.ckpt"]) -> None # ambiguous
_resolve_with_extension("lust_v10.safetensors", [...]) -> None # already has ext
"""
if "." in name: # user already provided an extension; don't second-guess
return None
candidates = [f"{name}{ext}" for ext in _MODEL_EXTENSIONS if f"{name}{ext}" in available]
if len(candidates) == 1:
return candidates[0]
return None # zero matches → real miss; multiple matches → ambiguous, force user to disambiguate
def _validate_model_available(model: str, family: str | None, lora: str | None) -> tuple[str, str | None]:
"""Verify model + LoRA exist on the live ComfyUI host before queueing.
Returns (resolved_model, resolved_lora) so the caller can substitute the
canonical filename when the user passed a bare name without extension
(e.g. -m lust_v10 → lust_v10.safetensors). Otherwise returns the inputs
unchanged.
Fails fast with typer.Exit(1) and a "did you mean" suggestion when the
requested file isn't loaded. Bucket lookup respects family:
- flux_unet / flux2_klein → diffusion_models/ (UNETLoader)
@@ -996,13 +1023,21 @@ def _validate_model_available(model: str, family: str | None, lora: str | None)
try:
loaded = get_loaded_models(console=None)
except Exception:
return # network down — let ComfyUI itself handle it
return model, lora # network down — let ComfyUI itself handle it
if not loaded:
return
return model, lora
bucket = _FAMILY_TO_LOADER_BUCKET.get(family or "", "checkpoints")
available = loaded.get(bucket, [])
if model not in available:
# Try implicit-extension resolution before failing: bare names like
# `lust_v10` should silently resolve to `lust_v10.safetensors` when
# there's an unambiguous match.
resolved = _resolve_with_extension(model, available)
if resolved is not None:
console.print(f"[dim]Resolved '{model}''{resolved}'[/dim]")
model = resolved
else:
console.print(f"[red]Model '{model}' not available on ComfyUI host[/red]")
console.print(f"[dim](looked in {bucket}/ — {len(available)} entries)[/dim]")
matches = get_close_matches(model, available, n=3, cutoff=0.5)
@@ -1013,22 +1048,36 @@ def _validate_model_available(model: str, family: str | None, lora: str | None)
else:
console.print(f"[dim]Run `tsr models` to see what's installed in {bucket}/.[/dim]")
# Suggest symlink fix if the file exists in checkpoints/ but family wants diffusion_models/
if bucket == "diffusion_models" and model in loaded.get("checkpoints", []):
if bucket == "diffusion_models" and (
model in loaded.get("checkpoints", [])
or _resolve_with_extension(model, loaded.get("checkpoints", [])) is not None
):
console.print(
f"[yellow]Hint:[/yellow] '{model}' is in checkpoints/ but UNet-only checkpoints need to be in diffusion_models/. "
f"On the ComfyUI host: [cyan]ln -s ../checkpoints/{model} <comfyui>/models/diffusion_models/{model}[/cyan]"
f"[yellow]Hint:[/yellow] '{model}' is in checkpoints/ but UNet-only "
"checkpoints need to be in diffusion_models/. On the ComfyUI host: "
f"[cyan]ln -s ../checkpoints/{model} "
f"<comfyui>/models/diffusion_models/{model}[/cyan]"
)
raise typer.Exit(1)
if lora and lora not in loaded.get("loras", []):
if lora is not None:
loras_available = loaded.get("loras", [])
if lora not in loras_available:
resolved_lora = _resolve_with_extension(lora, loras_available)
if resolved_lora is not None:
console.print(f"[dim]Resolved LoRA '{lora}''{resolved_lora}'[/dim]")
lora = resolved_lora
else:
console.print(f"[red]LoRA '{lora}' not available on ComfyUI host[/red]")
matches = get_close_matches(lora, loaded.get("loras", []), n=3, cutoff=0.5)
matches = get_close_matches(lora, loras_available, n=3, cutoff=0.5)
if matches:
console.print("[yellow]Did you mean:[/yellow]")
for m in matches:
console.print(f" [cyan]{m}[/cyan]")
raise typer.Exit(1)
return model, lora
def _run_generation( # noqa: PLR0915
*,
@@ -1097,7 +1146,10 @@ def _run_generation( # noqa: PLR0915
# instead of forwarding the request to ComfyUI for a generic 400 rejection.
# Skipped in --json mode and for remote dispatches (server already validates).
if model and not json_output and not remote:
_validate_model_available(model, model_family, lora)
# Returns possibly-rewritten names so bare inputs like `-m lust_v10`
# silently resolve to the canonical `lust_v10.safetensors` filename
# before being forwarded to ComfyUI's strict CLIPLoader / UNETLoader.
model, lora = _validate_model_available(model, model_family, lora)
# Build enhanced prompt with quality prefix (no automatic LoRA trigger injection)
prompt_parts: list[str] = []