131 lines
5.4 KiB
Python
131 lines
5.4 KiB
Python
"""Tests for tensors module."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import struct
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from tensors import config
|
|
from tensors.config import get_default_output_path, load_api_key
|
|
from tensors.safetensor import get_base_name, read_safetensor_metadata
|
|
|
|
|
|
class TestReadSafetensorMetadata:
|
|
"""Tests for read_safetensor_metadata function."""
|
|
|
|
def test_reads_valid_safetensor(self, temp_safetensor: Path) -> None:
|
|
"""Test reading metadata from a valid safetensor file."""
|
|
result = read_safetensor_metadata(temp_safetensor)
|
|
|
|
assert "metadata" in result
|
|
assert "tensor_count" in result
|
|
assert "header_size" in result
|
|
assert result["metadata"]["test_key"] == "test_value"
|
|
assert result["tensor_count"] == 0 # No tensors, just metadata
|
|
|
|
def test_raises_on_short_file(self, tmp_path: Path) -> None:
|
|
"""Test that short files raise ValueError."""
|
|
short_file = tmp_path / "short.safetensors"
|
|
short_file.write_bytes(b"short")
|
|
|
|
with pytest.raises(ValueError, match="too short"):
|
|
read_safetensor_metadata(short_file)
|
|
|
|
def test_raises_on_truncated_header(self, tmp_path: Path) -> None:
|
|
"""Test that truncated headers raise ValueError."""
|
|
truncated = tmp_path / "truncated.safetensors"
|
|
# Write header size that claims 1000 bytes but only provide 10
|
|
with truncated.open("wb") as f:
|
|
f.write(struct.pack("<Q", 1000))
|
|
f.write(b"x" * 10)
|
|
|
|
with pytest.raises(ValueError, match="truncated"):
|
|
read_safetensor_metadata(truncated)
|
|
|
|
def test_raises_on_huge_header_size(self, tmp_path: Path) -> None:
|
|
"""Test that unreasonably large header sizes raise ValueError."""
|
|
huge = tmp_path / "huge.safetensors"
|
|
with huge.open("wb") as f:
|
|
f.write(struct.pack("<Q", 200_000_000)) # 200MB header
|
|
|
|
with pytest.raises(ValueError, match="Invalid header size"):
|
|
read_safetensor_metadata(huge)
|
|
|
|
|
|
class TestGetBaseName:
|
|
"""Tests for get_base_name function."""
|
|
|
|
def test_removes_safetensors_extension(self) -> None:
|
|
"""Test that .safetensors extension is removed."""
|
|
assert get_base_name(Path("model.safetensors")) == "model"
|
|
|
|
def test_removes_sft_extension(self) -> None:
|
|
"""Test that .sft extension is removed."""
|
|
assert get_base_name(Path("model.sft")) == "model"
|
|
|
|
def test_handles_uppercase_extension(self) -> None:
|
|
"""Test that uppercase extensions are handled."""
|
|
assert get_base_name(Path("model.SAFETENSORS")) == "model"
|
|
|
|
def test_preserves_name_without_known_extension(self) -> None:
|
|
"""Test that unknown extensions use stem."""
|
|
assert get_base_name(Path("model.bin")) == "model"
|
|
|
|
|
|
class TestGetDefaultOutputPath:
|
|
"""Tests for get_default_output_path function."""
|
|
|
|
def test_returns_checkpoint_path(self) -> None:
|
|
"""Test that Checkpoint type returns checkpoints directory."""
|
|
result = get_default_output_path("Checkpoint")
|
|
assert result is not None
|
|
assert "checkpoints" in str(result)
|
|
|
|
def test_returns_lora_path(self) -> None:
|
|
"""Test that LORA type returns loras directory."""
|
|
result = get_default_output_path("LORA")
|
|
assert result is not None
|
|
assert "loras" in str(result)
|
|
|
|
def test_returns_none_for_unknown_type(self) -> None:
|
|
"""Test that unknown types return None."""
|
|
assert get_default_output_path("UnknownType") is None
|
|
assert get_default_output_path(None) is None
|
|
|
|
|
|
class TestLoadApiKey:
|
|
"""Tests for load_api_key function."""
|
|
|
|
def test_returns_env_var_if_set(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Test that environment variable takes precedence."""
|
|
monkeypatch.setenv("CIVITAI_API_KEY", "test-key-from-env")
|
|
assert load_api_key() == "test-key-from-env"
|
|
|
|
def test_returns_none_if_no_key(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
|
"""Test that None is returned when no key is available."""
|
|
monkeypatch.delenv("CIVITAI_API_KEY", raising=False)
|
|
# Point config and legacy files to nonexistent paths
|
|
monkeypatch.setattr(config, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml")
|
|
monkeypatch.setattr(config, "LEGACY_RC_FILE", tmp_path / "nonexistent")
|
|
assert load_api_key() is None
|
|
|
|
def test_returns_key_from_config_file(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
|
"""Test that key is loaded from TOML config file."""
|
|
monkeypatch.delenv("CIVITAI_API_KEY", raising=False)
|
|
config_file = tmp_path / "config.toml"
|
|
config_file.write_text('[api]\ncivitai_key = "key-from-config"\n')
|
|
monkeypatch.setattr(config, "CONFIG_FILE", config_file)
|
|
monkeypatch.setattr(config, "LEGACY_RC_FILE", tmp_path / "nonexistent")
|
|
assert load_api_key() == "key-from-config"
|
|
|
|
def test_returns_key_from_legacy_file(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
|
"""Test that key is loaded from legacy RC file when no config exists."""
|
|
monkeypatch.delenv("CIVITAI_API_KEY", raising=False)
|
|
legacy_file = tmp_path / ".sftrc"
|
|
legacy_file.write_text("legacy-key")
|
|
monkeypatch.setattr(config, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml")
|
|
monkeypatch.setattr(config, "LEGACY_RC_FILE", legacy_file)
|
|
assert load_api_key() == "legacy-key"
|