feat(cli): resolve bare model/lora names to .safetensors silently
Before: `tsr generate -m lust_v10` failed validation with a 'did you mean' hint pointing at 'lust_v10.safetensors'. Annoying because the bare name is what the user typed and what `tsr models` prints back to them. After: when the exact name misses and there is exactly one suffixed variant in the target bucket (diffusion_models/ or checkpoints/), the validator prints a dim '[Resolved 'X' \u2192 'X.safetensors']' line and substitutes the canonical filename before forwarding to ComfyUI. Extensions tried in order: .safetensors, .ckpt, .gguf, .pt, .bin. Same logic applies to LoRA names. Ambiguous matches (multiple extensions present) still fail \u2014 the user must disambiguate explicitly. Names that already contain a dot are left alone. Function signature changes from -> None to -> tuple[str, str | None] so the caller can pick up the resolved names.
This commit is contained in:
+61
-9
@@ -978,9 +978,36 @@ _FAMILY_TO_LOADER_BUCKET: dict[str, str] = {
|
||||
}
|
||||
|
||||
|
||||
def _validate_model_available(model: str, family: str | None, lora: str | None) -> None:
|
||||
# Extensions tried in order when the user passes a bare name (no suffix) and
|
||||
# the exact lookup misses. Safetensors first since that's the modern default.
|
||||
_MODEL_EXTENSIONS: tuple[str, ...] = (".safetensors", ".ckpt", ".gguf", ".pt", ".bin")
|
||||
|
||||
|
||||
def _resolve_with_extension(name: str, available: list[str]) -> str | None:
|
||||
"""If `name` is missing an extension but exactly one suffixed variant exists in
|
||||
`available`, return that variant. Otherwise return None.
|
||||
|
||||
Examples:
|
||||
_resolve_with_extension("lust_v10", ["lust_v10.safetensors", ...]) -> "lust_v10.safetensors"
|
||||
_resolve_with_extension("lust_v10", ["lust_v10.safetensors", "lust_v10.ckpt"]) -> None # ambiguous
|
||||
_resolve_with_extension("lust_v10.safetensors", [...]) -> None # already has ext
|
||||
"""
|
||||
if "." in name: # user already provided an extension; don't second-guess
|
||||
return None
|
||||
candidates = [f"{name}{ext}" for ext in _MODEL_EXTENSIONS if f"{name}{ext}" in available]
|
||||
if len(candidates) == 1:
|
||||
return candidates[0]
|
||||
return None # zero matches → real miss; multiple matches → ambiguous, force user to disambiguate
|
||||
|
||||
|
||||
def _validate_model_available(model: str, family: str | None, lora: str | None) -> tuple[str, str | None]:
|
||||
"""Verify model + LoRA exist on the live ComfyUI host before queueing.
|
||||
|
||||
Returns (resolved_model, resolved_lora) so the caller can substitute the
|
||||
canonical filename when the user passed a bare name without extension
|
||||
(e.g. -m lust_v10 → lust_v10.safetensors). Otherwise returns the inputs
|
||||
unchanged.
|
||||
|
||||
Fails fast with typer.Exit(1) and a "did you mean" suggestion when the
|
||||
requested file isn't loaded. Bucket lookup respects family:
|
||||
- flux_unet / flux2_klein → diffusion_models/ (UNETLoader)
|
||||
@@ -996,13 +1023,21 @@ def _validate_model_available(model: str, family: str | None, lora: str | None)
|
||||
try:
|
||||
loaded = get_loaded_models(console=None)
|
||||
except Exception:
|
||||
return # network down — let ComfyUI itself handle it
|
||||
return model, lora # network down — let ComfyUI itself handle it
|
||||
if not loaded:
|
||||
return
|
||||
return model, lora
|
||||
|
||||
bucket = _FAMILY_TO_LOADER_BUCKET.get(family or "", "checkpoints")
|
||||
available = loaded.get(bucket, [])
|
||||
if model not in available:
|
||||
# Try implicit-extension resolution before failing: bare names like
|
||||
# `lust_v10` should silently resolve to `lust_v10.safetensors` when
|
||||
# there's an unambiguous match.
|
||||
resolved = _resolve_with_extension(model, available)
|
||||
if resolved is not None:
|
||||
console.print(f"[dim]Resolved '{model}' → '{resolved}'[/dim]")
|
||||
model = resolved
|
||||
else:
|
||||
console.print(f"[red]Model '{model}' not available on ComfyUI host[/red]")
|
||||
console.print(f"[dim](looked in {bucket}/ — {len(available)} entries)[/dim]")
|
||||
matches = get_close_matches(model, available, n=3, cutoff=0.5)
|
||||
@@ -1013,22 +1048,36 @@ def _validate_model_available(model: str, family: str | None, lora: str | None)
|
||||
else:
|
||||
console.print(f"[dim]Run `tsr models` to see what's installed in {bucket}/.[/dim]")
|
||||
# Suggest symlink fix if the file exists in checkpoints/ but family wants diffusion_models/
|
||||
if bucket == "diffusion_models" and model in loaded.get("checkpoints", []):
|
||||
if bucket == "diffusion_models" and (
|
||||
model in loaded.get("checkpoints", [])
|
||||
or _resolve_with_extension(model, loaded.get("checkpoints", [])) is not None
|
||||
):
|
||||
console.print(
|
||||
f"[yellow]Hint:[/yellow] '{model}' is in checkpoints/ but UNet-only checkpoints need to be in diffusion_models/. "
|
||||
f"On the ComfyUI host: [cyan]ln -s ../checkpoints/{model} <comfyui>/models/diffusion_models/{model}[/cyan]"
|
||||
f"[yellow]Hint:[/yellow] '{model}' is in checkpoints/ but UNet-only "
|
||||
"checkpoints need to be in diffusion_models/. On the ComfyUI host: "
|
||||
f"[cyan]ln -s ../checkpoints/{model} "
|
||||
f"<comfyui>/models/diffusion_models/{model}[/cyan]"
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
if lora and lora not in loaded.get("loras", []):
|
||||
if lora is not None:
|
||||
loras_available = loaded.get("loras", [])
|
||||
if lora not in loras_available:
|
||||
resolved_lora = _resolve_with_extension(lora, loras_available)
|
||||
if resolved_lora is not None:
|
||||
console.print(f"[dim]Resolved LoRA '{lora}' → '{resolved_lora}'[/dim]")
|
||||
lora = resolved_lora
|
||||
else:
|
||||
console.print(f"[red]LoRA '{lora}' not available on ComfyUI host[/red]")
|
||||
matches = get_close_matches(lora, loaded.get("loras", []), n=3, cutoff=0.5)
|
||||
matches = get_close_matches(lora, loras_available, n=3, cutoff=0.5)
|
||||
if matches:
|
||||
console.print("[yellow]Did you mean:[/yellow]")
|
||||
for m in matches:
|
||||
console.print(f" [cyan]{m}[/cyan]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
return model, lora
|
||||
|
||||
|
||||
def _run_generation( # noqa: PLR0915
|
||||
*,
|
||||
@@ -1097,7 +1146,10 @@ def _run_generation( # noqa: PLR0915
|
||||
# instead of forwarding the request to ComfyUI for a generic 400 rejection.
|
||||
# Skipped in --json mode and for remote dispatches (server already validates).
|
||||
if model and not json_output and not remote:
|
||||
_validate_model_available(model, model_family, lora)
|
||||
# Returns possibly-rewritten names so bare inputs like `-m lust_v10`
|
||||
# silently resolve to the canonical `lust_v10.safetensors` filename
|
||||
# before being forwarded to ComfyUI's strict CLIPLoader / UNETLoader.
|
||||
model, lora = _validate_model_available(model, model_family, lora)
|
||||
|
||||
# Build enhanced prompt with quality prefix (no automatic LoRA trigger injection)
|
||||
prompt_parts: list[str] = []
|
||||
|
||||
Reference in New Issue
Block a user