diff --git a/.coverage b/.coverage index edbdd6f..d59af52 100644 Binary files a/.coverage and b/.coverage differ diff --git a/tensors/config.py b/tensors/config.py index 14ed4d4..455f15a 100644 --- a/tensors/config.py +++ b/tensors/config.py @@ -514,6 +514,9 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = { "height": 1024, "cfg": 7.0, "clip_skip": 2, + "sampler": "euler_ancestral", + "scheduler": "normal", + "steps": 25, }, "illustrious": { "quality_prefix": "masterpiece, best quality, highres", @@ -521,6 +524,9 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = { "width": 1024, "height": 1024, "cfg": 6.0, + "sampler": "euler_ancestral", + "scheduler": "normal", + "steps": 25, }, "sdxl": { "quality_prefix": "", @@ -528,6 +534,29 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = { "width": 1024, "height": 1024, "cfg": 7.0, + "sampler": "dpmpp_2m", + "scheduler": "karras", + "steps": 25, + }, + "sdxl_lightning": { + "quality_prefix": "", + "negative_prompt": "ugly, deformed, bad anatomy, bad hands, extra fingers, missing fingers, blurry, watermark", + "width": 1024, + "height": 1024, + "cfg": 2.0, + "sampler": "euler", + "scheduler": "sgm_uniform", + "steps": 8, # Lightning models use fewer steps + }, + "sdxl_turbo": { + "quality_prefix": "", + "negative_prompt": "", # Turbo models work best without negative prompts + "width": 1024, + "height": 1024, + "cfg": 1.0, # Very low CFG for turbo + "sampler": "euler_ancestral", + "scheduler": "normal", + "steps": 4, # Turbo models use very few steps }, "sd15": { "quality_prefix": "masterpiece, best quality", @@ -538,6 +567,19 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = { "width": 512, "height": 512, "cfg": 7.0, + "sampler": "dpmpp_2m", + "scheduler": "karras", + "steps": 20, + }, + "sd15_lcm": { + "quality_prefix": "masterpiece, best quality", + "negative_prompt": "", # LCM works best with minimal negative + "width": 512, + "height": 512, + "cfg": 1.5, + "sampler": "lcm", + "scheduler": "normal", + "steps": 6, }, "flux": { "quality_prefix": "", @@ -545,6 +587,19 @@ MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = { "width": 1024, "height": 1024, "cfg": 3.5, + "sampler": "euler", + "scheduler": "simple", + "steps": 20, + }, + "flux_schnell": { + "quality_prefix": "", + "negative_prompt": "", + "width": 1024, + "height": 1024, + "cfg": 1.0, # Schnell uses low CFG + "sampler": "euler", + "scheduler": "simple", + "steps": 4, # Schnell is a distilled model, very few steps }, } @@ -557,7 +612,8 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str | base_model: Optional CivitAI base_model field (e.g., "Pony", "SDXL 1.0") Returns: - Model family key (pony, illustrious, sdxl, sd15, flux) or None if unknown + Model family key (pony, illustrious, sdxl, sdxl_lightning, sdxl_turbo, + sd15, sd15_lcm, flux, flux_schnell) or None if unknown """ name_lower = model_name.lower() base_lower = (base_model or "").lower() @@ -568,20 +624,42 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str | return "pony" if "illustrious" in base_lower: return "illustrious" + # Flux variants (check specific variants before generic flux) + if "flux" in base_lower and "schnell" in base_lower: + return "flux_schnell" if "flux" in base_lower: return "flux" + # SD 1.5 variants + if "lcm" in base_lower and ("sd 1.5" in base_lower or "sd 1.4" in base_lower): + return "sd15_lcm" if "sd 1.5" in base_lower or "sd 1.4" in base_lower: return "sd15" + # SDXL variants (check specific variants before generic sdxl) + if "sdxl" in base_lower and "lightning" in base_lower: + return "sdxl_lightning" + if "sdxl" in base_lower and "turbo" in base_lower: + return "sdxl_turbo" if "sdxl" in base_lower: return "sdxl" - # Fall back to filename heuristics + # Fall back to filename heuristics (check specific variants first) if "pony" in name_lower: return "pony" if "illustrious" in name_lower or "noob" in name_lower: return "illustrious" + # Flux variants + if "flux" in name_lower and "schnell" in name_lower: + return "flux_schnell" if "flux" in name_lower: return "flux" + # SDXL variants + if "lightning" in name_lower and any(x in name_lower for x in ["sdxl", "xl"]): + return "sdxl_lightning" + if "turbo" in name_lower and any(x in name_lower for x in ["sdxl", "xl"]): + return "sdxl_turbo" + # SD 1.5 variants + if "lcm" in name_lower and any(x in name_lower for x in ["sd15", "sd1.5", "sd_1.5"]): + return "sd15_lcm" if any(x in name_lower for x in ["sd15", "sd1.5", "sd_1.5", "dreamshaper", "realistic", "deliberate", "anything"]): return "sd15" if any(x in name_lower for x in ["sdxl", "xl_"]): @@ -590,6 +668,45 @@ def detect_model_family(model_name: str, base_model: str | None = None) -> str | return None +def get_model_generation_defaults(model_name: str, base_model: str | None = None) -> dict[str, Any]: + """Get generation defaults for a model based on its family. + + Detects the model family and returns appropriate default settings for: + - sampler, scheduler, steps, cfg, width, height + - quality_prefix, negative_prompt + + Args: + model_name: Filename of the model + base_model: Optional CivitAI base_model field + + Returns: + Dict with generation defaults. Falls back to global SDXL defaults if family unknown. + """ + family = detect_model_family(model_name, base_model) + + # Get family-specific defaults or fall back to SDXL defaults + if family and family in MODEL_FAMILY_DEFAULTS: + defaults = dict(MODEL_FAMILY_DEFAULTS[family]) + else: + # Default to SDXL settings for unknown models + defaults = dict(MODEL_FAMILY_DEFAULTS.get("sdxl", {})) + + # Ensure all expected keys are present with fallbacks + defaults.setdefault("sampler", COMFYUI_DEFAULT_SAMPLER) + defaults.setdefault("scheduler", COMFYUI_DEFAULT_SCHEDULER) + defaults.setdefault("steps", COMFYUI_DEFAULT_STEPS) + defaults.setdefault("cfg", COMFYUI_DEFAULT_CFG) + defaults.setdefault("width", COMFYUI_DEFAULT_WIDTH) + defaults.setdefault("height", COMFYUI_DEFAULT_HEIGHT) + defaults.setdefault("quality_prefix", "") + defaults.setdefault("negative_prompt", "") + + # Include the detected family for reference + defaults["family"] = family + + return defaults + + def get_comfyui_url() -> str: """Get the ComfyUI server URL. diff --git a/tensors/server/comfyui_api_routes.py b/tensors/server/comfyui_api_routes.py index 8d89caf..5cb1aef 100644 --- a/tensors/server/comfyui_api_routes.py +++ b/tensors/server/comfyui_api_routes.py @@ -20,6 +20,8 @@ from tensors.comfyui import ( get_system_stats, queue_prompt, ) +from tensors.config import get_model_generation_defaults +from tensors.db import Database logger = logging.getLogger(__name__) @@ -224,14 +226,52 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]: This uses the built-in SDXL/Flux compatible workflow template. For custom workflows, use the /workflow endpoint instead. + + Sampler and scheduler are auto-selected based on model family if not specified + (when using default values). Family detection uses the model filename and + database metadata. """ + # Get family-specific defaults if model is specified + sampler = request.sampler + scheduler = request.scheduler + steps = request.steps + cfg = request.cfg + + if request.model: + # Look up base_model from database for better family detection + try: + db = Database() + base_model = db.get_base_model_by_filename(request.model) + except Exception: + base_model = None + + # Get family defaults + family_defaults = get_model_generation_defaults(request.model, base_model) + detected_family = family_defaults.get("family") + + # Apply family defaults only if request uses default values + # (allows explicit override by user) + if request.sampler == "euler": # Default value in schema + sampler = family_defaults["sampler"] + if request.scheduler == "normal": # Default value in schema + scheduler = family_defaults["scheduler"] + if request.steps == 20: # Default value in schema + steps = family_defaults["steps"] + if request.cfg == 7.0: # Default value in schema + cfg = family_defaults["cfg"] + + logger.debug("Detected model family: %s (sampler=%s, scheduler=%s, steps=%d, cfg=%.1f)", + detected_family, sampler, scheduler, steps, cfg) + lora_info = f", lora={request.lora_name}@{request.lora_strength}" if request.lora_name else "" logger.info( - "Generate request: model=%s, size=%dx%d, steps=%d%s, prompt=%r", + "Generate request: model=%s, size=%dx%d, steps=%d, sampler=%s, scheduler=%s%s, prompt=%r", request.model or "default", request.width, request.height, - request.steps, + steps, + sampler, + scheduler, lora_info, request.prompt[:100] + "..." if len(request.prompt) > 100 else request.prompt, ) @@ -244,11 +284,11 @@ def comfyui_generate(request: GenerateRequest) -> dict[str, Any]: model=request.model, width=request.width, height=request.height, - steps=request.steps, - cfg=request.cfg, + steps=steps, + cfg=cfg, seed=request.seed, - sampler=request.sampler, - scheduler=request.scheduler, + sampler=sampler, + scheduler=scheduler, vae=request.vae, lora_name=request.lora_name, lora_strength=request.lora_strength, diff --git a/tests/test_tensors.py b/tests/test_tensors.py index 621647c..1c6afc7 100644 --- a/tests/test_tensors.py +++ b/tests/test_tensors.py @@ -281,6 +281,123 @@ class TestEnums: assert SortOrder.newest.to_api() == "Newest" +class TestModelFamilyDetection: + """Tests for detect_model_family and get_model_generation_defaults.""" + + def test_detect_pony_from_base_model(self) -> None: + """Test detecting Pony family from base_model field.""" + from tensors.config import detect_model_family + + assert detect_model_family("model.safetensors", "Pony") == "pony" + assert detect_model_family("anything.safetensors", "PONY") == "pony" + + def test_detect_pony_from_filename(self) -> None: + """Test detecting Pony family from filename.""" + from tensors.config import detect_model_family + + assert detect_model_family("ponyDiffusionV6XL.safetensors") == "pony" + assert detect_model_family("autismmixPony_v10.safetensors") == "pony" + + def test_detect_illustrious_from_base_model(self) -> None: + """Test detecting Illustrious family from base_model field.""" + from tensors.config import detect_model_family + + assert detect_model_family("model.safetensors", "Illustrious") == "illustrious" + + def test_detect_illustrious_from_filename(self) -> None: + """Test detecting Illustrious family from filename.""" + from tensors.config import detect_model_family + + assert detect_model_family("illustriousXL_v10.safetensors") == "illustrious" + assert detect_model_family("noobaiXL_v10.safetensors") == "illustrious" + + def test_detect_flux_variants(self) -> None: + """Test detecting Flux family variants.""" + from tensors.config import detect_model_family + + assert detect_model_family("flux1-dev.safetensors") == "flux" + assert detect_model_family("flux1-schnell.safetensors") == "flux_schnell" + assert detect_model_family("model.safetensors", "Flux.1 D") == "flux" + assert detect_model_family("model.safetensors", "Flux.1 S schnell") == "flux_schnell" + + def test_detect_sdxl_variants(self) -> None: + """Test detecting SDXL family variants.""" + from tensors.config import detect_model_family + + assert detect_model_family("juggernautXL_v9.safetensors") == "sdxl" + assert detect_model_family("sdxl_lightning_4step.safetensors") == "sdxl_lightning" + assert detect_model_family("sdxl_turbo.safetensors") == "sdxl_turbo" + assert detect_model_family("model.safetensors", "SDXL 1.0") == "sdxl" + assert detect_model_family("model.safetensors", "SDXL Lightning") == "sdxl_lightning" + assert detect_model_family("model.safetensors", "SDXL Turbo") == "sdxl_turbo" + + def test_detect_sd15_variants(self) -> None: + """Test detecting SD 1.5 family variants.""" + from tensors.config import detect_model_family + + assert detect_model_family("dreamshaper_v8.safetensors") == "sd15" + assert detect_model_family("sd15_lcm.safetensors") == "sd15_lcm" + assert detect_model_family("model.safetensors", "SD 1.5") == "sd15" + assert detect_model_family("model.safetensors", "SD 1.5 LCM") == "sd15_lcm" + + def test_detect_unknown_returns_none(self) -> None: + """Test that unknown models return None.""" + from tensors.config import detect_model_family + + assert detect_model_family("random_model.safetensors") is None + assert detect_model_family("unknown.safetensors", "Unknown") is None + + def test_get_model_generation_defaults_pony(self) -> None: + """Test getting generation defaults for Pony models.""" + from tensors.config import get_model_generation_defaults + + defaults = get_model_generation_defaults("ponyDiffusionV6XL.safetensors") + assert defaults["family"] == "pony" + assert defaults["sampler"] == "euler_ancestral" + assert defaults["scheduler"] == "normal" + assert defaults["steps"] == 25 + assert defaults["cfg"] == 7.0 + + def test_get_model_generation_defaults_flux(self) -> None: + """Test getting generation defaults for Flux models.""" + from tensors.config import get_model_generation_defaults + + defaults = get_model_generation_defaults("flux1-dev-fp8.safetensors") + assert defaults["family"] == "flux" + assert defaults["sampler"] == "euler" + assert defaults["scheduler"] == "simple" + assert defaults["cfg"] == 3.5 + + def test_get_model_generation_defaults_flux_schnell(self) -> None: + """Test getting generation defaults for Flux Schnell models.""" + from tensors.config import get_model_generation_defaults + + defaults = get_model_generation_defaults("flux1-schnell.safetensors") + assert defaults["family"] == "flux_schnell" + assert defaults["steps"] == 4 + assert defaults["cfg"] == 1.0 + + def test_get_model_generation_defaults_sdxl_lightning(self) -> None: + """Test getting generation defaults for SDXL Lightning models.""" + from tensors.config import get_model_generation_defaults + + defaults = get_model_generation_defaults("sdxl_lightning_4step.safetensors") + assert defaults["family"] == "sdxl_lightning" + assert defaults["sampler"] == "euler" + assert defaults["scheduler"] == "sgm_uniform" + assert defaults["steps"] == 8 + assert defaults["cfg"] == 2.0 + + def test_get_model_generation_defaults_unknown_falls_back_to_sdxl(self) -> None: + """Test that unknown models fall back to SDXL defaults.""" + from tensors.config import get_model_generation_defaults + + defaults = get_model_generation_defaults("unknown_model.safetensors") + assert defaults["family"] is None + assert defaults["sampler"] == "dpmpp_2m" + assert defaults["scheduler"] == "karras" + + class TestDisplayFormatters: """Tests for display formatting functions."""