Update 2026-03-20 09:07

This commit is contained in:
aladac
2026-03-20 09:07:19 +01:00
parent 420d260936
commit 372133edcc
4 changed files with 282 additions and 8 deletions
+117
View File
@@ -281,6 +281,123 @@ class TestEnums:
assert SortOrder.newest.to_api() == "Newest"
class TestModelFamilyDetection:
"""Tests for detect_model_family and get_model_generation_defaults."""
def test_detect_pony_from_base_model(self) -> None:
"""Test detecting Pony family from base_model field."""
from tensors.config import detect_model_family
assert detect_model_family("model.safetensors", "Pony") == "pony"
assert detect_model_family("anything.safetensors", "PONY") == "pony"
def test_detect_pony_from_filename(self) -> None:
"""Test detecting Pony family from filename."""
from tensors.config import detect_model_family
assert detect_model_family("ponyDiffusionV6XL.safetensors") == "pony"
assert detect_model_family("autismmixPony_v10.safetensors") == "pony"
def test_detect_illustrious_from_base_model(self) -> None:
"""Test detecting Illustrious family from base_model field."""
from tensors.config import detect_model_family
assert detect_model_family("model.safetensors", "Illustrious") == "illustrious"
def test_detect_illustrious_from_filename(self) -> None:
"""Test detecting Illustrious family from filename."""
from tensors.config import detect_model_family
assert detect_model_family("illustriousXL_v10.safetensors") == "illustrious"
assert detect_model_family("noobaiXL_v10.safetensors") == "illustrious"
def test_detect_flux_variants(self) -> None:
"""Test detecting Flux family variants."""
from tensors.config import detect_model_family
assert detect_model_family("flux1-dev.safetensors") == "flux"
assert detect_model_family("flux1-schnell.safetensors") == "flux_schnell"
assert detect_model_family("model.safetensors", "Flux.1 D") == "flux"
assert detect_model_family("model.safetensors", "Flux.1 S schnell") == "flux_schnell"
def test_detect_sdxl_variants(self) -> None:
"""Test detecting SDXL family variants."""
from tensors.config import detect_model_family
assert detect_model_family("juggernautXL_v9.safetensors") == "sdxl"
assert detect_model_family("sdxl_lightning_4step.safetensors") == "sdxl_lightning"
assert detect_model_family("sdxl_turbo.safetensors") == "sdxl_turbo"
assert detect_model_family("model.safetensors", "SDXL 1.0") == "sdxl"
assert detect_model_family("model.safetensors", "SDXL Lightning") == "sdxl_lightning"
assert detect_model_family("model.safetensors", "SDXL Turbo") == "sdxl_turbo"
def test_detect_sd15_variants(self) -> None:
"""Test detecting SD 1.5 family variants."""
from tensors.config import detect_model_family
assert detect_model_family("dreamshaper_v8.safetensors") == "sd15"
assert detect_model_family("sd15_lcm.safetensors") == "sd15_lcm"
assert detect_model_family("model.safetensors", "SD 1.5") == "sd15"
assert detect_model_family("model.safetensors", "SD 1.5 LCM") == "sd15_lcm"
def test_detect_unknown_returns_none(self) -> None:
"""Test that unknown models return None."""
from tensors.config import detect_model_family
assert detect_model_family("random_model.safetensors") is None
assert detect_model_family("unknown.safetensors", "Unknown") is None
def test_get_model_generation_defaults_pony(self) -> None:
"""Test getting generation defaults for Pony models."""
from tensors.config import get_model_generation_defaults
defaults = get_model_generation_defaults("ponyDiffusionV6XL.safetensors")
assert defaults["family"] == "pony"
assert defaults["sampler"] == "euler_ancestral"
assert defaults["scheduler"] == "normal"
assert defaults["steps"] == 25
assert defaults["cfg"] == 7.0
def test_get_model_generation_defaults_flux(self) -> None:
"""Test getting generation defaults for Flux models."""
from tensors.config import get_model_generation_defaults
defaults = get_model_generation_defaults("flux1-dev-fp8.safetensors")
assert defaults["family"] == "flux"
assert defaults["sampler"] == "euler"
assert defaults["scheduler"] == "simple"
assert defaults["cfg"] == 3.5
def test_get_model_generation_defaults_flux_schnell(self) -> None:
"""Test getting generation defaults for Flux Schnell models."""
from tensors.config import get_model_generation_defaults
defaults = get_model_generation_defaults("flux1-schnell.safetensors")
assert defaults["family"] == "flux_schnell"
assert defaults["steps"] == 4
assert defaults["cfg"] == 1.0
def test_get_model_generation_defaults_sdxl_lightning(self) -> None:
"""Test getting generation defaults for SDXL Lightning models."""
from tensors.config import get_model_generation_defaults
defaults = get_model_generation_defaults("sdxl_lightning_4step.safetensors")
assert defaults["family"] == "sdxl_lightning"
assert defaults["sampler"] == "euler"
assert defaults["scheduler"] == "sgm_uniform"
assert defaults["steps"] == 8
assert defaults["cfg"] == 2.0
def test_get_model_generation_defaults_unknown_falls_back_to_sdxl(self) -> None:
"""Test that unknown models fall back to SDXL defaults."""
from tensors.config import get_model_generation_defaults
defaults = get_model_generation_defaults("unknown_model.safetensors")
assert defaults["family"] is None
assert defaults["sampler"] == "dpmpp_2m"
assert defaults["scheduler"] == "karras"
class TestDisplayFormatters:
"""Tests for display formatting functions."""