Files
tensors/tests/test_tensors.py
T
2026-02-03 20:55:19 +01:00

139 lines
5.4 KiB
Python

"""Tests for tensors module."""
from __future__ import annotations
import struct
from pathlib import Path
import pytest
import tensors
from tensors import (
get_base_name,
get_default_output_path,
load_api_key,
read_safetensor_metadata,
)
class TestReadSafetensorMetadata:
"""Tests for read_safetensor_metadata function."""
def test_reads_valid_safetensor(self, temp_safetensor: Path) -> None:
"""Test reading metadata from a valid safetensor file."""
result = read_safetensor_metadata(temp_safetensor)
assert "metadata" in result
assert "tensor_count" in result
assert "header_size" in result
assert result["metadata"]["test_key"] == "test_value"
assert result["tensor_count"] == 0 # No tensors, just metadata
def test_raises_on_short_file(self, tmp_path: Path) -> None:
"""Test that short files raise ValueError."""
short_file = tmp_path / "short.safetensors"
short_file.write_bytes(b"short")
with pytest.raises(ValueError, match="too short"):
read_safetensor_metadata(short_file)
def test_raises_on_truncated_header(self, tmp_path: Path) -> None:
"""Test that truncated headers raise ValueError."""
truncated = tmp_path / "truncated.safetensors"
# Write header size that claims 1000 bytes but only provide 10
with truncated.open("wb") as f:
f.write(struct.pack("<Q", 1000))
f.write(b"x" * 10)
with pytest.raises(ValueError, match="truncated"):
read_safetensor_metadata(truncated)
def test_raises_on_huge_header_size(self, tmp_path: Path) -> None:
"""Test that unreasonably large header sizes raise ValueError."""
huge = tmp_path / "huge.safetensors"
with huge.open("wb") as f:
f.write(struct.pack("<Q", 200_000_000)) # 200MB header
with pytest.raises(ValueError, match="Invalid header size"):
read_safetensor_metadata(huge)
class TestGetBaseName:
"""Tests for get_base_name function."""
def test_removes_safetensors_extension(self) -> None:
"""Test that .safetensors extension is removed."""
assert get_base_name(Path("model.safetensors")) == "model"
def test_removes_sft_extension(self) -> None:
"""Test that .sft extension is removed."""
assert get_base_name(Path("model.sft")) == "model"
def test_handles_uppercase_extension(self) -> None:
"""Test that uppercase extensions are handled."""
assert get_base_name(Path("model.SAFETENSORS")) == "model"
def test_preserves_name_without_known_extension(self) -> None:
"""Test that unknown extensions use stem."""
assert get_base_name(Path("model.bin")) == "model"
class TestGetDefaultOutputPath:
"""Tests for get_default_output_path function."""
def test_returns_checkpoint_path(self) -> None:
"""Test that Checkpoint type returns checkpoints directory."""
result = get_default_output_path("Checkpoint")
assert result is not None
assert "checkpoints" in str(result)
def test_returns_lora_path(self) -> None:
"""Test that LORA type returns loras directory."""
result = get_default_output_path("LORA")
assert result is not None
assert "loras" in str(result)
def test_returns_none_for_unknown_type(self) -> None:
"""Test that unknown types return None."""
assert get_default_output_path("UnknownType") is None
assert get_default_output_path(None) is None
class TestLoadApiKey:
"""Tests for load_api_key function."""
def test_returns_env_var_if_set(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that environment variable takes precedence."""
monkeypatch.setenv("CIVITAI_API_KEY", "test-key-from-env")
assert load_api_key() == "test-key-from-env"
def test_returns_none_if_no_key(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""Test that None is returned when no key is available."""
monkeypatch.delenv("CIVITAI_API_KEY", raising=False)
# Point config and legacy files to nonexistent paths
monkeypatch.setattr(tensors, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml")
monkeypatch.setattr(tensors, "LEGACY_RC_FILE", tmp_path / "nonexistent")
assert load_api_key() is None
def test_returns_key_from_config_file(
self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
"""Test that key is loaded from TOML config file."""
monkeypatch.delenv("CIVITAI_API_KEY", raising=False)
config_file = tmp_path / "config.toml"
config_file.write_text('[api]\ncivitai_key = "key-from-config"\n')
monkeypatch.setattr(tensors, "CONFIG_FILE", config_file)
monkeypatch.setattr(tensors, "LEGACY_RC_FILE", tmp_path / "nonexistent")
assert load_api_key() == "key-from-config"
def test_returns_key_from_legacy_file(
self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
"""Test that key is loaded from legacy RC file when no config exists."""
monkeypatch.delenv("CIVITAI_API_KEY", raising=False)
legacy_file = tmp_path / ".sftrc"
legacy_file.write_text("legacy-key")
monkeypatch.setattr(tensors, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml")
monkeypatch.setattr(tensors, "LEGACY_RC_FILE", legacy_file)
assert load_api_key() == "legacy-key"