diff --git a/tensors/cli.py b/tensors/cli.py index 46958fe..4c3e957 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -978,9 +978,36 @@ _FAMILY_TO_LOADER_BUCKET: dict[str, str] = { } -def _validate_model_available(model: str, family: str | None, lora: str | None) -> None: +# Extensions tried in order when the user passes a bare name (no suffix) and +# the exact lookup misses. Safetensors first since that's the modern default. +_MODEL_EXTENSIONS: tuple[str, ...] = (".safetensors", ".ckpt", ".gguf", ".pt", ".bin") + + +def _resolve_with_extension(name: str, available: list[str]) -> str | None: + """If `name` is missing an extension but exactly one suffixed variant exists in + `available`, return that variant. Otherwise return None. + + Examples: + _resolve_with_extension("lust_v10", ["lust_v10.safetensors", ...]) -> "lust_v10.safetensors" + _resolve_with_extension("lust_v10", ["lust_v10.safetensors", "lust_v10.ckpt"]) -> None # ambiguous + _resolve_with_extension("lust_v10.safetensors", [...]) -> None # already has ext + """ + if "." in name: # user already provided an extension; don't second-guess + return None + candidates = [f"{name}{ext}" for ext in _MODEL_EXTENSIONS if f"{name}{ext}" in available] + if len(candidates) == 1: + return candidates[0] + return None # zero matches → real miss; multiple matches → ambiguous, force user to disambiguate + + +def _validate_model_available(model: str, family: str | None, lora: str | None) -> tuple[str, str | None]: """Verify model + LoRA exist on the live ComfyUI host before queueing. + Returns (resolved_model, resolved_lora) so the caller can substitute the + canonical filename when the user passed a bare name without extension + (e.g. -m lust_v10 → lust_v10.safetensors). Otherwise returns the inputs + unchanged. + Fails fast with typer.Exit(1) and a "did you mean" suggestion when the requested file isn't loaded. Bucket lookup respects family: - flux_unet / flux2_klein → diffusion_models/ (UNETLoader) @@ -996,38 +1023,60 @@ def _validate_model_available(model: str, family: str | None, lora: str | None) try: loaded = get_loaded_models(console=None) except Exception: - return # network down — let ComfyUI itself handle it + return model, lora # network down — let ComfyUI itself handle it if not loaded: - return + return model, lora bucket = _FAMILY_TO_LOADER_BUCKET.get(family or "", "checkpoints") available = loaded.get(bucket, []) if model not in available: - console.print(f"[red]Model '{model}' not available on ComfyUI host[/red]") - console.print(f"[dim](looked in {bucket}/ — {len(available)} entries)[/dim]") - matches = get_close_matches(model, available, n=3, cutoff=0.5) - if matches: - console.print("[yellow]Did you mean:[/yellow]") - for m in matches: - console.print(f" [cyan]{m}[/cyan]") + # Try implicit-extension resolution before failing: bare names like + # `lust_v10` should silently resolve to `lust_v10.safetensors` when + # there's an unambiguous match. + resolved = _resolve_with_extension(model, available) + if resolved is not None: + console.print(f"[dim]Resolved '{model}' → '{resolved}'[/dim]") + model = resolved else: - console.print(f"[dim]Run `tsr models` to see what's installed in {bucket}/.[/dim]") - # Suggest symlink fix if the file exists in checkpoints/ but family wants diffusion_models/ - if bucket == "diffusion_models" and model in loaded.get("checkpoints", []): - console.print( - f"[yellow]Hint:[/yellow] '{model}' is in checkpoints/ but UNet-only checkpoints need to be in diffusion_models/. " - f"On the ComfyUI host: [cyan]ln -s ../checkpoints/{model} /models/diffusion_models/{model}[/cyan]" - ) - raise typer.Exit(1) + console.print(f"[red]Model '{model}' not available on ComfyUI host[/red]") + console.print(f"[dim](looked in {bucket}/ — {len(available)} entries)[/dim]") + matches = get_close_matches(model, available, n=3, cutoff=0.5) + if matches: + console.print("[yellow]Did you mean:[/yellow]") + for m in matches: + console.print(f" [cyan]{m}[/cyan]") + else: + console.print(f"[dim]Run `tsr models` to see what's installed in {bucket}/.[/dim]") + # Suggest symlink fix if the file exists in checkpoints/ but family wants diffusion_models/ + if bucket == "diffusion_models" and ( + model in loaded.get("checkpoints", []) + or _resolve_with_extension(model, loaded.get("checkpoints", [])) is not None + ): + console.print( + f"[yellow]Hint:[/yellow] '{model}' is in checkpoints/ but UNet-only " + "checkpoints need to be in diffusion_models/. On the ComfyUI host: " + f"[cyan]ln -s ../checkpoints/{model} " + f"/models/diffusion_models/{model}[/cyan]" + ) + raise typer.Exit(1) - if lora and lora not in loaded.get("loras", []): - console.print(f"[red]LoRA '{lora}' not available on ComfyUI host[/red]") - matches = get_close_matches(lora, loaded.get("loras", []), n=3, cutoff=0.5) - if matches: - console.print("[yellow]Did you mean:[/yellow]") - for m in matches: - console.print(f" [cyan]{m}[/cyan]") - raise typer.Exit(1) + if lora is not None: + loras_available = loaded.get("loras", []) + if lora not in loras_available: + resolved_lora = _resolve_with_extension(lora, loras_available) + if resolved_lora is not None: + console.print(f"[dim]Resolved LoRA '{lora}' → '{resolved_lora}'[/dim]") + lora = resolved_lora + else: + console.print(f"[red]LoRA '{lora}' not available on ComfyUI host[/red]") + matches = get_close_matches(lora, loras_available, n=3, cutoff=0.5) + if matches: + console.print("[yellow]Did you mean:[/yellow]") + for m in matches: + console.print(f" [cyan]{m}[/cyan]") + raise typer.Exit(1) + + return model, lora def _run_generation( # noqa: PLR0915 @@ -1097,7 +1146,10 @@ def _run_generation( # noqa: PLR0915 # instead of forwarding the request to ComfyUI for a generic 400 rejection. # Skipped in --json mode and for remote dispatches (server already validates). if model and not json_output and not remote: - _validate_model_available(model, model_family, lora) + # Returns possibly-rewritten names so bare inputs like `-m lust_v10` + # silently resolve to the canonical `lust_v10.safetensors` filename + # before being forwarded to ComfyUI's strict CLIPLoader / UNETLoader. + model, lora = _validate_model_available(model, model_family, lora) # Build enhanced prompt with quality prefix (no automatic LoRA trigger injection) prompt_parts: list[str] = []