diff --git a/.coverage b/.coverage index 54a3d72..cea8ccf 100644 Binary files a/.coverage and b/.coverage differ diff --git a/pyproject.toml b/pyproject.toml index e831b43..531f4af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dev = [ "pytest>=8.0", "pytest-cov>=4.1", "pre-commit>=3.6", + "respx>=0.22.0", ] [tool.ruff] diff --git a/tests/test_tensors.py b/tests/test_tensors.py index 134350d..95f7892 100644 --- a/tests/test_tensors.py +++ b/tests/test_tensors.py @@ -4,13 +4,45 @@ from __future__ import annotations import struct from pathlib import Path +from typing import Any +import httpx import pytest +import respx +from rich.console import Console +from typer.testing import CliRunner from tensors import config -from tensors.config import get_default_output_path, load_api_key +from tensors.api import ( + download_model, + fetch_civitai_by_hash, + fetch_civitai_model, + fetch_civitai_model_version, + search_civitai, +) +from tensors.cli import app +from tensors.config import ( + BaseModel, + ModelType, + SortOrder, + get_default_output_path, + load_api_key, + load_config, + save_config, +) +from tensors.display import ( + _format_count, + _format_size, + display_civitai_data, + display_file_info, + display_local_metadata, + display_model_info, + display_search_results, +) from tensors.safetensor import get_base_name, read_safetensor_metadata +runner = CliRunner() + class TestReadSafetensorMetadata: """Tests for read_safetensor_metadata function.""" @@ -128,3 +160,431 @@ class TestLoadApiKey: monkeypatch.setattr(config, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml") monkeypatch.setattr(config, "LEGACY_RC_FILE", legacy_file) assert load_api_key() == "legacy-key" + + +class TestSaveConfig: + """Tests for save_config function.""" + + def test_saves_simple_config(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + """Test saving a simple config.""" + config_dir = tmp_path / "config" + config_file = config_dir / "config.toml" + monkeypatch.setattr(config, "CONFIG_DIR", config_dir) + monkeypatch.setattr(config, "CONFIG_FILE", config_file) + + save_config({"key": "value"}) + + assert config_file.exists() + content = config_file.read_text() + assert 'key = "value"' in content + + def test_saves_nested_config(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + """Test saving a nested config with sections.""" + config_dir = tmp_path / "config" + config_file = config_dir / "config.toml" + monkeypatch.setattr(config, "CONFIG_DIR", config_dir) + monkeypatch.setattr(config, "CONFIG_FILE", config_file) + + save_config({"api": {"civitai_key": "test-key"}}) + + content = config_file.read_text() + assert "[api]" in content + assert 'civitai_key = "test-key"' in content + + def test_saves_numeric_values(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + """Test saving numeric values without quotes.""" + config_dir = tmp_path / "config" + config_file = config_dir / "config.toml" + monkeypatch.setattr(config, "CONFIG_DIR", config_dir) + monkeypatch.setattr(config, "CONFIG_FILE", config_file) + + save_config({"timeout": 30}) + + content = config_file.read_text() + assert "timeout = 30" in content + + +class TestLoadConfig: + """Tests for load_config function.""" + + def test_returns_empty_dict_if_no_config(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + """Test that empty dict is returned when config file doesn't exist.""" + monkeypatch.setattr(config, "CONFIG_FILE", tmp_path / "nonexistent.toml") + assert load_config() == {} + + +class TestEnums: + """Tests for enum to_api methods.""" + + def test_model_type_to_api(self) -> None: + """Test ModelType enum to_api conversion.""" + assert ModelType.checkpoint.to_api() == "Checkpoint" + assert ModelType.lora.to_api() == "LORA" + assert ModelType.embedding.to_api() == "TextualInversion" + assert ModelType.vae.to_api() == "VAE" + assert ModelType.controlnet.to_api() == "Controlnet" + assert ModelType.locon.to_api() == "LoCon" + + def test_base_model_to_api(self) -> None: + """Test BaseModel enum to_api conversion.""" + assert BaseModel.sd15.to_api() == "SD 1.5" + assert BaseModel.sdxl.to_api() == "SDXL 1.0" + assert BaseModel.pony.to_api() == "Pony" + assert BaseModel.flux.to_api() == "Flux.1 D" + assert BaseModel.illustrious.to_api() == "Illustrious" + + def test_sort_order_to_api(self) -> None: + """Test SortOrder enum to_api conversion.""" + assert SortOrder.downloads.to_api() == "Most Downloaded" + assert SortOrder.rating.to_api() == "Highest Rated" + assert SortOrder.newest.to_api() == "Newest" + + +class TestDisplayFormatters: + """Tests for display formatting functions.""" + + def test_format_size_kb(self) -> None: + """Test formatting sizes in KB.""" + assert _format_size(500) == "500 KB" + assert _format_size(1023) == "1023 KB" + + def test_format_size_mb(self) -> None: + """Test formatting sizes in MB.""" + assert _format_size(1024) == "1.0 MB" + assert _format_size(2048) == "2.0 MB" + assert _format_size(1024 * 500) == "500.0 MB" + + def test_format_size_gb(self) -> None: + """Test formatting sizes in GB.""" + assert _format_size(1024 * 1024) == "1.00 GB" + assert _format_size(1024 * 1024 * 2.5) == "2.50 GB" + + def test_format_count_small(self) -> None: + """Test formatting small counts.""" + assert _format_count(0) == "0" + assert _format_count(999) == "999" + + def test_format_count_thousands(self) -> None: + """Test formatting counts in thousands.""" + assert _format_count(1000) == "1.0K" + assert _format_count(5500) == "5.5K" + assert _format_count(999999) == "1000.0K" + + def test_format_count_millions(self) -> None: + """Test formatting counts in millions.""" + assert _format_count(1_000_000) == "1.0M" + assert _format_count(2_500_000) == "2.5M" + + +class TestDisplayFunctions: + """Tests for display functions with console output.""" + + def test_display_file_info(self, temp_safetensor: Path) -> None: + """Test display_file_info renders without error.""" + console = Console(force_terminal=True, width=80) + metadata = read_safetensor_metadata(temp_safetensor) + # Should not raise + display_file_info(temp_safetensor, metadata, "ABC123", console) + + def test_display_local_metadata_with_data(self) -> None: + """Test display_local_metadata with metadata.""" + console = Console(force_terminal=True, width=80) + metadata = {"metadata": {"key1": "value1", "key2": "value2"}, "tensor_count": 0, "header_size": 100} + # Should not raise + display_local_metadata(metadata, console) + + def test_display_local_metadata_empty(self) -> None: + """Test display_local_metadata with no metadata.""" + console = Console(force_terminal=True, width=80) + metadata: dict[str, Any] = {"metadata": {}, "tensor_count": 0, "header_size": 100} + # Should not raise + display_local_metadata(metadata, console) + + def test_display_local_metadata_with_filter(self) -> None: + """Test display_local_metadata with key filter.""" + console = Console(force_terminal=True, width=80) + metadata = {"metadata": {"key1": "value1", "key2": "value2"}, "tensor_count": 0, "header_size": 100} + # Should not raise + display_local_metadata(metadata, console, keys_filter=["key1"]) + + def test_display_civitai_data_none(self) -> None: + """Test display_civitai_data with None.""" + console = Console(force_terminal=True, width=80) + # Should not raise + display_civitai_data(None, console) + + def test_display_civitai_data_with_data(self) -> None: + """Test display_civitai_data with model data.""" + console = Console(force_terminal=True, width=80) + data = { + "modelId": 123, + "id": 456, + "name": "Test Model v1", + "baseModel": "SDXL 1.0", + "createdAt": "2024-01-01", + "trainedWords": ["word1", "word2"], + "downloadUrl": "https://example.com/download", + "files": [ + { + "primary": True, + "name": "model.safetensors", + "sizeKB": 5000, + "metadata": {"format": "SafeTensor", "fp": "fp16", "size": "full"}, + } + ], + } + # Should not raise + display_civitai_data(data, console) + + def test_display_model_info(self) -> None: + """Test display_model_info with model data.""" + console = Console(force_terminal=True, width=80) + data = { + "id": 123, + "name": "Test Model", + "type": "LORA", + "nsfw": False, + "creator": {"username": "testuser"}, + "tags": ["tag1", "tag2"], + "stats": {"downloadCount": 1000, "thumbsUpCount": 100}, + "modelVersions": [ + { + "id": 456, + "name": "v1.0", + "baseModel": "SDXL 1.0", + "createdAt": "2024-01-01", + "files": [{"primary": True, "name": "model.safetensors", "sizeKB": 5000}], + } + ], + } + # Should not raise + display_model_info(data, console) + + def test_display_search_results_empty(self) -> None: + """Test display_search_results with no results.""" + console = Console(force_terminal=True, width=80) + # Should not raise + display_search_results({"items": []}, console) + + def test_display_search_results_with_data(self) -> None: + """Test display_search_results with results.""" + console = Console(force_terminal=True, width=80) + results = { + "items": [ + { + "id": 123, + "name": "Test Model", + "type": "LORA", + "modelVersions": [{"baseModel": "SDXL 1.0", "files": [{"primary": True, "sizeKB": 5000}]}], + "stats": {"downloadCount": 1000, "thumbsUpCount": 100}, + } + ], + "metadata": {"totalItems": 1}, + } + # Should not raise + display_search_results(results, console) + + +class TestAPIFunctions: + """Tests for API functions with mocked HTTP.""" + + @respx.mock + def test_fetch_model_version_success(self) -> None: + """Test successful model version fetch.""" + console = Console(force_terminal=True, width=80) + respx.get("https://civitai.com/api/v1/model-versions/123").mock( + return_value=httpx.Response(200, json={"id": 123, "name": "Test"}) + ) + + result = fetch_civitai_model_version(123, None, console) + assert result == {"id": 123, "name": "Test"} + + @respx.mock + def test_fetch_model_version_not_found(self) -> None: + """Test model version not found.""" + console = Console(force_terminal=True, width=80) + respx.get("https://civitai.com/api/v1/model-versions/999").mock(return_value=httpx.Response(404)) + + result = fetch_civitai_model_version(999, None, console) + assert result is None + + @respx.mock + def test_fetch_model_success(self) -> None: + """Test successful model fetch.""" + console = Console(force_terminal=True, width=80) + respx.get("https://civitai.com/api/v1/models/123").mock( + return_value=httpx.Response(200, json={"id": 123, "name": "Test Model"}) + ) + + result = fetch_civitai_model(123, None, console) + assert result == {"id": 123, "name": "Test Model"} + + @respx.mock + def test_fetch_model_not_found(self) -> None: + """Test model not found.""" + console = Console(force_terminal=True, width=80) + respx.get("https://civitai.com/api/v1/models/999").mock(return_value=httpx.Response(404)) + + result = fetch_civitai_model(999, None, console) + assert result is None + + @respx.mock + def test_fetch_by_hash_success(self) -> None: + """Test successful hash lookup.""" + console = Console(force_terminal=True, width=80) + respx.get("https://civitai.com/api/v1/model-versions/by-hash/ABC123").mock( + return_value=httpx.Response(200, json={"id": 456, "name": "Found"}) + ) + + result = fetch_civitai_by_hash("ABC123", None, console) + assert result == {"id": 456, "name": "Found"} + + @respx.mock + def test_fetch_by_hash_not_found(self) -> None: + """Test hash not found.""" + console = Console(force_terminal=True, width=80) + respx.get("https://civitai.com/api/v1/model-versions/by-hash/NOTFOUND").mock(return_value=httpx.Response(404)) + + result = fetch_civitai_by_hash("NOTFOUND", None, console) + assert result is None + + @respx.mock + def test_search_civitai_success(self) -> None: + """Test successful search.""" + console = Console(force_terminal=True, width=80) + respx.get("https://civitai.com/api/v1/models").mock( + return_value=httpx.Response(200, json={"items": [{"id": 1}], "metadata": {}}) + ) + + result = search_civitai("test", None, None, SortOrder.downloads, 20, None, console) + assert result is not None + assert len(result["items"]) == 1 + + @respx.mock + def test_search_civitai_with_filters(self) -> None: + """Test search with type and base model filters.""" + console = Console(force_terminal=True, width=80) + respx.get("https://civitai.com/api/v1/models").mock( + return_value=httpx.Response(200, json={"items": [{"id": 1, "name": "Test LORA"}], "metadata": {}}) + ) + + result = search_civitai("test", ModelType.lora, BaseModel.sdxl, SortOrder.downloads, 20, None, console) + assert result is not None + + @respx.mock + def test_download_model_success(self, tmp_path: Path) -> None: + """Test successful model download.""" + console = Console(force_terminal=True, width=80) + dest = tmp_path / "model.safetensors" + respx.get("https://civitai.com/api/download/models/123").mock( + return_value=httpx.Response(200, content=b"fake model data") + ) + + result = download_model(123, dest, None, console, resume=False) + assert result is True + assert dest.exists() + assert dest.read_bytes() == b"fake model data" + + @respx.mock + def test_download_model_unauthorized(self, tmp_path: Path) -> None: + """Test download with 401 unauthorized.""" + console = Console(force_terminal=True, width=80) + dest = tmp_path / "model.safetensors" + respx.get("https://civitai.com/api/download/models/123").mock(return_value=httpx.Response(401)) + + result = download_model(123, dest, None, console, resume=False) + assert result is False + + +class TestCLI: + """Tests for CLI commands.""" + + def test_help(self) -> None: + """Test --help works.""" + result = runner.invoke(app, ["--help"]) + assert result.exit_code == 0 + assert "safetensor" in result.stdout.lower() + + def test_info_file_not_found(self, tmp_path: Path) -> None: + """Test info command with non-existent file.""" + result = runner.invoke(app, ["info", str(tmp_path / "nonexistent.safetensors")]) + assert result.exit_code == 1 + assert "not found" in result.stdout.lower() + + def test_info_with_safetensor(self, temp_safetensor: Path) -> None: + """Test info command with valid safetensor file.""" + result = runner.invoke(app, ["info", str(temp_safetensor), "--skip-civitai"]) + assert result.exit_code == 0 + + def test_info_json_output(self, temp_safetensor: Path) -> None: + """Test info command with JSON output.""" + result = runner.invoke(app, ["info", str(temp_safetensor), "--skip-civitai", "--json"]) + assert result.exit_code == 0 + assert "sha256" in result.stdout + + def test_info_meta_filter(self, temp_safetensor: Path) -> None: + """Test info command with metadata filter.""" + result = runner.invoke(app, ["info", str(temp_safetensor), "--meta", "test_key"]) + assert result.exit_code == 0 + assert "test_value" in result.stdout + + @respx.mock + def test_search_command(self) -> None: + """Test search command.""" + respx.get("https://civitai.com/api/v1/models").mock( + return_value=httpx.Response( + 200, + json={ + "items": [{"id": 1, "name": "Test", "type": "LORA", "modelVersions": [], "stats": {}}], + "metadata": {"totalItems": 1}, + }, + ) + ) + + result = runner.invoke(app, ["search", "test"]) + assert result.exit_code == 0 + + @respx.mock + def test_get_command(self) -> None: + """Test get command.""" + respx.get("https://civitai.com/api/v1/models/123").mock( + return_value=httpx.Response( + 200, + json={ + "id": 123, + "name": "Test Model", + "type": "LORA", + "nsfw": False, + "stats": {}, + "modelVersions": [], + }, + ) + ) + + result = runner.invoke(app, ["get", "123"]) + assert result.exit_code == 0 + + @respx.mock + def test_get_command_not_found(self) -> None: + """Test get command with non-existent model.""" + respx.get("https://civitai.com/api/v1/models/999").mock(return_value=httpx.Response(404)) + + result = runner.invoke(app, ["get", "999"]) + assert result.exit_code == 1 + assert "not found" in result.stdout.lower() + + def test_config_show(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + """Test config --show command.""" + monkeypatch.delenv("CIVITAI_API_KEY", raising=False) + monkeypatch.setattr(config, "CONFIG_FILE", tmp_path / "config.toml") + monkeypatch.setattr(config, "LEGACY_RC_FILE", tmp_path / "nonexistent") + + result = runner.invoke(app, ["config", "--show"]) + assert result.exit_code == 0 + assert "config file" in result.stdout.lower() + + def test_download_no_args(self) -> None: + """Test dl command with no arguments.""" + result = runner.invoke(app, ["dl"]) + assert result.exit_code == 1 + assert "must specify" in result.stdout.lower() diff --git a/uv.lock b/uv.lock index b9fc952..a17a45d 100644 --- a/uv.lock +++ b/uv.lock @@ -490,6 +490,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" }, ] +[[package]] +name = "respx" +version = "0.22.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f4/7c/96bd0bc759cf009675ad1ee1f96535edcb11e9666b985717eb8c87192a95/respx-0.22.0.tar.gz", hash = "sha256:3c8924caa2a50bd71aefc07aa812f2466ff489f1848c96e954a5362d17095d91", size = 28439, upload-time = "2024-12-19T22:33:59.374Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/67/afbb0978d5399bc9ea200f1d4489a23c9a1dad4eee6376242b8182389c79/respx-0.22.0-py2.py3-none-any.whl", hash = "sha256:631128d4c9aba15e56903fb5f66fb1eff412ce28dd387ca3a81339e52dbd3ad0", size = 25127, upload-time = "2024-12-19T22:33:57.837Z" }, +] + [[package]] name = "rich" version = "14.3.0" @@ -578,6 +590,7 @@ dev = [ { name = "pre-commit" }, { name = "pytest" }, { name = "pytest-cov" }, + { name = "respx" }, { name = "ruff" }, ] @@ -596,6 +609,7 @@ dev = [ { name = "pre-commit", specifier = ">=3.6" }, { name = "pytest", specifier = ">=8.0" }, { name = "pytest-cov", specifier = ">=4.1" }, + { name = "respx", specifier = ">=0.22.0" }, { name = "ruff", specifier = ">=0.9.0" }, ]