Add comprehensive tests, coverage 21% → 74%
This commit is contained in:
@@ -29,6 +29,7 @@ dev = [
|
||||
"pytest>=8.0",
|
||||
"pytest-cov>=4.1",
|
||||
"pre-commit>=3.6",
|
||||
"respx>=0.22.0",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
||||
+461
-1
@@ -4,13 +4,45 @@ from __future__ import annotations
|
||||
|
||||
import struct
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
from rich.console import Console
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from tensors import config
|
||||
from tensors.config import get_default_output_path, load_api_key
|
||||
from tensors.api import (
|
||||
download_model,
|
||||
fetch_civitai_by_hash,
|
||||
fetch_civitai_model,
|
||||
fetch_civitai_model_version,
|
||||
search_civitai,
|
||||
)
|
||||
from tensors.cli import app
|
||||
from tensors.config import (
|
||||
BaseModel,
|
||||
ModelType,
|
||||
SortOrder,
|
||||
get_default_output_path,
|
||||
load_api_key,
|
||||
load_config,
|
||||
save_config,
|
||||
)
|
||||
from tensors.display import (
|
||||
_format_count,
|
||||
_format_size,
|
||||
display_civitai_data,
|
||||
display_file_info,
|
||||
display_local_metadata,
|
||||
display_model_info,
|
||||
display_search_results,
|
||||
)
|
||||
from tensors.safetensor import get_base_name, read_safetensor_metadata
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
class TestReadSafetensorMetadata:
|
||||
"""Tests for read_safetensor_metadata function."""
|
||||
@@ -128,3 +160,431 @@ class TestLoadApiKey:
|
||||
monkeypatch.setattr(config, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml")
|
||||
monkeypatch.setattr(config, "LEGACY_RC_FILE", legacy_file)
|
||||
assert load_api_key() == "legacy-key"
|
||||
|
||||
|
||||
class TestSaveConfig:
|
||||
"""Tests for save_config function."""
|
||||
|
||||
def test_saves_simple_config(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
"""Test saving a simple config."""
|
||||
config_dir = tmp_path / "config"
|
||||
config_file = config_dir / "config.toml"
|
||||
monkeypatch.setattr(config, "CONFIG_DIR", config_dir)
|
||||
monkeypatch.setattr(config, "CONFIG_FILE", config_file)
|
||||
|
||||
save_config({"key": "value"})
|
||||
|
||||
assert config_file.exists()
|
||||
content = config_file.read_text()
|
||||
assert 'key = "value"' in content
|
||||
|
||||
def test_saves_nested_config(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
"""Test saving a nested config with sections."""
|
||||
config_dir = tmp_path / "config"
|
||||
config_file = config_dir / "config.toml"
|
||||
monkeypatch.setattr(config, "CONFIG_DIR", config_dir)
|
||||
monkeypatch.setattr(config, "CONFIG_FILE", config_file)
|
||||
|
||||
save_config({"api": {"civitai_key": "test-key"}})
|
||||
|
||||
content = config_file.read_text()
|
||||
assert "[api]" in content
|
||||
assert 'civitai_key = "test-key"' in content
|
||||
|
||||
def test_saves_numeric_values(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
"""Test saving numeric values without quotes."""
|
||||
config_dir = tmp_path / "config"
|
||||
config_file = config_dir / "config.toml"
|
||||
monkeypatch.setattr(config, "CONFIG_DIR", config_dir)
|
||||
monkeypatch.setattr(config, "CONFIG_FILE", config_file)
|
||||
|
||||
save_config({"timeout": 30})
|
||||
|
||||
content = config_file.read_text()
|
||||
assert "timeout = 30" in content
|
||||
|
||||
|
||||
class TestLoadConfig:
|
||||
"""Tests for load_config function."""
|
||||
|
||||
def test_returns_empty_dict_if_no_config(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
"""Test that empty dict is returned when config file doesn't exist."""
|
||||
monkeypatch.setattr(config, "CONFIG_FILE", tmp_path / "nonexistent.toml")
|
||||
assert load_config() == {}
|
||||
|
||||
|
||||
class TestEnums:
|
||||
"""Tests for enum to_api methods."""
|
||||
|
||||
def test_model_type_to_api(self) -> None:
|
||||
"""Test ModelType enum to_api conversion."""
|
||||
assert ModelType.checkpoint.to_api() == "Checkpoint"
|
||||
assert ModelType.lora.to_api() == "LORA"
|
||||
assert ModelType.embedding.to_api() == "TextualInversion"
|
||||
assert ModelType.vae.to_api() == "VAE"
|
||||
assert ModelType.controlnet.to_api() == "Controlnet"
|
||||
assert ModelType.locon.to_api() == "LoCon"
|
||||
|
||||
def test_base_model_to_api(self) -> None:
|
||||
"""Test BaseModel enum to_api conversion."""
|
||||
assert BaseModel.sd15.to_api() == "SD 1.5"
|
||||
assert BaseModel.sdxl.to_api() == "SDXL 1.0"
|
||||
assert BaseModel.pony.to_api() == "Pony"
|
||||
assert BaseModel.flux.to_api() == "Flux.1 D"
|
||||
assert BaseModel.illustrious.to_api() == "Illustrious"
|
||||
|
||||
def test_sort_order_to_api(self) -> None:
|
||||
"""Test SortOrder enum to_api conversion."""
|
||||
assert SortOrder.downloads.to_api() == "Most Downloaded"
|
||||
assert SortOrder.rating.to_api() == "Highest Rated"
|
||||
assert SortOrder.newest.to_api() == "Newest"
|
||||
|
||||
|
||||
class TestDisplayFormatters:
|
||||
"""Tests for display formatting functions."""
|
||||
|
||||
def test_format_size_kb(self) -> None:
|
||||
"""Test formatting sizes in KB."""
|
||||
assert _format_size(500) == "500 KB"
|
||||
assert _format_size(1023) == "1023 KB"
|
||||
|
||||
def test_format_size_mb(self) -> None:
|
||||
"""Test formatting sizes in MB."""
|
||||
assert _format_size(1024) == "1.0 MB"
|
||||
assert _format_size(2048) == "2.0 MB"
|
||||
assert _format_size(1024 * 500) == "500.0 MB"
|
||||
|
||||
def test_format_size_gb(self) -> None:
|
||||
"""Test formatting sizes in GB."""
|
||||
assert _format_size(1024 * 1024) == "1.00 GB"
|
||||
assert _format_size(1024 * 1024 * 2.5) == "2.50 GB"
|
||||
|
||||
def test_format_count_small(self) -> None:
|
||||
"""Test formatting small counts."""
|
||||
assert _format_count(0) == "0"
|
||||
assert _format_count(999) == "999"
|
||||
|
||||
def test_format_count_thousands(self) -> None:
|
||||
"""Test formatting counts in thousands."""
|
||||
assert _format_count(1000) == "1.0K"
|
||||
assert _format_count(5500) == "5.5K"
|
||||
assert _format_count(999999) == "1000.0K"
|
||||
|
||||
def test_format_count_millions(self) -> None:
|
||||
"""Test formatting counts in millions."""
|
||||
assert _format_count(1_000_000) == "1.0M"
|
||||
assert _format_count(2_500_000) == "2.5M"
|
||||
|
||||
|
||||
class TestDisplayFunctions:
|
||||
"""Tests for display functions with console output."""
|
||||
|
||||
def test_display_file_info(self, temp_safetensor: Path) -> None:
|
||||
"""Test display_file_info renders without error."""
|
||||
console = Console(force_terminal=True, width=80)
|
||||
metadata = read_safetensor_metadata(temp_safetensor)
|
||||
# Should not raise
|
||||
display_file_info(temp_safetensor, metadata, "ABC123", console)
|
||||
|
||||
def test_display_local_metadata_with_data(self) -> None:
|
||||
"""Test display_local_metadata with metadata."""
|
||||
console = Console(force_terminal=True, width=80)
|
||||
metadata = {"metadata": {"key1": "value1", "key2": "value2"}, "tensor_count": 0, "header_size": 100}
|
||||
# Should not raise
|
||||
display_local_metadata(metadata, console)
|
||||
|
||||
def test_display_local_metadata_empty(self) -> None:
|
||||
"""Test display_local_metadata with no metadata."""
|
||||
console = Console(force_terminal=True, width=80)
|
||||
metadata: dict[str, Any] = {"metadata": {}, "tensor_count": 0, "header_size": 100}
|
||||
# Should not raise
|
||||
display_local_metadata(metadata, console)
|
||||
|
||||
def test_display_local_metadata_with_filter(self) -> None:
|
||||
"""Test display_local_metadata with key filter."""
|
||||
console = Console(force_terminal=True, width=80)
|
||||
metadata = {"metadata": {"key1": "value1", "key2": "value2"}, "tensor_count": 0, "header_size": 100}
|
||||
# Should not raise
|
||||
display_local_metadata(metadata, console, keys_filter=["key1"])
|
||||
|
||||
def test_display_civitai_data_none(self) -> None:
|
||||
"""Test display_civitai_data with None."""
|
||||
console = Console(force_terminal=True, width=80)
|
||||
# Should not raise
|
||||
display_civitai_data(None, console)
|
||||
|
||||
def test_display_civitai_data_with_data(self) -> None:
|
||||
"""Test display_civitai_data with model data."""
|
||||
console = Console(force_terminal=True, width=80)
|
||||
data = {
|
||||
"modelId": 123,
|
||||
"id": 456,
|
||||
"name": "Test Model v1",
|
||||
"baseModel": "SDXL 1.0",
|
||||
"createdAt": "2024-01-01",
|
||||
"trainedWords": ["word1", "word2"],
|
||||
"downloadUrl": "https://example.com/download",
|
||||
"files": [
|
||||
{
|
||||
"primary": True,
|
||||
"name": "model.safetensors",
|
||||
"sizeKB": 5000,
|
||||
"metadata": {"format": "SafeTensor", "fp": "fp16", "size": "full"},
|
||||
}
|
||||
],
|
||||
}
|
||||
# Should not raise
|
||||
display_civitai_data(data, console)
|
||||
|
||||
def test_display_model_info(self) -> None:
|
||||
"""Test display_model_info with model data."""
|
||||
console = Console(force_terminal=True, width=80)
|
||||
data = {
|
||||
"id": 123,
|
||||
"name": "Test Model",
|
||||
"type": "LORA",
|
||||
"nsfw": False,
|
||||
"creator": {"username": "testuser"},
|
||||
"tags": ["tag1", "tag2"],
|
||||
"stats": {"downloadCount": 1000, "thumbsUpCount": 100},
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 456,
|
||||
"name": "v1.0",
|
||||
"baseModel": "SDXL 1.0",
|
||||
"createdAt": "2024-01-01",
|
||||
"files": [{"primary": True, "name": "model.safetensors", "sizeKB": 5000}],
|
||||
}
|
||||
],
|
||||
}
|
||||
# Should not raise
|
||||
display_model_info(data, console)
|
||||
|
||||
def test_display_search_results_empty(self) -> None:
|
||||
"""Test display_search_results with no results."""
|
||||
console = Console(force_terminal=True, width=80)
|
||||
# Should not raise
|
||||
display_search_results({"items": []}, console)
|
||||
|
||||
def test_display_search_results_with_data(self) -> None:
|
||||
"""Test display_search_results with results."""
|
||||
console = Console(force_terminal=True, width=80)
|
||||
results = {
|
||||
"items": [
|
||||
{
|
||||
"id": 123,
|
||||
"name": "Test Model",
|
||||
"type": "LORA",
|
||||
"modelVersions": [{"baseModel": "SDXL 1.0", "files": [{"primary": True, "sizeKB": 5000}]}],
|
||||
"stats": {"downloadCount": 1000, "thumbsUpCount": 100},
|
||||
}
|
||||
],
|
||||
"metadata": {"totalItems": 1},
|
||||
}
|
||||
# Should not raise
|
||||
display_search_results(results, console)
|
||||
|
||||
|
||||
class TestAPIFunctions:
|
||||
"""Tests for API functions with mocked HTTP."""
|
||||
|
||||
@respx.mock
|
||||
def test_fetch_model_version_success(self) -> None:
|
||||
"""Test successful model version fetch."""
|
||||
console = Console(force_terminal=True, width=80)
|
||||
respx.get("https://civitai.com/api/v1/model-versions/123").mock(
|
||||
return_value=httpx.Response(200, json={"id": 123, "name": "Test"})
|
||||
)
|
||||
|
||||
result = fetch_civitai_model_version(123, None, console)
|
||||
assert result == {"id": 123, "name": "Test"}
|
||||
|
||||
@respx.mock
|
||||
def test_fetch_model_version_not_found(self) -> None:
|
||||
"""Test model version not found."""
|
||||
console = Console(force_terminal=True, width=80)
|
||||
respx.get("https://civitai.com/api/v1/model-versions/999").mock(return_value=httpx.Response(404))
|
||||
|
||||
result = fetch_civitai_model_version(999, None, console)
|
||||
assert result is None
|
||||
|
||||
@respx.mock
|
||||
def test_fetch_model_success(self) -> None:
|
||||
"""Test successful model fetch."""
|
||||
console = Console(force_terminal=True, width=80)
|
||||
respx.get("https://civitai.com/api/v1/models/123").mock(
|
||||
return_value=httpx.Response(200, json={"id": 123, "name": "Test Model"})
|
||||
)
|
||||
|
||||
result = fetch_civitai_model(123, None, console)
|
||||
assert result == {"id": 123, "name": "Test Model"}
|
||||
|
||||
@respx.mock
|
||||
def test_fetch_model_not_found(self) -> None:
|
||||
"""Test model not found."""
|
||||
console = Console(force_terminal=True, width=80)
|
||||
respx.get("https://civitai.com/api/v1/models/999").mock(return_value=httpx.Response(404))
|
||||
|
||||
result = fetch_civitai_model(999, None, console)
|
||||
assert result is None
|
||||
|
||||
@respx.mock
|
||||
def test_fetch_by_hash_success(self) -> None:
|
||||
"""Test successful hash lookup."""
|
||||
console = Console(force_terminal=True, width=80)
|
||||
respx.get("https://civitai.com/api/v1/model-versions/by-hash/ABC123").mock(
|
||||
return_value=httpx.Response(200, json={"id": 456, "name": "Found"})
|
||||
)
|
||||
|
||||
result = fetch_civitai_by_hash("ABC123", None, console)
|
||||
assert result == {"id": 456, "name": "Found"}
|
||||
|
||||
@respx.mock
|
||||
def test_fetch_by_hash_not_found(self) -> None:
|
||||
"""Test hash not found."""
|
||||
console = Console(force_terminal=True, width=80)
|
||||
respx.get("https://civitai.com/api/v1/model-versions/by-hash/NOTFOUND").mock(return_value=httpx.Response(404))
|
||||
|
||||
result = fetch_civitai_by_hash("NOTFOUND", None, console)
|
||||
assert result is None
|
||||
|
||||
@respx.mock
|
||||
def test_search_civitai_success(self) -> None:
|
||||
"""Test successful search."""
|
||||
console = Console(force_terminal=True, width=80)
|
||||
respx.get("https://civitai.com/api/v1/models").mock(
|
||||
return_value=httpx.Response(200, json={"items": [{"id": 1}], "metadata": {}})
|
||||
)
|
||||
|
||||
result = search_civitai("test", None, None, SortOrder.downloads, 20, None, console)
|
||||
assert result is not None
|
||||
assert len(result["items"]) == 1
|
||||
|
||||
@respx.mock
|
||||
def test_search_civitai_with_filters(self) -> None:
|
||||
"""Test search with type and base model filters."""
|
||||
console = Console(force_terminal=True, width=80)
|
||||
respx.get("https://civitai.com/api/v1/models").mock(
|
||||
return_value=httpx.Response(200, json={"items": [{"id": 1, "name": "Test LORA"}], "metadata": {}})
|
||||
)
|
||||
|
||||
result = search_civitai("test", ModelType.lora, BaseModel.sdxl, SortOrder.downloads, 20, None, console)
|
||||
assert result is not None
|
||||
|
||||
@respx.mock
|
||||
def test_download_model_success(self, tmp_path: Path) -> None:
|
||||
"""Test successful model download."""
|
||||
console = Console(force_terminal=True, width=80)
|
||||
dest = tmp_path / "model.safetensors"
|
||||
respx.get("https://civitai.com/api/download/models/123").mock(
|
||||
return_value=httpx.Response(200, content=b"fake model data")
|
||||
)
|
||||
|
||||
result = download_model(123, dest, None, console, resume=False)
|
||||
assert result is True
|
||||
assert dest.exists()
|
||||
assert dest.read_bytes() == b"fake model data"
|
||||
|
||||
@respx.mock
|
||||
def test_download_model_unauthorized(self, tmp_path: Path) -> None:
|
||||
"""Test download with 401 unauthorized."""
|
||||
console = Console(force_terminal=True, width=80)
|
||||
dest = tmp_path / "model.safetensors"
|
||||
respx.get("https://civitai.com/api/download/models/123").mock(return_value=httpx.Response(401))
|
||||
|
||||
result = download_model(123, dest, None, console, resume=False)
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestCLI:
|
||||
"""Tests for CLI commands."""
|
||||
|
||||
def test_help(self) -> None:
|
||||
"""Test --help works."""
|
||||
result = runner.invoke(app, ["--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "safetensor" in result.stdout.lower()
|
||||
|
||||
def test_info_file_not_found(self, tmp_path: Path) -> None:
|
||||
"""Test info command with non-existent file."""
|
||||
result = runner.invoke(app, ["info", str(tmp_path / "nonexistent.safetensors")])
|
||||
assert result.exit_code == 1
|
||||
assert "not found" in result.stdout.lower()
|
||||
|
||||
def test_info_with_safetensor(self, temp_safetensor: Path) -> None:
|
||||
"""Test info command with valid safetensor file."""
|
||||
result = runner.invoke(app, ["info", str(temp_safetensor), "--skip-civitai"])
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_info_json_output(self, temp_safetensor: Path) -> None:
|
||||
"""Test info command with JSON output."""
|
||||
result = runner.invoke(app, ["info", str(temp_safetensor), "--skip-civitai", "--json"])
|
||||
assert result.exit_code == 0
|
||||
assert "sha256" in result.stdout
|
||||
|
||||
def test_info_meta_filter(self, temp_safetensor: Path) -> None:
|
||||
"""Test info command with metadata filter."""
|
||||
result = runner.invoke(app, ["info", str(temp_safetensor), "--meta", "test_key"])
|
||||
assert result.exit_code == 0
|
||||
assert "test_value" in result.stdout
|
||||
|
||||
@respx.mock
|
||||
def test_search_command(self) -> None:
|
||||
"""Test search command."""
|
||||
respx.get("https://civitai.com/api/v1/models").mock(
|
||||
return_value=httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"items": [{"id": 1, "name": "Test", "type": "LORA", "modelVersions": [], "stats": {}}],
|
||||
"metadata": {"totalItems": 1},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["search", "test"])
|
||||
assert result.exit_code == 0
|
||||
|
||||
@respx.mock
|
||||
def test_get_command(self) -> None:
|
||||
"""Test get command."""
|
||||
respx.get("https://civitai.com/api/v1/models/123").mock(
|
||||
return_value=httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"id": 123,
|
||||
"name": "Test Model",
|
||||
"type": "LORA",
|
||||
"nsfw": False,
|
||||
"stats": {},
|
||||
"modelVersions": [],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["get", "123"])
|
||||
assert result.exit_code == 0
|
||||
|
||||
@respx.mock
|
||||
def test_get_command_not_found(self) -> None:
|
||||
"""Test get command with non-existent model."""
|
||||
respx.get("https://civitai.com/api/v1/models/999").mock(return_value=httpx.Response(404))
|
||||
|
||||
result = runner.invoke(app, ["get", "999"])
|
||||
assert result.exit_code == 1
|
||||
assert "not found" in result.stdout.lower()
|
||||
|
||||
def test_config_show(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
"""Test config --show command."""
|
||||
monkeypatch.delenv("CIVITAI_API_KEY", raising=False)
|
||||
monkeypatch.setattr(config, "CONFIG_FILE", tmp_path / "config.toml")
|
||||
monkeypatch.setattr(config, "LEGACY_RC_FILE", tmp_path / "nonexistent")
|
||||
|
||||
result = runner.invoke(app, ["config", "--show"])
|
||||
assert result.exit_code == 0
|
||||
assert "config file" in result.stdout.lower()
|
||||
|
||||
def test_download_no_args(self) -> None:
|
||||
"""Test dl command with no arguments."""
|
||||
result = runner.invoke(app, ["dl"])
|
||||
assert result.exit_code == 1
|
||||
assert "must specify" in result.stdout.lower()
|
||||
|
||||
@@ -490,6 +490,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "respx"
|
||||
version = "0.22.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "httpx" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f4/7c/96bd0bc759cf009675ad1ee1f96535edcb11e9666b985717eb8c87192a95/respx-0.22.0.tar.gz", hash = "sha256:3c8924caa2a50bd71aefc07aa812f2466ff489f1848c96e954a5362d17095d91", size = 28439, upload-time = "2024-12-19T22:33:59.374Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/67/afbb0978d5399bc9ea200f1d4489a23c9a1dad4eee6376242b8182389c79/respx-0.22.0-py2.py3-none-any.whl", hash = "sha256:631128d4c9aba15e56903fb5f66fb1eff412ce28dd387ca3a81339e52dbd3ad0", size = 25127, upload-time = "2024-12-19T22:33:57.837Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rich"
|
||||
version = "14.3.0"
|
||||
@@ -578,6 +590,7 @@ dev = [
|
||||
{ name = "pre-commit" },
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-cov" },
|
||||
{ name = "respx" },
|
||||
{ name = "ruff" },
|
||||
]
|
||||
|
||||
@@ -596,6 +609,7 @@ dev = [
|
||||
{ name = "pre-commit", specifier = ">=3.6" },
|
||||
{ name = "pytest", specifier = ">=8.0" },
|
||||
{ name = "pytest-cov", specifier = ">=4.1" },
|
||||
{ name = "respx", specifier = ">=0.22.0" },
|
||||
{ name = "ruff", specifier = ">=0.9.0" },
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user