Update 2026-03-20 09:07
This commit is contained in:
+119
-2
@@ -514,6 +514,9 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||
"height": 1024,
|
||||
"cfg": 7.0,
|
||||
"clip_skip": 2,
|
||||
"sampler": "euler_ancestral",
|
||||
"scheduler": "normal",
|
||||
"steps": 25,
|
||||
},
|
||||
"illustrious": {
|
||||
"quality_prefix": "masterpiece, best quality, highres",
|
||||
@@ -521,6 +524,9 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"cfg": 6.0,
|
||||
"sampler": "euler_ancestral",
|
||||
"scheduler": "normal",
|
||||
"steps": 25,
|
||||
},
|
||||
"sdxl": {
|
||||
"quality_prefix": "",
|
||||
@@ -528,6 +534,29 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"cfg": 7.0,
|
||||
"sampler": "dpmpp_2m",
|
||||
"scheduler": "karras",
|
||||
"steps": 25,
|
||||
},
|
||||
"sdxl_lightning": {
|
||||
"quality_prefix": "",
|
||||
"negative_prompt": "ugly, deformed, bad anatomy, bad hands, extra fingers, missing fingers, blurry, watermark",
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"cfg": 2.0,
|
||||
"sampler": "euler",
|
||||
"scheduler": "sgm_uniform",
|
||||
"steps": 8, # Lightning models use fewer steps
|
||||
},
|
||||
"sdxl_turbo": {
|
||||
"quality_prefix": "",
|
||||
"negative_prompt": "", # Turbo models work best without negative prompts
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"cfg": 1.0, # Very low CFG for turbo
|
||||
"sampler": "euler_ancestral",
|
||||
"scheduler": "normal",
|
||||
"steps": 4, # Turbo models use very few steps
|
||||
},
|
||||
"sd15": {
|
||||
"quality_prefix": "masterpiece, best quality",
|
||||
@@ -538,6 +567,19 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"cfg": 7.0,
|
||||
"sampler": "dpmpp_2m",
|
||||
"scheduler": "karras",
|
||||
"steps": 20,
|
||||
},
|
||||
"sd15_lcm": {
|
||||
"quality_prefix": "masterpiece, best quality",
|
||||
"negative_prompt": "", # LCM works best with minimal negative
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"cfg": 1.5,
|
||||
"sampler": "lcm",
|
||||
"scheduler": "normal",
|
||||
"steps": 6,
|
||||
},
|
||||
"flux": {
|
||||
"quality_prefix": "",
|
||||
@@ -545,6 +587,19 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"cfg": 3.5,
|
||||
"sampler": "euler",
|
||||
"scheduler": "simple",
|
||||
"steps": 20,
|
||||
},
|
||||
"flux_schnell": {
|
||||
"quality_prefix": "",
|
||||
"negative_prompt": "",
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"cfg": 1.0, # Schnell uses low CFG
|
||||
"sampler": "euler",
|
||||
"scheduler": "simple",
|
||||
"steps": 4, # Schnell is a distilled model, very few steps
|
||||
},
|
||||
}
|
||||
|
||||
@@ -557,7 +612,8 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str |
|
||||
base_model: Optional CivitAI base_model field (e.g., "Pony", "SDXL 1.0")
|
||||
|
||||
Returns:
|
||||
Model family key (pony, illustrious, sdxl, sd15, flux) or None if unknown
|
||||
Model family key (pony, illustrious, sdxl, sdxl_lightning, sdxl_turbo,
|
||||
sd15, sd15_lcm, flux, flux_schnell) or None if unknown
|
||||
"""
|
||||
name_lower = model_name.lower()
|
||||
base_lower = (base_model or "").lower()
|
||||
@@ -568,20 +624,42 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str |
|
||||
return "pony"
|
||||
if "illustrious" in base_lower:
|
||||
return "illustrious"
|
||||
# Flux variants (check specific variants before generic flux)
|
||||
if "flux" in base_lower and "schnell" in base_lower:
|
||||
return "flux_schnell"
|
||||
if "flux" in base_lower:
|
||||
return "flux"
|
||||
# SD 1.5 variants
|
||||
if "lcm" in base_lower and ("sd 1.5" in base_lower or "sd 1.4" in base_lower):
|
||||
return "sd15_lcm"
|
||||
if "sd 1.5" in base_lower or "sd 1.4" in base_lower:
|
||||
return "sd15"
|
||||
# SDXL variants (check specific variants before generic sdxl)
|
||||
if "sdxl" in base_lower and "lightning" in base_lower:
|
||||
return "sdxl_lightning"
|
||||
if "sdxl" in base_lower and "turbo" in base_lower:
|
||||
return "sdxl_turbo"
|
||||
if "sdxl" in base_lower:
|
||||
return "sdxl"
|
||||
|
||||
# Fall back to filename heuristics
|
||||
# Fall back to filename heuristics (check specific variants first)
|
||||
if "pony" in name_lower:
|
||||
return "pony"
|
||||
if "illustrious" in name_lower or "noob" in name_lower:
|
||||
return "illustrious"
|
||||
# Flux variants
|
||||
if "flux" in name_lower and "schnell" in name_lower:
|
||||
return "flux_schnell"
|
||||
if "flux" in name_lower:
|
||||
return "flux"
|
||||
# SDXL variants
|
||||
if "lightning" in name_lower and any(x in name_lower for x in ["sdxl", "xl"]):
|
||||
return "sdxl_lightning"
|
||||
if "turbo" in name_lower and any(x in name_lower for x in ["sdxl", "xl"]):
|
||||
return "sdxl_turbo"
|
||||
# SD 1.5 variants
|
||||
if "lcm" in name_lower and any(x in name_lower for x in ["sd15", "sd1.5", "sd_1.5"]):
|
||||
return "sd15_lcm"
|
||||
if any(x in name_lower for x in ["sd15", "sd1.5", "sd_1.5", "dreamshaper", "realistic", "deliberate", "anything"]):
|
||||
return "sd15"
|
||||
if any(x in name_lower for x in ["sdxl", "xl_"]):
|
||||
@@ -590,6 +668,45 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str |
|
||||
return None
|
||||
|
||||
|
||||
def get_model_generation_defaults(model_name: str, base_model: str | None = None) -> dict[str, Any]:
|
||||
"""Get generation defaults for a model based on its family.
|
||||
|
||||
Detects the model family and returns appropriate default settings for:
|
||||
- sampler, scheduler, steps, cfg, width, height
|
||||
- quality_prefix, negative_prompt
|
||||
|
||||
Args:
|
||||
model_name: Filename of the model
|
||||
base_model: Optional CivitAI base_model field
|
||||
|
||||
Returns:
|
||||
Dict with generation defaults. Falls back to global SDXL defaults if family unknown.
|
||||
"""
|
||||
family = detect_model_family(model_name, base_model)
|
||||
|
||||
# Get family-specific defaults or fall back to SDXL defaults
|
||||
if family and family in MODEL_FAMILY_DEFAULTS:
|
||||
defaults = dict(MODEL_FAMILY_DEFAULTS[family])
|
||||
else:
|
||||
# Default to SDXL settings for unknown models
|
||||
defaults = dict(MODEL_FAMILY_DEFAULTS.get("sdxl", {}))
|
||||
|
||||
# Ensure all expected keys are present with fallbacks
|
||||
defaults.setdefault("sampler", COMFYUI_DEFAULT_SAMPLER)
|
||||
defaults.setdefault("scheduler", COMFYUI_DEFAULT_SCHEDULER)
|
||||
defaults.setdefault("steps", COMFYUI_DEFAULT_STEPS)
|
||||
defaults.setdefault("cfg", COMFYUI_DEFAULT_CFG)
|
||||
defaults.setdefault("width", COMFYUI_DEFAULT_WIDTH)
|
||||
defaults.setdefault("height", COMFYUI_DEFAULT_HEIGHT)
|
||||
defaults.setdefault("quality_prefix", "")
|
||||
defaults.setdefault("negative_prompt", "")
|
||||
|
||||
# Include the detected family for reference
|
||||
defaults["family"] = family
|
||||
|
||||
return defaults
|
||||
|
||||
|
||||
def get_comfyui_url() -> str:
|
||||
"""Get the ComfyUI server URL.
|
||||
|
||||
|
||||
@@ -20,6 +20,8 @@ from tensors.comfyui import (
|
||||
get_system_stats,
|
||||
queue_prompt,
|
||||
)
|
||||
from tensors.config import get_model_generation_defaults
|
||||
from tensors.db import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -224,14 +226,52 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
|
||||
|
||||
This uses the built-in SDXL/Flux compatible workflow template.
|
||||
For custom workflows, use the /workflow endpoint instead.
|
||||
|
||||
Sampler and scheduler are auto-selected based on model family if not specified
|
||||
(when using default values). Family detection uses the model filename and
|
||||
database metadata.
|
||||
"""
|
||||
# Get family-specific defaults if model is specified
|
||||
sampler = request.sampler
|
||||
scheduler = request.scheduler
|
||||
steps = request.steps
|
||||
cfg = request.cfg
|
||||
|
||||
if request.model:
|
||||
# Look up base_model from database for better family detection
|
||||
try:
|
||||
db = Database()
|
||||
base_model = db.get_base_model_by_filename(request.model)
|
||||
except Exception:
|
||||
base_model = None
|
||||
|
||||
# Get family defaults
|
||||
family_defaults = get_model_generation_defaults(request.model, base_model)
|
||||
detected_family = family_defaults.get("family")
|
||||
|
||||
# Apply family defaults only if request uses default values
|
||||
# (allows explicit override by user)
|
||||
if request.sampler == "euler": # Default value in schema
|
||||
sampler = family_defaults["sampler"]
|
||||
if request.scheduler == "normal": # Default value in schema
|
||||
scheduler = family_defaults["scheduler"]
|
||||
if request.steps == 20: # Default value in schema
|
||||
steps = family_defaults["steps"]
|
||||
if request.cfg == 7.0: # Default value in schema
|
||||
cfg = family_defaults["cfg"]
|
||||
|
||||
logger.debug("Detected model family: %s (sampler=%s, scheduler=%s, steps=%d, cfg=%.1f)",
|
||||
detected_family, sampler, scheduler, steps, cfg)
|
||||
|
||||
lora_info = f", lora={request.lora_name}@{request.lora_strength}" if request.lora_name else ""
|
||||
logger.info(
|
||||
"Generate request: model=%s, size=%dx%d, steps=%d%s, prompt=%r",
|
||||
"Generate request: model=%s, size=%dx%d, steps=%d, sampler=%s, scheduler=%s%s, prompt=%r",
|
||||
request.model or "default",
|
||||
request.width,
|
||||
request.height,
|
||||
request.steps,
|
||||
steps,
|
||||
sampler,
|
||||
scheduler,
|
||||
lora_info,
|
||||
request.prompt[:100] + "..." if len(request.prompt) > 100 else request.prompt,
|
||||
)
|
||||
@@ -244,11 +284,11 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
|
||||
model=request.model,
|
||||
width=request.width,
|
||||
height=request.height,
|
||||
steps=request.steps,
|
||||
cfg=request.cfg,
|
||||
steps=steps,
|
||||
cfg=cfg,
|
||||
seed=request.seed,
|
||||
sampler=request.sampler,
|
||||
scheduler=request.scheduler,
|
||||
sampler=sampler,
|
||||
scheduler=scheduler,
|
||||
vae=request.vae,
|
||||
lora_name=request.lora_name,
|
||||
lora_strength=request.lora_strength,
|
||||
|
||||
@@ -281,6 +281,123 @@ class TestEnums:
|
||||
assert SortOrder.newest.to_api() == "Newest"
|
||||
|
||||
|
||||
class TestModelFamilyDetection:
|
||||
"""Tests for detect_model_family and get_model_generation_defaults."""
|
||||
|
||||
def test_detect_pony_from_base_model(self) -> None:
|
||||
"""Test detecting Pony family from base_model field."""
|
||||
from tensors.config import detect_model_family
|
||||
|
||||
assert detect_model_family("model.safetensors", "Pony") == "pony"
|
||||
assert detect_model_family("anything.safetensors", "PONY") == "pony"
|
||||
|
||||
def test_detect_pony_from_filename(self) -> None:
|
||||
"""Test detecting Pony family from filename."""
|
||||
from tensors.config import detect_model_family
|
||||
|
||||
assert detect_model_family("ponyDiffusionV6XL.safetensors") == "pony"
|
||||
assert detect_model_family("autismmixPony_v10.safetensors") == "pony"
|
||||
|
||||
def test_detect_illustrious_from_base_model(self) -> None:
|
||||
"""Test detecting Illustrious family from base_model field."""
|
||||
from tensors.config import detect_model_family
|
||||
|
||||
assert detect_model_family("model.safetensors", "Illustrious") == "illustrious"
|
||||
|
||||
def test_detect_illustrious_from_filename(self) -> None:
|
||||
"""Test detecting Illustrious family from filename."""
|
||||
from tensors.config import detect_model_family
|
||||
|
||||
assert detect_model_family("illustriousXL_v10.safetensors") == "illustrious"
|
||||
assert detect_model_family("noobaiXL_v10.safetensors") == "illustrious"
|
||||
|
||||
def test_detect_flux_variants(self) -> None:
|
||||
"""Test detecting Flux family variants."""
|
||||
from tensors.config import detect_model_family
|
||||
|
||||
assert detect_model_family("flux1-dev.safetensors") == "flux"
|
||||
assert detect_model_family("flux1-schnell.safetensors") == "flux_schnell"
|
||||
assert detect_model_family("model.safetensors", "Flux.1 D") == "flux"
|
||||
assert detect_model_family("model.safetensors", "Flux.1 S schnell") == "flux_schnell"
|
||||
|
||||
def test_detect_sdxl_variants(self) -> None:
|
||||
"""Test detecting SDXL family variants."""
|
||||
from tensors.config import detect_model_family
|
||||
|
||||
assert detect_model_family("juggernautXL_v9.safetensors") == "sdxl"
|
||||
assert detect_model_family("sdxl_lightning_4step.safetensors") == "sdxl_lightning"
|
||||
assert detect_model_family("sdxl_turbo.safetensors") == "sdxl_turbo"
|
||||
assert detect_model_family("model.safetensors", "SDXL 1.0") == "sdxl"
|
||||
assert detect_model_family("model.safetensors", "SDXL Lightning") == "sdxl_lightning"
|
||||
assert detect_model_family("model.safetensors", "SDXL Turbo") == "sdxl_turbo"
|
||||
|
||||
def test_detect_sd15_variants(self) -> None:
|
||||
"""Test detecting SD 1.5 family variants."""
|
||||
from tensors.config import detect_model_family
|
||||
|
||||
assert detect_model_family("dreamshaper_v8.safetensors") == "sd15"
|
||||
assert detect_model_family("sd15_lcm.safetensors") == "sd15_lcm"
|
||||
assert detect_model_family("model.safetensors", "SD 1.5") == "sd15"
|
||||
assert detect_model_family("model.safetensors", "SD 1.5 LCM") == "sd15_lcm"
|
||||
|
||||
def test_detect_unknown_returns_none(self) -> None:
|
||||
"""Test that unknown models return None."""
|
||||
from tensors.config import detect_model_family
|
||||
|
||||
assert detect_model_family("random_model.safetensors") is None
|
||||
assert detect_model_family("unknown.safetensors", "Unknown") is None
|
||||
|
||||
def test_get_model_generation_defaults_pony(self) -> None:
|
||||
"""Test getting generation defaults for Pony models."""
|
||||
from tensors.config import get_model_generation_defaults
|
||||
|
||||
defaults = get_model_generation_defaults("ponyDiffusionV6XL.safetensors")
|
||||
assert defaults["family"] == "pony"
|
||||
assert defaults["sampler"] == "euler_ancestral"
|
||||
assert defaults["scheduler"] == "normal"
|
||||
assert defaults["steps"] == 25
|
||||
assert defaults["cfg"] == 7.0
|
||||
|
||||
def test_get_model_generation_defaults_flux(self) -> None:
|
||||
"""Test getting generation defaults for Flux models."""
|
||||
from tensors.config import get_model_generation_defaults
|
||||
|
||||
defaults = get_model_generation_defaults("flux1-dev-fp8.safetensors")
|
||||
assert defaults["family"] == "flux"
|
||||
assert defaults["sampler"] == "euler"
|
||||
assert defaults["scheduler"] == "simple"
|
||||
assert defaults["cfg"] == 3.5
|
||||
|
||||
def test_get_model_generation_defaults_flux_schnell(self) -> None:
|
||||
"""Test getting generation defaults for Flux Schnell models."""
|
||||
from tensors.config import get_model_generation_defaults
|
||||
|
||||
defaults = get_model_generation_defaults("flux1-schnell.safetensors")
|
||||
assert defaults["family"] == "flux_schnell"
|
||||
assert defaults["steps"] == 4
|
||||
assert defaults["cfg"] == 1.0
|
||||
|
||||
def test_get_model_generation_defaults_sdxl_lightning(self) -> None:
|
||||
"""Test getting generation defaults for SDXL Lightning models."""
|
||||
from tensors.config import get_model_generation_defaults
|
||||
|
||||
defaults = get_model_generation_defaults("sdxl_lightning_4step.safetensors")
|
||||
assert defaults["family"] == "sdxl_lightning"
|
||||
assert defaults["sampler"] == "euler"
|
||||
assert defaults["scheduler"] == "sgm_uniform"
|
||||
assert defaults["steps"] == 8
|
||||
assert defaults["cfg"] == 2.0
|
||||
|
||||
def test_get_model_generation_defaults_unknown_falls_back_to_sdxl(self) -> None:
|
||||
"""Test that unknown models fall back to SDXL defaults."""
|
||||
from tensors.config import get_model_generation_defaults
|
||||
|
||||
defaults = get_model_generation_defaults("unknown_model.safetensors")
|
||||
assert defaults["family"] is None
|
||||
assert defaults["sampler"] == "dpmpp_2m"
|
||||
assert defaults["scheduler"] == "karras"
|
||||
|
||||
|
||||
class TestDisplayFormatters:
|
||||
"""Tests for display formatting functions."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user