Phase 6: Tests for database, server, and client modules
- Add tests/test_db.py with 33 tests for Database class: - Schema initialization and migrations - Local file CRUD operations (scan, list, link) - CivitAI model caching (cache_model, tags, versions, files) - Query operations (search, get_model, get_triggers) - Statistics and context manager - Extend tests/test_server.py with 27 tests for API endpoints: - Gallery endpoints (list, get, meta, edit, delete, stats) - Database endpoints (files, models, stats) - Gallery class unit tests - Add tests/test_client.py with 33 tests for TsrClient: - Server status operations - Gallery image operations (list, get, delete, edit, download) - Model management (list, active, switch, loras) - Image generation - CivitAI download operations - Database query operations - Error handling and context manager Total: 191 tests passing with 61% coverage Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -27,9 +27,9 @@
|
||||
- [x] Step 5.3: ~~Create `rocm-docker/tsr-server.service`~~ (skipped)
|
||||
|
||||
## Phase 6: Tests
|
||||
- [ ] Step 6.1: `tests/test_db.py` (database module tests)
|
||||
- [ ] Step 6.2: `tests/test_server.py` (API endpoint tests)
|
||||
- [ ] Step 6.3: `tests/test_client.py` (client module tests)
|
||||
- [x] Step 6.1: `tests/test_db.py` (database module tests)
|
||||
- [x] Step 6.2: `tests/test_server.py` (API endpoint tests)
|
||||
- [x] Step 6.3: `tests/test_client.py` (client module tests)
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -0,0 +1,481 @@
|
||||
"""Tests for the TsrClient HTTP client module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import respx
|
||||
from httpx import Response
|
||||
|
||||
from tensors.client import TsrClient, TsrClientError
|
||||
|
||||
BASE_URL = "http://test-server:8080"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_server():
|
||||
"""Activate respx mock for the test server."""
|
||||
with respx.mock(base_url=BASE_URL, assert_all_called=False) as rsps:
|
||||
yield rsps
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_server) -> TsrClient: # noqa: ARG001 - mock_server activates respx
|
||||
"""TsrClient connected to mock server."""
|
||||
return TsrClient(BASE_URL)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Status Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestStatus:
|
||||
"""Tests for server status endpoint."""
|
||||
|
||||
def test_status_success(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test getting server status."""
|
||||
mock_server.get("/status").mock(return_value=Response(200, json={"running": True, "pid": 12345, "model": "/test.gguf"}))
|
||||
|
||||
with client:
|
||||
result = client.status()
|
||||
|
||||
assert result["running"] is True
|
||||
assert result["pid"] == 12345
|
||||
|
||||
def test_status_error(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test handling status error."""
|
||||
mock_server.get("/status").mock(return_value=Response(503, text="Service unavailable"))
|
||||
|
||||
with client, pytest.raises(TsrClientError, match="HTTP 503"):
|
||||
client.status()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Gallery Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGalleryImages:
|
||||
"""Tests for gallery image operations."""
|
||||
|
||||
def test_list_images(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test listing gallery images."""
|
||||
mock_server.get("/api/images").mock(
|
||||
return_value=Response(
|
||||
200,
|
||||
json={
|
||||
"images": [
|
||||
{"id": "123_0", "filename": "123_0.png", "width": 512, "height": 512},
|
||||
{"id": "124_1", "filename": "124_1.png", "width": 1024, "height": 1024},
|
||||
],
|
||||
"total": 2,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
with client:
|
||||
result = client.list_images()
|
||||
|
||||
assert len(result["images"]) == 2
|
||||
assert result["total"] == 2
|
||||
|
||||
def test_list_images_with_pagination(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test listing images with pagination."""
|
||||
mock_server.get("/api/images", params={"limit": 10, "offset": 5}).mock(
|
||||
return_value=Response(200, json={"images": [], "total": 100})
|
||||
)
|
||||
|
||||
with client:
|
||||
result = client.list_images(limit=10, offset=5)
|
||||
|
||||
assert result["total"] == 100
|
||||
|
||||
def test_get_image_meta(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test getting image metadata."""
|
||||
mock_server.get("/api/images/123_0/meta").mock(
|
||||
return_value=Response(
|
||||
200,
|
||||
json={
|
||||
"id": "123_0",
|
||||
"path": "/gallery/123_0.png",
|
||||
"metadata": {"prompt": "test prompt", "seed": 42},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
with client:
|
||||
result = client.get_image_meta("123_0")
|
||||
|
||||
assert result["id"] == "123_0"
|
||||
assert result["metadata"]["prompt"] == "test prompt"
|
||||
|
||||
def test_delete_image(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test deleting an image."""
|
||||
mock_server.delete("/api/images/123_0").mock(return_value=Response(200, json={"deleted": True, "id": "123_0"}))
|
||||
|
||||
with client:
|
||||
result = client.delete_image("123_0")
|
||||
|
||||
assert result["deleted"] is True
|
||||
|
||||
def test_edit_image(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test editing image metadata."""
|
||||
mock_server.post("/api/images/123_0/edit").mock(
|
||||
return_value=Response(200, json={"id": "123_0", "metadata": {"tags": ["favorite"], "rating": 5}})
|
||||
)
|
||||
|
||||
with client:
|
||||
result = client.edit_image("123_0", {"tags": ["favorite"], "rating": 5})
|
||||
|
||||
assert result["metadata"]["tags"] == ["favorite"]
|
||||
|
||||
def test_download_image(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test downloading image bytes."""
|
||||
image_bytes = b"\x89PNG test image data"
|
||||
mock_server.get("/api/images/123_0").mock(return_value=Response(200, content=image_bytes))
|
||||
|
||||
with client:
|
||||
result = client.download_image("123_0")
|
||||
|
||||
assert result == image_bytes
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Models Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestModels:
|
||||
"""Tests for model management operations."""
|
||||
|
||||
def test_list_models(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test listing available models."""
|
||||
mock_server.get("/api/models").mock(
|
||||
return_value=Response(
|
||||
200,
|
||||
json={
|
||||
"models": [
|
||||
{"name": "sdxl_base", "path": "/models/sdxl_base.safetensors"},
|
||||
{"name": "pony_v6", "path": "/models/pony_v6.safetensors"},
|
||||
],
|
||||
"active": "/models/sdxl_base.safetensors",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
with client:
|
||||
result = client.list_models()
|
||||
|
||||
assert len(result["models"]) == 2
|
||||
assert result["active"] == "/models/sdxl_base.safetensors"
|
||||
|
||||
def test_get_active_model(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test getting active model."""
|
||||
mock_server.get("/api/models/active").mock(return_value=Response(200, json={"model": "/models/sdxl_base.safetensors"}))
|
||||
|
||||
with client:
|
||||
result = client.get_active_model()
|
||||
|
||||
assert result["model"] == "/models/sdxl_base.safetensors"
|
||||
|
||||
def test_switch_model(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test switching model."""
|
||||
mock_server.post("/api/models/switch").mock(
|
||||
return_value=Response(200, json={"status": "ok", "model": "/models/pony_v6.safetensors"})
|
||||
)
|
||||
|
||||
with client:
|
||||
result = client.switch_model("/models/pony_v6.safetensors")
|
||||
|
||||
assert result["status"] == "ok"
|
||||
|
||||
def test_list_loras(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test listing LoRAs."""
|
||||
mock_server.get("/api/models/loras").mock(
|
||||
return_value=Response(
|
||||
200,
|
||||
json={
|
||||
"loras": [
|
||||
{"name": "detail_tweaker", "path": "/loras/detail_tweaker.safetensors"},
|
||||
]
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
with client:
|
||||
result = client.list_loras()
|
||||
|
||||
assert len(result["loras"]) == 1
|
||||
|
||||
def test_scan_models(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test scanning models."""
|
||||
mock_server.get("/api/models/scan").mock(return_value=Response(200, json={"scanned": 5}))
|
||||
|
||||
with client:
|
||||
result = client.scan_models()
|
||||
|
||||
assert result["scanned"] == 5
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Generation Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGeneration:
|
||||
"""Tests for image generation."""
|
||||
|
||||
def test_generate(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test generating an image."""
|
||||
mock_server.post("/api/generate").mock(
|
||||
return_value=Response(
|
||||
200,
|
||||
json={
|
||||
"images": [{"id": "999_42", "seed": 42}],
|
||||
"parameters": {"prompt": "test prompt", "seed": 42},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
with client:
|
||||
result = client.generate(
|
||||
prompt="test prompt",
|
||||
width=512,
|
||||
height=512,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
assert len(result["images"]) == 1
|
||||
assert result["images"][0]["seed"] == 42
|
||||
|
||||
def test_generate_with_all_params(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test generation with all parameters."""
|
||||
mock_server.post("/api/generate").mock(return_value=Response(200, json={"images": []}))
|
||||
|
||||
with client:
|
||||
result = client.generate(
|
||||
prompt="detailed test prompt",
|
||||
negative_prompt="bad quality",
|
||||
width=1024,
|
||||
height=1024,
|
||||
steps=30,
|
||||
cfg_scale=5.5,
|
||||
seed=12345,
|
||||
sampler_name="DPM++ 2M",
|
||||
scheduler="karras",
|
||||
batch_size=2,
|
||||
save_to_gallery=False,
|
||||
return_base64=True,
|
||||
)
|
||||
|
||||
assert "images" in result
|
||||
|
||||
def test_list_samplers(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test listing samplers."""
|
||||
mock_server.get("/api/samplers").mock(return_value=Response(200, json={"samplers": ["Euler", "DPM++ 2M", "Euler a"]}))
|
||||
|
||||
with client:
|
||||
result = client.list_samplers()
|
||||
|
||||
assert "samplers" in result
|
||||
|
||||
def test_list_schedulers(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test listing schedulers."""
|
||||
mock_server.get("/api/schedulers").mock(
|
||||
return_value=Response(200, json={"schedulers": ["simple", "karras", "sgm_uniform"]})
|
||||
)
|
||||
|
||||
with client:
|
||||
result = client.list_schedulers()
|
||||
|
||||
assert "schedulers" in result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Download Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestDownload:
|
||||
"""Tests for CivitAI download operations."""
|
||||
|
||||
def test_start_download_by_version(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test starting download by version ID."""
|
||||
mock_server.post("/api/download").mock(
|
||||
return_value=Response(200, json={"download_id": "abc123", "status": "started", "version_id": 12345})
|
||||
)
|
||||
|
||||
with client:
|
||||
result = client.start_download(version_id=12345)
|
||||
|
||||
assert result["download_id"] == "abc123"
|
||||
|
||||
def test_start_download_by_hash(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test starting download by hash."""
|
||||
mock_server.post("/api/download").mock(return_value=Response(200, json={"download_id": "def456", "status": "started"}))
|
||||
|
||||
with client:
|
||||
result = client.start_download(hash_val="ABC123DEF456")
|
||||
|
||||
assert result["status"] == "started"
|
||||
|
||||
def test_get_download_status(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test getting download status."""
|
||||
mock_server.get("/api/download/status/abc123").mock(
|
||||
return_value=Response(200, json={"download_id": "abc123", "status": "downloading", "progress": 0.5})
|
||||
)
|
||||
|
||||
with client:
|
||||
result = client.get_download_status("abc123")
|
||||
|
||||
assert result["progress"] == 0.5
|
||||
|
||||
def test_list_downloads(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test listing active downloads."""
|
||||
mock_server.get("/api/download/active").mock(
|
||||
return_value=Response(200, json={"downloads": [{"id": "abc123", "progress": 0.75}]})
|
||||
)
|
||||
|
||||
with client:
|
||||
result = client.list_downloads()
|
||||
|
||||
assert len(result["downloads"]) == 1
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Database Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestDatabase:
|
||||
"""Tests for database operations."""
|
||||
|
||||
def test_db_list_files(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test listing local files."""
|
||||
mock_server.get("/api/db/files").mock(
|
||||
return_value=Response(200, json=[{"id": 1, "file_path": "/models/test.safetensors", "sha256": "abc123"}])
|
||||
)
|
||||
|
||||
with client:
|
||||
result = client.db_list_files()
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["sha256"] == "abc123"
|
||||
|
||||
def test_db_search_models(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test searching cached models."""
|
||||
mock_server.get("/api/db/models").mock(
|
||||
return_value=Response(200, json=[{"civitai_id": 12345, "name": "Test Model", "type": "LORA"}])
|
||||
)
|
||||
|
||||
with client:
|
||||
result = client.db_search_models(query="Test", model_type="LORA")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "Test Model"
|
||||
|
||||
def test_db_get_model(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test getting cached model."""
|
||||
mock_server.get("/api/db/models/12345").mock(
|
||||
return_value=Response(200, json={"civitai_id": 12345, "name": "Test Model", "type": "Checkpoint"})
|
||||
)
|
||||
|
||||
with client:
|
||||
result = client.db_get_model(12345)
|
||||
|
||||
assert result["name"] == "Test Model"
|
||||
|
||||
def test_db_get_triggers(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test getting trigger words."""
|
||||
mock_server.get("/api/db/triggers/12345").mock(return_value=Response(200, json=["trigger1", "trigger2"]))
|
||||
|
||||
with client:
|
||||
result = client.db_get_triggers(version_id=12345)
|
||||
|
||||
assert result == ["trigger1", "trigger2"]
|
||||
|
||||
def test_db_stats(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test getting database stats."""
|
||||
mock_server.get("/api/db/stats").mock(
|
||||
return_value=Response(200, json={"local_files": 10, "models": 5, "model_versions": 15})
|
||||
)
|
||||
|
||||
with client:
|
||||
result = client.db_stats()
|
||||
|
||||
assert result["local_files"] == 10
|
||||
|
||||
def test_db_scan(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test scanning directory."""
|
||||
mock_server.post("/api/db/scan").mock(return_value=Response(200, json={"scanned": 3, "files": []}))
|
||||
|
||||
with client:
|
||||
result = client.db_scan("/models")
|
||||
|
||||
assert result["scanned"] == 3
|
||||
|
||||
def test_db_link(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test linking files to CivitAI."""
|
||||
mock_server.post("/api/db/link").mock(return_value=Response(200, json={"linked": 2}))
|
||||
|
||||
with client:
|
||||
result = client.db_link()
|
||||
|
||||
assert result["linked"] == 2
|
||||
|
||||
def test_db_cache(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test caching model data."""
|
||||
mock_server.post("/api/db/cache").mock(return_value=Response(200, json={"model_id": 12345, "cached": True}))
|
||||
|
||||
with client:
|
||||
result = client.db_cache(12345)
|
||||
|
||||
assert result["cached"] is True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Error Handling Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Tests for error handling."""
|
||||
|
||||
def test_http_error(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test HTTP error handling."""
|
||||
mock_server.get("/api/images").mock(return_value=Response(500, text="Internal server error"))
|
||||
|
||||
with client, pytest.raises(TsrClientError, match="HTTP 500"):
|
||||
client.list_images()
|
||||
|
||||
def test_not_found_error(self, client: TsrClient, mock_server) -> None:
|
||||
"""Test 404 error handling."""
|
||||
mock_server.get("/api/images/nonexistent/meta").mock(return_value=Response(404, json={"detail": "Image not found"}))
|
||||
|
||||
with client, pytest.raises(TsrClientError, match="HTTP 404"):
|
||||
client.get_image_meta("nonexistent")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Context Manager Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestContextManager:
|
||||
"""Tests for context manager usage."""
|
||||
|
||||
def test_context_manager(self, mock_server) -> None:
|
||||
"""Test client works as context manager."""
|
||||
mock_server.get("/status").mock(return_value=Response(200, json={"running": True}))
|
||||
|
||||
with TsrClient(BASE_URL) as client:
|
||||
result = client.status()
|
||||
assert result["running"] is True
|
||||
|
||||
def test_client_without_context(self, mock_server) -> None:
|
||||
"""Test client works without context manager."""
|
||||
mock_server.get("/status").mock(return_value=Response(200, json={"running": True}))
|
||||
|
||||
client = TsrClient(BASE_URL)
|
||||
result = client.status()
|
||||
assert result["running"] is True
|
||||
@@ -0,0 +1,450 @@
|
||||
"""Tests for the database module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import struct
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from tensors.db import Database
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db(tmp_path: Path) -> Database:
|
||||
"""Create a temporary database for testing."""
|
||||
db_path = tmp_path / "test_models.db"
|
||||
db = Database(db_path=db_path)
|
||||
db.init_schema()
|
||||
return db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_safetensor(tmp_path: Path) -> Path:
|
||||
"""Create a sample safetensor file for testing."""
|
||||
header = {
|
||||
"__metadata__": {
|
||||
"format": "pt",
|
||||
"test_key": "test_value",
|
||||
}
|
||||
}
|
||||
header_bytes = json.dumps(header).encode("utf-8")
|
||||
header_size = len(header_bytes)
|
||||
|
||||
file_path = tmp_path / "models" / "test_lora.safetensors"
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with file_path.open("wb") as f:
|
||||
f.write(struct.pack("<Q", header_size))
|
||||
f.write(header_bytes)
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_civitai_model() -> dict:
|
||||
"""Sample CivitAI model API response."""
|
||||
return {
|
||||
"id": 123456,
|
||||
"name": "Test LoRA",
|
||||
"description": "A test LoRA model",
|
||||
"type": "LORA",
|
||||
"nsfw": False,
|
||||
"poi": False,
|
||||
"minor": False,
|
||||
"tags": ["test", "lora", "anime"],
|
||||
"creator": {
|
||||
"username": "test_creator",
|
||||
"image": "https://example.com/avatar.png",
|
||||
},
|
||||
"stats": {
|
||||
"downloadCount": 1000,
|
||||
"thumbsUpCount": 500,
|
||||
"thumbsDownCount": 10,
|
||||
"commentCount": 50,
|
||||
"tippedAmountCount": 5,
|
||||
},
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 789012,
|
||||
"name": "v1.0",
|
||||
"description": "Initial release",
|
||||
"baseModel": "SDXL 1.0",
|
||||
"trainedWords": ["test_trigger", "lora_trigger"],
|
||||
"files": [
|
||||
{
|
||||
"id": 111222,
|
||||
"name": "test_lora.safetensors",
|
||||
"type": "Model",
|
||||
"sizeKB": 150000,
|
||||
"primary": True,
|
||||
"hashes": {
|
||||
"SHA256": "ABC123DEF456",
|
||||
"BLAKE3": "789XYZ",
|
||||
},
|
||||
"metadata": {
|
||||
"format": "SafeTensor",
|
||||
"size": "full",
|
||||
"fp": "fp16",
|
||||
},
|
||||
}
|
||||
],
|
||||
"images": [
|
||||
{
|
||||
"id": 333444,
|
||||
"url": "https://example.com/image.png",
|
||||
"type": "image",
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"meta": {
|
||||
"prompt": "test prompt",
|
||||
"negativePrompt": "bad quality",
|
||||
"cfgScale": 7.0,
|
||||
},
|
||||
}
|
||||
],
|
||||
"stats": {
|
||||
"downloadCount": 1000,
|
||||
"thumbsUpCount": 500,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class TestDatabaseSchema:
|
||||
"""Tests for database schema initialization."""
|
||||
|
||||
def test_init_schema(self, temp_db: Database) -> None:
|
||||
"""Test schema initialization creates tables."""
|
||||
cur = temp_db.conn.cursor()
|
||||
cur.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
tables = {row[0] for row in cur.fetchall()}
|
||||
|
||||
expected = {
|
||||
"local_files",
|
||||
"safetensor_metadata",
|
||||
"models",
|
||||
"model_versions",
|
||||
"version_files",
|
||||
"file_hashes",
|
||||
"trained_words",
|
||||
"creators",
|
||||
"tags",
|
||||
"model_tags",
|
||||
"version_images",
|
||||
"image_generation_params",
|
||||
"image_resources",
|
||||
}
|
||||
assert expected.issubset(tables)
|
||||
|
||||
def test_init_schema_creates_views(self, temp_db: Database) -> None:
|
||||
"""Test schema creates required views."""
|
||||
cur = temp_db.conn.cursor()
|
||||
cur.execute("SELECT name FROM sqlite_master WHERE type='view'")
|
||||
views = {row[0] for row in cur.fetchall()}
|
||||
|
||||
assert "v_local_files_full" in views
|
||||
assert "v_models_with_latest" in views
|
||||
|
||||
|
||||
class TestLocalFiles:
|
||||
"""Tests for local file operations."""
|
||||
|
||||
def test_scan_directory(self, temp_db: Database, sample_safetensor: Path) -> None:
|
||||
"""Test scanning directory for safetensor files."""
|
||||
results = temp_db.scan_directory(sample_safetensor.parent)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["file_path"] == str(sample_safetensor.resolve())
|
||||
assert "sha256" in results[0]
|
||||
assert results[0]["sha256"] # Should have hash
|
||||
|
||||
def test_scan_directory_empty(self, temp_db: Database, tmp_path: Path) -> None:
|
||||
"""Test scanning empty directory."""
|
||||
empty_dir = tmp_path / "empty"
|
||||
empty_dir.mkdir()
|
||||
|
||||
results = temp_db.scan_directory(empty_dir)
|
||||
assert results == []
|
||||
|
||||
def test_list_local_files(self, temp_db: Database, sample_safetensor: Path) -> None:
|
||||
"""Test listing local files after scan."""
|
||||
temp_db.scan_directory(sample_safetensor.parent)
|
||||
files = temp_db.list_local_files()
|
||||
|
||||
assert len(files) == 1
|
||||
assert files[0]["file_path"] == str(sample_safetensor.resolve())
|
||||
|
||||
def test_get_local_file_by_path(self, temp_db: Database, sample_safetensor: Path) -> None:
|
||||
"""Test getting local file by path."""
|
||||
temp_db.scan_directory(sample_safetensor.parent)
|
||||
|
||||
file_info = temp_db.get_local_file_by_path(str(sample_safetensor.resolve()))
|
||||
assert file_info is not None
|
||||
assert file_info["file_path"] == str(sample_safetensor.resolve())
|
||||
|
||||
def test_get_local_file_by_path_not_found(self, temp_db: Database) -> None:
|
||||
"""Test getting non-existent file."""
|
||||
result = temp_db.get_local_file_by_path("/nonexistent/file.safetensors")
|
||||
assert result is None
|
||||
|
||||
def test_get_unlinked_files(self, temp_db: Database, sample_safetensor: Path) -> None:
|
||||
"""Test getting unlinked files."""
|
||||
temp_db.scan_directory(sample_safetensor.parent)
|
||||
unlinked = temp_db.get_unlinked_files()
|
||||
|
||||
assert len(unlinked) == 1
|
||||
assert unlinked[0].get("civitai_model_id", True)
|
||||
|
||||
def test_link_file_to_civitai(self, temp_db: Database, sample_safetensor: Path) -> None:
|
||||
"""Test linking a file to CivitAI."""
|
||||
results = temp_db.scan_directory(sample_safetensor.parent)
|
||||
file_id = results[0]["id"]
|
||||
|
||||
temp_db.link_file_to_civitai(file_id, model_id=123, version_id=456)
|
||||
|
||||
# Should have no unlinked files now
|
||||
unlinked = temp_db.get_unlinked_files()
|
||||
assert len(unlinked) == 0
|
||||
|
||||
def test_upsert_local_file_updates_existing(self, temp_db: Database, sample_safetensor: Path) -> None:
|
||||
"""Test that scanning same file twice updates instead of inserting."""
|
||||
temp_db.scan_directory(sample_safetensor.parent)
|
||||
temp_db.scan_directory(sample_safetensor.parent)
|
||||
|
||||
files = temp_db.list_local_files()
|
||||
assert len(files) == 1
|
||||
|
||||
|
||||
class TestCivitAICache:
|
||||
"""Tests for CivitAI model caching."""
|
||||
|
||||
def test_cache_model(self, temp_db: Database, sample_civitai_model: dict) -> None:
|
||||
"""Test caching a full CivitAI model."""
|
||||
model_id = temp_db.cache_model(sample_civitai_model)
|
||||
assert model_id > 0
|
||||
|
||||
def test_cache_model_creates_creator(self, temp_db: Database, sample_civitai_model: dict) -> None:
|
||||
"""Test that caching model creates creator record."""
|
||||
temp_db.cache_model(sample_civitai_model)
|
||||
|
||||
cur = temp_db.conn.cursor()
|
||||
cur.execute("SELECT * FROM creators WHERE username = ?", ("test_creator",))
|
||||
creator = cur.fetchone()
|
||||
|
||||
assert creator is not None
|
||||
assert creator["username"] == "test_creator"
|
||||
|
||||
def test_cache_model_creates_tags(self, temp_db: Database, sample_civitai_model: dict) -> None:
|
||||
"""Test that caching model creates tags."""
|
||||
temp_db.cache_model(sample_civitai_model)
|
||||
|
||||
cur = temp_db.conn.cursor()
|
||||
cur.execute("SELECT COUNT(*) FROM tags")
|
||||
count = cur.fetchone()[0]
|
||||
|
||||
assert count == 3 # test, lora, anime
|
||||
|
||||
def test_cache_model_creates_versions(self, temp_db: Database, sample_civitai_model: dict) -> None:
|
||||
"""Test that caching model creates versions."""
|
||||
temp_db.cache_model(sample_civitai_model)
|
||||
|
||||
cur = temp_db.conn.cursor()
|
||||
cur.execute("SELECT * FROM model_versions WHERE civitai_id = ?", (789012,))
|
||||
version = cur.fetchone()
|
||||
|
||||
assert version is not None
|
||||
assert version["name"] == "v1.0"
|
||||
assert version["base_model"] == "SDXL 1.0"
|
||||
|
||||
def test_cache_model_creates_trained_words(self, temp_db: Database, sample_civitai_model: dict) -> None:
|
||||
"""Test that caching model creates trained words."""
|
||||
temp_db.cache_model(sample_civitai_model)
|
||||
|
||||
cur = temp_db.conn.cursor()
|
||||
cur.execute("SELECT word FROM trained_words ORDER BY position")
|
||||
words = [row[0] for row in cur.fetchall()]
|
||||
|
||||
assert words == ["test_trigger", "lora_trigger"]
|
||||
|
||||
def test_cache_model_creates_files_and_hashes(self, temp_db: Database, sample_civitai_model: dict) -> None:
|
||||
"""Test that caching model creates files and hashes."""
|
||||
temp_db.cache_model(sample_civitai_model)
|
||||
|
||||
cur = temp_db.conn.cursor()
|
||||
cur.execute("SELECT * FROM version_files WHERE civitai_id = ?", (111222,))
|
||||
file_record = cur.fetchone()
|
||||
|
||||
assert file_record is not None
|
||||
assert file_record["name"] == "test_lora.safetensors"
|
||||
assert file_record["is_primary"] == 1
|
||||
|
||||
cur.execute("SELECT hash_type, hash_value FROM file_hashes WHERE file_id = ?", (file_record["id"],))
|
||||
hashes = {row[0]: row[1] for row in cur.fetchall()}
|
||||
|
||||
assert hashes["SHA256"] == "ABC123DEF456"
|
||||
assert hashes["BLAKE3"] == "789XYZ"
|
||||
|
||||
def test_cache_model_idempotent(self, temp_db: Database, sample_civitai_model: dict) -> None:
|
||||
"""Test that caching same model twice is idempotent."""
|
||||
id1 = temp_db.cache_model(sample_civitai_model)
|
||||
id2 = temp_db.cache_model(sample_civitai_model)
|
||||
|
||||
assert id1 == id2
|
||||
|
||||
cur = temp_db.conn.cursor()
|
||||
cur.execute("SELECT COUNT(*) FROM models")
|
||||
assert cur.fetchone()[0] == 1
|
||||
|
||||
|
||||
class TestQueryOperations:
|
||||
"""Tests for search and query operations."""
|
||||
|
||||
def test_search_models_by_name(self, temp_db: Database, sample_civitai_model: dict) -> None:
|
||||
"""Test searching models by name."""
|
||||
temp_db.cache_model(sample_civitai_model)
|
||||
results = temp_db.search_models(query="Test")
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["name"] == "Test LoRA"
|
||||
|
||||
def test_search_models_by_type(self, temp_db: Database, sample_civitai_model: dict) -> None:
|
||||
"""Test searching models by type."""
|
||||
temp_db.cache_model(sample_civitai_model)
|
||||
results = temp_db.search_models(model_type="LORA")
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
def test_search_models_by_base_model(self, temp_db: Database, sample_civitai_model: dict) -> None:
|
||||
"""Test searching models by base model."""
|
||||
temp_db.cache_model(sample_civitai_model)
|
||||
results = temp_db.search_models(base_model="SDXL")
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
def test_search_models_no_results(self, temp_db: Database, sample_civitai_model: dict) -> None:
|
||||
"""Test search with no matching results."""
|
||||
temp_db.cache_model(sample_civitai_model)
|
||||
results = temp_db.search_models(query="nonexistent")
|
||||
|
||||
assert len(results) == 0
|
||||
|
||||
def test_search_models_limit(self, temp_db: Database, sample_civitai_model: dict) -> None:
|
||||
"""Test search respects limit."""
|
||||
# Cache multiple models
|
||||
for i in range(5):
|
||||
model = sample_civitai_model.copy()
|
||||
model["id"] = 100000 + i
|
||||
model["name"] = f"Model {i}"
|
||||
temp_db.cache_model(model)
|
||||
|
||||
results = temp_db.search_models(limit=3)
|
||||
assert len(results) == 3
|
||||
|
||||
def test_get_model(self, temp_db: Database, sample_civitai_model: dict) -> None:
|
||||
"""Test getting model by CivitAI ID."""
|
||||
temp_db.cache_model(sample_civitai_model)
|
||||
model = temp_db.get_model(123456)
|
||||
|
||||
assert model is not None
|
||||
assert model["name"] == "Test LoRA"
|
||||
assert model["type"] == "LORA"
|
||||
|
||||
def test_get_model_not_found(self, temp_db: Database) -> None:
|
||||
"""Test getting non-existent model."""
|
||||
result = temp_db.get_model(999999)
|
||||
assert result is None
|
||||
|
||||
def test_get_version_by_hash(self, temp_db: Database, sample_civitai_model: dict) -> None:
|
||||
"""Test finding version by file hash."""
|
||||
temp_db.cache_model(sample_civitai_model)
|
||||
version = temp_db.get_version_by_hash("ABC123DEF456")
|
||||
|
||||
assert version is not None
|
||||
assert version["model_name"] == "Test LoRA"
|
||||
assert version["version_name"] == "v1.0"
|
||||
|
||||
def test_get_version_by_hash_case_insensitive(self, temp_db: Database, sample_civitai_model: dict) -> None:
|
||||
"""Test hash lookup is case insensitive."""
|
||||
temp_db.cache_model(sample_civitai_model)
|
||||
version = temp_db.get_version_by_hash("abc123def456")
|
||||
|
||||
assert version is not None
|
||||
|
||||
|
||||
class TestTriggerWords:
|
||||
"""Tests for trigger word operations."""
|
||||
|
||||
def test_get_triggers_by_version(self, temp_db: Database, sample_civitai_model: dict) -> None:
|
||||
"""Test getting triggers by version ID."""
|
||||
temp_db.cache_model(sample_civitai_model)
|
||||
triggers = temp_db.get_triggers_by_version(789012)
|
||||
|
||||
assert triggers == ["test_trigger", "lora_trigger"]
|
||||
|
||||
def test_get_triggers_by_file_path(self, temp_db: Database, sample_civitai_model: dict, sample_safetensor: Path) -> None:
|
||||
"""Test getting triggers by linked file path."""
|
||||
temp_db.cache_model(sample_civitai_model)
|
||||
results = temp_db.scan_directory(sample_safetensor.parent)
|
||||
temp_db.link_file_to_civitai(results[0]["id"], model_id=123456, version_id=789012)
|
||||
|
||||
triggers = temp_db.get_triggers(str(sample_safetensor.resolve()))
|
||||
assert triggers == ["test_trigger", "lora_trigger"]
|
||||
|
||||
def test_get_triggers_empty(self, temp_db: Database) -> None:
|
||||
"""Test getting triggers for unlinked file."""
|
||||
triggers = temp_db.get_triggers("/nonexistent/file.safetensors")
|
||||
assert triggers == []
|
||||
|
||||
|
||||
class TestStatistics:
|
||||
"""Tests for database statistics."""
|
||||
|
||||
def test_get_stats_empty(self, temp_db: Database) -> None:
|
||||
"""Test stats on empty database."""
|
||||
stats = temp_db.get_stats()
|
||||
|
||||
assert stats["local_files"] == 0
|
||||
assert stats["models"] == 0
|
||||
assert stats["model_versions"] == 0
|
||||
|
||||
def test_get_stats_with_data(self, temp_db: Database, sample_civitai_model: dict, sample_safetensor: Path) -> None:
|
||||
"""Test stats with data."""
|
||||
temp_db.cache_model(sample_civitai_model)
|
||||
temp_db.scan_directory(sample_safetensor.parent)
|
||||
|
||||
stats = temp_db.get_stats()
|
||||
|
||||
assert stats["local_files"] == 1
|
||||
assert stats["models"] == 1
|
||||
assert stats["model_versions"] == 1
|
||||
assert stats["trained_words"] == 2
|
||||
assert stats["creators"] == 1
|
||||
assert stats["tags"] == 3
|
||||
|
||||
|
||||
class TestContextManager:
|
||||
"""Tests for database context manager."""
|
||||
|
||||
def test_context_manager(self, tmp_path: Path) -> None:
|
||||
"""Test database works as context manager."""
|
||||
db_path = tmp_path / "test.db"
|
||||
|
||||
with Database(db_path=db_path) as db:
|
||||
db.init_schema()
|
||||
stats = db.get_stats()
|
||||
assert stats["local_files"] == 0
|
||||
|
||||
# Connection should be closed
|
||||
assert db._conn is None
|
||||
|
||||
def test_connection_reuse(self, tmp_path: Path) -> None:
|
||||
"""Test that connection is reused within context."""
|
||||
db_path = tmp_path / "test.db"
|
||||
|
||||
with Database(db_path=db_path) as db:
|
||||
db.init_schema()
|
||||
conn1 = db.conn
|
||||
conn2 = db.conn
|
||||
assert conn1 is conn2
|
||||
@@ -167,3 +167,380 @@ class TestProcessManager:
|
||||
cfg = ServerConfig(model="/m.gguf")
|
||||
assert cfg.port == 1234
|
||||
assert cfg.args == []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Gallery Endpoint Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_gallery(tmp_path):
|
||||
"""Create a temporary gallery for testing."""
|
||||
from tensors.server.gallery import Gallery # noqa: PLC0415
|
||||
|
||||
gallery_dir = tmp_path / "gallery"
|
||||
gallery_dir.mkdir()
|
||||
return Gallery(gallery_dir=gallery_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gallery_with_images(temp_gallery):
|
||||
"""Gallery with some test images."""
|
||||
# Create test images
|
||||
for i in range(3):
|
||||
# Create a minimal PNG (1x1 pixel)
|
||||
image_data = (
|
||||
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01"
|
||||
b"\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00"
|
||||
b"\x00\x00\x0cIDATx\x9cc\xf8\x0f\x00\x00\x01\x01\x00"
|
||||
b"\x05\x18\xd8N\x00\x00\x00\x00IEND\xaeB`\x82"
|
||||
)
|
||||
metadata = {"prompt": f"test prompt {i}", "seed": i, "width": 512, "height": 512}
|
||||
temp_gallery.save_image(image_data, metadata=metadata, seed=i)
|
||||
|
||||
return temp_gallery
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gallery_api(temp_gallery) -> TestClient:
|
||||
"""Test client for gallery API with temp gallery."""
|
||||
from fastapi import FastAPI # noqa: PLC0415
|
||||
|
||||
# Override the gallery singleton
|
||||
from tensors.server import gallery_routes # noqa: PLC0415
|
||||
from tensors.server.gallery_routes import create_gallery_router # noqa: PLC0415
|
||||
|
||||
gallery_routes._gallery = temp_gallery
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(create_gallery_router())
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestGalleryList:
|
||||
"""Tests for gallery list endpoint."""
|
||||
|
||||
def test_list_images_empty(self, gallery_api: TestClient) -> None:
|
||||
"""Test listing empty gallery."""
|
||||
response = gallery_api.get("/api/images")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["images"] == []
|
||||
assert data["total"] == 0
|
||||
|
||||
def test_list_images_with_data(self, gallery_api: TestClient, gallery_with_images) -> None:
|
||||
"""Test listing gallery with images."""
|
||||
from tensors.server import gallery_routes # noqa: PLC0415
|
||||
|
||||
gallery_routes._gallery = gallery_with_images
|
||||
|
||||
response = gallery_api.get("/api/images")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["images"]) == 3
|
||||
assert data["total"] == 3
|
||||
|
||||
def test_list_images_pagination(self, gallery_api: TestClient, gallery_with_images) -> None:
|
||||
"""Test pagination parameters."""
|
||||
from tensors.server import gallery_routes # noqa: PLC0415
|
||||
|
||||
gallery_routes._gallery = gallery_with_images
|
||||
|
||||
response = gallery_api.get("/api/images?limit=2&offset=1")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["images"]) == 2
|
||||
|
||||
|
||||
class TestGalleryGetImage:
|
||||
"""Tests for getting individual images."""
|
||||
|
||||
def test_get_image_not_found(self, gallery_api: TestClient) -> None:
|
||||
"""Test getting non-existent image returns 404."""
|
||||
response = gallery_api.get("/api/images/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_get_image_success(self, gallery_api: TestClient, gallery_with_images) -> None:
|
||||
"""Test getting an image file."""
|
||||
from tensors.server import gallery_routes # noqa: PLC0415
|
||||
|
||||
gallery_routes._gallery = gallery_with_images
|
||||
|
||||
list_response = gallery_api.get("/api/images")
|
||||
images = list_response.json()["images"]
|
||||
image_id = images[0]["id"]
|
||||
|
||||
response = gallery_api.get(f"/api/images/{image_id}")
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "image/png"
|
||||
|
||||
|
||||
class TestGalleryMetadata:
|
||||
"""Tests for image metadata endpoints."""
|
||||
|
||||
def test_get_metadata_not_found(self, gallery_api: TestClient) -> None:
|
||||
"""Test getting metadata for non-existent image."""
|
||||
response = gallery_api.get("/api/images/nonexistent/meta")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_get_metadata_success(self, gallery_api: TestClient, gallery_with_images) -> None:
|
||||
"""Test getting image metadata."""
|
||||
from tensors.server import gallery_routes # noqa: PLC0415
|
||||
|
||||
gallery_routes._gallery = gallery_with_images
|
||||
|
||||
list_response = gallery_api.get("/api/images")
|
||||
images = list_response.json()["images"]
|
||||
image_id = images[0]["id"]
|
||||
|
||||
response = gallery_api.get(f"/api/images/{image_id}/meta")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == image_id
|
||||
assert "metadata" in data
|
||||
|
||||
def test_edit_metadata(self, gallery_api: TestClient, gallery_with_images) -> None:
|
||||
"""Test updating image metadata."""
|
||||
from tensors.server import gallery_routes # noqa: PLC0415
|
||||
|
||||
gallery_routes._gallery = gallery_with_images
|
||||
|
||||
list_response = gallery_api.get("/api/images")
|
||||
images = list_response.json()["images"]
|
||||
image_id = images[0]["id"]
|
||||
|
||||
response = gallery_api.post(
|
||||
f"/api/images/{image_id}/edit",
|
||||
json={"tags": ["test", "favorite"], "rating": 5},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["metadata"]["tags"] == ["test", "favorite"]
|
||||
assert data["metadata"]["rating"] == 5
|
||||
|
||||
def test_edit_metadata_not_found(self, gallery_api: TestClient) -> None:
|
||||
"""Test editing non-existent image metadata."""
|
||||
response = gallery_api.post(
|
||||
"/api/images/nonexistent/edit",
|
||||
json={"tags": ["test"]},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestGalleryDelete:
|
||||
"""Tests for deleting images."""
|
||||
|
||||
def test_delete_image_not_found(self, gallery_api: TestClient) -> None:
|
||||
"""Test deleting non-existent image."""
|
||||
response = gallery_api.delete("/api/images/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_delete_image_success(self, gallery_api: TestClient, gallery_with_images) -> None:
|
||||
"""Test deleting an image."""
|
||||
from tensors.server import gallery_routes # noqa: PLC0415
|
||||
|
||||
gallery_routes._gallery = gallery_with_images
|
||||
|
||||
list_response = gallery_api.get("/api/images")
|
||||
initial_count = list_response.json()["total"]
|
||||
image_id = list_response.json()["images"][0]["id"]
|
||||
|
||||
response = gallery_api.delete(f"/api/images/{image_id}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["deleted"] is True
|
||||
|
||||
list_response = gallery_api.get("/api/images")
|
||||
assert list_response.json()["total"] == initial_count - 1
|
||||
|
||||
|
||||
class TestGalleryStats:
|
||||
"""Tests for gallery statistics endpoint."""
|
||||
|
||||
def test_stats_empty(self, gallery_api: TestClient) -> None:
|
||||
"""Test stats on empty gallery."""
|
||||
response = gallery_api.get("/api/images/stats/summary")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_images"] == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Database Endpoint Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db(tmp_path):
|
||||
"""Create a temporary database for testing."""
|
||||
from tensors.db import Database # noqa: PLC0415
|
||||
|
||||
db_path = tmp_path / "test_models.db"
|
||||
db = Database(db_path=db_path)
|
||||
db.init_schema()
|
||||
return db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_api(temp_db, monkeypatch) -> TestClient:
|
||||
"""Test client for db API with temp database."""
|
||||
from fastapi import FastAPI # noqa: PLC0415
|
||||
|
||||
# Monkeypatch Database to use temp_db path
|
||||
from tensors import db as db_module # noqa: PLC0415
|
||||
from tensors.server.db_routes import create_db_router # noqa: PLC0415
|
||||
|
||||
monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path)
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(create_db_router())
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestDbEndpoints:
|
||||
"""Tests for database API endpoints."""
|
||||
|
||||
def test_list_files_empty(self, db_api: TestClient) -> None:
|
||||
"""Test listing files from empty database."""
|
||||
response = db_api.get("/api/db/files")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
def test_search_models_empty(self, db_api: TestClient) -> None:
|
||||
"""Test searching models in empty database."""
|
||||
response = db_api.get("/api/db/models")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
def test_search_models_with_query(self, db_api: TestClient, temp_db, monkeypatch) -> None:
|
||||
"""Test searching models with query parameters."""
|
||||
from tensors import db as db_module # noqa: PLC0415
|
||||
|
||||
monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path)
|
||||
|
||||
model_data = {
|
||||
"id": 12345,
|
||||
"name": "Test Model",
|
||||
"type": "LORA",
|
||||
"tags": [],
|
||||
"modelVersions": [],
|
||||
}
|
||||
temp_db.cache_model(model_data)
|
||||
|
||||
response = db_api.get("/api/db/models?query=Test")
|
||||
assert response.status_code == 200
|
||||
results = response.json()
|
||||
assert len(results) >= 1
|
||||
|
||||
def test_get_model_not_found(self, db_api: TestClient) -> None:
|
||||
"""Test getting non-existent model."""
|
||||
response = db_api.get("/api/db/models/999999")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_get_model_success(self, db_api: TestClient, temp_db, monkeypatch) -> None:
|
||||
"""Test getting cached model."""
|
||||
from tensors import db as db_module # noqa: PLC0415
|
||||
|
||||
monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path)
|
||||
|
||||
model_data = {
|
||||
"id": 12345,
|
||||
"name": "Test Model",
|
||||
"type": "Checkpoint",
|
||||
"tags": [],
|
||||
"modelVersions": [],
|
||||
}
|
||||
temp_db.cache_model(model_data)
|
||||
|
||||
response = db_api.get("/api/db/models/12345")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Test Model"
|
||||
|
||||
def test_get_stats(self, db_api: TestClient) -> None:
|
||||
"""Test getting database stats."""
|
||||
response = db_api.get("/api/db/stats")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "local_files" in data
|
||||
assert "models" in data
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Gallery Class Unit Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGalleryClass:
|
||||
"""Unit tests for the Gallery class."""
|
||||
|
||||
def test_save_image(self, temp_gallery) -> None:
|
||||
"""Test saving an image to gallery."""
|
||||
image_data = b"\x89PNG test data"
|
||||
metadata = {"prompt": "test", "seed": 42}
|
||||
|
||||
result = temp_gallery.save_image(image_data, metadata=metadata, seed=42)
|
||||
|
||||
assert result.id is not None
|
||||
assert result.path.exists()
|
||||
assert result.meta_path.exists()
|
||||
|
||||
def test_list_images_empty(self, temp_gallery) -> None:
|
||||
"""Test listing empty gallery."""
|
||||
images = temp_gallery.list_images()
|
||||
assert images == []
|
||||
|
||||
def test_list_images_sorted(self, gallery_with_images) -> None:
|
||||
"""Test images are sorted by creation time."""
|
||||
images = gallery_with_images.list_images(newest_first=True)
|
||||
assert len(images) == 3
|
||||
|
||||
times = [img.created_at for img in images]
|
||||
assert times == sorted(times, reverse=True)
|
||||
|
||||
def test_get_image(self, gallery_with_images) -> None:
|
||||
"""Test getting image by ID."""
|
||||
images = gallery_with_images.list_images()
|
||||
image_id = images[0].id
|
||||
|
||||
result = gallery_with_images.get_image(image_id)
|
||||
assert result is not None
|
||||
assert result.id == image_id
|
||||
|
||||
def test_get_image_not_found(self, temp_gallery) -> None:
|
||||
"""Test getting non-existent image."""
|
||||
result = temp_gallery.get_image("nonexistent")
|
||||
assert result is None
|
||||
|
||||
def test_delete_image(self, gallery_with_images) -> None:
|
||||
"""Test deleting an image."""
|
||||
images = gallery_with_images.list_images()
|
||||
image_id = images[0].id
|
||||
image_path = images[0].path
|
||||
|
||||
result = gallery_with_images.delete_image(image_id)
|
||||
assert result is True
|
||||
assert not image_path.exists()
|
||||
|
||||
def test_delete_image_not_found(self, temp_gallery) -> None:
|
||||
"""Test deleting non-existent image."""
|
||||
result = temp_gallery.delete_image("nonexistent")
|
||||
assert result is False
|
||||
|
||||
def test_count(self, gallery_with_images) -> None:
|
||||
"""Test counting images."""
|
||||
assert gallery_with_images.count() == 3
|
||||
|
||||
def test_update_metadata(self, gallery_with_images) -> None:
|
||||
"""Test updating metadata."""
|
||||
images = gallery_with_images.list_images()
|
||||
image_id = images[0].id
|
||||
|
||||
result = gallery_with_images.update_metadata(image_id, {"custom_field": "value"})
|
||||
assert result is not None
|
||||
assert result["custom_field"] == "value"
|
||||
|
||||
def test_update_metadata_not_found(self, temp_gallery) -> None:
|
||||
"""Test updating metadata for non-existent image."""
|
||||
result = temp_gallery.update_metadata("nonexistent", {"field": "value"})
|
||||
assert result is None
|
||||
|
||||
Reference in New Issue
Block a user