Update 2026-03-20 09:07

This commit is contained in:
aladac
2026-03-20 09:07:19 +01:00
parent 420d260936
commit 372133edcc
4 changed files with 282 additions and 8 deletions
BIN
View File
Binary file not shown.
+119 -2
View File
@@ -514,6 +514,9 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
"height": 1024, "height": 1024,
"cfg": 7.0, "cfg": 7.0,
"clip_skip": 2, "clip_skip": 2,
"sampler": "euler_ancestral",
"scheduler": "normal",
"steps": 25,
}, },
"illustrious": { "illustrious": {
"quality_prefix": "masterpiece, best quality, highres", "quality_prefix": "masterpiece, best quality, highres",
@@ -521,6 +524,9 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
"width": 1024, "width": 1024,
"height": 1024, "height": 1024,
"cfg": 6.0, "cfg": 6.0,
"sampler": "euler_ancestral",
"scheduler": "normal",
"steps": 25,
}, },
"sdxl": { "sdxl": {
"quality_prefix": "", "quality_prefix": "",
@@ -528,6 +534,29 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
"width": 1024, "width": 1024,
"height": 1024, "height": 1024,
"cfg": 7.0, "cfg": 7.0,
"sampler": "dpmpp_2m",
"scheduler": "karras",
"steps": 25,
},
"sdxl_lightning": {
"quality_prefix": "",
"negative_prompt": "ugly, deformed, bad anatomy, bad hands, extra fingers, missing fingers, blurry, watermark",
"width": 1024,
"height": 1024,
"cfg": 2.0,
"sampler": "euler",
"scheduler": "sgm_uniform",
"steps": 8, # Lightning models use fewer steps
},
"sdxl_turbo": {
"quality_prefix": "",
"negative_prompt": "", # Turbo models work best without negative prompts
"width": 1024,
"height": 1024,
"cfg": 1.0, # Very low CFG for turbo
"sampler": "euler_ancestral",
"scheduler": "normal",
"steps": 4, # Turbo models use very few steps
}, },
"sd15": { "sd15": {
"quality_prefix": "masterpiece, best quality", "quality_prefix": "masterpiece, best quality",
@@ -538,6 +567,19 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
"width": 512, "width": 512,
"height": 512, "height": 512,
"cfg": 7.0, "cfg": 7.0,
"sampler": "dpmpp_2m",
"scheduler": "karras",
"steps": 20,
},
"sd15_lcm": {
"quality_prefix": "masterpiece, best quality",
"negative_prompt": "", # LCM works best with minimal negative
"width": 512,
"height": 512,
"cfg": 1.5,
"sampler": "lcm",
"scheduler": "normal",
"steps": 6,
}, },
"flux": { "flux": {
"quality_prefix": "", "quality_prefix": "",
@@ -545,6 +587,19 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
"width": 1024, "width": 1024,
"height": 1024, "height": 1024,
"cfg": 3.5, "cfg": 3.5,
"sampler": "euler",
"scheduler": "simple",
"steps": 20,
},
"flux_schnell": {
"quality_prefix": "",
"negative_prompt": "",
"width": 1024,
"height": 1024,
"cfg": 1.0, # Schnell uses low CFG
"sampler": "euler",
"scheduler": "simple",
"steps": 4, # Schnell is a distilled model, very few steps
}, },
} }
@@ -557,7 +612,8 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str |
base_model: Optional CivitAI base_model field (e.g., "Pony", "SDXL 1.0") base_model: Optional CivitAI base_model field (e.g., "Pony", "SDXL 1.0")
Returns: Returns:
Model family key (pony, illustrious, sdxl, sd15, flux) or None if unknown Model family key (pony, illustrious, sdxl, sdxl_lightning, sdxl_turbo,
sd15, sd15_lcm, flux, flux_schnell) or None if unknown
""" """
name_lower = model_name.lower() name_lower = model_name.lower()
base_lower = (base_model or "").lower() base_lower = (base_model or "").lower()
@@ -568,20 +624,42 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str |
return "pony" return "pony"
if "illustrious" in base_lower: if "illustrious" in base_lower:
return "illustrious" return "illustrious"
# Flux variants (check specific variants before generic flux)
if "flux" in base_lower and "schnell" in base_lower:
return "flux_schnell"
if "flux" in base_lower: if "flux" in base_lower:
return "flux" return "flux"
# SD 1.5 variants
if "lcm" in base_lower and ("sd 1.5" in base_lower or "sd 1.4" in base_lower):
return "sd15_lcm"
if "sd 1.5" in base_lower or "sd 1.4" in base_lower: if "sd 1.5" in base_lower or "sd 1.4" in base_lower:
return "sd15" return "sd15"
# SDXL variants (check specific variants before generic sdxl)
if "sdxl" in base_lower and "lightning" in base_lower:
return "sdxl_lightning"
if "sdxl" in base_lower and "turbo" in base_lower:
return "sdxl_turbo"
if "sdxl" in base_lower: if "sdxl" in base_lower:
return "sdxl" return "sdxl"
# Fall back to filename heuristics # Fall back to filename heuristics (check specific variants first)
if "pony" in name_lower: if "pony" in name_lower:
return "pony" return "pony"
if "illustrious" in name_lower or "noob" in name_lower: if "illustrious" in name_lower or "noob" in name_lower:
return "illustrious" return "illustrious"
# Flux variants
if "flux" in name_lower and "schnell" in name_lower:
return "flux_schnell"
if "flux" in name_lower: if "flux" in name_lower:
return "flux" return "flux"
# SDXL variants
if "lightning" in name_lower and any(x in name_lower for x in ["sdxl", "xl"]):
return "sdxl_lightning"
if "turbo" in name_lower and any(x in name_lower for x in ["sdxl", "xl"]):
return "sdxl_turbo"
# SD 1.5 variants
if "lcm" in name_lower and any(x in name_lower for x in ["sd15", "sd1.5", "sd_1.5"]):
return "sd15_lcm"
if any(x in name_lower for x in ["sd15", "sd1.5", "sd_1.5", "dreamshaper", "realistic", "deliberate", "anything"]): if any(x in name_lower for x in ["sd15", "sd1.5", "sd_1.5", "dreamshaper", "realistic", "deliberate", "anything"]):
return "sd15" return "sd15"
if any(x in name_lower for x in ["sdxl", "xl_"]): if any(x in name_lower for x in ["sdxl", "xl_"]):
@@ -590,6 +668,45 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str |
return None return None
def get_model_generation_defaults(model_name: str, base_model: str | None = None) -> dict[str, Any]:
"""Get generation defaults for a model based on its family.
Detects the model family and returns appropriate default settings for:
- sampler, scheduler, steps, cfg, width, height
- quality_prefix, negative_prompt
Args:
model_name: Filename of the model
base_model: Optional CivitAI base_model field
Returns:
Dict with generation defaults. Falls back to global SDXL defaults if family unknown.
"""
family = detect_model_family(model_name, base_model)
# Get family-specific defaults or fall back to SDXL defaults
if family and family in MODEL_FAMILY_DEFAULTS:
defaults = dict(MODEL_FAMILY_DEFAULTS[family])
else:
# Default to SDXL settings for unknown models
defaults = dict(MODEL_FAMILY_DEFAULTS.get("sdxl", {}))
# Ensure all expected keys are present with fallbacks
defaults.setdefault("sampler", COMFYUI_DEFAULT_SAMPLER)
defaults.setdefault("scheduler", COMFYUI_DEFAULT_SCHEDULER)
defaults.setdefault("steps", COMFYUI_DEFAULT_STEPS)
defaults.setdefault("cfg", COMFYUI_DEFAULT_CFG)
defaults.setdefault("width", COMFYUI_DEFAULT_WIDTH)
defaults.setdefault("height", COMFYUI_DEFAULT_HEIGHT)
defaults.setdefault("quality_prefix", "")
defaults.setdefault("negative_prompt", "")
# Include the detected family for reference
defaults["family"] = family
return defaults
def get_comfyui_url() -> str: def get_comfyui_url() -> str:
"""Get the ComfyUI server URL. """Get the ComfyUI server URL.
+46 -6
View File
@@ -20,6 +20,8 @@ from tensors.comfyui import (
get_system_stats, get_system_stats,
queue_prompt, queue_prompt,
) )
from tensors.config import get_model_generation_defaults
from tensors.db import Database
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -224,14 +226,52 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
This uses the built-in SDXL/Flux compatible workflow template. This uses the built-in SDXL/Flux compatible workflow template.
For custom workflows, use the /workflow endpoint instead. For custom workflows, use the /workflow endpoint instead.
Sampler and scheduler are auto-selected based on model family if not specified
(when using default values). Family detection uses the model filename and
database metadata.
""" """
# Get family-specific defaults if model is specified
sampler = request.sampler
scheduler = request.scheduler
steps = request.steps
cfg = request.cfg
if request.model:
# Look up base_model from database for better family detection
try:
db = Database()
base_model = db.get_base_model_by_filename(request.model)
except Exception:
base_model = None
# Get family defaults
family_defaults = get_model_generation_defaults(request.model, base_model)
detected_family = family_defaults.get("family")
# Apply family defaults only if request uses default values
# (allows explicit override by user)
if request.sampler == "euler": # Default value in schema
sampler = family_defaults["sampler"]
if request.scheduler == "normal": # Default value in schema
scheduler = family_defaults["scheduler"]
if request.steps == 20: # Default value in schema
steps = family_defaults["steps"]
if request.cfg == 7.0: # Default value in schema
cfg = family_defaults["cfg"]
logger.debug("Detected model family: %s (sampler=%s, scheduler=%s, steps=%d, cfg=%.1f)",
detected_family, sampler, scheduler, steps, cfg)
lora_info = f", lora={request.lora_name}@{request.lora_strength}" if request.lora_name else "" lora_info = f", lora={request.lora_name}@{request.lora_strength}" if request.lora_name else ""
logger.info( logger.info(
"Generate request: model=%s, size=%dx%d, steps=%d%s, prompt=%r", "Generate request: model=%s, size=%dx%d, steps=%d, sampler=%s, scheduler=%s%s, prompt=%r",
request.model or "default", request.model or "default",
request.width, request.width,
request.height, request.height,
request.steps, steps,
sampler,
scheduler,
lora_info, lora_info,
request.prompt[:100] + "..." if len(request.prompt) > 100 else request.prompt, request.prompt[:100] + "..." if len(request.prompt) > 100 else request.prompt,
) )
@@ -244,11 +284,11 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
model=request.model, model=request.model,
width=request.width, width=request.width,
height=request.height, height=request.height,
steps=request.steps, steps=steps,
cfg=request.cfg, cfg=cfg,
seed=request.seed, seed=request.seed,
sampler=request.sampler, sampler=sampler,
scheduler=request.scheduler, scheduler=scheduler,
vae=request.vae, vae=request.vae,
lora_name=request.lora_name, lora_name=request.lora_name,
lora_strength=request.lora_strength, lora_strength=request.lora_strength,
+117
View File
@@ -281,6 +281,123 @@ class TestEnums:
assert SortOrder.newest.to_api() == "Newest" assert SortOrder.newest.to_api() == "Newest"
class TestModelFamilyDetection:
"""Tests for detect_model_family and get_model_generation_defaults."""
def test_detect_pony_from_base_model(self) -> None:
"""Test detecting Pony family from base_model field."""
from tensors.config import detect_model_family
assert detect_model_family("model.safetensors", "Pony") == "pony"
assert detect_model_family("anything.safetensors", "PONY") == "pony"
def test_detect_pony_from_filename(self) -> None:
"""Test detecting Pony family from filename."""
from tensors.config import detect_model_family
assert detect_model_family("ponyDiffusionV6XL.safetensors") == "pony"
assert detect_model_family("autismmixPony_v10.safetensors") == "pony"
def test_detect_illustrious_from_base_model(self) -> None:
"""Test detecting Illustrious family from base_model field."""
from tensors.config import detect_model_family
assert detect_model_family("model.safetensors", "Illustrious") == "illustrious"
def test_detect_illustrious_from_filename(self) -> None:
"""Test detecting Illustrious family from filename."""
from tensors.config import detect_model_family
assert detect_model_family("illustriousXL_v10.safetensors") == "illustrious"
assert detect_model_family("noobaiXL_v10.safetensors") == "illustrious"
def test_detect_flux_variants(self) -> None:
"""Test detecting Flux family variants."""
from tensors.config import detect_model_family
assert detect_model_family("flux1-dev.safetensors") == "flux"
assert detect_model_family("flux1-schnell.safetensors") == "flux_schnell"
assert detect_model_family("model.safetensors", "Flux.1 D") == "flux"
assert detect_model_family("model.safetensors", "Flux.1 S schnell") == "flux_schnell"
def test_detect_sdxl_variants(self) -> None:
"""Test detecting SDXL family variants."""
from tensors.config import detect_model_family
assert detect_model_family("juggernautXL_v9.safetensors") == "sdxl"
assert detect_model_family("sdxl_lightning_4step.safetensors") == "sdxl_lightning"
assert detect_model_family("sdxl_turbo.safetensors") == "sdxl_turbo"
assert detect_model_family("model.safetensors", "SDXL 1.0") == "sdxl"
assert detect_model_family("model.safetensors", "SDXL Lightning") == "sdxl_lightning"
assert detect_model_family("model.safetensors", "SDXL Turbo") == "sdxl_turbo"
def test_detect_sd15_variants(self) -> None:
"""Test detecting SD 1.5 family variants."""
from tensors.config import detect_model_family
assert detect_model_family("dreamshaper_v8.safetensors") == "sd15"
assert detect_model_family("sd15_lcm.safetensors") == "sd15_lcm"
assert detect_model_family("model.safetensors", "SD 1.5") == "sd15"
assert detect_model_family("model.safetensors", "SD 1.5 LCM") == "sd15_lcm"
def test_detect_unknown_returns_none(self) -> None:
"""Test that unknown models return None."""
from tensors.config import detect_model_family
assert detect_model_family("random_model.safetensors") is None
assert detect_model_family("unknown.safetensors", "Unknown") is None
def test_get_model_generation_defaults_pony(self) -> None:
"""Test getting generation defaults for Pony models."""
from tensors.config import get_model_generation_defaults
defaults = get_model_generation_defaults("ponyDiffusionV6XL.safetensors")
assert defaults["family"] == "pony"
assert defaults["sampler"] == "euler_ancestral"
assert defaults["scheduler"] == "normal"
assert defaults["steps"] == 25
assert defaults["cfg"] == 7.0
def test_get_model_generation_defaults_flux(self) -> None:
"""Test getting generation defaults for Flux models."""
from tensors.config import get_model_generation_defaults
defaults = get_model_generation_defaults("flux1-dev-fp8.safetensors")
assert defaults["family"] == "flux"
assert defaults["sampler"] == "euler"
assert defaults["scheduler"] == "simple"
assert defaults["cfg"] == 3.5
def test_get_model_generation_defaults_flux_schnell(self) -> None:
"""Test getting generation defaults for Flux Schnell models."""
from tensors.config import get_model_generation_defaults
defaults = get_model_generation_defaults("flux1-schnell.safetensors")
assert defaults["family"] == "flux_schnell"
assert defaults["steps"] == 4
assert defaults["cfg"] == 1.0
def test_get_model_generation_defaults_sdxl_lightning(self) -> None:
"""Test getting generation defaults for SDXL Lightning models."""
from tensors.config import get_model_generation_defaults
defaults = get_model_generation_defaults("sdxl_lightning_4step.safetensors")
assert defaults["family"] == "sdxl_lightning"
assert defaults["sampler"] == "euler"
assert defaults["scheduler"] == "sgm_uniform"
assert defaults["steps"] == 8
assert defaults["cfg"] == 2.0
def test_get_model_generation_defaults_unknown_falls_back_to_sdxl(self) -> None:
"""Test that unknown models fall back to SDXL defaults."""
from tensors.config import get_model_generation_defaults
defaults = get_model_generation_defaults("unknown_model.safetensors")
assert defaults["family"] is None
assert defaults["sampler"] == "dpmpp_2m"
assert defaults["scheduler"] == "karras"
class TestDisplayFormatters: class TestDisplayFormatters:
"""Tests for display formatting functions.""" """Tests for display formatting functions."""