Add model family-specific sampler, scheduler, and VAE defaults

- Add sampler/scheduler/steps/vae to MODEL_FAMILY_DEFAULTS for all families
- Add zimage family detection for ZImageTurbo models
- Flux and zimage families use ae.safetensors VAE
- SD 1.5 families use checkpoint built-in VAE
- SDXL families use sdxl_vae.safetensors
- API auto-applies family defaults when request uses default values

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
aladac
2026-03-20 09:21:50 +01:00
parent 3432e5cb99
commit 2a704aa677
4 changed files with 72 additions and 5 deletions
+9 -2
View File
@@ -775,8 +775,15 @@ def _build_workflow(
workflow["6"]["inputs"]["text"] = prompt workflow["6"]["inputs"]["text"] = prompt
workflow["7"]["inputs"]["text"] = negative_prompt workflow["7"]["inputs"]["text"] = negative_prompt
# Set VAE # Set VAE - use external VAE if specified, otherwise use checkpoint's built-in VAE
workflow["11"]["inputs"]["vae_name"] = vae or DEFAULT_VAE if vae:
# Use external VAE loader (node 11)
workflow["11"]["inputs"]["vae_name"] = vae
else:
# Use VAE from checkpoint (node 4, output index 2) - works for SD 1.5 models
# Remove VAELoader node and connect VAEDecode directly to checkpoint
del workflow["11"]
workflow["8"]["inputs"]["vae"] = ["4", 2]
# Inject LoRA loader if specified # Inject LoRA loader if specified
if lora_name: if lora_name:
+27 -1
View File
@@ -517,6 +517,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
"sampler": "euler_ancestral", "sampler": "euler_ancestral",
"scheduler": "normal", "scheduler": "normal",
"steps": 25, "steps": 25,
"vae": "sdxl_vae.safetensors",
}, },
"illustrious": { "illustrious": {
"quality_prefix": "masterpiece, best quality, highres", "quality_prefix": "masterpiece, best quality, highres",
@@ -527,6 +528,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
"sampler": "euler_ancestral", "sampler": "euler_ancestral",
"scheduler": "normal", "scheduler": "normal",
"steps": 25, "steps": 25,
"vae": "sdxl_vae.safetensors",
}, },
"sdxl": { "sdxl": {
"quality_prefix": "", "quality_prefix": "",
@@ -537,6 +539,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
"sampler": "dpmpp_2m", "sampler": "dpmpp_2m",
"scheduler": "karras", "scheduler": "karras",
"steps": 25, "steps": 25,
"vae": "sdxl_vae.safetensors",
}, },
"sdxl_lightning": { "sdxl_lightning": {
"quality_prefix": "", "quality_prefix": "",
@@ -547,6 +550,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
"sampler": "euler", "sampler": "euler",
"scheduler": "sgm_uniform", "scheduler": "sgm_uniform",
"steps": 8, # Lightning models use fewer steps "steps": 8, # Lightning models use fewer steps
"vae": "sdxl_vae.safetensors",
}, },
"sdxl_turbo": { "sdxl_turbo": {
"quality_prefix": "", "quality_prefix": "",
@@ -557,6 +561,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
"sampler": "euler_ancestral", "sampler": "euler_ancestral",
"scheduler": "normal", "scheduler": "normal",
"steps": 4, # Turbo models use very few steps "steps": 4, # Turbo models use very few steps
"vae": "sdxl_vae.safetensors",
}, },
"sd15": { "sd15": {
"quality_prefix": "masterpiece, best quality", "quality_prefix": "masterpiece, best quality",
@@ -570,6 +575,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
"sampler": "dpmpp_2m", "sampler": "dpmpp_2m",
"scheduler": "karras", "scheduler": "karras",
"steps": 20, "steps": 20,
"vae": None, # Use checkpoint's built-in VAE
}, },
"sd15_lcm": { "sd15_lcm": {
"quality_prefix": "masterpiece, best quality", "quality_prefix": "masterpiece, best quality",
@@ -580,6 +586,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
"sampler": "lcm", "sampler": "lcm",
"scheduler": "normal", "scheduler": "normal",
"steps": 6, "steps": 6,
"vae": None, # Use checkpoint's built-in VAE
}, },
"flux": { "flux": {
"quality_prefix": "", "quality_prefix": "",
@@ -590,6 +597,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
"sampler": "euler", "sampler": "euler",
"scheduler": "simple", "scheduler": "simple",
"steps": 20, "steps": 20,
"vae": "ae.safetensors",
}, },
"flux_schnell": { "flux_schnell": {
"quality_prefix": "", "quality_prefix": "",
@@ -600,6 +608,18 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
"sampler": "euler", "sampler": "euler",
"scheduler": "simple", "scheduler": "simple",
"steps": 4, # Schnell is a distilled model, very few steps "steps": 4, # Schnell is a distilled model, very few steps
"vae": "ae.safetensors",
},
"zimage": {
"quality_prefix": "",
"negative_prompt": "", # Turbo models work best without negative prompts
"width": 1024,
"height": 1024,
"cfg": 1.0, # Very low CFG for turbo
"sampler": "euler",
"scheduler": "simple",
"steps": 4, # ZImageTurbo is a distilled model
"vae": "ae.safetensors",
}, },
} }
@@ -613,7 +633,7 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str |
Returns: Returns:
Model family key (pony, illustrious, sdxl, sdxl_lightning, sdxl_turbo, Model family key (pony, illustrious, sdxl, sdxl_lightning, sdxl_turbo,
sd15, sd15_lcm, flux, flux_schnell) or None if unknown sd15, sd15_lcm, flux, flux_schnell, zimage) or None if unknown
""" """
name_lower = model_name.lower() name_lower = model_name.lower()
base_lower = (base_model or "").lower() base_lower = (base_model or "").lower()
@@ -629,6 +649,9 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str |
return "flux_schnell" return "flux_schnell"
if "flux" in base_lower: if "flux" in base_lower:
return "flux" return "flux"
# ZImageTurbo
if "zimage" in base_lower:
return "zimage"
# SD 1.5 variants # SD 1.5 variants
if "lcm" in base_lower and ("sd 1.5" in base_lower or "sd 1.4" in base_lower): if "lcm" in base_lower and ("sd 1.5" in base_lower or "sd 1.4" in base_lower):
return "sd15_lcm" return "sd15_lcm"
@@ -652,6 +675,9 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str |
return "flux_schnell" return "flux_schnell"
if "flux" in name_lower: if "flux" in name_lower:
return "flux" return "flux"
# ZImageTurbo
if "zimage" in name_lower:
return "zimage"
# SDXL variants # SDXL variants
if "lightning" in name_lower and any(x in name_lower for x in ["sdxl", "xl"]): if "lightning" in name_lower and any(x in name_lower for x in ["sdxl", "xl"]):
return "sdxl_lightning" return "sdxl_lightning"
+6 -2
View File
@@ -236,6 +236,7 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
scheduler = request.scheduler scheduler = request.scheduler
steps = request.steps steps = request.steps
cfg = request.cfg cfg = request.cfg
vae = request.vae
if request.model: if request.model:
# Look up base_model from database for better family detection # Look up base_model from database for better family detection
@@ -259,14 +260,17 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
steps = family_defaults["steps"] steps = family_defaults["steps"]
if request.cfg == 7.0: # Default value in schema if request.cfg == 7.0: # Default value in schema
cfg = family_defaults["cfg"] cfg = family_defaults["cfg"]
if request.vae is None: # No VAE specified, use family default
vae = family_defaults.get("vae") # None means use checkpoint VAE
logger.debug( logger.debug(
"Detected model family: %s (sampler=%s, scheduler=%s, steps=%d, cfg=%.1f)", "Detected model family: %s (sampler=%s, scheduler=%s, steps=%d, cfg=%.1f, vae=%s)",
detected_family, detected_family,
sampler, sampler,
scheduler, scheduler,
steps, steps,
cfg, cfg,
vae or "checkpoint",
) )
lora_info = f", lora={request.lora_name}@{request.lora_strength}" if request.lora_name else "" lora_info = f", lora={request.lora_name}@{request.lora_strength}" if request.lora_name else ""
@@ -295,7 +299,7 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
seed=request.seed, seed=request.seed,
sampler=sampler, sampler=sampler,
scheduler=scheduler, scheduler=scheduler,
vae=request.vae, vae=vae,
lora_name=request.lora_name, lora_name=request.lora_name,
lora_strength=request.lora_strength, lora_strength=request.lora_strength,
) )
+30
View File
@@ -377,6 +377,36 @@ class TestModelFamilyDetection:
assert defaults["steps"] == 4 assert defaults["steps"] == 4
assert defaults["cfg"] == 1.0 assert defaults["cfg"] == 1.0
def test_detect_zimage(self) -> None:
"""Test detecting ZImageTurbo family."""
from tensors.config import detect_model_family
assert detect_model_family("zimageturbo_v1.safetensors") == "zimage"
assert detect_model_family("ZIMAGE_xl.safetensors") == "zimage"
assert detect_model_family("model.safetensors", "ZImageTurbo") == "zimage"
def test_get_model_generation_defaults_zimage(self) -> None:
"""Test getting generation defaults for ZImageTurbo models."""
from tensors.config import get_model_generation_defaults
defaults = get_model_generation_defaults("zimageturbo_v1.safetensors")
assert defaults["family"] == "zimage"
assert defaults["sampler"] == "euler"
assert defaults["scheduler"] == "simple"
assert defaults["steps"] == 4
assert defaults["cfg"] == 1.0
assert defaults["vae"] == "ae.safetensors"
def test_flux_uses_ae_vae(self) -> None:
"""Test that Flux models use ae.safetensors VAE."""
from tensors.config import get_model_generation_defaults
defaults = get_model_generation_defaults("flux1-dev-fp8.safetensors")
assert defaults["vae"] == "ae.safetensors"
defaults_schnell = get_model_generation_defaults("flux1-schnell.safetensors")
assert defaults_schnell["vae"] == "ae.safetensors"
def test_get_model_generation_defaults_sdxl_lightning(self) -> None: def test_get_model_generation_defaults_sdxl_lightning(self) -> None:
"""Test getting generation defaults for SDXL Lightning models.""" """Test getting generation defaults for SDXL Lightning models."""
from tensors.config import get_model_generation_defaults from tensors.config import get_model_generation_defaults