Add comprehensive tests, coverage 21% → 74%

This commit is contained in:
Adam Ladachowski
2026-02-03 23:10:30 +01:00
parent 438f2a93f5
commit 9f8d8d6fcd
4 changed files with 476 additions and 1 deletions
BIN
View File
Binary file not shown.
+1
View File
@@ -29,6 +29,7 @@ dev = [
"pytest>=8.0",
"pytest-cov>=4.1",
"pre-commit>=3.6",
"respx>=0.22.0",
]
[tool.ruff]
+461 -1
View File
@@ -4,13 +4,45 @@ from __future__ import annotations
import struct
from pathlib import Path
from typing import Any
import httpx
import pytest
import respx
from rich.console import Console
from typer.testing import CliRunner
from tensors import config
from tensors.config import get_default_output_path, load_api_key
from tensors.api import (
download_model,
fetch_civitai_by_hash,
fetch_civitai_model,
fetch_civitai_model_version,
search_civitai,
)
from tensors.cli import app
from tensors.config import (
BaseModel,
ModelType,
SortOrder,
get_default_output_path,
load_api_key,
load_config,
save_config,
)
from tensors.display import (
_format_count,
_format_size,
display_civitai_data,
display_file_info,
display_local_metadata,
display_model_info,
display_search_results,
)
from tensors.safetensor import get_base_name, read_safetensor_metadata
runner = CliRunner()
class TestReadSafetensorMetadata:
"""Tests for read_safetensor_metadata function."""
@@ -128,3 +160,431 @@ class TestLoadApiKey:
monkeypatch.setattr(config, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml")
monkeypatch.setattr(config, "LEGACY_RC_FILE", legacy_file)
assert load_api_key() == "legacy-key"
class TestSaveConfig:
"""Tests for save_config function."""
def test_saves_simple_config(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""Test saving a simple config."""
config_dir = tmp_path / "config"
config_file = config_dir / "config.toml"
monkeypatch.setattr(config, "CONFIG_DIR", config_dir)
monkeypatch.setattr(config, "CONFIG_FILE", config_file)
save_config({"key": "value"})
assert config_file.exists()
content = config_file.read_text()
assert 'key = "value"' in content
def test_saves_nested_config(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""Test saving a nested config with sections."""
config_dir = tmp_path / "config"
config_file = config_dir / "config.toml"
monkeypatch.setattr(config, "CONFIG_DIR", config_dir)
monkeypatch.setattr(config, "CONFIG_FILE", config_file)
save_config({"api": {"civitai_key": "test-key"}})
content = config_file.read_text()
assert "[api]" in content
assert 'civitai_key = "test-key"' in content
def test_saves_numeric_values(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""Test saving numeric values without quotes."""
config_dir = tmp_path / "config"
config_file = config_dir / "config.toml"
monkeypatch.setattr(config, "CONFIG_DIR", config_dir)
monkeypatch.setattr(config, "CONFIG_FILE", config_file)
save_config({"timeout": 30})
content = config_file.read_text()
assert "timeout = 30" in content
class TestLoadConfig:
"""Tests for load_config function."""
def test_returns_empty_dict_if_no_config(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""Test that empty dict is returned when config file doesn't exist."""
monkeypatch.setattr(config, "CONFIG_FILE", tmp_path / "nonexistent.toml")
assert load_config() == {}
class TestEnums:
"""Tests for enum to_api methods."""
def test_model_type_to_api(self) -> None:
"""Test ModelType enum to_api conversion."""
assert ModelType.checkpoint.to_api() == "Checkpoint"
assert ModelType.lora.to_api() == "LORA"
assert ModelType.embedding.to_api() == "TextualInversion"
assert ModelType.vae.to_api() == "VAE"
assert ModelType.controlnet.to_api() == "Controlnet"
assert ModelType.locon.to_api() == "LoCon"
def test_base_model_to_api(self) -> None:
"""Test BaseModel enum to_api conversion."""
assert BaseModel.sd15.to_api() == "SD 1.5"
assert BaseModel.sdxl.to_api() == "SDXL 1.0"
assert BaseModel.pony.to_api() == "Pony"
assert BaseModel.flux.to_api() == "Flux.1 D"
assert BaseModel.illustrious.to_api() == "Illustrious"
def test_sort_order_to_api(self) -> None:
"""Test SortOrder enum to_api conversion."""
assert SortOrder.downloads.to_api() == "Most Downloaded"
assert SortOrder.rating.to_api() == "Highest Rated"
assert SortOrder.newest.to_api() == "Newest"
class TestDisplayFormatters:
"""Tests for display formatting functions."""
def test_format_size_kb(self) -> None:
"""Test formatting sizes in KB."""
assert _format_size(500) == "500 KB"
assert _format_size(1023) == "1023 KB"
def test_format_size_mb(self) -> None:
"""Test formatting sizes in MB."""
assert _format_size(1024) == "1.0 MB"
assert _format_size(2048) == "2.0 MB"
assert _format_size(1024 * 500) == "500.0 MB"
def test_format_size_gb(self) -> None:
"""Test formatting sizes in GB."""
assert _format_size(1024 * 1024) == "1.00 GB"
assert _format_size(1024 * 1024 * 2.5) == "2.50 GB"
def test_format_count_small(self) -> None:
"""Test formatting small counts."""
assert _format_count(0) == "0"
assert _format_count(999) == "999"
def test_format_count_thousands(self) -> None:
"""Test formatting counts in thousands."""
assert _format_count(1000) == "1.0K"
assert _format_count(5500) == "5.5K"
assert _format_count(999999) == "1000.0K"
def test_format_count_millions(self) -> None:
"""Test formatting counts in millions."""
assert _format_count(1_000_000) == "1.0M"
assert _format_count(2_500_000) == "2.5M"
class TestDisplayFunctions:
"""Tests for display functions with console output."""
def test_display_file_info(self, temp_safetensor: Path) -> None:
"""Test display_file_info renders without error."""
console = Console(force_terminal=True, width=80)
metadata = read_safetensor_metadata(temp_safetensor)
# Should not raise
display_file_info(temp_safetensor, metadata, "ABC123", console)
def test_display_local_metadata_with_data(self) -> None:
"""Test display_local_metadata with metadata."""
console = Console(force_terminal=True, width=80)
metadata = {"metadata": {"key1": "value1", "key2": "value2"}, "tensor_count": 0, "header_size": 100}
# Should not raise
display_local_metadata(metadata, console)
def test_display_local_metadata_empty(self) -> None:
"""Test display_local_metadata with no metadata."""
console = Console(force_terminal=True, width=80)
metadata: dict[str, Any] = {"metadata": {}, "tensor_count": 0, "header_size": 100}
# Should not raise
display_local_metadata(metadata, console)
def test_display_local_metadata_with_filter(self) -> None:
"""Test display_local_metadata with key filter."""
console = Console(force_terminal=True, width=80)
metadata = {"metadata": {"key1": "value1", "key2": "value2"}, "tensor_count": 0, "header_size": 100}
# Should not raise
display_local_metadata(metadata, console, keys_filter=["key1"])
def test_display_civitai_data_none(self) -> None:
"""Test display_civitai_data with None."""
console = Console(force_terminal=True, width=80)
# Should not raise
display_civitai_data(None, console)
def test_display_civitai_data_with_data(self) -> None:
"""Test display_civitai_data with model data."""
console = Console(force_terminal=True, width=80)
data = {
"modelId": 123,
"id": 456,
"name": "Test Model v1",
"baseModel": "SDXL 1.0",
"createdAt": "2024-01-01",
"trainedWords": ["word1", "word2"],
"downloadUrl": "https://example.com/download",
"files": [
{
"primary": True,
"name": "model.safetensors",
"sizeKB": 5000,
"metadata": {"format": "SafeTensor", "fp": "fp16", "size": "full"},
}
],
}
# Should not raise
display_civitai_data(data, console)
def test_display_model_info(self) -> None:
"""Test display_model_info with model data."""
console = Console(force_terminal=True, width=80)
data = {
"id": 123,
"name": "Test Model",
"type": "LORA",
"nsfw": False,
"creator": {"username": "testuser"},
"tags": ["tag1", "tag2"],
"stats": {"downloadCount": 1000, "thumbsUpCount": 100},
"modelVersions": [
{
"id": 456,
"name": "v1.0",
"baseModel": "SDXL 1.0",
"createdAt": "2024-01-01",
"files": [{"primary": True, "name": "model.safetensors", "sizeKB": 5000}],
}
],
}
# Should not raise
display_model_info(data, console)
def test_display_search_results_empty(self) -> None:
"""Test display_search_results with no results."""
console = Console(force_terminal=True, width=80)
# Should not raise
display_search_results({"items": []}, console)
def test_display_search_results_with_data(self) -> None:
"""Test display_search_results with results."""
console = Console(force_terminal=True, width=80)
results = {
"items": [
{
"id": 123,
"name": "Test Model",
"type": "LORA",
"modelVersions": [{"baseModel": "SDXL 1.0", "files": [{"primary": True, "sizeKB": 5000}]}],
"stats": {"downloadCount": 1000, "thumbsUpCount": 100},
}
],
"metadata": {"totalItems": 1},
}
# Should not raise
display_search_results(results, console)
class TestAPIFunctions:
"""Tests for API functions with mocked HTTP."""
@respx.mock
def test_fetch_model_version_success(self) -> None:
"""Test successful model version fetch."""
console = Console(force_terminal=True, width=80)
respx.get("https://civitai.com/api/v1/model-versions/123").mock(
return_value=httpx.Response(200, json={"id": 123, "name": "Test"})
)
result = fetch_civitai_model_version(123, None, console)
assert result == {"id": 123, "name": "Test"}
@respx.mock
def test_fetch_model_version_not_found(self) -> None:
"""Test model version not found."""
console = Console(force_terminal=True, width=80)
respx.get("https://civitai.com/api/v1/model-versions/999").mock(return_value=httpx.Response(404))
result = fetch_civitai_model_version(999, None, console)
assert result is None
@respx.mock
def test_fetch_model_success(self) -> None:
"""Test successful model fetch."""
console = Console(force_terminal=True, width=80)
respx.get("https://civitai.com/api/v1/models/123").mock(
return_value=httpx.Response(200, json={"id": 123, "name": "Test Model"})
)
result = fetch_civitai_model(123, None, console)
assert result == {"id": 123, "name": "Test Model"}
@respx.mock
def test_fetch_model_not_found(self) -> None:
"""Test model not found."""
console = Console(force_terminal=True, width=80)
respx.get("https://civitai.com/api/v1/models/999").mock(return_value=httpx.Response(404))
result = fetch_civitai_model(999, None, console)
assert result is None
@respx.mock
def test_fetch_by_hash_success(self) -> None:
"""Test successful hash lookup."""
console = Console(force_terminal=True, width=80)
respx.get("https://civitai.com/api/v1/model-versions/by-hash/ABC123").mock(
return_value=httpx.Response(200, json={"id": 456, "name": "Found"})
)
result = fetch_civitai_by_hash("ABC123", None, console)
assert result == {"id": 456, "name": "Found"}
@respx.mock
def test_fetch_by_hash_not_found(self) -> None:
"""Test hash not found."""
console = Console(force_terminal=True, width=80)
respx.get("https://civitai.com/api/v1/model-versions/by-hash/NOTFOUND").mock(return_value=httpx.Response(404))
result = fetch_civitai_by_hash("NOTFOUND", None, console)
assert result is None
@respx.mock
def test_search_civitai_success(self) -> None:
"""Test successful search."""
console = Console(force_terminal=True, width=80)
respx.get("https://civitai.com/api/v1/models").mock(
return_value=httpx.Response(200, json={"items": [{"id": 1}], "metadata": {}})
)
result = search_civitai("test", None, None, SortOrder.downloads, 20, None, console)
assert result is not None
assert len(result["items"]) == 1
@respx.mock
def test_search_civitai_with_filters(self) -> None:
"""Test search with type and base model filters."""
console = Console(force_terminal=True, width=80)
respx.get("https://civitai.com/api/v1/models").mock(
return_value=httpx.Response(200, json={"items": [{"id": 1, "name": "Test LORA"}], "metadata": {}})
)
result = search_civitai("test", ModelType.lora, BaseModel.sdxl, SortOrder.downloads, 20, None, console)
assert result is not None
@respx.mock
def test_download_model_success(self, tmp_path: Path) -> None:
"""Test successful model download."""
console = Console(force_terminal=True, width=80)
dest = tmp_path / "model.safetensors"
respx.get("https://civitai.com/api/download/models/123").mock(
return_value=httpx.Response(200, content=b"fake model data")
)
result = download_model(123, dest, None, console, resume=False)
assert result is True
assert dest.exists()
assert dest.read_bytes() == b"fake model data"
@respx.mock
def test_download_model_unauthorized(self, tmp_path: Path) -> None:
"""Test download with 401 unauthorized."""
console = Console(force_terminal=True, width=80)
dest = tmp_path / "model.safetensors"
respx.get("https://civitai.com/api/download/models/123").mock(return_value=httpx.Response(401))
result = download_model(123, dest, None, console, resume=False)
assert result is False
class TestCLI:
"""Tests for CLI commands."""
def test_help(self) -> None:
"""Test --help works."""
result = runner.invoke(app, ["--help"])
assert result.exit_code == 0
assert "safetensor" in result.stdout.lower()
def test_info_file_not_found(self, tmp_path: Path) -> None:
"""Test info command with non-existent file."""
result = runner.invoke(app, ["info", str(tmp_path / "nonexistent.safetensors")])
assert result.exit_code == 1
assert "not found" in result.stdout.lower()
def test_info_with_safetensor(self, temp_safetensor: Path) -> None:
"""Test info command with valid safetensor file."""
result = runner.invoke(app, ["info", str(temp_safetensor), "--skip-civitai"])
assert result.exit_code == 0
def test_info_json_output(self, temp_safetensor: Path) -> None:
"""Test info command with JSON output."""
result = runner.invoke(app, ["info", str(temp_safetensor), "--skip-civitai", "--json"])
assert result.exit_code == 0
assert "sha256" in result.stdout
def test_info_meta_filter(self, temp_safetensor: Path) -> None:
"""Test info command with metadata filter."""
result = runner.invoke(app, ["info", str(temp_safetensor), "--meta", "test_key"])
assert result.exit_code == 0
assert "test_value" in result.stdout
@respx.mock
def test_search_command(self) -> None:
"""Test search command."""
respx.get("https://civitai.com/api/v1/models").mock(
return_value=httpx.Response(
200,
json={
"items": [{"id": 1, "name": "Test", "type": "LORA", "modelVersions": [], "stats": {}}],
"metadata": {"totalItems": 1},
},
)
)
result = runner.invoke(app, ["search", "test"])
assert result.exit_code == 0
@respx.mock
def test_get_command(self) -> None:
"""Test get command."""
respx.get("https://civitai.com/api/v1/models/123").mock(
return_value=httpx.Response(
200,
json={
"id": 123,
"name": "Test Model",
"type": "LORA",
"nsfw": False,
"stats": {},
"modelVersions": [],
},
)
)
result = runner.invoke(app, ["get", "123"])
assert result.exit_code == 0
@respx.mock
def test_get_command_not_found(self) -> None:
"""Test get command with non-existent model."""
respx.get("https://civitai.com/api/v1/models/999").mock(return_value=httpx.Response(404))
result = runner.invoke(app, ["get", "999"])
assert result.exit_code == 1
assert "not found" in result.stdout.lower()
def test_config_show(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""Test config --show command."""
monkeypatch.delenv("CIVITAI_API_KEY", raising=False)
monkeypatch.setattr(config, "CONFIG_FILE", tmp_path / "config.toml")
monkeypatch.setattr(config, "LEGACY_RC_FILE", tmp_path / "nonexistent")
result = runner.invoke(app, ["config", "--show"])
assert result.exit_code == 0
assert "config file" in result.stdout.lower()
def test_download_no_args(self) -> None:
"""Test dl command with no arguments."""
result = runner.invoke(app, ["dl"])
assert result.exit_code == 1
assert "must specify" in result.stdout.lower()
Generated
+14
View File
@@ -490,6 +490,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" },
]
[[package]]
name = "respx"
version = "0.22.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "httpx" },
]
sdist = { url = "https://files.pythonhosted.org/packages/f4/7c/96bd0bc759cf009675ad1ee1f96535edcb11e9666b985717eb8c87192a95/respx-0.22.0.tar.gz", hash = "sha256:3c8924caa2a50bd71aefc07aa812f2466ff489f1848c96e954a5362d17095d91", size = 28439, upload-time = "2024-12-19T22:33:59.374Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/8e/67/afbb0978d5399bc9ea200f1d4489a23c9a1dad4eee6376242b8182389c79/respx-0.22.0-py2.py3-none-any.whl", hash = "sha256:631128d4c9aba15e56903fb5f66fb1eff412ce28dd387ca3a81339e52dbd3ad0", size = 25127, upload-time = "2024-12-19T22:33:57.837Z" },
]
[[package]]
name = "rich"
version = "14.3.0"
@@ -578,6 +590,7 @@ dev = [
{ name = "pre-commit" },
{ name = "pytest" },
{ name = "pytest-cov" },
{ name = "respx" },
{ name = "ruff" },
]
@@ -596,6 +609,7 @@ dev = [
{ name = "pre-commit", specifier = ">=3.6" },
{ name = "pytest", specifier = ">=8.0" },
{ name = "pytest-cov", specifier = ">=4.1" },
{ name = "respx", specifier = ">=0.22.0" },
{ name = "ruff", specifier = ">=0.9.0" },
]