From 2a704aa67751cefda7c9754f7bc20a128105e3de Mon Sep 17 00:00:00 2001 From: aladac Date: Fri, 20 Mar 2026 09:21:50 +0100 Subject: [PATCH] Add model family-specific sampler, scheduler, and VAE defaults - Add sampler/scheduler/steps/vae to MODEL_FAMILY_DEFAULTS for all families - Add zimage family detection for ZImageTurbo models - Flux and zimage families use ae.safetensors VAE - SD 1.5 families use checkpoint built-in VAE - SDXL families use sdxl_vae.safetensors - API auto-applies family defaults when request uses default values Co-Authored-By: Claude Opus 4.5 --- tensors/comfyui.py | 11 ++++++++-- tensors/config.py | 28 +++++++++++++++++++++++++- tensors/server/comfyui_api_routes.py | 8 ++++++-- tests/test_tensors.py | 30 ++++++++++++++++++++++++++++ 4 files changed, 72 insertions(+), 5 deletions(-) diff --git a/tensors/comfyui.py b/tensors/comfyui.py index 4298457..1a60b3f 100644 --- a/tensors/comfyui.py +++ b/tensors/comfyui.py @@ -775,8 +775,15 @@ def _build_workflow( workflow["6"]["inputs"]["text"] = prompt workflow["7"]["inputs"]["text"] = negative_prompt - # Set VAE - workflow["11"]["inputs"]["vae_name"] = vae or DEFAULT_VAE + # Set VAE - use external VAE if specified, otherwise use checkpoint's built-in VAE + if vae: + # Use external VAE loader (node 11) + workflow["11"]["inputs"]["vae_name"] = vae + else: + # Use VAE from checkpoint (node 4, output index 2) - works for SD 1.5 models + # Remove VAELoader node and connect VAEDecode directly to checkpoint + del workflow["11"] + workflow["8"]["inputs"]["vae"] = ["4", 2] # Inject LoRA loader if specified if lora_name: diff --git a/tensors/config.py b/tensors/config.py index 455f15a..8f8dbfc 100644 --- a/tensors/config.py +++ b/tensors/config.py @@ -517,6 +517,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = { "sampler": "euler_ancestral", "scheduler": "normal", "steps": 25, + "vae": "sdxl_vae.safetensors", }, "illustrious": { "quality_prefix": "masterpiece, best quality, highres", @@ -527,6 +528,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = { "sampler": "euler_ancestral", "scheduler": "normal", "steps": 25, + "vae": "sdxl_vae.safetensors", }, "sdxl": { "quality_prefix": "", @@ -537,6 +539,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = { "sampler": "dpmpp_2m", "scheduler": "karras", "steps": 25, + "vae": "sdxl_vae.safetensors", }, "sdxl_lightning": { "quality_prefix": "", @@ -547,6 +550,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = { "sampler": "euler", "scheduler": "sgm_uniform", "steps": 8, # Lightning models use fewer steps + "vae": "sdxl_vae.safetensors", }, "sdxl_turbo": { "quality_prefix": "", @@ -557,6 +561,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = { "sampler": "euler_ancestral", "scheduler": "normal", "steps": 4, # Turbo models use very few steps + "vae": "sdxl_vae.safetensors", }, "sd15": { "quality_prefix": "masterpiece, best quality", @@ -570,6 +575,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = { "sampler": "dpmpp_2m", "scheduler": "karras", "steps": 20, + "vae": None, # Use checkpoint's built-in VAE }, "sd15_lcm": { "quality_prefix": "masterpiece, best quality", @@ -580,6 +586,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = { "sampler": "lcm", "scheduler": "normal", "steps": 6, + "vae": None, # Use checkpoint's built-in VAE }, "flux": { "quality_prefix": "", @@ -590,6 +597,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = { "sampler": "euler", "scheduler": "simple", "steps": 20, + "vae": "ae.safetensors", }, "flux_schnell": { "quality_prefix": "", @@ -600,6 +608,18 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = { "sampler": "euler", "scheduler": "simple", "steps": 4, # Schnell is a distilled model, very few steps + "vae": "ae.safetensors", + }, + "zimage": { + "quality_prefix": "", + "negative_prompt": "", # Turbo models work best without negative prompts + "width": 1024, + "height": 1024, + "cfg": 1.0, # Very low CFG for turbo + "sampler": "euler", + "scheduler": "simple", + "steps": 4, # ZImageTurbo is a distilled model + "vae": "ae.safetensors", }, } @@ -613,7 +633,7 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str | Returns: Model family key (pony, illustrious, sdxl, sdxl_lightning, sdxl_turbo, - sd15, sd15_lcm, flux, flux_schnell) or None if unknown + sd15, sd15_lcm, flux, flux_schnell, zimage) or None if unknown """ name_lower = model_name.lower() base_lower = (base_model or "").lower() @@ -629,6 +649,9 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str | return "flux_schnell" if "flux" in base_lower: return "flux" + # ZImageTurbo + if "zimage" in base_lower: + return "zimage" # SD 1.5 variants if "lcm" in base_lower and ("sd 1.5" in base_lower or "sd 1.4" in base_lower): return "sd15_lcm" @@ -652,6 +675,9 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str | return "flux_schnell" if "flux" in name_lower: return "flux" + # ZImageTurbo + if "zimage" in name_lower: + return "zimage" # SDXL variants if "lightning" in name_lower and any(x in name_lower for x in ["sdxl", "xl"]): return "sdxl_lightning" diff --git a/tensors/server/comfyui_api_routes.py b/tensors/server/comfyui_api_routes.py index 415bd2c..3be8ce7 100644 --- a/tensors/server/comfyui_api_routes.py +++ b/tensors/server/comfyui_api_routes.py @@ -236,6 +236,7 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]: scheduler = request.scheduler steps = request.steps cfg = request.cfg + vae = request.vae if request.model: # Look up base_model from database for better family detection @@ -259,14 +260,17 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]: steps = family_defaults["steps"] if request.cfg == 7.0: # Default value in schema cfg = family_defaults["cfg"] + if request.vae is None: # No VAE specified, use family default + vae = family_defaults.get("vae") # None means use checkpoint VAE logger.debug( - "Detected model family: %s (sampler=%s, scheduler=%s, steps=%d, cfg=%.1f)", + "Detected model family: %s (sampler=%s, scheduler=%s, steps=%d, cfg=%.1f, vae=%s)", detected_family, sampler, scheduler, steps, cfg, + vae or "checkpoint", ) lora_info = f", lora={request.lora_name}@{request.lora_strength}" if request.lora_name else "" @@ -295,7 +299,7 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]: seed=request.seed, sampler=sampler, scheduler=scheduler, - vae=request.vae, + vae=vae, lora_name=request.lora_name, lora_strength=request.lora_strength, ) diff --git a/tests/test_tensors.py b/tests/test_tensors.py index 1c6afc7..fcdb285 100644 --- a/tests/test_tensors.py +++ b/tests/test_tensors.py @@ -377,6 +377,36 @@ class TestModelFamilyDetection: assert defaults["steps"] == 4 assert defaults["cfg"] == 1.0 + def test_detect_zimage(self) -> None: + """Test detecting ZImageTurbo family.""" + from tensors.config import detect_model_family + + assert detect_model_family("zimageturbo_v1.safetensors") == "zimage" + assert detect_model_family("ZIMAGE_xl.safetensors") == "zimage" + assert detect_model_family("model.safetensors", "ZImageTurbo") == "zimage" + + def test_get_model_generation_defaults_zimage(self) -> None: + """Test getting generation defaults for ZImageTurbo models.""" + from tensors.config import get_model_generation_defaults + + defaults = get_model_generation_defaults("zimageturbo_v1.safetensors") + assert defaults["family"] == "zimage" + assert defaults["sampler"] == "euler" + assert defaults["scheduler"] == "simple" + assert defaults["steps"] == 4 + assert defaults["cfg"] == 1.0 + assert defaults["vae"] == "ae.safetensors" + + def test_flux_uses_ae_vae(self) -> None: + """Test that Flux models use ae.safetensors VAE.""" + from tensors.config import get_model_generation_defaults + + defaults = get_model_generation_defaults("flux1-dev-fp8.safetensors") + assert defaults["vae"] == "ae.safetensors" + + defaults_schnell = get_model_generation_defaults("flux1-schnell.safetensors") + assert defaults_schnell["vae"] == "ae.safetensors" + def test_get_model_generation_defaults_sdxl_lightning(self) -> None: """Test getting generation defaults for SDXL Lightning models.""" from tensors.config import get_model_generation_defaults