"""Tests for tensors module.""" from __future__ import annotations import struct from pathlib import Path import pytest import tensors from tensors import ( get_base_name, get_default_output_path, load_api_key, read_safetensor_metadata, ) class TestReadSafetensorMetadata: """Tests for read_safetensor_metadata function.""" def test_reads_valid_safetensor(self, temp_safetensor: Path) -> None: """Test reading metadata from a valid safetensor file.""" result = read_safetensor_metadata(temp_safetensor) assert "metadata" in result assert "tensor_count" in result assert "header_size" in result assert result["metadata"]["test_key"] == "test_value" assert result["tensor_count"] == 0 # No tensors, just metadata def test_raises_on_short_file(self, tmp_path: Path) -> None: """Test that short files raise ValueError.""" short_file = tmp_path / "short.safetensors" short_file.write_bytes(b"short") with pytest.raises(ValueError, match="too short"): read_safetensor_metadata(short_file) def test_raises_on_truncated_header(self, tmp_path: Path) -> None: """Test that truncated headers raise ValueError.""" truncated = tmp_path / "truncated.safetensors" # Write header size that claims 1000 bytes but only provide 10 with truncated.open("wb") as f: f.write(struct.pack(" None: """Test that unreasonably large header sizes raise ValueError.""" huge = tmp_path / "huge.safetensors" with huge.open("wb") as f: f.write(struct.pack(" None: """Test that .safetensors extension is removed.""" assert get_base_name(Path("model.safetensors")) == "model" def test_removes_sft_extension(self) -> None: """Test that .sft extension is removed.""" assert get_base_name(Path("model.sft")) == "model" def test_handles_uppercase_extension(self) -> None: """Test that uppercase extensions are handled.""" assert get_base_name(Path("model.SAFETENSORS")) == "model" def test_preserves_name_without_known_extension(self) -> None: """Test that unknown extensions use stem.""" assert get_base_name(Path("model.bin")) == "model" class TestGetDefaultOutputPath: """Tests for get_default_output_path function.""" def test_returns_checkpoint_path(self) -> None: """Test that Checkpoint type returns checkpoints directory.""" result = get_default_output_path("Checkpoint") assert result is not None assert "checkpoints" in str(result) def test_returns_lora_path(self) -> None: """Test that LORA type returns loras directory.""" result = get_default_output_path("LORA") assert result is not None assert "loras" in str(result) def test_returns_none_for_unknown_type(self) -> None: """Test that unknown types return None.""" assert get_default_output_path("UnknownType") is None assert get_default_output_path(None) is None class TestLoadApiKey: """Tests for load_api_key function.""" def test_returns_env_var_if_set(self, monkeypatch: pytest.MonkeyPatch) -> None: """Test that environment variable takes precedence.""" monkeypatch.setenv("CIVITAI_API_KEY", "test-key-from-env") assert load_api_key() == "test-key-from-env" def test_returns_none_if_no_key(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: """Test that None is returned when no key is available.""" monkeypatch.delenv("CIVITAI_API_KEY", raising=False) # Point config and legacy files to nonexistent paths monkeypatch.setattr(tensors, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml") monkeypatch.setattr(tensors, "LEGACY_RC_FILE", tmp_path / "nonexistent") assert load_api_key() is None def test_returns_key_from_config_file( self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path ) -> None: """Test that key is loaded from TOML config file.""" monkeypatch.delenv("CIVITAI_API_KEY", raising=False) config_file = tmp_path / "config.toml" config_file.write_text('[api]\ncivitai_key = "key-from-config"\n') monkeypatch.setattr(tensors, "CONFIG_FILE", config_file) monkeypatch.setattr(tensors, "LEGACY_RC_FILE", tmp_path / "nonexistent") assert load_api_key() == "key-from-config" def test_returns_key_from_legacy_file( self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path ) -> None: """Test that key is loaded from legacy RC file when no config exists.""" monkeypatch.delenv("CIVITAI_API_KEY", raising=False) legacy_file = tmp_path / ".sftrc" legacy_file.write_text("legacy-key") monkeypatch.setattr(tensors, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml") monkeypatch.setattr(tensors, "LEGACY_RC_FILE", legacy_file) assert load_api_key() == "legacy-key"