💬 Commit message: Update 2026-02-03 20:53:31, 12 files
This commit is contained in:
@@ -0,0 +1,30 @@
|
||||
"""Pytest configuration and fixtures."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import struct
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_safetensor(tmp_path):
|
||||
"""Create a minimal valid safetensor file for testing."""
|
||||
|
||||
# Create minimal safetensor with empty tensors and some metadata
|
||||
header = {
|
||||
"__metadata__": {
|
||||
"format": "pt",
|
||||
"test_key": "test_value",
|
||||
}
|
||||
}
|
||||
header_bytes = json.dumps(header).encode("utf-8")
|
||||
header_size = len(header_bytes)
|
||||
|
||||
file_path = tmp_path / "test_model.safetensors"
|
||||
with file_path.open("wb") as f:
|
||||
f.write(struct.pack("<Q", header_size))
|
||||
f.write(header_bytes)
|
||||
|
||||
return file_path
|
||||
@@ -0,0 +1,115 @@
|
||||
"""Tests for tensors module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import struct
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
import tensors
|
||||
from tensors import (
|
||||
get_base_name,
|
||||
get_default_output_path,
|
||||
load_api_key,
|
||||
read_safetensor_metadata,
|
||||
)
|
||||
|
||||
|
||||
class TestReadSafetensorMetadata:
|
||||
"""Tests for read_safetensor_metadata function."""
|
||||
|
||||
def test_reads_valid_safetensor(self, temp_safetensor: Path) -> None:
|
||||
"""Test reading metadata from a valid safetensor file."""
|
||||
result = read_safetensor_metadata(temp_safetensor)
|
||||
|
||||
assert "metadata" in result
|
||||
assert "tensor_count" in result
|
||||
assert "header_size" in result
|
||||
assert result["metadata"]["test_key"] == "test_value"
|
||||
assert result["tensor_count"] == 0 # No tensors, just metadata
|
||||
|
||||
def test_raises_on_short_file(self, tmp_path: Path) -> None:
|
||||
"""Test that short files raise ValueError."""
|
||||
short_file = tmp_path / "short.safetensors"
|
||||
short_file.write_bytes(b"short")
|
||||
|
||||
with pytest.raises(ValueError, match="too short"):
|
||||
read_safetensor_metadata(short_file)
|
||||
|
||||
def test_raises_on_truncated_header(self, tmp_path: Path) -> None:
|
||||
"""Test that truncated headers raise ValueError."""
|
||||
truncated = tmp_path / "truncated.safetensors"
|
||||
# Write header size that claims 1000 bytes but only provide 10
|
||||
with truncated.open("wb") as f:
|
||||
f.write(struct.pack("<Q", 1000))
|
||||
f.write(b"x" * 10)
|
||||
|
||||
with pytest.raises(ValueError, match="truncated"):
|
||||
read_safetensor_metadata(truncated)
|
||||
|
||||
def test_raises_on_huge_header_size(self, tmp_path: Path) -> None:
|
||||
"""Test that unreasonably large header sizes raise ValueError."""
|
||||
huge = tmp_path / "huge.safetensors"
|
||||
with huge.open("wb") as f:
|
||||
f.write(struct.pack("<Q", 200_000_000)) # 200MB header
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid header size"):
|
||||
read_safetensor_metadata(huge)
|
||||
|
||||
|
||||
class TestGetBaseName:
|
||||
"""Tests for get_base_name function."""
|
||||
|
||||
def test_removes_safetensors_extension(self) -> None:
|
||||
"""Test that .safetensors extension is removed."""
|
||||
assert get_base_name(Path("model.safetensors")) == "model"
|
||||
|
||||
def test_removes_sft_extension(self) -> None:
|
||||
"""Test that .sft extension is removed."""
|
||||
assert get_base_name(Path("model.sft")) == "model"
|
||||
|
||||
def test_handles_uppercase_extension(self) -> None:
|
||||
"""Test that uppercase extensions are handled."""
|
||||
assert get_base_name(Path("model.SAFETENSORS")) == "model"
|
||||
|
||||
def test_preserves_name_without_known_extension(self) -> None:
|
||||
"""Test that unknown extensions use stem."""
|
||||
assert get_base_name(Path("model.bin")) == "model"
|
||||
|
||||
|
||||
class TestGetDefaultOutputPath:
|
||||
"""Tests for get_default_output_path function."""
|
||||
|
||||
def test_returns_checkpoint_path(self) -> None:
|
||||
"""Test that Checkpoint type returns checkpoints directory."""
|
||||
result = get_default_output_path("Checkpoint")
|
||||
assert result is not None
|
||||
assert "checkpoints" in str(result)
|
||||
|
||||
def test_returns_lora_path(self) -> None:
|
||||
"""Test that LORA type returns loras directory."""
|
||||
result = get_default_output_path("LORA")
|
||||
assert result is not None
|
||||
assert "loras" in str(result)
|
||||
|
||||
def test_returns_none_for_unknown_type(self) -> None:
|
||||
"""Test that unknown types return None."""
|
||||
assert get_default_output_path("UnknownType") is None
|
||||
assert get_default_output_path(None) is None
|
||||
|
||||
|
||||
class TestLoadApiKey:
|
||||
"""Tests for load_api_key function."""
|
||||
|
||||
def test_returns_env_var_if_set(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test that environment variable takes precedence."""
|
||||
monkeypatch.setenv("CIVITAI_API_KEY", "test-key-from-env")
|
||||
assert load_api_key() == "test-key-from-env"
|
||||
|
||||
def test_returns_none_if_no_key(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
"""Test that None is returned when no key is available."""
|
||||
monkeypatch.delenv("CIVITAI_API_KEY", raising=False)
|
||||
# Temporarily point RC_FILE to nonexistent file
|
||||
monkeypatch.setattr(tensors, "RC_FILE", tmp_path / "nonexistent")
|
||||
assert load_api_key() is None
|
||||
Reference in New Issue
Block a user