Files
tensors/tests/test_tensors.py
T
aladac 338a7fe267 fix(generate): dispatch hybrid Flux models to Flux workflow
Models like gonzalomoXLFluxPony are architecturally Flux but CivitAI
tags them as 'Pony', causing the SDXL workflow to be sent to ComfyUI
which fails validation. The filename now overrides base_model when it
contains 'flux'.

Also adds:
- Full Flux Dev/Schnell workflow template (ModelSamplingFlux,
  FluxGuidance, ConditioningZeroOut, EmptySD3LatentImage); KSampler
  cfg locked to 1.0, caller cfg routed to FluxGuidance
- --family/-F flag to manually override family detection
- queue_prompt now surfaces ComfyUI node_errors from 400 responses
- Tests for Flux workflow builder (8 cases) and updated family defaults
2026-05-17 15:50:25 +02:00

884 lines
36 KiB
Python

"""Tests for tensors module."""
from __future__ import annotations
import struct
from pathlib import Path
from typing import Any
import httpx
import pytest
import respx
from rich.console import Console
from typer.testing import CliRunner
from tensors import config
from tensors.api import (
download_model,
fetch_civitai_by_hash,
fetch_civitai_model,
fetch_civitai_model_version,
search_civitai,
)
from tensors.cli import app
from tensors.config import (
BaseModel,
ModelType,
SortOrder,
get_default_output_path,
get_model_paths,
load_api_key,
load_config,
save_config,
)
from tensors.display import (
_format_count,
_format_size,
display_civitai_data,
display_file_info,
display_local_metadata,
display_model_info,
display_search_results,
)
from tensors.safetensor import get_base_name, read_safetensor_metadata
runner = CliRunner()
class TestReadSafetensorMetadata:
"""Tests for read_safetensor_metadata function."""
def test_reads_valid_safetensor(self, temp_safetensor: Path) -> None:
"""Test reading metadata from a valid safetensor file."""
result = read_safetensor_metadata(temp_safetensor)
assert "metadata" in result
assert "tensor_count" in result
assert "header_size" in result
assert result["metadata"]["test_key"] == "test_value"
assert result["tensor_count"] == 0 # No tensors, just metadata
def test_raises_on_short_file(self, tmp_path: Path) -> None:
"""Test that short files raise ValueError."""
short_file = tmp_path / "short.safetensors"
short_file.write_bytes(b"short")
with pytest.raises(ValueError, match="too short"):
read_safetensor_metadata(short_file)
def test_raises_on_truncated_header(self, tmp_path: Path) -> None:
"""Test that truncated headers raise ValueError."""
truncated = tmp_path / "truncated.safetensors"
# Write header size that claims 1000 bytes but only provide 10
with truncated.open("wb") as f:
f.write(struct.pack("<Q", 1000))
f.write(b"x" * 10)
with pytest.raises(ValueError, match="truncated"):
read_safetensor_metadata(truncated)
def test_raises_on_huge_header_size(self, tmp_path: Path) -> None:
"""Test that unreasonably large header sizes raise ValueError."""
huge = tmp_path / "huge.safetensors"
with huge.open("wb") as f:
f.write(struct.pack("<Q", 200_000_000)) # 200MB header
with pytest.raises(ValueError, match="Invalid header size"):
read_safetensor_metadata(huge)
class TestGetBaseName:
"""Tests for get_base_name function."""
def test_removes_safetensors_extension(self) -> None:
"""Test that .safetensors extension is removed."""
assert get_base_name(Path("model.safetensors")) == "model"
def test_removes_sft_extension(self) -> None:
"""Test that .sft extension is removed."""
assert get_base_name(Path("model.sft")) == "model"
def test_handles_uppercase_extension(self) -> None:
"""Test that uppercase extensions are handled."""
assert get_base_name(Path("model.SAFETENSORS")) == "model"
def test_preserves_name_without_known_extension(self) -> None:
"""Test that unknown extensions use stem."""
assert get_base_name(Path("model.bin")) == "model"
class TestGetDefaultOutputPath:
"""Tests for get_default_output_path function."""
def test_returns_checkpoint_path(self) -> None:
"""Test that Checkpoint type returns checkpoints directory."""
result = get_default_output_path("Checkpoint")
assert result is not None
assert "checkpoints" in str(result)
def test_returns_lora_path(self) -> None:
"""Test that LORA type returns loras directory."""
result = get_default_output_path("LORA")
assert result is not None
assert "loras" in str(result)
def test_returns_none_for_unknown_type(self) -> None:
"""Test that unknown types return None."""
assert get_default_output_path("UnknownType") is None
assert get_default_output_path(None) is None
class TestGetModelPaths:
"""Tests for get_model_paths function."""
def test_returns_dict_with_all_types(self) -> None:
"""Test that all model types are included."""
paths = get_model_paths()
assert isinstance(paths, dict)
assert "Checkpoint" in paths
assert "LORA" in paths
assert "LoCon" in paths
assert "TextualInversion" in paths
assert "VAE" in paths
assert "Controlnet" in paths
def test_config_override(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that config.toml paths override defaults."""
# Create a config file with custom path
config_file = tmp_path / "config.toml"
config_file.write_text('[paths]\ncheckpoints = "/custom/checkpoints"\n')
monkeypatch.setattr(config, "CONFIG_FILE", config_file)
paths = get_model_paths()
assert paths["Checkpoint"] == Path("/custom/checkpoints")
# Other types should still be defaults
assert "loras" in str(paths["LORA"])
def test_get_default_output_path_uses_config(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that get_default_output_path respects config overrides."""
config_file = tmp_path / "config.toml"
config_file.write_text('[paths]\nloras = "/custom/loras"\n')
monkeypatch.setattr(config, "CONFIG_FILE", config_file)
result = get_default_output_path("LORA")
assert result == Path("/custom/loras")
# LoCon should also use the loras path
result = get_default_output_path("LoCon")
assert result == Path("/custom/loras")
class TestLoadApiKey:
"""Tests for load_api_key function."""
def test_returns_env_var_if_set(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that environment variable takes precedence."""
monkeypatch.setenv("CIVITAI_API_KEY", "test-key-from-env")
assert load_api_key() == "test-key-from-env"
def test_returns_none_if_no_key(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""Test that None is returned when no key is available."""
monkeypatch.delenv("CIVITAI_API_KEY", raising=False)
# Point config and legacy files to nonexistent paths
monkeypatch.setattr(config, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml")
monkeypatch.setattr(config, "LEGACY_RC_FILE", tmp_path / "nonexistent")
assert load_api_key() is None
def test_returns_key_from_config_file(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""Test that key is loaded from TOML config file."""
monkeypatch.delenv("CIVITAI_API_KEY", raising=False)
config_file = tmp_path / "config.toml"
config_file.write_text('[api]\ncivitai_key = "key-from-config"\n')
monkeypatch.setattr(config, "CONFIG_FILE", config_file)
monkeypatch.setattr(config, "LEGACY_RC_FILE", tmp_path / "nonexistent")
assert load_api_key() == "key-from-config"
def test_returns_key_from_legacy_file(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""Test that key is loaded from legacy RC file when no config exists."""
monkeypatch.delenv("CIVITAI_API_KEY", raising=False)
legacy_file = tmp_path / ".sftrc"
legacy_file.write_text("legacy-key")
monkeypatch.setattr(config, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml")
monkeypatch.setattr(config, "LEGACY_RC_FILE", legacy_file)
assert load_api_key() == "legacy-key"
class TestSaveConfig:
"""Tests for save_config function."""
def test_saves_simple_config(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""Test saving a simple config."""
config_dir = tmp_path / "config"
config_file = config_dir / "config.toml"
monkeypatch.setattr(config, "CONFIG_DIR", config_dir)
monkeypatch.setattr(config, "CONFIG_FILE", config_file)
save_config({"key": "value"})
assert config_file.exists()
content = config_file.read_text()
assert 'key = "value"' in content
def test_saves_nested_config(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""Test saving a nested config with sections."""
config_dir = tmp_path / "config"
config_file = config_dir / "config.toml"
monkeypatch.setattr(config, "CONFIG_DIR", config_dir)
monkeypatch.setattr(config, "CONFIG_FILE", config_file)
save_config({"api": {"civitai_key": "test-key"}})
content = config_file.read_text()
assert "[api]" in content
assert 'civitai_key = "test-key"' in content
def test_saves_numeric_values(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""Test saving numeric values without quotes."""
config_dir = tmp_path / "config"
config_file = config_dir / "config.toml"
monkeypatch.setattr(config, "CONFIG_DIR", config_dir)
monkeypatch.setattr(config, "CONFIG_FILE", config_file)
save_config({"timeout": 30})
content = config_file.read_text()
assert "timeout = 30" in content
class TestLoadConfig:
"""Tests for load_config function."""
def test_returns_empty_dict_if_no_config(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""Test that empty dict is returned when config file doesn't exist."""
monkeypatch.setattr(config, "CONFIG_FILE", tmp_path / "nonexistent.toml")
assert load_config() == {}
class TestEnums:
"""Tests for enum to_api methods."""
def test_model_type_to_api(self) -> None:
"""Test ModelType enum to_api conversion."""
assert ModelType.checkpoint.to_api() == "Checkpoint"
assert ModelType.lora.to_api() == "LORA"
assert ModelType.embedding.to_api() == "TextualInversion"
assert ModelType.vae.to_api() == "VAE"
assert ModelType.controlnet.to_api() == "Controlnet"
assert ModelType.locon.to_api() == "LoCon"
def test_base_model_to_api(self) -> None:
"""Test BaseModel enum to_api conversion."""
assert BaseModel.sd15.to_api() == "SD 1.5"
assert BaseModel.sdxl.to_api() == "SDXL 1.0"
assert BaseModel.pony.to_api() == "Pony"
assert BaseModel.flux_dev.to_api() == "Flux.1 D"
assert BaseModel.illustrious.to_api() == "Illustrious"
def test_sort_order_to_api(self) -> None:
"""Test SortOrder enum to_api conversion."""
assert SortOrder.downloads.to_api() == "Most Downloaded"
assert SortOrder.rating.to_api() == "Highest Rated"
assert SortOrder.newest.to_api() == "Newest"
class TestModelFamilyDetection:
"""Tests for detect_model_family and get_model_generation_defaults."""
def test_detect_pony_from_base_model(self) -> None:
"""Test detecting Pony family from base_model field."""
from tensors.config import detect_model_family
assert detect_model_family("model.safetensors", "Pony") == "pony"
assert detect_model_family("anything.safetensors", "PONY") == "pony"
def test_detect_pony_from_filename(self) -> None:
"""Test detecting Pony family from filename."""
from tensors.config import detect_model_family
assert detect_model_family("ponyDiffusionV6XL.safetensors") == "pony"
assert detect_model_family("autismmixPony_v10.safetensors") == "pony"
def test_detect_illustrious_from_base_model(self) -> None:
"""Test detecting Illustrious family from base_model field."""
from tensors.config import detect_model_family
assert detect_model_family("model.safetensors", "Illustrious") == "illustrious"
def test_detect_illustrious_from_filename(self) -> None:
"""Test detecting Illustrious family from filename."""
from tensors.config import detect_model_family
assert detect_model_family("illustriousXL_v10.safetensors") == "illustrious"
assert detect_model_family("noobaiXL_v10.safetensors") == "illustrious"
def test_detect_flux_variants(self) -> None:
"""Test detecting Flux family variants."""
from tensors.config import detect_model_family
assert detect_model_family("flux1-dev.safetensors") == "flux"
assert detect_model_family("flux1-schnell.safetensors") == "flux_schnell"
assert detect_model_family("model.safetensors", "Flux.1 D") == "flux"
assert detect_model_family("model.safetensors", "Flux.1 S schnell") == "flux_schnell"
def test_detect_sdxl_variants(self) -> None:
"""Test detecting SDXL family variants."""
from tensors.config import detect_model_family
assert detect_model_family("juggernautXL_v9.safetensors") == "sdxl"
assert detect_model_family("sdxl_lightning_4step.safetensors") == "sdxl_lightning"
assert detect_model_family("sdxl_turbo.safetensors") == "sdxl_turbo"
assert detect_model_family("model.safetensors", "SDXL 1.0") == "sdxl"
assert detect_model_family("model.safetensors", "SDXL Lightning") == "sdxl_lightning"
assert detect_model_family("model.safetensors", "SDXL Turbo") == "sdxl_turbo"
def test_detect_sd15_variants(self) -> None:
"""Test detecting SD 1.5 family variants."""
from tensors.config import detect_model_family
assert detect_model_family("dreamshaper_v8.safetensors") == "sd15"
assert detect_model_family("sd15_lcm.safetensors") == "sd15_lcm"
assert detect_model_family("model.safetensors", "SD 1.5") == "sd15"
assert detect_model_family("model.safetensors", "SD 1.5 LCM") == "sd15_lcm"
def test_detect_unknown_returns_none(self) -> None:
"""Test that unknown models return None."""
from tensors.config import detect_model_family
assert detect_model_family("random_model.safetensors") is None
assert detect_model_family("unknown.safetensors", "Unknown") is None
def test_get_model_generation_defaults_pony(self) -> None:
"""Test getting generation defaults for Pony models."""
from tensors.config import get_model_generation_defaults
defaults = get_model_generation_defaults("ponyDiffusionV6XL.safetensors")
assert defaults["family"] == "pony"
assert defaults["sampler"] == "euler_ancestral"
assert defaults["scheduler"] == "normal"
assert defaults["steps"] == 25
assert defaults["cfg"] == 6.5
def test_get_model_generation_defaults_flux(self) -> None:
"""Test getting generation defaults for Flux models.
Flux Dev is guidance-distilled: KSampler.cfg MUST be 1.0; the real
prompt-adherence dial is the FluxGuidance node's ``guidance`` value.
See https://comfyanonymous.github.io/ComfyUI_examples/flux/
"""
from tensors.config import get_model_generation_defaults
defaults = get_model_generation_defaults("flux1-dev-fp8.safetensors")
assert defaults["family"] == "flux"
assert defaults["sampler"] == "euler"
assert defaults["scheduler"] == "simple"
assert defaults["cfg"] == 1.0
assert defaults["guidance"] == 3.5
def test_get_model_generation_defaults_flux_schnell(self) -> None:
"""Test getting generation defaults for Flux Schnell models."""
from tensors.config import get_model_generation_defaults
defaults = get_model_generation_defaults("flux1-schnell.safetensors")
assert defaults["family"] == "flux_schnell"
assert defaults["steps"] == 4
assert defaults["cfg"] == 1.0
assert defaults["guidance"] == 3.5
def test_detect_zimage(self) -> None:
"""Test detecting ZImageTurbo family."""
from tensors.config import detect_model_family
assert detect_model_family("zimageturbo_v1.safetensors") == "zimage"
assert detect_model_family("ZIMAGE_xl.safetensors") == "zimage"
assert detect_model_family("model.safetensors", "ZImageTurbo") == "zimage"
def test_get_model_generation_defaults_zimage(self) -> None:
"""Test getting generation defaults for ZImageTurbo models."""
from tensors.config import get_model_generation_defaults
defaults = get_model_generation_defaults("zimageturbo_v1.safetensors")
assert defaults["family"] == "zimage"
assert defaults["sampler"] == "euler"
assert defaults["scheduler"] == "simple"
assert defaults["steps"] == 4
assert defaults["cfg"] == 1.0
assert defaults["vae"] == "ae.safetensors"
def test_flux_uses_ae_vae(self) -> None:
"""Test that Flux models use ae.safetensors VAE."""
from tensors.config import get_model_generation_defaults
defaults = get_model_generation_defaults("flux1-dev-fp8.safetensors")
assert defaults["vae"] == "ae.safetensors"
defaults_schnell = get_model_generation_defaults("flux1-schnell.safetensors")
assert defaults_schnell["vae"] == "ae.safetensors"
def test_get_model_generation_defaults_sdxl_lightning(self) -> None:
"""Test getting generation defaults for SDXL Lightning models."""
from tensors.config import get_model_generation_defaults
defaults = get_model_generation_defaults("sdxl_lightning_4step.safetensors")
assert defaults["family"] == "sdxl_lightning"
assert defaults["sampler"] == "euler"
assert defaults["scheduler"] == "sgm_uniform"
assert defaults["steps"] == 8
assert defaults["cfg"] == 2.0
def test_get_model_generation_defaults_unknown_falls_back_to_sdxl(self) -> None:
"""Test that unknown models fall back to SDXL defaults."""
from tensors.config import get_model_generation_defaults
defaults = get_model_generation_defaults("unknown_model.safetensors")
assert defaults["family"] is None
assert defaults["sampler"] == "dpmpp_2m"
assert defaults["scheduler"] == "karras"
class TestFluxWorkflowBuilder:
"""Tests for the Flux-specific branch of _build_workflow."""
def test_flux_dispatch_uses_flux_template(self) -> None:
"""Building a workflow for a Flux model emits the Flux node graph."""
from tensors.comfyui import _build_workflow
wf = _build_workflow(prompt="a cat", model="flux1-dev-fp8.safetensors")
# Flux template uses node IDs in the 100s; default SDXL template uses single digits.
assert "100" in wf and wf["100"]["class_type"] == "CheckpointLoaderSimple"
assert "120" in wf and wf["120"]["class_type"] == "ModelSamplingFlux"
assert "140" in wf and wf["140"]["class_type"] == "FluxGuidance"
assert "132" in wf and wf["132"]["class_type"] == "ConditioningZeroOut"
assert "150" in wf and wf["150"]["class_type"] == "EmptySD3LatentImage"
assert "3" not in wf # default SDXL KSampler ID must NOT be present
def test_flux_ksampler_cfg_locked_to_one(self) -> None:
"""KSampler cfg MUST be 1.0 for Flux Dev — caller cfg must NOT leak through."""
from tensors.comfyui import _build_workflow
wf = _build_workflow(prompt="a cat", model="flux1-dev-fp8.safetensors", cfg=7.5)
assert wf["160"]["inputs"]["cfg"] == 1.0
# The caller's cfg=7.5 should be re-routed to FluxGuidance
assert wf["140"]["inputs"]["guidance"] == 7.5
def test_flux_explicit_guidance_wins_over_cfg(self) -> None:
"""Explicit guidance overrides re-interpreted cfg."""
from tensors.comfyui import _build_workflow
wf = _build_workflow(prompt="a cat", model="flux1-dev-fp8.safetensors", cfg=7.5, guidance=4.0)
assert wf["140"]["inputs"]["guidance"] == 4.0
def test_flux_default_guidance_from_preset(self) -> None:
"""No caller value -> preset guidance (3.5) wins."""
from tensors.comfyui import _build_workflow
wf = _build_workflow(prompt="a cat", model="flux1-dev-fp8.safetensors")
assert wf["140"]["inputs"]["guidance"] == 3.5
def test_flux_lora_injection(self) -> None:
"""LoRA injects node 110 and reroutes ModelSamplingFlux + CLIPTextEncodes."""
from tensors.comfyui import _build_workflow
wf = _build_workflow(
prompt="a cat",
model="flux1-dev-fp8.safetensors",
lora_name="my_style.safetensors",
lora_strength=0.7,
)
assert "110" in wf and wf["110"]["class_type"] == "LoraLoader"
assert wf["110"]["inputs"]["lora_name"] == "my_style.safetensors"
assert wf["110"]["inputs"]["strength_model"] == 0.7
# Downstream consumers must read from the LoRA node
assert wf["120"]["inputs"]["model"] == ["110", 0]
assert wf["130"]["inputs"]["clip"] == ["110", 1]
assert wf["131"]["inputs"]["clip"] == ["110", 1]
def test_flux_external_vae_swaps_decoder_input(self) -> None:
"""Providing an external VAE adds node 171 (VAELoader) and rewires VAEDecode."""
from tensors.comfyui import _build_workflow
wf = _build_workflow(
prompt="a cat",
model="flux1-dev-fp8.safetensors",
vae="ae.safetensors",
)
assert "171" in wf and wf["171"]["class_type"] == "VAELoader"
assert wf["171"]["inputs"]["vae_name"] == "ae.safetensors"
assert wf["170"]["inputs"]["vae"] == ["171", 0]
def test_flux_model_sampling_dimensions_match_latent(self) -> None:
"""ModelSamplingFlux width/height must equal the latent dimensions for correct shift."""
from tensors.comfyui import _build_workflow
wf = _build_workflow(
prompt="a cat",
model="flux1-dev-fp8.safetensors",
width=1216,
height=832,
)
assert wf["120"]["inputs"]["width"] == 1216
assert wf["120"]["inputs"]["height"] == 832
assert wf["150"]["inputs"]["width"] == 1216
assert wf["150"]["inputs"]["height"] == 832
def test_non_flux_model_uses_default_template(self) -> None:
"""SDXL/Pony/etc. checkpoints continue to use the legacy template."""
from tensors.comfyui import _build_workflow
wf = _build_workflow(prompt="a cat", model="ponyDiffusionV6XL.safetensors")
# Default SDXL template has KSampler at node "3"
assert "3" in wf and wf["3"]["class_type"] == "KSampler"
# Flux-specific nodes must NOT be present
assert "140" not in wf
assert "120" not in wf
class TestDisplayFormatters:
"""Tests for display formatting functions."""
def test_format_size_kb(self) -> None:
"""Test formatting sizes in KB."""
assert _format_size(500) == "500 KB"
assert _format_size(1023) == "1023 KB"
def test_format_size_mb(self) -> None:
"""Test formatting sizes in MB."""
assert _format_size(1024) == "1.0 MB"
assert _format_size(2048) == "2.0 MB"
assert _format_size(1024 * 500) == "500.0 MB"
def test_format_size_gb(self) -> None:
"""Test formatting sizes in GB."""
assert _format_size(1024 * 1024) == "1.00 GB"
assert _format_size(1024 * 1024 * 2.5) == "2.50 GB"
def test_format_count_small(self) -> None:
"""Test formatting small counts."""
assert _format_count(0) == "0"
assert _format_count(999) == "999"
def test_format_count_thousands(self) -> None:
"""Test formatting counts in thousands."""
assert _format_count(1000) == "1.0K"
assert _format_count(5500) == "5.5K"
assert _format_count(999999) == "1000.0K"
def test_format_count_millions(self) -> None:
"""Test formatting counts in millions."""
assert _format_count(1_000_000) == "1.0M"
assert _format_count(2_500_000) == "2.5M"
class TestDisplayFunctions:
"""Tests for display functions with console output."""
def test_display_file_info(self, temp_safetensor: Path) -> None:
"""Test display_file_info renders without error."""
console = Console(force_terminal=True, width=80)
metadata = read_safetensor_metadata(temp_safetensor)
# Should not raise
display_file_info(temp_safetensor, metadata, "ABC123", console)
def test_display_local_metadata_with_data(self) -> None:
"""Test display_local_metadata with metadata."""
console = Console(force_terminal=True, width=80)
metadata = {"metadata": {"key1": "value1", "key2": "value2"}, "tensor_count": 0, "header_size": 100}
# Should not raise
display_local_metadata(metadata, console)
def test_display_local_metadata_empty(self) -> None:
"""Test display_local_metadata with no metadata."""
console = Console(force_terminal=True, width=80)
metadata: dict[str, Any] = {"metadata": {}, "tensor_count": 0, "header_size": 100}
# Should not raise
display_local_metadata(metadata, console)
def test_display_local_metadata_with_filter(self) -> None:
"""Test display_local_metadata with key filter."""
console = Console(force_terminal=True, width=80)
metadata = {"metadata": {"key1": "value1", "key2": "value2"}, "tensor_count": 0, "header_size": 100}
# Should not raise
display_local_metadata(metadata, console, keys_filter=["key1"])
def test_display_civitai_data_none(self) -> None:
"""Test display_civitai_data with None."""
console = Console(force_terminal=True, width=80)
# Should not raise
display_civitai_data(None, console)
def test_display_civitai_data_with_data(self) -> None:
"""Test display_civitai_data with model data."""
console = Console(force_terminal=True, width=80)
data = {
"modelId": 123,
"id": 456,
"name": "Test Model v1",
"baseModel": "SDXL 1.0",
"createdAt": "2024-01-01",
"trainedWords": ["word1", "word2"],
"downloadUrl": "https://example.com/download",
"files": [
{
"primary": True,
"name": "model.safetensors",
"sizeKB": 5000,
"metadata": {"format": "SafeTensor", "fp": "fp16", "size": "full"},
}
],
}
# Should not raise
display_civitai_data(data, console)
def test_display_model_info(self) -> None:
"""Test display_model_info with model data."""
console = Console(force_terminal=True, width=80)
data = {
"id": 123,
"name": "Test Model",
"type": "LORA",
"nsfw": False,
"creator": {"username": "testuser"},
"tags": ["tag1", "tag2"],
"stats": {"downloadCount": 1000, "thumbsUpCount": 100},
"modelVersions": [
{
"id": 456,
"name": "v1.0",
"baseModel": "SDXL 1.0",
"createdAt": "2024-01-01",
"files": [{"primary": True, "name": "model.safetensors", "sizeKB": 5000}],
}
],
}
# Should not raise
display_model_info(data, console)
def test_display_search_results_empty(self) -> None:
"""Test display_search_results with no results."""
console = Console(force_terminal=True, width=80)
# Should not raise
display_search_results({"items": []}, console)
def test_display_search_results_with_data(self) -> None:
"""Test display_search_results with results."""
console = Console(force_terminal=True, width=80)
results = {
"items": [
{
"id": 123,
"name": "Test Model",
"type": "LORA",
"modelVersions": [{"baseModel": "SDXL 1.0", "files": [{"primary": True, "sizeKB": 5000}]}],
"stats": {"downloadCount": 1000, "thumbsUpCount": 100},
}
],
"metadata": {"totalItems": 1},
}
# Should not raise
display_search_results(results, console)
class TestAPIFunctions:
"""Tests for API functions with mocked HTTP."""
@respx.mock
def test_fetch_model_version_success(self) -> None:
"""Test successful model version fetch."""
console = Console(force_terminal=True, width=80)
respx.get("https://civitai.com/api/v1/model-versions/123").mock(
return_value=httpx.Response(200, json={"id": 123, "name": "Test"})
)
result = fetch_civitai_model_version(123, None, console)
assert result == {"id": 123, "name": "Test"}
@respx.mock
def test_fetch_model_version_not_found(self) -> None:
"""Test model version not found."""
console = Console(force_terminal=True, width=80)
respx.get("https://civitai.com/api/v1/model-versions/999").mock(return_value=httpx.Response(404))
result = fetch_civitai_model_version(999, None, console)
assert result is None
@respx.mock
def test_fetch_model_success(self) -> None:
"""Test successful model fetch."""
console = Console(force_terminal=True, width=80)
respx.get("https://civitai.com/api/v1/models/123").mock(
return_value=httpx.Response(200, json={"id": 123, "name": "Test Model"})
)
result = fetch_civitai_model(123, None, console)
assert result == {"id": 123, "name": "Test Model"}
@respx.mock
def test_fetch_model_not_found(self) -> None:
"""Test model not found."""
console = Console(force_terminal=True, width=80)
respx.get("https://civitai.com/api/v1/models/999").mock(return_value=httpx.Response(404))
result = fetch_civitai_model(999, None, console)
assert result is None
@respx.mock
def test_fetch_by_hash_success(self) -> None:
"""Test successful hash lookup."""
console = Console(force_terminal=True, width=80)
respx.get("https://civitai.com/api/v1/model-versions/by-hash/ABC123").mock(
return_value=httpx.Response(200, json={"id": 456, "name": "Found"})
)
result = fetch_civitai_by_hash("ABC123", None, console)
assert result == {"id": 456, "name": "Found"}
@respx.mock
def test_fetch_by_hash_not_found(self) -> None:
"""Test hash not found."""
console = Console(force_terminal=True, width=80)
respx.get("https://civitai.com/api/v1/model-versions/by-hash/NOTFOUND").mock(return_value=httpx.Response(404))
result = fetch_civitai_by_hash("NOTFOUND", None, console)
assert result is None
@respx.mock
def test_search_civitai_success(self) -> None:
"""Test successful search."""
console = Console(force_terminal=True, width=80)
respx.get("https://civitai.com/api/v1/models").mock(
return_value=httpx.Response(200, json={"items": [{"id": 1}], "metadata": {}})
)
result = search_civitai("test", None, None, SortOrder.downloads, 20, None, console)
assert result is not None
assert len(result["items"]) == 1
@respx.mock
def test_search_civitai_with_filters(self) -> None:
"""Test search with type and base model filters."""
console = Console(force_terminal=True, width=80)
respx.get("https://civitai.com/api/v1/models").mock(
return_value=httpx.Response(200, json={"items": [{"id": 1, "name": "Test LORA"}], "metadata": {}})
)
result = search_civitai("test", ModelType.lora, BaseModel.sdxl, SortOrder.downloads, 20, None, console)
assert result is not None
@respx.mock
def test_download_model_success(self, tmp_path: Path) -> None:
"""Test successful model download."""
console = Console(force_terminal=True, width=80)
dest = tmp_path / "model.safetensors"
respx.get("https://civitai.com/api/download/models/123").mock(
return_value=httpx.Response(200, content=b"fake model data")
)
result = download_model(123, dest, None, console, resume=False)
assert result is True
assert dest.exists()
assert dest.read_bytes() == b"fake model data"
@respx.mock
def test_download_model_unauthorized(self, tmp_path: Path) -> None:
"""Test download with 401 unauthorized."""
console = Console(force_terminal=True, width=80)
dest = tmp_path / "model.safetensors"
respx.get("https://civitai.com/api/download/models/123").mock(return_value=httpx.Response(401))
result = download_model(123, dest, None, console, resume=False)
assert result is False
class TestCLI:
"""Tests for CLI commands."""
def test_help(self) -> None:
"""Test --help works."""
result = runner.invoke(app, ["--help"])
assert result.exit_code == 0
assert "safetensor" in result.stdout.lower()
def test_info_file_not_found(self, tmp_path: Path) -> None:
"""Test info command with non-existent file."""
result = runner.invoke(app, ["info", str(tmp_path / "nonexistent.safetensors")])
assert result.exit_code == 1
assert "not found" in result.stdout.lower()
def test_info_with_safetensor(self, temp_safetensor: Path) -> None:
"""Test info command with valid safetensor file."""
result = runner.invoke(app, ["info", str(temp_safetensor), "--skip-civitai"])
assert result.exit_code == 0
def test_info_json_output(self, temp_safetensor: Path) -> None:
"""Test info command with JSON output."""
result = runner.invoke(app, ["info", str(temp_safetensor), "--skip-civitai", "--json"])
assert result.exit_code == 0
assert "sha256" in result.stdout
def test_info_meta_filter(self, temp_safetensor: Path) -> None:
"""Test info command with metadata filter."""
result = runner.invoke(app, ["info", str(temp_safetensor), "--meta", "test_key"])
assert result.exit_code == 0
assert "test_value" in result.stdout
@respx.mock
def test_search_command(self) -> None:
"""Test search command."""
respx.get("https://civitai.com/api/v1/models").mock(
return_value=httpx.Response(
200,
json={
"items": [{"id": 1, "name": "Test", "type": "LORA", "modelVersions": [], "stats": {}}],
"metadata": {"totalItems": 1},
},
)
)
result = runner.invoke(app, ["search", "test"])
assert result.exit_code == 0
@respx.mock
def test_get_command(self) -> None:
"""Test get command."""
respx.get("https://civitai.com/api/v1/models/123").mock(
return_value=httpx.Response(
200,
json={
"id": 123,
"name": "Test Model",
"type": "LORA",
"nsfw": False,
"stats": {},
"modelVersions": [],
},
)
)
result = runner.invoke(app, ["get", "123"])
assert result.exit_code == 0
@respx.mock
def test_get_command_not_found(self) -> None:
"""Test get command with non-existent model."""
respx.get("https://civitai.com/api/v1/models/999").mock(return_value=httpx.Response(404))
result = runner.invoke(app, ["get", "999"])
assert result.exit_code == 1
assert "not found" in result.stdout.lower()
def test_config_show(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""Test config --show command."""
monkeypatch.delenv("CIVITAI_API_KEY", raising=False)
monkeypatch.setattr(config, "CONFIG_FILE", tmp_path / "config.toml")
monkeypatch.setattr(config, "LEGACY_RC_FILE", tmp_path / "nonexistent")
result = runner.invoke(app, ["config", "--show"])
assert result.exit_code == 0
assert "config file" in result.stdout.lower()
def test_download_no_args(self) -> None:
"""Test dl command with no arguments."""
result = runner.invoke(app, ["dl"])
assert result.exit_code == 1
assert "must specify" in result.stdout.lower()