Add model family-specific sampler, scheduler, and VAE defaults
- Add sampler/scheduler/steps/vae to MODEL_FAMILY_DEFAULTS for all families - Add zimage family detection for ZImageTurbo models - Flux and zimage families use ae.safetensors VAE - SD 1.5 families use checkpoint built-in VAE - SDXL families use sdxl_vae.safetensors - API auto-applies family defaults when request uses default values Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
+9
-2
@@ -775,8 +775,15 @@ def _build_workflow(
|
|||||||
workflow["6"]["inputs"]["text"] = prompt
|
workflow["6"]["inputs"]["text"] = prompt
|
||||||
workflow["7"]["inputs"]["text"] = negative_prompt
|
workflow["7"]["inputs"]["text"] = negative_prompt
|
||||||
|
|
||||||
# Set VAE
|
# Set VAE - use external VAE if specified, otherwise use checkpoint's built-in VAE
|
||||||
workflow["11"]["inputs"]["vae_name"] = vae or DEFAULT_VAE
|
if vae:
|
||||||
|
# Use external VAE loader (node 11)
|
||||||
|
workflow["11"]["inputs"]["vae_name"] = vae
|
||||||
|
else:
|
||||||
|
# Use VAE from checkpoint (node 4, output index 2) - works for SD 1.5 models
|
||||||
|
# Remove VAELoader node and connect VAEDecode directly to checkpoint
|
||||||
|
del workflow["11"]
|
||||||
|
workflow["8"]["inputs"]["vae"] = ["4", 2]
|
||||||
|
|
||||||
# Inject LoRA loader if specified
|
# Inject LoRA loader if specified
|
||||||
if lora_name:
|
if lora_name:
|
||||||
|
|||||||
+27
-1
@@ -517,6 +517,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
|||||||
"sampler": "euler_ancestral",
|
"sampler": "euler_ancestral",
|
||||||
"scheduler": "normal",
|
"scheduler": "normal",
|
||||||
"steps": 25,
|
"steps": 25,
|
||||||
|
"vae": "sdxl_vae.safetensors",
|
||||||
},
|
},
|
||||||
"illustrious": {
|
"illustrious": {
|
||||||
"quality_prefix": "masterpiece, best quality, highres",
|
"quality_prefix": "masterpiece, best quality, highres",
|
||||||
@@ -527,6 +528,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
|||||||
"sampler": "euler_ancestral",
|
"sampler": "euler_ancestral",
|
||||||
"scheduler": "normal",
|
"scheduler": "normal",
|
||||||
"steps": 25,
|
"steps": 25,
|
||||||
|
"vae": "sdxl_vae.safetensors",
|
||||||
},
|
},
|
||||||
"sdxl": {
|
"sdxl": {
|
||||||
"quality_prefix": "",
|
"quality_prefix": "",
|
||||||
@@ -537,6 +539,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
|||||||
"sampler": "dpmpp_2m",
|
"sampler": "dpmpp_2m",
|
||||||
"scheduler": "karras",
|
"scheduler": "karras",
|
||||||
"steps": 25,
|
"steps": 25,
|
||||||
|
"vae": "sdxl_vae.safetensors",
|
||||||
},
|
},
|
||||||
"sdxl_lightning": {
|
"sdxl_lightning": {
|
||||||
"quality_prefix": "",
|
"quality_prefix": "",
|
||||||
@@ -547,6 +550,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
|||||||
"sampler": "euler",
|
"sampler": "euler",
|
||||||
"scheduler": "sgm_uniform",
|
"scheduler": "sgm_uniform",
|
||||||
"steps": 8, # Lightning models use fewer steps
|
"steps": 8, # Lightning models use fewer steps
|
||||||
|
"vae": "sdxl_vae.safetensors",
|
||||||
},
|
},
|
||||||
"sdxl_turbo": {
|
"sdxl_turbo": {
|
||||||
"quality_prefix": "",
|
"quality_prefix": "",
|
||||||
@@ -557,6 +561,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
|||||||
"sampler": "euler_ancestral",
|
"sampler": "euler_ancestral",
|
||||||
"scheduler": "normal",
|
"scheduler": "normal",
|
||||||
"steps": 4, # Turbo models use very few steps
|
"steps": 4, # Turbo models use very few steps
|
||||||
|
"vae": "sdxl_vae.safetensors",
|
||||||
},
|
},
|
||||||
"sd15": {
|
"sd15": {
|
||||||
"quality_prefix": "masterpiece, best quality",
|
"quality_prefix": "masterpiece, best quality",
|
||||||
@@ -570,6 +575,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
|||||||
"sampler": "dpmpp_2m",
|
"sampler": "dpmpp_2m",
|
||||||
"scheduler": "karras",
|
"scheduler": "karras",
|
||||||
"steps": 20,
|
"steps": 20,
|
||||||
|
"vae": None, # Use checkpoint's built-in VAE
|
||||||
},
|
},
|
||||||
"sd15_lcm": {
|
"sd15_lcm": {
|
||||||
"quality_prefix": "masterpiece, best quality",
|
"quality_prefix": "masterpiece, best quality",
|
||||||
@@ -580,6 +586,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
|||||||
"sampler": "lcm",
|
"sampler": "lcm",
|
||||||
"scheduler": "normal",
|
"scheduler": "normal",
|
||||||
"steps": 6,
|
"steps": 6,
|
||||||
|
"vae": None, # Use checkpoint's built-in VAE
|
||||||
},
|
},
|
||||||
"flux": {
|
"flux": {
|
||||||
"quality_prefix": "",
|
"quality_prefix": "",
|
||||||
@@ -590,6 +597,7 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
|||||||
"sampler": "euler",
|
"sampler": "euler",
|
||||||
"scheduler": "simple",
|
"scheduler": "simple",
|
||||||
"steps": 20,
|
"steps": 20,
|
||||||
|
"vae": "ae.safetensors",
|
||||||
},
|
},
|
||||||
"flux_schnell": {
|
"flux_schnell": {
|
||||||
"quality_prefix": "",
|
"quality_prefix": "",
|
||||||
@@ -600,6 +608,18 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
|||||||
"sampler": "euler",
|
"sampler": "euler",
|
||||||
"scheduler": "simple",
|
"scheduler": "simple",
|
||||||
"steps": 4, # Schnell is a distilled model, very few steps
|
"steps": 4, # Schnell is a distilled model, very few steps
|
||||||
|
"vae": "ae.safetensors",
|
||||||
|
},
|
||||||
|
"zimage": {
|
||||||
|
"quality_prefix": "",
|
||||||
|
"negative_prompt": "", # Turbo models work best without negative prompts
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024,
|
||||||
|
"cfg": 1.0, # Very low CFG for turbo
|
||||||
|
"sampler": "euler",
|
||||||
|
"scheduler": "simple",
|
||||||
|
"steps": 4, # ZImageTurbo is a distilled model
|
||||||
|
"vae": "ae.safetensors",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -613,7 +633,7 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str |
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Model family key (pony, illustrious, sdxl, sdxl_lightning, sdxl_turbo,
|
Model family key (pony, illustrious, sdxl, sdxl_lightning, sdxl_turbo,
|
||||||
sd15, sd15_lcm, flux, flux_schnell) or None if unknown
|
sd15, sd15_lcm, flux, flux_schnell, zimage) or None if unknown
|
||||||
"""
|
"""
|
||||||
name_lower = model_name.lower()
|
name_lower = model_name.lower()
|
||||||
base_lower = (base_model or "").lower()
|
base_lower = (base_model or "").lower()
|
||||||
@@ -629,6 +649,9 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str |
|
|||||||
return "flux_schnell"
|
return "flux_schnell"
|
||||||
if "flux" in base_lower:
|
if "flux" in base_lower:
|
||||||
return "flux"
|
return "flux"
|
||||||
|
# ZImageTurbo
|
||||||
|
if "zimage" in base_lower:
|
||||||
|
return "zimage"
|
||||||
# SD 1.5 variants
|
# SD 1.5 variants
|
||||||
if "lcm" in base_lower and ("sd 1.5" in base_lower or "sd 1.4" in base_lower):
|
if "lcm" in base_lower and ("sd 1.5" in base_lower or "sd 1.4" in base_lower):
|
||||||
return "sd15_lcm"
|
return "sd15_lcm"
|
||||||
@@ -652,6 +675,9 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str |
|
|||||||
return "flux_schnell"
|
return "flux_schnell"
|
||||||
if "flux" in name_lower:
|
if "flux" in name_lower:
|
||||||
return "flux"
|
return "flux"
|
||||||
|
# ZImageTurbo
|
||||||
|
if "zimage" in name_lower:
|
||||||
|
return "zimage"
|
||||||
# SDXL variants
|
# SDXL variants
|
||||||
if "lightning" in name_lower and any(x in name_lower for x in ["sdxl", "xl"]):
|
if "lightning" in name_lower and any(x in name_lower for x in ["sdxl", "xl"]):
|
||||||
return "sdxl_lightning"
|
return "sdxl_lightning"
|
||||||
|
|||||||
@@ -236,6 +236,7 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
|
|||||||
scheduler = request.scheduler
|
scheduler = request.scheduler
|
||||||
steps = request.steps
|
steps = request.steps
|
||||||
cfg = request.cfg
|
cfg = request.cfg
|
||||||
|
vae = request.vae
|
||||||
|
|
||||||
if request.model:
|
if request.model:
|
||||||
# Look up base_model from database for better family detection
|
# Look up base_model from database for better family detection
|
||||||
@@ -259,14 +260,17 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
|
|||||||
steps = family_defaults["steps"]
|
steps = family_defaults["steps"]
|
||||||
if request.cfg == 7.0: # Default value in schema
|
if request.cfg == 7.0: # Default value in schema
|
||||||
cfg = family_defaults["cfg"]
|
cfg = family_defaults["cfg"]
|
||||||
|
if request.vae is None: # No VAE specified, use family default
|
||||||
|
vae = family_defaults.get("vae") # None means use checkpoint VAE
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Detected model family: %s (sampler=%s, scheduler=%s, steps=%d, cfg=%.1f)",
|
"Detected model family: %s (sampler=%s, scheduler=%s, steps=%d, cfg=%.1f, vae=%s)",
|
||||||
detected_family,
|
detected_family,
|
||||||
sampler,
|
sampler,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
cfg,
|
cfg,
|
||||||
|
vae or "checkpoint",
|
||||||
)
|
)
|
||||||
|
|
||||||
lora_info = f", lora={request.lora_name}@{request.lora_strength}" if request.lora_name else ""
|
lora_info = f", lora={request.lora_name}@{request.lora_strength}" if request.lora_name else ""
|
||||||
@@ -295,7 +299,7 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
|
|||||||
seed=request.seed,
|
seed=request.seed,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
vae=request.vae,
|
vae=vae,
|
||||||
lora_name=request.lora_name,
|
lora_name=request.lora_name,
|
||||||
lora_strength=request.lora_strength,
|
lora_strength=request.lora_strength,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -377,6 +377,36 @@ class TestModelFamilyDetection:
|
|||||||
assert defaults["steps"] == 4
|
assert defaults["steps"] == 4
|
||||||
assert defaults["cfg"] == 1.0
|
assert defaults["cfg"] == 1.0
|
||||||
|
|
||||||
|
def test_detect_zimage(self) -> None:
|
||||||
|
"""Test detecting ZImageTurbo family."""
|
||||||
|
from tensors.config import detect_model_family
|
||||||
|
|
||||||
|
assert detect_model_family("zimageturbo_v1.safetensors") == "zimage"
|
||||||
|
assert detect_model_family("ZIMAGE_xl.safetensors") == "zimage"
|
||||||
|
assert detect_model_family("model.safetensors", "ZImageTurbo") == "zimage"
|
||||||
|
|
||||||
|
def test_get_model_generation_defaults_zimage(self) -> None:
|
||||||
|
"""Test getting generation defaults for ZImageTurbo models."""
|
||||||
|
from tensors.config import get_model_generation_defaults
|
||||||
|
|
||||||
|
defaults = get_model_generation_defaults("zimageturbo_v1.safetensors")
|
||||||
|
assert defaults["family"] == "zimage"
|
||||||
|
assert defaults["sampler"] == "euler"
|
||||||
|
assert defaults["scheduler"] == "simple"
|
||||||
|
assert defaults["steps"] == 4
|
||||||
|
assert defaults["cfg"] == 1.0
|
||||||
|
assert defaults["vae"] == "ae.safetensors"
|
||||||
|
|
||||||
|
def test_flux_uses_ae_vae(self) -> None:
|
||||||
|
"""Test that Flux models use ae.safetensors VAE."""
|
||||||
|
from tensors.config import get_model_generation_defaults
|
||||||
|
|
||||||
|
defaults = get_model_generation_defaults("flux1-dev-fp8.safetensors")
|
||||||
|
assert defaults["vae"] == "ae.safetensors"
|
||||||
|
|
||||||
|
defaults_schnell = get_model_generation_defaults("flux1-schnell.safetensors")
|
||||||
|
assert defaults_schnell["vae"] == "ae.safetensors"
|
||||||
|
|
||||||
def test_get_model_generation_defaults_sdxl_lightning(self) -> None:
|
def test_get_model_generation_defaults_sdxl_lightning(self) -> None:
|
||||||
"""Test getting generation defaults for SDXL Lightning models."""
|
"""Test getting generation defaults for SDXL Lightning models."""
|
||||||
from tensors.config import get_model_generation_defaults
|
from tensors.config import get_model_generation_defaults
|
||||||
|
|||||||
Reference in New Issue
Block a user