Add model family-specific sampler, scheduler, and VAE defaults
- Add sampler/scheduler/steps/vae to MODEL_FAMILY_DEFAULTS for all families - Add zimage family detection for ZImageTurbo models - Flux and zimage families use ae.safetensors VAE - SD 1.5 families use checkpoint built-in VAE - SDXL families use sdxl_vae.safetensors - API auto-applies family defaults when request uses default values Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
+9
-2
@@ -775,8 +775,15 @@ def _build_workflow(
|
||||
workflow["6"]["inputs"]["text"] = prompt
|
||||
workflow["7"]["inputs"]["text"] = negative_prompt
|
||||
|
||||
# Set VAE
|
||||
workflow["11"]["inputs"]["vae_name"] = vae or DEFAULT_VAE
|
||||
# Set VAE - use external VAE if specified, otherwise use checkpoint's built-in VAE
|
||||
if vae:
|
||||
# Use external VAE loader (node 11)
|
||||
workflow["11"]["inputs"]["vae_name"] = vae
|
||||
else:
|
||||
# Use VAE from checkpoint (node 4, output index 2) - works for SD 1.5 models
|
||||
# Remove VAELoader node and connect VAEDecode directly to checkpoint
|
||||
del workflow["11"]
|
||||
workflow["8"]["inputs"]["vae"] = ["4", 2]
|
||||
|
||||
# Inject LoRA loader if specified
|
||||
if lora_name:
|
||||
|
||||
+27
-1
@@ -517,6 +517,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||
"sampler": "euler_ancestral",
|
||||
"scheduler": "normal",
|
||||
"steps": 25,
|
||||
"vae": "sdxl_vae.safetensors",
|
||||
},
|
||||
"illustrious": {
|
||||
"quality_prefix": "masterpiece, best quality, highres",
|
||||
@@ -527,6 +528,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||
"sampler": "euler_ancestral",
|
||||
"scheduler": "normal",
|
||||
"steps": 25,
|
||||
"vae": "sdxl_vae.safetensors",
|
||||
},
|
||||
"sdxl": {
|
||||
"quality_prefix": "",
|
||||
@@ -537,6 +539,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||
"sampler": "dpmpp_2m",
|
||||
"scheduler": "karras",
|
||||
"steps": 25,
|
||||
"vae": "sdxl_vae.safetensors",
|
||||
},
|
||||
"sdxl_lightning": {
|
||||
"quality_prefix": "",
|
||||
@@ -547,6 +550,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||
"sampler": "euler",
|
||||
"scheduler": "sgm_uniform",
|
||||
"steps": 8, # Lightning models use fewer steps
|
||||
"vae": "sdxl_vae.safetensors",
|
||||
},
|
||||
"sdxl_turbo": {
|
||||
"quality_prefix": "",
|
||||
@@ -557,6 +561,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||
"sampler": "euler_ancestral",
|
||||
"scheduler": "normal",
|
||||
"steps": 4, # Turbo models use very few steps
|
||||
"vae": "sdxl_vae.safetensors",
|
||||
},
|
||||
"sd15": {
|
||||
"quality_prefix": "masterpiece, best quality",
|
||||
@@ -570,6 +575,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||
"sampler": "dpmpp_2m",
|
||||
"scheduler": "karras",
|
||||
"steps": 20,
|
||||
"vae": None, # Use checkpoint's built-in VAE
|
||||
},
|
||||
"sd15_lcm": {
|
||||
"quality_prefix": "masterpiece, best quality",
|
||||
@@ -580,6 +586,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||
"sampler": "lcm",
|
||||
"scheduler": "normal",
|
||||
"steps": 6,
|
||||
"vae": None, # Use checkpoint's built-in VAE
|
||||
},
|
||||
"flux": {
|
||||
"quality_prefix": "",
|
||||
@@ -590,6 +597,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||
"sampler": "euler",
|
||||
"scheduler": "simple",
|
||||
"steps": 20,
|
||||
"vae": "ae.safetensors",
|
||||
},
|
||||
"flux_schnell": {
|
||||
"quality_prefix": "",
|
||||
@@ -600,6 +608,18 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||
"sampler": "euler",
|
||||
"scheduler": "simple",
|
||||
"steps": 4, # Schnell is a distilled model, very few steps
|
||||
"vae": "ae.safetensors",
|
||||
},
|
||||
"zimage": {
|
||||
"quality_prefix": "",
|
||||
"negative_prompt": "", # Turbo models work best without negative prompts
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"cfg": 1.0, # Very low CFG for turbo
|
||||
"sampler": "euler",
|
||||
"scheduler": "simple",
|
||||
"steps": 4, # ZImageTurbo is a distilled model
|
||||
"vae": "ae.safetensors",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -613,7 +633,7 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str |
|
||||
|
||||
Returns:
|
||||
Model family key (pony, illustrious, sdxl, sdxl_lightning, sdxl_turbo,
|
||||
sd15, sd15_lcm, flux, flux_schnell) or None if unknown
|
||||
sd15, sd15_lcm, flux, flux_schnell, zimage) or None if unknown
|
||||
"""
|
||||
name_lower = model_name.lower()
|
||||
base_lower = (base_model or "").lower()
|
||||
@@ -629,6 +649,9 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str |
|
||||
return "flux_schnell"
|
||||
if "flux" in base_lower:
|
||||
return "flux"
|
||||
# ZImageTurbo
|
||||
if "zimage" in base_lower:
|
||||
return "zimage"
|
||||
# SD 1.5 variants
|
||||
if "lcm" in base_lower and ("sd 1.5" in base_lower or "sd 1.4" in base_lower):
|
||||
return "sd15_lcm"
|
||||
@@ -652,6 +675,9 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str |
|
||||
return "flux_schnell"
|
||||
if "flux" in name_lower:
|
||||
return "flux"
|
||||
# ZImageTurbo
|
||||
if "zimage" in name_lower:
|
||||
return "zimage"
|
||||
# SDXL variants
|
||||
if "lightning" in name_lower and any(x in name_lower for x in ["sdxl", "xl"]):
|
||||
return "sdxl_lightning"
|
||||
|
||||
@@ -236,6 +236,7 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
|
||||
scheduler = request.scheduler
|
||||
steps = request.steps
|
||||
cfg = request.cfg
|
||||
vae = request.vae
|
||||
|
||||
if request.model:
|
||||
# Look up base_model from database for better family detection
|
||||
@@ -259,14 +260,17 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
|
||||
steps = family_defaults["steps"]
|
||||
if request.cfg == 7.0: # Default value in schema
|
||||
cfg = family_defaults["cfg"]
|
||||
if request.vae is None: # No VAE specified, use family default
|
||||
vae = family_defaults.get("vae") # None means use checkpoint VAE
|
||||
|
||||
logger.debug(
|
||||
"Detected model family: %s (sampler=%s, scheduler=%s, steps=%d, cfg=%.1f)",
|
||||
"Detected model family: %s (sampler=%s, scheduler=%s, steps=%d, cfg=%.1f, vae=%s)",
|
||||
detected_family,
|
||||
sampler,
|
||||
scheduler,
|
||||
steps,
|
||||
cfg,
|
||||
vae or "checkpoint",
|
||||
)
|
||||
|
||||
lora_info = f", lora={request.lora_name}@{request.lora_strength}" if request.lora_name else ""
|
||||
@@ -295,7 +299,7 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
|
||||
seed=request.seed,
|
||||
sampler=sampler,
|
||||
scheduler=scheduler,
|
||||
vae=request.vae,
|
||||
vae=vae,
|
||||
lora_name=request.lora_name,
|
||||
lora_strength=request.lora_strength,
|
||||
)
|
||||
|
||||
@@ -377,6 +377,36 @@ class TestModelFamilyDetection:
|
||||
assert defaults["steps"] == 4
|
||||
assert defaults["cfg"] == 1.0
|
||||
|
||||
def test_detect_zimage(self) -> None:
|
||||
"""Test detecting ZImageTurbo family."""
|
||||
from tensors.config import detect_model_family
|
||||
|
||||
assert detect_model_family("zimageturbo_v1.safetensors") == "zimage"
|
||||
assert detect_model_family("ZIMAGE_xl.safetensors") == "zimage"
|
||||
assert detect_model_family("model.safetensors", "ZImageTurbo") == "zimage"
|
||||
|
||||
def test_get_model_generation_defaults_zimage(self) -> None:
|
||||
"""Test getting generation defaults for ZImageTurbo models."""
|
||||
from tensors.config import get_model_generation_defaults
|
||||
|
||||
defaults = get_model_generation_defaults("zimageturbo_v1.safetensors")
|
||||
assert defaults["family"] == "zimage"
|
||||
assert defaults["sampler"] == "euler"
|
||||
assert defaults["scheduler"] == "simple"
|
||||
assert defaults["steps"] == 4
|
||||
assert defaults["cfg"] == 1.0
|
||||
assert defaults["vae"] == "ae.safetensors"
|
||||
|
||||
def test_flux_uses_ae_vae(self) -> None:
|
||||
"""Test that Flux models use ae.safetensors VAE."""
|
||||
from tensors.config import get_model_generation_defaults
|
||||
|
||||
defaults = get_model_generation_defaults("flux1-dev-fp8.safetensors")
|
||||
assert defaults["vae"] == "ae.safetensors"
|
||||
|
||||
defaults_schnell = get_model_generation_defaults("flux1-schnell.safetensors")
|
||||
assert defaults_schnell["vae"] == "ae.safetensors"
|
||||
|
||||
def test_get_model_generation_defaults_sdxl_lightning(self) -> None:
|
||||
"""Test getting generation defaults for SDXL Lightning models."""
|
||||
from tensors.config import get_model_generation_defaults
|
||||
|
||||
Reference in New Issue
Block a user