feat(cli): resolve bare model/lora names to .safetensors silently
Before: `tsr generate -m lust_v10` failed validation with a 'did you mean' hint pointing at 'lust_v10.safetensors'. Annoying because the bare name is what the user typed and what `tsr models` prints back to them. After: when the exact name misses and there is exactly one suffixed variant in the target bucket (diffusion_models/ or checkpoints/), the validator prints a dim '[Resolved 'X' \u2192 'X.safetensors']' line and substitutes the canonical filename before forwarding to ComfyUI. Extensions tried in order: .safetensors, .ckpt, .gguf, .pt, .bin. Same logic applies to LoRA names. Ambiguous matches (multiple extensions present) still fail \u2014 the user must disambiguate explicitly. Names that already contain a dot are left alone. Function signature changes from -> None to -> tuple[str, str | None] so the caller can pick up the resolved names.
This commit is contained in:
+79
-27
@@ -978,9 +978,36 @@ _FAMILY_TO_LOADER_BUCKET: dict[str, str] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _validate_model_available(model: str, family: str | None, lora: str | None) -> None:
|
# Extensions tried in order when the user passes a bare name (no suffix) and
|
||||||
|
# the exact lookup misses. Safetensors first since that's the modern default.
|
||||||
|
_MODEL_EXTENSIONS: tuple[str, ...] = (".safetensors", ".ckpt", ".gguf", ".pt", ".bin")
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_with_extension(name: str, available: list[str]) -> str | None:
|
||||||
|
"""If `name` is missing an extension but exactly one suffixed variant exists in
|
||||||
|
`available`, return that variant. Otherwise return None.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
_resolve_with_extension("lust_v10", ["lust_v10.safetensors", ...]) -> "lust_v10.safetensors"
|
||||||
|
_resolve_with_extension("lust_v10", ["lust_v10.safetensors", "lust_v10.ckpt"]) -> None # ambiguous
|
||||||
|
_resolve_with_extension("lust_v10.safetensors", [...]) -> None # already has ext
|
||||||
|
"""
|
||||||
|
if "." in name: # user already provided an extension; don't second-guess
|
||||||
|
return None
|
||||||
|
candidates = [f"{name}{ext}" for ext in _MODEL_EXTENSIONS if f"{name}{ext}" in available]
|
||||||
|
if len(candidates) == 1:
|
||||||
|
return candidates[0]
|
||||||
|
return None # zero matches → real miss; multiple matches → ambiguous, force user to disambiguate
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_model_available(model: str, family: str | None, lora: str | None) -> tuple[str, str | None]:
|
||||||
"""Verify model + LoRA exist on the live ComfyUI host before queueing.
|
"""Verify model + LoRA exist on the live ComfyUI host before queueing.
|
||||||
|
|
||||||
|
Returns (resolved_model, resolved_lora) so the caller can substitute the
|
||||||
|
canonical filename when the user passed a bare name without extension
|
||||||
|
(e.g. -m lust_v10 → lust_v10.safetensors). Otherwise returns the inputs
|
||||||
|
unchanged.
|
||||||
|
|
||||||
Fails fast with typer.Exit(1) and a "did you mean" suggestion when the
|
Fails fast with typer.Exit(1) and a "did you mean" suggestion when the
|
||||||
requested file isn't loaded. Bucket lookup respects family:
|
requested file isn't loaded. Bucket lookup respects family:
|
||||||
- flux_unet / flux2_klein → diffusion_models/ (UNETLoader)
|
- flux_unet / flux2_klein → diffusion_models/ (UNETLoader)
|
||||||
@@ -996,38 +1023,60 @@ def _validate_model_available(model: str, family: str | None, lora: str | None)
|
|||||||
try:
|
try:
|
||||||
loaded = get_loaded_models(console=None)
|
loaded = get_loaded_models(console=None)
|
||||||
except Exception:
|
except Exception:
|
||||||
return # network down — let ComfyUI itself handle it
|
return model, lora # network down — let ComfyUI itself handle it
|
||||||
if not loaded:
|
if not loaded:
|
||||||
return
|
return model, lora
|
||||||
|
|
||||||
bucket = _FAMILY_TO_LOADER_BUCKET.get(family or "", "checkpoints")
|
bucket = _FAMILY_TO_LOADER_BUCKET.get(family or "", "checkpoints")
|
||||||
available = loaded.get(bucket, [])
|
available = loaded.get(bucket, [])
|
||||||
if model not in available:
|
if model not in available:
|
||||||
console.print(f"[red]Model '{model}' not available on ComfyUI host[/red]")
|
# Try implicit-extension resolution before failing: bare names like
|
||||||
console.print(f"[dim](looked in {bucket}/ — {len(available)} entries)[/dim]")
|
# `lust_v10` should silently resolve to `lust_v10.safetensors` when
|
||||||
matches = get_close_matches(model, available, n=3, cutoff=0.5)
|
# there's an unambiguous match.
|
||||||
if matches:
|
resolved = _resolve_with_extension(model, available)
|
||||||
console.print("[yellow]Did you mean:[/yellow]")
|
if resolved is not None:
|
||||||
for m in matches:
|
console.print(f"[dim]Resolved '{model}' → '{resolved}'[/dim]")
|
||||||
console.print(f" [cyan]{m}[/cyan]")
|
model = resolved
|
||||||
else:
|
else:
|
||||||
console.print(f"[dim]Run `tsr models` to see what's installed in {bucket}/.[/dim]")
|
console.print(f"[red]Model '{model}' not available on ComfyUI host[/red]")
|
||||||
# Suggest symlink fix if the file exists in checkpoints/ but family wants diffusion_models/
|
console.print(f"[dim](looked in {bucket}/ — {len(available)} entries)[/dim]")
|
||||||
if bucket == "diffusion_models" and model in loaded.get("checkpoints", []):
|
matches = get_close_matches(model, available, n=3, cutoff=0.5)
|
||||||
console.print(
|
if matches:
|
||||||
f"[yellow]Hint:[/yellow] '{model}' is in checkpoints/ but UNet-only checkpoints need to be in diffusion_models/. "
|
console.print("[yellow]Did you mean:[/yellow]")
|
||||||
f"On the ComfyUI host: [cyan]ln -s ../checkpoints/{model} <comfyui>/models/diffusion_models/{model}[/cyan]"
|
for m in matches:
|
||||||
)
|
console.print(f" [cyan]{m}[/cyan]")
|
||||||
raise typer.Exit(1)
|
else:
|
||||||
|
console.print(f"[dim]Run `tsr models` to see what's installed in {bucket}/.[/dim]")
|
||||||
|
# Suggest symlink fix if the file exists in checkpoints/ but family wants diffusion_models/
|
||||||
|
if bucket == "diffusion_models" and (
|
||||||
|
model in loaded.get("checkpoints", [])
|
||||||
|
or _resolve_with_extension(model, loaded.get("checkpoints", [])) is not None
|
||||||
|
):
|
||||||
|
console.print(
|
||||||
|
f"[yellow]Hint:[/yellow] '{model}' is in checkpoints/ but UNet-only "
|
||||||
|
"checkpoints need to be in diffusion_models/. On the ComfyUI host: "
|
||||||
|
f"[cyan]ln -s ../checkpoints/{model} "
|
||||||
|
f"<comfyui>/models/diffusion_models/{model}[/cyan]"
|
||||||
|
)
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
if lora and lora not in loaded.get("loras", []):
|
if lora is not None:
|
||||||
console.print(f"[red]LoRA '{lora}' not available on ComfyUI host[/red]")
|
loras_available = loaded.get("loras", [])
|
||||||
matches = get_close_matches(lora, loaded.get("loras", []), n=3, cutoff=0.5)
|
if lora not in loras_available:
|
||||||
if matches:
|
resolved_lora = _resolve_with_extension(lora, loras_available)
|
||||||
console.print("[yellow]Did you mean:[/yellow]")
|
if resolved_lora is not None:
|
||||||
for m in matches:
|
console.print(f"[dim]Resolved LoRA '{lora}' → '{resolved_lora}'[/dim]")
|
||||||
console.print(f" [cyan]{m}[/cyan]")
|
lora = resolved_lora
|
||||||
raise typer.Exit(1)
|
else:
|
||||||
|
console.print(f"[red]LoRA '{lora}' not available on ComfyUI host[/red]")
|
||||||
|
matches = get_close_matches(lora, loras_available, n=3, cutoff=0.5)
|
||||||
|
if matches:
|
||||||
|
console.print("[yellow]Did you mean:[/yellow]")
|
||||||
|
for m in matches:
|
||||||
|
console.print(f" [cyan]{m}[/cyan]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
return model, lora
|
||||||
|
|
||||||
|
|
||||||
def _run_generation( # noqa: PLR0915
|
def _run_generation( # noqa: PLR0915
|
||||||
@@ -1097,7 +1146,10 @@ def _run_generation( # noqa: PLR0915
|
|||||||
# instead of forwarding the request to ComfyUI for a generic 400 rejection.
|
# instead of forwarding the request to ComfyUI for a generic 400 rejection.
|
||||||
# Skipped in --json mode and for remote dispatches (server already validates).
|
# Skipped in --json mode and for remote dispatches (server already validates).
|
||||||
if model and not json_output and not remote:
|
if model and not json_output and not remote:
|
||||||
_validate_model_available(model, model_family, lora)
|
# Returns possibly-rewritten names so bare inputs like `-m lust_v10`
|
||||||
|
# silently resolve to the canonical `lust_v10.safetensors` filename
|
||||||
|
# before being forwarded to ComfyUI's strict CLIPLoader / UNETLoader.
|
||||||
|
model, lora = _validate_model_available(model, model_family, lora)
|
||||||
|
|
||||||
# Build enhanced prompt with quality prefix (no automatic LoRA trigger injection)
|
# Build enhanced prompt with quality prefix (no automatic LoRA trigger injection)
|
||||||
prompt_parts: list[str] = []
|
prompt_parts: list[str] = []
|
||||||
|
|||||||
Reference in New Issue
Block a user