Update 2026-03-20 09:07
This commit is contained in:
@@ -281,6 +281,123 @@ class TestEnums:
|
||||
assert SortOrder.newest.to_api() == "Newest"
|
||||
|
||||
|
||||
class TestModelFamilyDetection:
|
||||
"""Tests for detect_model_family and get_model_generation_defaults."""
|
||||
|
||||
def test_detect_pony_from_base_model(self) -> None:
|
||||
"""Test detecting Pony family from base_model field."""
|
||||
from tensors.config import detect_model_family
|
||||
|
||||
assert detect_model_family("model.safetensors", "Pony") == "pony"
|
||||
assert detect_model_family("anything.safetensors", "PONY") == "pony"
|
||||
|
||||
def test_detect_pony_from_filename(self) -> None:
|
||||
"""Test detecting Pony family from filename."""
|
||||
from tensors.config import detect_model_family
|
||||
|
||||
assert detect_model_family("ponyDiffusionV6XL.safetensors") == "pony"
|
||||
assert detect_model_family("autismmixPony_v10.safetensors") == "pony"
|
||||
|
||||
def test_detect_illustrious_from_base_model(self) -> None:
|
||||
"""Test detecting Illustrious family from base_model field."""
|
||||
from tensors.config import detect_model_family
|
||||
|
||||
assert detect_model_family("model.safetensors", "Illustrious") == "illustrious"
|
||||
|
||||
def test_detect_illustrious_from_filename(self) -> None:
|
||||
"""Test detecting Illustrious family from filename."""
|
||||
from tensors.config import detect_model_family
|
||||
|
||||
assert detect_model_family("illustriousXL_v10.safetensors") == "illustrious"
|
||||
assert detect_model_family("noobaiXL_v10.safetensors") == "illustrious"
|
||||
|
||||
def test_detect_flux_variants(self) -> None:
|
||||
"""Test detecting Flux family variants."""
|
||||
from tensors.config import detect_model_family
|
||||
|
||||
assert detect_model_family("flux1-dev.safetensors") == "flux"
|
||||
assert detect_model_family("flux1-schnell.safetensors") == "flux_schnell"
|
||||
assert detect_model_family("model.safetensors", "Flux.1 D") == "flux"
|
||||
assert detect_model_family("model.safetensors", "Flux.1 S schnell") == "flux_schnell"
|
||||
|
||||
def test_detect_sdxl_variants(self) -> None:
|
||||
"""Test detecting SDXL family variants."""
|
||||
from tensors.config import detect_model_family
|
||||
|
||||
assert detect_model_family("juggernautXL_v9.safetensors") == "sdxl"
|
||||
assert detect_model_family("sdxl_lightning_4step.safetensors") == "sdxl_lightning"
|
||||
assert detect_model_family("sdxl_turbo.safetensors") == "sdxl_turbo"
|
||||
assert detect_model_family("model.safetensors", "SDXL 1.0") == "sdxl"
|
||||
assert detect_model_family("model.safetensors", "SDXL Lightning") == "sdxl_lightning"
|
||||
assert detect_model_family("model.safetensors", "SDXL Turbo") == "sdxl_turbo"
|
||||
|
||||
def test_detect_sd15_variants(self) -> None:
|
||||
"""Test detecting SD 1.5 family variants."""
|
||||
from tensors.config import detect_model_family
|
||||
|
||||
assert detect_model_family("dreamshaper_v8.safetensors") == "sd15"
|
||||
assert detect_model_family("sd15_lcm.safetensors") == "sd15_lcm"
|
||||
assert detect_model_family("model.safetensors", "SD 1.5") == "sd15"
|
||||
assert detect_model_family("model.safetensors", "SD 1.5 LCM") == "sd15_lcm"
|
||||
|
||||
def test_detect_unknown_returns_none(self) -> None:
|
||||
"""Test that unknown models return None."""
|
||||
from tensors.config import detect_model_family
|
||||
|
||||
assert detect_model_family("random_model.safetensors") is None
|
||||
assert detect_model_family("unknown.safetensors", "Unknown") is None
|
||||
|
||||
def test_get_model_generation_defaults_pony(self) -> None:
|
||||
"""Test getting generation defaults for Pony models."""
|
||||
from tensors.config import get_model_generation_defaults
|
||||
|
||||
defaults = get_model_generation_defaults("ponyDiffusionV6XL.safetensors")
|
||||
assert defaults["family"] == "pony"
|
||||
assert defaults["sampler"] == "euler_ancestral"
|
||||
assert defaults["scheduler"] == "normal"
|
||||
assert defaults["steps"] == 25
|
||||
assert defaults["cfg"] == 7.0
|
||||
|
||||
def test_get_model_generation_defaults_flux(self) -> None:
|
||||
"""Test getting generation defaults for Flux models."""
|
||||
from tensors.config import get_model_generation_defaults
|
||||
|
||||
defaults = get_model_generation_defaults("flux1-dev-fp8.safetensors")
|
||||
assert defaults["family"] == "flux"
|
||||
assert defaults["sampler"] == "euler"
|
||||
assert defaults["scheduler"] == "simple"
|
||||
assert defaults["cfg"] == 3.5
|
||||
|
||||
def test_get_model_generation_defaults_flux_schnell(self) -> None:
|
||||
"""Test getting generation defaults for Flux Schnell models."""
|
||||
from tensors.config import get_model_generation_defaults
|
||||
|
||||
defaults = get_model_generation_defaults("flux1-schnell.safetensors")
|
||||
assert defaults["family"] == "flux_schnell"
|
||||
assert defaults["steps"] == 4
|
||||
assert defaults["cfg"] == 1.0
|
||||
|
||||
def test_get_model_generation_defaults_sdxl_lightning(self) -> None:
|
||||
"""Test getting generation defaults for SDXL Lightning models."""
|
||||
from tensors.config import get_model_generation_defaults
|
||||
|
||||
defaults = get_model_generation_defaults("sdxl_lightning_4step.safetensors")
|
||||
assert defaults["family"] == "sdxl_lightning"
|
||||
assert defaults["sampler"] == "euler"
|
||||
assert defaults["scheduler"] == "sgm_uniform"
|
||||
assert defaults["steps"] == 8
|
||||
assert defaults["cfg"] == 2.0
|
||||
|
||||
def test_get_model_generation_defaults_unknown_falls_back_to_sdxl(self) -> None:
|
||||
"""Test that unknown models fall back to SDXL defaults."""
|
||||
from tensors.config import get_model_generation_defaults
|
||||
|
||||
defaults = get_model_generation_defaults("unknown_model.safetensors")
|
||||
assert defaults["family"] is None
|
||||
assert defaults["sampler"] == "dpmpp_2m"
|
||||
assert defaults["scheduler"] == "karras"
|
||||
|
||||
|
||||
class TestDisplayFormatters:
|
||||
"""Tests for display formatting functions."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user