Update 2026-03-20 09:07
This commit is contained in:
+119
-2
@@ -514,6 +514,9 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
|||||||
"height": 1024,
|
"height": 1024,
|
||||||
"cfg": 7.0,
|
"cfg": 7.0,
|
||||||
"clip_skip": 2,
|
"clip_skip": 2,
|
||||||
|
"sampler": "euler_ancestral",
|
||||||
|
"scheduler": "normal",
|
||||||
|
"steps": 25,
|
||||||
},
|
},
|
||||||
"illustrious": {
|
"illustrious": {
|
||||||
"quality_prefix": "masterpiece, best quality, highres",
|
"quality_prefix": "masterpiece, best quality, highres",
|
||||||
@@ -521,6 +524,9 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
|||||||
"width": 1024,
|
"width": 1024,
|
||||||
"height": 1024,
|
"height": 1024,
|
||||||
"cfg": 6.0,
|
"cfg": 6.0,
|
||||||
|
"sampler": "euler_ancestral",
|
||||||
|
"scheduler": "normal",
|
||||||
|
"steps": 25,
|
||||||
},
|
},
|
||||||
"sdxl": {
|
"sdxl": {
|
||||||
"quality_prefix": "",
|
"quality_prefix": "",
|
||||||
@@ -528,6 +534,29 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
|||||||
"width": 1024,
|
"width": 1024,
|
||||||
"height": 1024,
|
"height": 1024,
|
||||||
"cfg": 7.0,
|
"cfg": 7.0,
|
||||||
|
"sampler": "dpmpp_2m",
|
||||||
|
"scheduler": "karras",
|
||||||
|
"steps": 25,
|
||||||
|
},
|
||||||
|
"sdxl_lightning": {
|
||||||
|
"quality_prefix": "",
|
||||||
|
"negative_prompt": "ugly, deformed, bad anatomy, bad hands, extra fingers, missing fingers, blurry, watermark",
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024,
|
||||||
|
"cfg": 2.0,
|
||||||
|
"sampler": "euler",
|
||||||
|
"scheduler": "sgm_uniform",
|
||||||
|
"steps": 8, # Lightning models use fewer steps
|
||||||
|
},
|
||||||
|
"sdxl_turbo": {
|
||||||
|
"quality_prefix": "",
|
||||||
|
"negative_prompt": "", # Turbo models work best without negative prompts
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024,
|
||||||
|
"cfg": 1.0, # Very low CFG for turbo
|
||||||
|
"sampler": "euler_ancestral",
|
||||||
|
"scheduler": "normal",
|
||||||
|
"steps": 4, # Turbo models use very few steps
|
||||||
},
|
},
|
||||||
"sd15": {
|
"sd15": {
|
||||||
"quality_prefix": "masterpiece, best quality",
|
"quality_prefix": "masterpiece, best quality",
|
||||||
@@ -538,6 +567,19 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
|||||||
"width": 512,
|
"width": 512,
|
||||||
"height": 512,
|
"height": 512,
|
||||||
"cfg": 7.0,
|
"cfg": 7.0,
|
||||||
|
"sampler": "dpmpp_2m",
|
||||||
|
"scheduler": "karras",
|
||||||
|
"steps": 20,
|
||||||
|
},
|
||||||
|
"sd15_lcm": {
|
||||||
|
"quality_prefix": "masterpiece, best quality",
|
||||||
|
"negative_prompt": "", # LCM works best with minimal negative
|
||||||
|
"width": 512,
|
||||||
|
"height": 512,
|
||||||
|
"cfg": 1.5,
|
||||||
|
"sampler": "lcm",
|
||||||
|
"scheduler": "normal",
|
||||||
|
"steps": 6,
|
||||||
},
|
},
|
||||||
"flux": {
|
"flux": {
|
||||||
"quality_prefix": "",
|
"quality_prefix": "",
|
||||||
@@ -545,6 +587,19 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
|||||||
"width": 1024,
|
"width": 1024,
|
||||||
"height": 1024,
|
"height": 1024,
|
||||||
"cfg": 3.5,
|
"cfg": 3.5,
|
||||||
|
"sampler": "euler",
|
||||||
|
"scheduler": "simple",
|
||||||
|
"steps": 20,
|
||||||
|
},
|
||||||
|
"flux_schnell": {
|
||||||
|
"quality_prefix": "",
|
||||||
|
"negative_prompt": "",
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024,
|
||||||
|
"cfg": 1.0, # Schnell uses low CFG
|
||||||
|
"sampler": "euler",
|
||||||
|
"scheduler": "simple",
|
||||||
|
"steps": 4, # Schnell is a distilled model, very few steps
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -557,7 +612,8 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str |
|
|||||||
base_model: Optional CivitAI base_model field (e.g., "Pony", "SDXL 1.0")
|
base_model: Optional CivitAI base_model field (e.g., "Pony", "SDXL 1.0")
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Model family key (pony, illustrious, sdxl, sd15, flux) or None if unknown
|
Model family key (pony, illustrious, sdxl, sdxl_lightning, sdxl_turbo,
|
||||||
|
sd15, sd15_lcm, flux, flux_schnell) or None if unknown
|
||||||
"""
|
"""
|
||||||
name_lower = model_name.lower()
|
name_lower = model_name.lower()
|
||||||
base_lower = (base_model or "").lower()
|
base_lower = (base_model or "").lower()
|
||||||
@@ -568,20 +624,42 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str |
|
|||||||
return "pony"
|
return "pony"
|
||||||
if "illustrious" in base_lower:
|
if "illustrious" in base_lower:
|
||||||
return "illustrious"
|
return "illustrious"
|
||||||
|
# Flux variants (check specific variants before generic flux)
|
||||||
|
if "flux" in base_lower and "schnell" in base_lower:
|
||||||
|
return "flux_schnell"
|
||||||
if "flux" in base_lower:
|
if "flux" in base_lower:
|
||||||
return "flux"
|
return "flux"
|
||||||
|
# SD 1.5 variants
|
||||||
|
if "lcm" in base_lower and ("sd 1.5" in base_lower or "sd 1.4" in base_lower):
|
||||||
|
return "sd15_lcm"
|
||||||
if "sd 1.5" in base_lower or "sd 1.4" in base_lower:
|
if "sd 1.5" in base_lower or "sd 1.4" in base_lower:
|
||||||
return "sd15"
|
return "sd15"
|
||||||
|
# SDXL variants (check specific variants before generic sdxl)
|
||||||
|
if "sdxl" in base_lower and "lightning" in base_lower:
|
||||||
|
return "sdxl_lightning"
|
||||||
|
if "sdxl" in base_lower and "turbo" in base_lower:
|
||||||
|
return "sdxl_turbo"
|
||||||
if "sdxl" in base_lower:
|
if "sdxl" in base_lower:
|
||||||
return "sdxl"
|
return "sdxl"
|
||||||
|
|
||||||
# Fall back to filename heuristics
|
# Fall back to filename heuristics (check specific variants first)
|
||||||
if "pony" in name_lower:
|
if "pony" in name_lower:
|
||||||
return "pony"
|
return "pony"
|
||||||
if "illustrious" in name_lower or "noob" in name_lower:
|
if "illustrious" in name_lower or "noob" in name_lower:
|
||||||
return "illustrious"
|
return "illustrious"
|
||||||
|
# Flux variants
|
||||||
|
if "flux" in name_lower and "schnell" in name_lower:
|
||||||
|
return "flux_schnell"
|
||||||
if "flux" in name_lower:
|
if "flux" in name_lower:
|
||||||
return "flux"
|
return "flux"
|
||||||
|
# SDXL variants
|
||||||
|
if "lightning" in name_lower and any(x in name_lower for x in ["sdxl", "xl"]):
|
||||||
|
return "sdxl_lightning"
|
||||||
|
if "turbo" in name_lower and any(x in name_lower for x in ["sdxl", "xl"]):
|
||||||
|
return "sdxl_turbo"
|
||||||
|
# SD 1.5 variants
|
||||||
|
if "lcm" in name_lower and any(x in name_lower for x in ["sd15", "sd1.5", "sd_1.5"]):
|
||||||
|
return "sd15_lcm"
|
||||||
if any(x in name_lower for x in ["sd15", "sd1.5", "sd_1.5", "dreamshaper", "realistic", "deliberate", "anything"]):
|
if any(x in name_lower for x in ["sd15", "sd1.5", "sd_1.5", "dreamshaper", "realistic", "deliberate", "anything"]):
|
||||||
return "sd15"
|
return "sd15"
|
||||||
if any(x in name_lower for x in ["sdxl", "xl_"]):
|
if any(x in name_lower for x in ["sdxl", "xl_"]):
|
||||||
@@ -590,6 +668,45 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str |
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_generation_defaults(model_name: str, base_model: str | None = None) -> dict[str, Any]:
|
||||||
|
"""Get generation defaults for a model based on its family.
|
||||||
|
|
||||||
|
Detects the model family and returns appropriate default settings for:
|
||||||
|
- sampler, scheduler, steps, cfg, width, height
|
||||||
|
- quality_prefix, negative_prompt
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Filename of the model
|
||||||
|
base_model: Optional CivitAI base_model field
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with generation defaults. Falls back to global SDXL defaults if family unknown.
|
||||||
|
"""
|
||||||
|
family = detect_model_family(model_name, base_model)
|
||||||
|
|
||||||
|
# Get family-specific defaults or fall back to SDXL defaults
|
||||||
|
if family and family in MODEL_FAMILY_DEFAULTS:
|
||||||
|
defaults = dict(MODEL_FAMILY_DEFAULTS[family])
|
||||||
|
else:
|
||||||
|
# Default to SDXL settings for unknown models
|
||||||
|
defaults = dict(MODEL_FAMILY_DEFAULTS.get("sdxl", {}))
|
||||||
|
|
||||||
|
# Ensure all expected keys are present with fallbacks
|
||||||
|
defaults.setdefault("sampler", COMFYUI_DEFAULT_SAMPLER)
|
||||||
|
defaults.setdefault("scheduler", COMFYUI_DEFAULT_SCHEDULER)
|
||||||
|
defaults.setdefault("steps", COMFYUI_DEFAULT_STEPS)
|
||||||
|
defaults.setdefault("cfg", COMFYUI_DEFAULT_CFG)
|
||||||
|
defaults.setdefault("width", COMFYUI_DEFAULT_WIDTH)
|
||||||
|
defaults.setdefault("height", COMFYUI_DEFAULT_HEIGHT)
|
||||||
|
defaults.setdefault("quality_prefix", "")
|
||||||
|
defaults.setdefault("negative_prompt", "")
|
||||||
|
|
||||||
|
# Include the detected family for reference
|
||||||
|
defaults["family"] = family
|
||||||
|
|
||||||
|
return defaults
|
||||||
|
|
||||||
|
|
||||||
def get_comfyui_url() -> str:
|
def get_comfyui_url() -> str:
|
||||||
"""Get the ComfyUI server URL.
|
"""Get the ComfyUI server URL.
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ from tensors.comfyui import (
|
|||||||
get_system_stats,
|
get_system_stats,
|
||||||
queue_prompt,
|
queue_prompt,
|
||||||
)
|
)
|
||||||
|
from tensors.config import get_model_generation_defaults
|
||||||
|
from tensors.db import Database
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -224,14 +226,52 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
|
|||||||
|
|
||||||
This uses the built-in SDXL/Flux compatible workflow template.
|
This uses the built-in SDXL/Flux compatible workflow template.
|
||||||
For custom workflows, use the /workflow endpoint instead.
|
For custom workflows, use the /workflow endpoint instead.
|
||||||
|
|
||||||
|
Sampler and scheduler are auto-selected based on model family if not specified
|
||||||
|
(when using default values). Family detection uses the model filename and
|
||||||
|
database metadata.
|
||||||
"""
|
"""
|
||||||
|
# Get family-specific defaults if model is specified
|
||||||
|
sampler = request.sampler
|
||||||
|
scheduler = request.scheduler
|
||||||
|
steps = request.steps
|
||||||
|
cfg = request.cfg
|
||||||
|
|
||||||
|
if request.model:
|
||||||
|
# Look up base_model from database for better family detection
|
||||||
|
try:
|
||||||
|
db = Database()
|
||||||
|
base_model = db.get_base_model_by_filename(request.model)
|
||||||
|
except Exception:
|
||||||
|
base_model = None
|
||||||
|
|
||||||
|
# Get family defaults
|
||||||
|
family_defaults = get_model_generation_defaults(request.model, base_model)
|
||||||
|
detected_family = family_defaults.get("family")
|
||||||
|
|
||||||
|
# Apply family defaults only if request uses default values
|
||||||
|
# (allows explicit override by user)
|
||||||
|
if request.sampler == "euler": # Default value in schema
|
||||||
|
sampler = family_defaults["sampler"]
|
||||||
|
if request.scheduler == "normal": # Default value in schema
|
||||||
|
scheduler = family_defaults["scheduler"]
|
||||||
|
if request.steps == 20: # Default value in schema
|
||||||
|
steps = family_defaults["steps"]
|
||||||
|
if request.cfg == 7.0: # Default value in schema
|
||||||
|
cfg = family_defaults["cfg"]
|
||||||
|
|
||||||
|
logger.debug("Detected model family: %s (sampler=%s, scheduler=%s, steps=%d, cfg=%.1f)",
|
||||||
|
detected_family, sampler, scheduler, steps, cfg)
|
||||||
|
|
||||||
lora_info = f", lora={request.lora_name}@{request.lora_strength}" if request.lora_name else ""
|
lora_info = f", lora={request.lora_name}@{request.lora_strength}" if request.lora_name else ""
|
||||||
logger.info(
|
logger.info(
|
||||||
"Generate request: model=%s, size=%dx%d, steps=%d%s, prompt=%r",
|
"Generate request: model=%s, size=%dx%d, steps=%d, sampler=%s, scheduler=%s%s, prompt=%r",
|
||||||
request.model or "default",
|
request.model or "default",
|
||||||
request.width,
|
request.width,
|
||||||
request.height,
|
request.height,
|
||||||
request.steps,
|
steps,
|
||||||
|
sampler,
|
||||||
|
scheduler,
|
||||||
lora_info,
|
lora_info,
|
||||||
request.prompt[:100] + "..." if len(request.prompt) > 100 else request.prompt,
|
request.prompt[:100] + "..." if len(request.prompt) > 100 else request.prompt,
|
||||||
)
|
)
|
||||||
@@ -244,11 +284,11 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]:
|
|||||||
model=request.model,
|
model=request.model,
|
||||||
width=request.width,
|
width=request.width,
|
||||||
height=request.height,
|
height=request.height,
|
||||||
steps=request.steps,
|
steps=steps,
|
||||||
cfg=request.cfg,
|
cfg=cfg,
|
||||||
seed=request.seed,
|
seed=request.seed,
|
||||||
sampler=request.sampler,
|
sampler=sampler,
|
||||||
scheduler=request.scheduler,
|
scheduler=scheduler,
|
||||||
vae=request.vae,
|
vae=request.vae,
|
||||||
lora_name=request.lora_name,
|
lora_name=request.lora_name,
|
||||||
lora_strength=request.lora_strength,
|
lora_strength=request.lora_strength,
|
||||||
|
|||||||
@@ -281,6 +281,123 @@ class TestEnums:
|
|||||||
assert SortOrder.newest.to_api() == "Newest"
|
assert SortOrder.newest.to_api() == "Newest"
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelFamilyDetection:
|
||||||
|
"""Tests for detect_model_family and get_model_generation_defaults."""
|
||||||
|
|
||||||
|
def test_detect_pony_from_base_model(self) -> None:
|
||||||
|
"""Test detecting Pony family from base_model field."""
|
||||||
|
from tensors.config import detect_model_family
|
||||||
|
|
||||||
|
assert detect_model_family("model.safetensors", "Pony") == "pony"
|
||||||
|
assert detect_model_family("anything.safetensors", "PONY") == "pony"
|
||||||
|
|
||||||
|
def test_detect_pony_from_filename(self) -> None:
|
||||||
|
"""Test detecting Pony family from filename."""
|
||||||
|
from tensors.config import detect_model_family
|
||||||
|
|
||||||
|
assert detect_model_family("ponyDiffusionV6XL.safetensors") == "pony"
|
||||||
|
assert detect_model_family("autismmixPony_v10.safetensors") == "pony"
|
||||||
|
|
||||||
|
def test_detect_illustrious_from_base_model(self) -> None:
|
||||||
|
"""Test detecting Illustrious family from base_model field."""
|
||||||
|
from tensors.config import detect_model_family
|
||||||
|
|
||||||
|
assert detect_model_family("model.safetensors", "Illustrious") == "illustrious"
|
||||||
|
|
||||||
|
def test_detect_illustrious_from_filename(self) -> None:
|
||||||
|
"""Test detecting Illustrious family from filename."""
|
||||||
|
from tensors.config import detect_model_family
|
||||||
|
|
||||||
|
assert detect_model_family("illustriousXL_v10.safetensors") == "illustrious"
|
||||||
|
assert detect_model_family("noobaiXL_v10.safetensors") == "illustrious"
|
||||||
|
|
||||||
|
def test_detect_flux_variants(self) -> None:
|
||||||
|
"""Test detecting Flux family variants."""
|
||||||
|
from tensors.config import detect_model_family
|
||||||
|
|
||||||
|
assert detect_model_family("flux1-dev.safetensors") == "flux"
|
||||||
|
assert detect_model_family("flux1-schnell.safetensors") == "flux_schnell"
|
||||||
|
assert detect_model_family("model.safetensors", "Flux.1 D") == "flux"
|
||||||
|
assert detect_model_family("model.safetensors", "Flux.1 S schnell") == "flux_schnell"
|
||||||
|
|
||||||
|
def test_detect_sdxl_variants(self) -> None:
|
||||||
|
"""Test detecting SDXL family variants."""
|
||||||
|
from tensors.config import detect_model_family
|
||||||
|
|
||||||
|
assert detect_model_family("juggernautXL_v9.safetensors") == "sdxl"
|
||||||
|
assert detect_model_family("sdxl_lightning_4step.safetensors") == "sdxl_lightning"
|
||||||
|
assert detect_model_family("sdxl_turbo.safetensors") == "sdxl_turbo"
|
||||||
|
assert detect_model_family("model.safetensors", "SDXL 1.0") == "sdxl"
|
||||||
|
assert detect_model_family("model.safetensors", "SDXL Lightning") == "sdxl_lightning"
|
||||||
|
assert detect_model_family("model.safetensors", "SDXL Turbo") == "sdxl_turbo"
|
||||||
|
|
||||||
|
def test_detect_sd15_variants(self) -> None:
|
||||||
|
"""Test detecting SD 1.5 family variants."""
|
||||||
|
from tensors.config import detect_model_family
|
||||||
|
|
||||||
|
assert detect_model_family("dreamshaper_v8.safetensors") == "sd15"
|
||||||
|
assert detect_model_family("sd15_lcm.safetensors") == "sd15_lcm"
|
||||||
|
assert detect_model_family("model.safetensors", "SD 1.5") == "sd15"
|
||||||
|
assert detect_model_family("model.safetensors", "SD 1.5 LCM") == "sd15_lcm"
|
||||||
|
|
||||||
|
def test_detect_unknown_returns_none(self) -> None:
|
||||||
|
"""Test that unknown models return None."""
|
||||||
|
from tensors.config import detect_model_family
|
||||||
|
|
||||||
|
assert detect_model_family("random_model.safetensors") is None
|
||||||
|
assert detect_model_family("unknown.safetensors", "Unknown") is None
|
||||||
|
|
||||||
|
def test_get_model_generation_defaults_pony(self) -> None:
|
||||||
|
"""Test getting generation defaults for Pony models."""
|
||||||
|
from tensors.config import get_model_generation_defaults
|
||||||
|
|
||||||
|
defaults = get_model_generation_defaults("ponyDiffusionV6XL.safetensors")
|
||||||
|
assert defaults["family"] == "pony"
|
||||||
|
assert defaults["sampler"] == "euler_ancestral"
|
||||||
|
assert defaults["scheduler"] == "normal"
|
||||||
|
assert defaults["steps"] == 25
|
||||||
|
assert defaults["cfg"] == 7.0
|
||||||
|
|
||||||
|
def test_get_model_generation_defaults_flux(self) -> None:
|
||||||
|
"""Test getting generation defaults for Flux models."""
|
||||||
|
from tensors.config import get_model_generation_defaults
|
||||||
|
|
||||||
|
defaults = get_model_generation_defaults("flux1-dev-fp8.safetensors")
|
||||||
|
assert defaults["family"] == "flux"
|
||||||
|
assert defaults["sampler"] == "euler"
|
||||||
|
assert defaults["scheduler"] == "simple"
|
||||||
|
assert defaults["cfg"] == 3.5
|
||||||
|
|
||||||
|
def test_get_model_generation_defaults_flux_schnell(self) -> None:
|
||||||
|
"""Test getting generation defaults for Flux Schnell models."""
|
||||||
|
from tensors.config import get_model_generation_defaults
|
||||||
|
|
||||||
|
defaults = get_model_generation_defaults("flux1-schnell.safetensors")
|
||||||
|
assert defaults["family"] == "flux_schnell"
|
||||||
|
assert defaults["steps"] == 4
|
||||||
|
assert defaults["cfg"] == 1.0
|
||||||
|
|
||||||
|
def test_get_model_generation_defaults_sdxl_lightning(self) -> None:
|
||||||
|
"""Test getting generation defaults for SDXL Lightning models."""
|
||||||
|
from tensors.config import get_model_generation_defaults
|
||||||
|
|
||||||
|
defaults = get_model_generation_defaults("sdxl_lightning_4step.safetensors")
|
||||||
|
assert defaults["family"] == "sdxl_lightning"
|
||||||
|
assert defaults["sampler"] == "euler"
|
||||||
|
assert defaults["scheduler"] == "sgm_uniform"
|
||||||
|
assert defaults["steps"] == 8
|
||||||
|
assert defaults["cfg"] == 2.0
|
||||||
|
|
||||||
|
def test_get_model_generation_defaults_unknown_falls_back_to_sdxl(self) -> None:
|
||||||
|
"""Test that unknown models fall back to SDXL defaults."""
|
||||||
|
from tensors.config import get_model_generation_defaults
|
||||||
|
|
||||||
|
defaults = get_model_generation_defaults("unknown_model.safetensors")
|
||||||
|
assert defaults["family"] is None
|
||||||
|
assert defaults["sampler"] == "dpmpp_2m"
|
||||||
|
assert defaults["scheduler"] == "karras"
|
||||||
|
|
||||||
|
|
||||||
class TestDisplayFormatters:
|
class TestDisplayFormatters:
|
||||||
"""Tests for display formatting functions."""
|
"""Tests for display formatting functions."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user