From e257a029da61b3f10c77a4f3fb6dc400e8a79875 Mon Sep 17 00:00:00 2001 From: Adam Ladachowski Date: Sat, 14 Feb 2026 01:54:00 +0100 Subject: [PATCH] Phase 6: Tests for database, server, and client modules - Add tests/test_db.py with 33 tests for Database class: - Schema initialization and migrations - Local file CRUD operations (scan, list, link) - CivitAI model caching (cache_model, tags, versions, files) - Query operations (search, get_model, get_triggers) - Statistics and context manager - Extend tests/test_server.py with 27 tests for API endpoints: - Gallery endpoints (list, get, meta, edit, delete, stats) - Database endpoints (files, models, stats) - Gallery class unit tests - Add tests/test_client.py with 33 tests for TsrClient: - Server status operations - Gallery image operations (list, get, delete, edit, download) - Model management (list, active, switch, loras) - Image generation - CivitAI download operations - Database query operations - Error handling and context manager Total: 191 tests passing with 61% coverage Co-Authored-By: Claude Opus 4.5 --- TODO.md | 6 +- tests/test_client.py | 481 +++++++++++++++++++++++++++++++++++++++++++ tests/test_db.py | 450 ++++++++++++++++++++++++++++++++++++++++ tests/test_server.py | 377 +++++++++++++++++++++++++++++++++ 4 files changed, 1311 insertions(+), 3 deletions(-) create mode 100644 tests/test_client.py create mode 100644 tests/test_db.py diff --git a/TODO.md b/TODO.md index bc2926a..dd75db9 100644 --- a/TODO.md +++ b/TODO.md @@ -27,9 +27,9 @@ - [x] Step 5.3: ~~Create `rocm-docker/tsr-server.service`~~ (skipped) ## Phase 6: Tests -- [ ] Step 6.1: `tests/test_db.py` (database module tests) -- [ ] Step 6.2: `tests/test_server.py` (API endpoint tests) -- [ ] Step 6.3: `tests/test_client.py` (client module tests) +- [x] Step 6.1: `tests/test_db.py` (database module tests) +- [x] Step 6.2: `tests/test_server.py` (API endpoint tests) +- [x] Step 6.3: `tests/test_client.py` (client module tests) --- diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..ba2ff49 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,481 @@ +"""Tests for the TsrClient HTTP client module.""" + +from __future__ import annotations + +import pytest +import respx +from httpx import Response + +from tensors.client import TsrClient, TsrClientError + +BASE_URL = "http://test-server:8080" + + +@pytest.fixture +def mock_server(): + """Activate respx mock for the test server.""" + with respx.mock(base_url=BASE_URL, assert_all_called=False) as rsps: + yield rsps + + +@pytest.fixture +def client(mock_server) -> TsrClient: # noqa: ARG001 - mock_server activates respx + """TsrClient connected to mock server.""" + return TsrClient(BASE_URL) + + +# ============================================================================= +# Status Tests +# ============================================================================= + + +class TestStatus: + """Tests for server status endpoint.""" + + def test_status_success(self, client: TsrClient, mock_server) -> None: + """Test getting server status.""" + mock_server.get("/status").mock(return_value=Response(200, json={"running": True, "pid": 12345, "model": "/test.gguf"})) + + with client: + result = client.status() + + assert result["running"] is True + assert result["pid"] == 12345 + + def test_status_error(self, client: TsrClient, mock_server) -> None: + """Test handling status error.""" + mock_server.get("/status").mock(return_value=Response(503, text="Service unavailable")) + + with client, pytest.raises(TsrClientError, match="HTTP 503"): + client.status() + + +# ============================================================================= +# Gallery Tests +# ============================================================================= + + +class TestGalleryImages: + """Tests for gallery image operations.""" + + def test_list_images(self, client: TsrClient, mock_server) -> None: + """Test listing gallery images.""" + mock_server.get("/api/images").mock( + return_value=Response( + 200, + json={ + "images": [ + {"id": "123_0", "filename": "123_0.png", "width": 512, "height": 512}, + {"id": "124_1", "filename": "124_1.png", "width": 1024, "height": 1024}, + ], + "total": 2, + }, + ) + ) + + with client: + result = client.list_images() + + assert len(result["images"]) == 2 + assert result["total"] == 2 + + def test_list_images_with_pagination(self, client: TsrClient, mock_server) -> None: + """Test listing images with pagination.""" + mock_server.get("/api/images", params={"limit": 10, "offset": 5}).mock( + return_value=Response(200, json={"images": [], "total": 100}) + ) + + with client: + result = client.list_images(limit=10, offset=5) + + assert result["total"] == 100 + + def test_get_image_meta(self, client: TsrClient, mock_server) -> None: + """Test getting image metadata.""" + mock_server.get("/api/images/123_0/meta").mock( + return_value=Response( + 200, + json={ + "id": "123_0", + "path": "/gallery/123_0.png", + "metadata": {"prompt": "test prompt", "seed": 42}, + }, + ) + ) + + with client: + result = client.get_image_meta("123_0") + + assert result["id"] == "123_0" + assert result["metadata"]["prompt"] == "test prompt" + + def test_delete_image(self, client: TsrClient, mock_server) -> None: + """Test deleting an image.""" + mock_server.delete("/api/images/123_0").mock(return_value=Response(200, json={"deleted": True, "id": "123_0"})) + + with client: + result = client.delete_image("123_0") + + assert result["deleted"] is True + + def test_edit_image(self, client: TsrClient, mock_server) -> None: + """Test editing image metadata.""" + mock_server.post("/api/images/123_0/edit").mock( + return_value=Response(200, json={"id": "123_0", "metadata": {"tags": ["favorite"], "rating": 5}}) + ) + + with client: + result = client.edit_image("123_0", {"tags": ["favorite"], "rating": 5}) + + assert result["metadata"]["tags"] == ["favorite"] + + def test_download_image(self, client: TsrClient, mock_server) -> None: + """Test downloading image bytes.""" + image_bytes = b"\x89PNG test image data" + mock_server.get("/api/images/123_0").mock(return_value=Response(200, content=image_bytes)) + + with client: + result = client.download_image("123_0") + + assert result == image_bytes + + +# ============================================================================= +# Models Tests +# ============================================================================= + + +class TestModels: + """Tests for model management operations.""" + + def test_list_models(self, client: TsrClient, mock_server) -> None: + """Test listing available models.""" + mock_server.get("/api/models").mock( + return_value=Response( + 200, + json={ + "models": [ + {"name": "sdxl_base", "path": "/models/sdxl_base.safetensors"}, + {"name": "pony_v6", "path": "/models/pony_v6.safetensors"}, + ], + "active": "/models/sdxl_base.safetensors", + }, + ) + ) + + with client: + result = client.list_models() + + assert len(result["models"]) == 2 + assert result["active"] == "/models/sdxl_base.safetensors" + + def test_get_active_model(self, client: TsrClient, mock_server) -> None: + """Test getting active model.""" + mock_server.get("/api/models/active").mock(return_value=Response(200, json={"model": "/models/sdxl_base.safetensors"})) + + with client: + result = client.get_active_model() + + assert result["model"] == "/models/sdxl_base.safetensors" + + def test_switch_model(self, client: TsrClient, mock_server) -> None: + """Test switching model.""" + mock_server.post("/api/models/switch").mock( + return_value=Response(200, json={"status": "ok", "model": "/models/pony_v6.safetensors"}) + ) + + with client: + result = client.switch_model("/models/pony_v6.safetensors") + + assert result["status"] == "ok" + + def test_list_loras(self, client: TsrClient, mock_server) -> None: + """Test listing LoRAs.""" + mock_server.get("/api/models/loras").mock( + return_value=Response( + 200, + json={ + "loras": [ + {"name": "detail_tweaker", "path": "/loras/detail_tweaker.safetensors"}, + ] + }, + ) + ) + + with client: + result = client.list_loras() + + assert len(result["loras"]) == 1 + + def test_scan_models(self, client: TsrClient, mock_server) -> None: + """Test scanning models.""" + mock_server.get("/api/models/scan").mock(return_value=Response(200, json={"scanned": 5})) + + with client: + result = client.scan_models() + + assert result["scanned"] == 5 + + +# ============================================================================= +# Generation Tests +# ============================================================================= + + +class TestGeneration: + """Tests for image generation.""" + + def test_generate(self, client: TsrClient, mock_server) -> None: + """Test generating an image.""" + mock_server.post("/api/generate").mock( + return_value=Response( + 200, + json={ + "images": [{"id": "999_42", "seed": 42}], + "parameters": {"prompt": "test prompt", "seed": 42}, + }, + ) + ) + + with client: + result = client.generate( + prompt="test prompt", + width=512, + height=512, + seed=42, + ) + + assert len(result["images"]) == 1 + assert result["images"][0]["seed"] == 42 + + def test_generate_with_all_params(self, client: TsrClient, mock_server) -> None: + """Test generation with all parameters.""" + mock_server.post("/api/generate").mock(return_value=Response(200, json={"images": []})) + + with client: + result = client.generate( + prompt="detailed test prompt", + negative_prompt="bad quality", + width=1024, + height=1024, + steps=30, + cfg_scale=5.5, + seed=12345, + sampler_name="DPM++ 2M", + scheduler="karras", + batch_size=2, + save_to_gallery=False, + return_base64=True, + ) + + assert "images" in result + + def test_list_samplers(self, client: TsrClient, mock_server) -> None: + """Test listing samplers.""" + mock_server.get("/api/samplers").mock(return_value=Response(200, json={"samplers": ["Euler", "DPM++ 2M", "Euler a"]})) + + with client: + result = client.list_samplers() + + assert "samplers" in result + + def test_list_schedulers(self, client: TsrClient, mock_server) -> None: + """Test listing schedulers.""" + mock_server.get("/api/schedulers").mock( + return_value=Response(200, json={"schedulers": ["simple", "karras", "sgm_uniform"]}) + ) + + with client: + result = client.list_schedulers() + + assert "schedulers" in result + + +# ============================================================================= +# Download Tests +# ============================================================================= + + +class TestDownload: + """Tests for CivitAI download operations.""" + + def test_start_download_by_version(self, client: TsrClient, mock_server) -> None: + """Test starting download by version ID.""" + mock_server.post("/api/download").mock( + return_value=Response(200, json={"download_id": "abc123", "status": "started", "version_id": 12345}) + ) + + with client: + result = client.start_download(version_id=12345) + + assert result["download_id"] == "abc123" + + def test_start_download_by_hash(self, client: TsrClient, mock_server) -> None: + """Test starting download by hash.""" + mock_server.post("/api/download").mock(return_value=Response(200, json={"download_id": "def456", "status": "started"})) + + with client: + result = client.start_download(hash_val="ABC123DEF456") + + assert result["status"] == "started" + + def test_get_download_status(self, client: TsrClient, mock_server) -> None: + """Test getting download status.""" + mock_server.get("/api/download/status/abc123").mock( + return_value=Response(200, json={"download_id": "abc123", "status": "downloading", "progress": 0.5}) + ) + + with client: + result = client.get_download_status("abc123") + + assert result["progress"] == 0.5 + + def test_list_downloads(self, client: TsrClient, mock_server) -> None: + """Test listing active downloads.""" + mock_server.get("/api/download/active").mock( + return_value=Response(200, json={"downloads": [{"id": "abc123", "progress": 0.75}]}) + ) + + with client: + result = client.list_downloads() + + assert len(result["downloads"]) == 1 + + +# ============================================================================= +# Database Tests +# ============================================================================= + + +class TestDatabase: + """Tests for database operations.""" + + def test_db_list_files(self, client: TsrClient, mock_server) -> None: + """Test listing local files.""" + mock_server.get("/api/db/files").mock( + return_value=Response(200, json=[{"id": 1, "file_path": "/models/test.safetensors", "sha256": "abc123"}]) + ) + + with client: + result = client.db_list_files() + + assert len(result) == 1 + assert result[0]["sha256"] == "abc123" + + def test_db_search_models(self, client: TsrClient, mock_server) -> None: + """Test searching cached models.""" + mock_server.get("/api/db/models").mock( + return_value=Response(200, json=[{"civitai_id": 12345, "name": "Test Model", "type": "LORA"}]) + ) + + with client: + result = client.db_search_models(query="Test", model_type="LORA") + + assert len(result) == 1 + assert result[0]["name"] == "Test Model" + + def test_db_get_model(self, client: TsrClient, mock_server) -> None: + """Test getting cached model.""" + mock_server.get("/api/db/models/12345").mock( + return_value=Response(200, json={"civitai_id": 12345, "name": "Test Model", "type": "Checkpoint"}) + ) + + with client: + result = client.db_get_model(12345) + + assert result["name"] == "Test Model" + + def test_db_get_triggers(self, client: TsrClient, mock_server) -> None: + """Test getting trigger words.""" + mock_server.get("/api/db/triggers/12345").mock(return_value=Response(200, json=["trigger1", "trigger2"])) + + with client: + result = client.db_get_triggers(version_id=12345) + + assert result == ["trigger1", "trigger2"] + + def test_db_stats(self, client: TsrClient, mock_server) -> None: + """Test getting database stats.""" + mock_server.get("/api/db/stats").mock( + return_value=Response(200, json={"local_files": 10, "models": 5, "model_versions": 15}) + ) + + with client: + result = client.db_stats() + + assert result["local_files"] == 10 + + def test_db_scan(self, client: TsrClient, mock_server) -> None: + """Test scanning directory.""" + mock_server.post("/api/db/scan").mock(return_value=Response(200, json={"scanned": 3, "files": []})) + + with client: + result = client.db_scan("/models") + + assert result["scanned"] == 3 + + def test_db_link(self, client: TsrClient, mock_server) -> None: + """Test linking files to CivitAI.""" + mock_server.post("/api/db/link").mock(return_value=Response(200, json={"linked": 2})) + + with client: + result = client.db_link() + + assert result["linked"] == 2 + + def test_db_cache(self, client: TsrClient, mock_server) -> None: + """Test caching model data.""" + mock_server.post("/api/db/cache").mock(return_value=Response(200, json={"model_id": 12345, "cached": True})) + + with client: + result = client.db_cache(12345) + + assert result["cached"] is True + + +# ============================================================================= +# Error Handling Tests +# ============================================================================= + + +class TestErrorHandling: + """Tests for error handling.""" + + def test_http_error(self, client: TsrClient, mock_server) -> None: + """Test HTTP error handling.""" + mock_server.get("/api/images").mock(return_value=Response(500, text="Internal server error")) + + with client, pytest.raises(TsrClientError, match="HTTP 500"): + client.list_images() + + def test_not_found_error(self, client: TsrClient, mock_server) -> None: + """Test 404 error handling.""" + mock_server.get("/api/images/nonexistent/meta").mock(return_value=Response(404, json={"detail": "Image not found"})) + + with client, pytest.raises(TsrClientError, match="HTTP 404"): + client.get_image_meta("nonexistent") + + +# ============================================================================= +# Context Manager Tests +# ============================================================================= + + +class TestContextManager: + """Tests for context manager usage.""" + + def test_context_manager(self, mock_server) -> None: + """Test client works as context manager.""" + mock_server.get("/status").mock(return_value=Response(200, json={"running": True})) + + with TsrClient(BASE_URL) as client: + result = client.status() + assert result["running"] is True + + def test_client_without_context(self, mock_server) -> None: + """Test client works without context manager.""" + mock_server.get("/status").mock(return_value=Response(200, json={"running": True})) + + client = TsrClient(BASE_URL) + result = client.status() + assert result["running"] is True diff --git a/tests/test_db.py b/tests/test_db.py new file mode 100644 index 0000000..02a0b34 --- /dev/null +++ b/tests/test_db.py @@ -0,0 +1,450 @@ +"""Tests for the database module.""" + +from __future__ import annotations + +import json +import struct +from pathlib import Path + +import pytest + +from tensors.db import Database + + +@pytest.fixture +def temp_db(tmp_path: Path) -> Database: + """Create a temporary database for testing.""" + db_path = tmp_path / "test_models.db" + db = Database(db_path=db_path) + db.init_schema() + return db + + +@pytest.fixture +def sample_safetensor(tmp_path: Path) -> Path: + """Create a sample safetensor file for testing.""" + header = { + "__metadata__": { + "format": "pt", + "test_key": "test_value", + } + } + header_bytes = json.dumps(header).encode("utf-8") + header_size = len(header_bytes) + + file_path = tmp_path / "models" / "test_lora.safetensors" + file_path.parent.mkdir(parents=True, exist_ok=True) + with file_path.open("wb") as f: + f.write(struct.pack(" dict: + """Sample CivitAI model API response.""" + return { + "id": 123456, + "name": "Test LoRA", + "description": "A test LoRA model", + "type": "LORA", + "nsfw": False, + "poi": False, + "minor": False, + "tags": ["test", "lora", "anime"], + "creator": { + "username": "test_creator", + "image": "https://example.com/avatar.png", + }, + "stats": { + "downloadCount": 1000, + "thumbsUpCount": 500, + "thumbsDownCount": 10, + "commentCount": 50, + "tippedAmountCount": 5, + }, + "modelVersions": [ + { + "id": 789012, + "name": "v1.0", + "description": "Initial release", + "baseModel": "SDXL 1.0", + "trainedWords": ["test_trigger", "lora_trigger"], + "files": [ + { + "id": 111222, + "name": "test_lora.safetensors", + "type": "Model", + "sizeKB": 150000, + "primary": True, + "hashes": { + "SHA256": "ABC123DEF456", + "BLAKE3": "789XYZ", + }, + "metadata": { + "format": "SafeTensor", + "size": "full", + "fp": "fp16", + }, + } + ], + "images": [ + { + "id": 333444, + "url": "https://example.com/image.png", + "type": "image", + "width": 1024, + "height": 1024, + "meta": { + "prompt": "test prompt", + "negativePrompt": "bad quality", + "cfgScale": 7.0, + }, + } + ], + "stats": { + "downloadCount": 1000, + "thumbsUpCount": 500, + }, + } + ], + } + + +class TestDatabaseSchema: + """Tests for database schema initialization.""" + + def test_init_schema(self, temp_db: Database) -> None: + """Test schema initialization creates tables.""" + cur = temp_db.conn.cursor() + cur.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = {row[0] for row in cur.fetchall()} + + expected = { + "local_files", + "safetensor_metadata", + "models", + "model_versions", + "version_files", + "file_hashes", + "trained_words", + "creators", + "tags", + "model_tags", + "version_images", + "image_generation_params", + "image_resources", + } + assert expected.issubset(tables) + + def test_init_schema_creates_views(self, temp_db: Database) -> None: + """Test schema creates required views.""" + cur = temp_db.conn.cursor() + cur.execute("SELECT name FROM sqlite_master WHERE type='view'") + views = {row[0] for row in cur.fetchall()} + + assert "v_local_files_full" in views + assert "v_models_with_latest" in views + + +class TestLocalFiles: + """Tests for local file operations.""" + + def test_scan_directory(self, temp_db: Database, sample_safetensor: Path) -> None: + """Test scanning directory for safetensor files.""" + results = temp_db.scan_directory(sample_safetensor.parent) + + assert len(results) == 1 + assert results[0]["file_path"] == str(sample_safetensor.resolve()) + assert "sha256" in results[0] + assert results[0]["sha256"] # Should have hash + + def test_scan_directory_empty(self, temp_db: Database, tmp_path: Path) -> None: + """Test scanning empty directory.""" + empty_dir = tmp_path / "empty" + empty_dir.mkdir() + + results = temp_db.scan_directory(empty_dir) + assert results == [] + + def test_list_local_files(self, temp_db: Database, sample_safetensor: Path) -> None: + """Test listing local files after scan.""" + temp_db.scan_directory(sample_safetensor.parent) + files = temp_db.list_local_files() + + assert len(files) == 1 + assert files[0]["file_path"] == str(sample_safetensor.resolve()) + + def test_get_local_file_by_path(self, temp_db: Database, sample_safetensor: Path) -> None: + """Test getting local file by path.""" + temp_db.scan_directory(sample_safetensor.parent) + + file_info = temp_db.get_local_file_by_path(str(sample_safetensor.resolve())) + assert file_info is not None + assert file_info["file_path"] == str(sample_safetensor.resolve()) + + def test_get_local_file_by_path_not_found(self, temp_db: Database) -> None: + """Test getting non-existent file.""" + result = temp_db.get_local_file_by_path("/nonexistent/file.safetensors") + assert result is None + + def test_get_unlinked_files(self, temp_db: Database, sample_safetensor: Path) -> None: + """Test getting unlinked files.""" + temp_db.scan_directory(sample_safetensor.parent) + unlinked = temp_db.get_unlinked_files() + + assert len(unlinked) == 1 + assert unlinked[0].get("civitai_model_id", True) + + def test_link_file_to_civitai(self, temp_db: Database, sample_safetensor: Path) -> None: + """Test linking a file to CivitAI.""" + results = temp_db.scan_directory(sample_safetensor.parent) + file_id = results[0]["id"] + + temp_db.link_file_to_civitai(file_id, model_id=123, version_id=456) + + # Should have no unlinked files now + unlinked = temp_db.get_unlinked_files() + assert len(unlinked) == 0 + + def test_upsert_local_file_updates_existing(self, temp_db: Database, sample_safetensor: Path) -> None: + """Test that scanning same file twice updates instead of inserting.""" + temp_db.scan_directory(sample_safetensor.parent) + temp_db.scan_directory(sample_safetensor.parent) + + files = temp_db.list_local_files() + assert len(files) == 1 + + +class TestCivitAICache: + """Tests for CivitAI model caching.""" + + def test_cache_model(self, temp_db: Database, sample_civitai_model: dict) -> None: + """Test caching a full CivitAI model.""" + model_id = temp_db.cache_model(sample_civitai_model) + assert model_id > 0 + + def test_cache_model_creates_creator(self, temp_db: Database, sample_civitai_model: dict) -> None: + """Test that caching model creates creator record.""" + temp_db.cache_model(sample_civitai_model) + + cur = temp_db.conn.cursor() + cur.execute("SELECT * FROM creators WHERE username = ?", ("test_creator",)) + creator = cur.fetchone() + + assert creator is not None + assert creator["username"] == "test_creator" + + def test_cache_model_creates_tags(self, temp_db: Database, sample_civitai_model: dict) -> None: + """Test that caching model creates tags.""" + temp_db.cache_model(sample_civitai_model) + + cur = temp_db.conn.cursor() + cur.execute("SELECT COUNT(*) FROM tags") + count = cur.fetchone()[0] + + assert count == 3 # test, lora, anime + + def test_cache_model_creates_versions(self, temp_db: Database, sample_civitai_model: dict) -> None: + """Test that caching model creates versions.""" + temp_db.cache_model(sample_civitai_model) + + cur = temp_db.conn.cursor() + cur.execute("SELECT * FROM model_versions WHERE civitai_id = ?", (789012,)) + version = cur.fetchone() + + assert version is not None + assert version["name"] == "v1.0" + assert version["base_model"] == "SDXL 1.0" + + def test_cache_model_creates_trained_words(self, temp_db: Database, sample_civitai_model: dict) -> None: + """Test that caching model creates trained words.""" + temp_db.cache_model(sample_civitai_model) + + cur = temp_db.conn.cursor() + cur.execute("SELECT word FROM trained_words ORDER BY position") + words = [row[0] for row in cur.fetchall()] + + assert words == ["test_trigger", "lora_trigger"] + + def test_cache_model_creates_files_and_hashes(self, temp_db: Database, sample_civitai_model: dict) -> None: + """Test that caching model creates files and hashes.""" + temp_db.cache_model(sample_civitai_model) + + cur = temp_db.conn.cursor() + cur.execute("SELECT * FROM version_files WHERE civitai_id = ?", (111222,)) + file_record = cur.fetchone() + + assert file_record is not None + assert file_record["name"] == "test_lora.safetensors" + assert file_record["is_primary"] == 1 + + cur.execute("SELECT hash_type, hash_value FROM file_hashes WHERE file_id = ?", (file_record["id"],)) + hashes = {row[0]: row[1] for row in cur.fetchall()} + + assert hashes["SHA256"] == "ABC123DEF456" + assert hashes["BLAKE3"] == "789XYZ" + + def test_cache_model_idempotent(self, temp_db: Database, sample_civitai_model: dict) -> None: + """Test that caching same model twice is idempotent.""" + id1 = temp_db.cache_model(sample_civitai_model) + id2 = temp_db.cache_model(sample_civitai_model) + + assert id1 == id2 + + cur = temp_db.conn.cursor() + cur.execute("SELECT COUNT(*) FROM models") + assert cur.fetchone()[0] == 1 + + +class TestQueryOperations: + """Tests for search and query operations.""" + + def test_search_models_by_name(self, temp_db: Database, sample_civitai_model: dict) -> None: + """Test searching models by name.""" + temp_db.cache_model(sample_civitai_model) + results = temp_db.search_models(query="Test") + + assert len(results) == 1 + assert results[0]["name"] == "Test LoRA" + + def test_search_models_by_type(self, temp_db: Database, sample_civitai_model: dict) -> None: + """Test searching models by type.""" + temp_db.cache_model(sample_civitai_model) + results = temp_db.search_models(model_type="LORA") + + assert len(results) == 1 + + def test_search_models_by_base_model(self, temp_db: Database, sample_civitai_model: dict) -> None: + """Test searching models by base model.""" + temp_db.cache_model(sample_civitai_model) + results = temp_db.search_models(base_model="SDXL") + + assert len(results) == 1 + + def test_search_models_no_results(self, temp_db: Database, sample_civitai_model: dict) -> None: + """Test search with no matching results.""" + temp_db.cache_model(sample_civitai_model) + results = temp_db.search_models(query="nonexistent") + + assert len(results) == 0 + + def test_search_models_limit(self, temp_db: Database, sample_civitai_model: dict) -> None: + """Test search respects limit.""" + # Cache multiple models + for i in range(5): + model = sample_civitai_model.copy() + model["id"] = 100000 + i + model["name"] = f"Model {i}" + temp_db.cache_model(model) + + results = temp_db.search_models(limit=3) + assert len(results) == 3 + + def test_get_model(self, temp_db: Database, sample_civitai_model: dict) -> None: + """Test getting model by CivitAI ID.""" + temp_db.cache_model(sample_civitai_model) + model = temp_db.get_model(123456) + + assert model is not None + assert model["name"] == "Test LoRA" + assert model["type"] == "LORA" + + def test_get_model_not_found(self, temp_db: Database) -> None: + """Test getting non-existent model.""" + result = temp_db.get_model(999999) + assert result is None + + def test_get_version_by_hash(self, temp_db: Database, sample_civitai_model: dict) -> None: + """Test finding version by file hash.""" + temp_db.cache_model(sample_civitai_model) + version = temp_db.get_version_by_hash("ABC123DEF456") + + assert version is not None + assert version["model_name"] == "Test LoRA" + assert version["version_name"] == "v1.0" + + def test_get_version_by_hash_case_insensitive(self, temp_db: Database, sample_civitai_model: dict) -> None: + """Test hash lookup is case insensitive.""" + temp_db.cache_model(sample_civitai_model) + version = temp_db.get_version_by_hash("abc123def456") + + assert version is not None + + +class TestTriggerWords: + """Tests for trigger word operations.""" + + def test_get_triggers_by_version(self, temp_db: Database, sample_civitai_model: dict) -> None: + """Test getting triggers by version ID.""" + temp_db.cache_model(sample_civitai_model) + triggers = temp_db.get_triggers_by_version(789012) + + assert triggers == ["test_trigger", "lora_trigger"] + + def test_get_triggers_by_file_path(self, temp_db: Database, sample_civitai_model: dict, sample_safetensor: Path) -> None: + """Test getting triggers by linked file path.""" + temp_db.cache_model(sample_civitai_model) + results = temp_db.scan_directory(sample_safetensor.parent) + temp_db.link_file_to_civitai(results[0]["id"], model_id=123456, version_id=789012) + + triggers = temp_db.get_triggers(str(sample_safetensor.resolve())) + assert triggers == ["test_trigger", "lora_trigger"] + + def test_get_triggers_empty(self, temp_db: Database) -> None: + """Test getting triggers for unlinked file.""" + triggers = temp_db.get_triggers("/nonexistent/file.safetensors") + assert triggers == [] + + +class TestStatistics: + """Tests for database statistics.""" + + def test_get_stats_empty(self, temp_db: Database) -> None: + """Test stats on empty database.""" + stats = temp_db.get_stats() + + assert stats["local_files"] == 0 + assert stats["models"] == 0 + assert stats["model_versions"] == 0 + + def test_get_stats_with_data(self, temp_db: Database, sample_civitai_model: dict, sample_safetensor: Path) -> None: + """Test stats with data.""" + temp_db.cache_model(sample_civitai_model) + temp_db.scan_directory(sample_safetensor.parent) + + stats = temp_db.get_stats() + + assert stats["local_files"] == 1 + assert stats["models"] == 1 + assert stats["model_versions"] == 1 + assert stats["trained_words"] == 2 + assert stats["creators"] == 1 + assert stats["tags"] == 3 + + +class TestContextManager: + """Tests for database context manager.""" + + def test_context_manager(self, tmp_path: Path) -> None: + """Test database works as context manager.""" + db_path = tmp_path / "test.db" + + with Database(db_path=db_path) as db: + db.init_schema() + stats = db.get_stats() + assert stats["local_files"] == 0 + + # Connection should be closed + assert db._conn is None + + def test_connection_reuse(self, tmp_path: Path) -> None: + """Test that connection is reused within context.""" + db_path = tmp_path / "test.db" + + with Database(db_path=db_path) as db: + db.init_schema() + conn1 = db.conn + conn2 = db.conn + assert conn1 is conn2 diff --git a/tests/test_server.py b/tests/test_server.py index f9f8db4..17d98cb 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -167,3 +167,380 @@ class TestProcessManager: cfg = ServerConfig(model="/m.gguf") assert cfg.port == 1234 assert cfg.args == [] + + +# ============================================================================= +# Gallery Endpoint Tests +# ============================================================================= + + +@pytest.fixture +def temp_gallery(tmp_path): + """Create a temporary gallery for testing.""" + from tensors.server.gallery import Gallery # noqa: PLC0415 + + gallery_dir = tmp_path / "gallery" + gallery_dir.mkdir() + return Gallery(gallery_dir=gallery_dir) + + +@pytest.fixture +def gallery_with_images(temp_gallery): + """Gallery with some test images.""" + # Create test images + for i in range(3): + # Create a minimal PNG (1x1 pixel) + image_data = ( + b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01" + b"\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00" + b"\x00\x00\x0cIDATx\x9cc\xf8\x0f\x00\x00\x01\x01\x00" + b"\x05\x18\xd8N\x00\x00\x00\x00IEND\xaeB`\x82" + ) + metadata = {"prompt": f"test prompt {i}", "seed": i, "width": 512, "height": 512} + temp_gallery.save_image(image_data, metadata=metadata, seed=i) + + return temp_gallery + + +@pytest.fixture +def gallery_api(temp_gallery) -> TestClient: + """Test client for gallery API with temp gallery.""" + from fastapi import FastAPI # noqa: PLC0415 + + # Override the gallery singleton + from tensors.server import gallery_routes # noqa: PLC0415 + from tensors.server.gallery_routes import create_gallery_router # noqa: PLC0415 + + gallery_routes._gallery = temp_gallery + + app = FastAPI() + app.include_router(create_gallery_router()) + return TestClient(app) + + +class TestGalleryList: + """Tests for gallery list endpoint.""" + + def test_list_images_empty(self, gallery_api: TestClient) -> None: + """Test listing empty gallery.""" + response = gallery_api.get("/api/images") + assert response.status_code == 200 + data = response.json() + assert data["images"] == [] + assert data["total"] == 0 + + def test_list_images_with_data(self, gallery_api: TestClient, gallery_with_images) -> None: + """Test listing gallery with images.""" + from tensors.server import gallery_routes # noqa: PLC0415 + + gallery_routes._gallery = gallery_with_images + + response = gallery_api.get("/api/images") + assert response.status_code == 200 + data = response.json() + assert len(data["images"]) == 3 + assert data["total"] == 3 + + def test_list_images_pagination(self, gallery_api: TestClient, gallery_with_images) -> None: + """Test pagination parameters.""" + from tensors.server import gallery_routes # noqa: PLC0415 + + gallery_routes._gallery = gallery_with_images + + response = gallery_api.get("/api/images?limit=2&offset=1") + assert response.status_code == 200 + data = response.json() + assert len(data["images"]) == 2 + + +class TestGalleryGetImage: + """Tests for getting individual images.""" + + def test_get_image_not_found(self, gallery_api: TestClient) -> None: + """Test getting non-existent image returns 404.""" + response = gallery_api.get("/api/images/nonexistent") + assert response.status_code == 404 + + def test_get_image_success(self, gallery_api: TestClient, gallery_with_images) -> None: + """Test getting an image file.""" + from tensors.server import gallery_routes # noqa: PLC0415 + + gallery_routes._gallery = gallery_with_images + + list_response = gallery_api.get("/api/images") + images = list_response.json()["images"] + image_id = images[0]["id"] + + response = gallery_api.get(f"/api/images/{image_id}") + assert response.status_code == 200 + assert response.headers["content-type"] == "image/png" + + +class TestGalleryMetadata: + """Tests for image metadata endpoints.""" + + def test_get_metadata_not_found(self, gallery_api: TestClient) -> None: + """Test getting metadata for non-existent image.""" + response = gallery_api.get("/api/images/nonexistent/meta") + assert response.status_code == 404 + + def test_get_metadata_success(self, gallery_api: TestClient, gallery_with_images) -> None: + """Test getting image metadata.""" + from tensors.server import gallery_routes # noqa: PLC0415 + + gallery_routes._gallery = gallery_with_images + + list_response = gallery_api.get("/api/images") + images = list_response.json()["images"] + image_id = images[0]["id"] + + response = gallery_api.get(f"/api/images/{image_id}/meta") + assert response.status_code == 200 + data = response.json() + assert data["id"] == image_id + assert "metadata" in data + + def test_edit_metadata(self, gallery_api: TestClient, gallery_with_images) -> None: + """Test updating image metadata.""" + from tensors.server import gallery_routes # noqa: PLC0415 + + gallery_routes._gallery = gallery_with_images + + list_response = gallery_api.get("/api/images") + images = list_response.json()["images"] + image_id = images[0]["id"] + + response = gallery_api.post( + f"/api/images/{image_id}/edit", + json={"tags": ["test", "favorite"], "rating": 5}, + ) + assert response.status_code == 200 + data = response.json() + assert data["metadata"]["tags"] == ["test", "favorite"] + assert data["metadata"]["rating"] == 5 + + def test_edit_metadata_not_found(self, gallery_api: TestClient) -> None: + """Test editing non-existent image metadata.""" + response = gallery_api.post( + "/api/images/nonexistent/edit", + json={"tags": ["test"]}, + ) + assert response.status_code == 404 + + +class TestGalleryDelete: + """Tests for deleting images.""" + + def test_delete_image_not_found(self, gallery_api: TestClient) -> None: + """Test deleting non-existent image.""" + response = gallery_api.delete("/api/images/nonexistent") + assert response.status_code == 404 + + def test_delete_image_success(self, gallery_api: TestClient, gallery_with_images) -> None: + """Test deleting an image.""" + from tensors.server import gallery_routes # noqa: PLC0415 + + gallery_routes._gallery = gallery_with_images + + list_response = gallery_api.get("/api/images") + initial_count = list_response.json()["total"] + image_id = list_response.json()["images"][0]["id"] + + response = gallery_api.delete(f"/api/images/{image_id}") + assert response.status_code == 200 + assert response.json()["deleted"] is True + + list_response = gallery_api.get("/api/images") + assert list_response.json()["total"] == initial_count - 1 + + +class TestGalleryStats: + """Tests for gallery statistics endpoint.""" + + def test_stats_empty(self, gallery_api: TestClient) -> None: + """Test stats on empty gallery.""" + response = gallery_api.get("/api/images/stats/summary") + assert response.status_code == 200 + data = response.json() + assert data["total_images"] == 0 + + +# ============================================================================= +# Database Endpoint Tests +# ============================================================================= + + +@pytest.fixture +def temp_db(tmp_path): + """Create a temporary database for testing.""" + from tensors.db import Database # noqa: PLC0415 + + db_path = tmp_path / "test_models.db" + db = Database(db_path=db_path) + db.init_schema() + return db + + +@pytest.fixture +def db_api(temp_db, monkeypatch) -> TestClient: + """Test client for db API with temp database.""" + from fastapi import FastAPI # noqa: PLC0415 + + # Monkeypatch Database to use temp_db path + from tensors import db as db_module # noqa: PLC0415 + from tensors.server.db_routes import create_db_router # noqa: PLC0415 + + monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path) + + app = FastAPI() + app.include_router(create_db_router()) + return TestClient(app) + + +class TestDbEndpoints: + """Tests for database API endpoints.""" + + def test_list_files_empty(self, db_api: TestClient) -> None: + """Test listing files from empty database.""" + response = db_api.get("/api/db/files") + assert response.status_code == 200 + assert response.json() == [] + + def test_search_models_empty(self, db_api: TestClient) -> None: + """Test searching models in empty database.""" + response = db_api.get("/api/db/models") + assert response.status_code == 200 + assert response.json() == [] + + def test_search_models_with_query(self, db_api: TestClient, temp_db, monkeypatch) -> None: + """Test searching models with query parameters.""" + from tensors import db as db_module # noqa: PLC0415 + + monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path) + + model_data = { + "id": 12345, + "name": "Test Model", + "type": "LORA", + "tags": [], + "modelVersions": [], + } + temp_db.cache_model(model_data) + + response = db_api.get("/api/db/models?query=Test") + assert response.status_code == 200 + results = response.json() + assert len(results) >= 1 + + def test_get_model_not_found(self, db_api: TestClient) -> None: + """Test getting non-existent model.""" + response = db_api.get("/api/db/models/999999") + assert response.status_code == 404 + + def test_get_model_success(self, db_api: TestClient, temp_db, monkeypatch) -> None: + """Test getting cached model.""" + from tensors import db as db_module # noqa: PLC0415 + + monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path) + + model_data = { + "id": 12345, + "name": "Test Model", + "type": "Checkpoint", + "tags": [], + "modelVersions": [], + } + temp_db.cache_model(model_data) + + response = db_api.get("/api/db/models/12345") + assert response.status_code == 200 + data = response.json() + assert data["name"] == "Test Model" + + def test_get_stats(self, db_api: TestClient) -> None: + """Test getting database stats.""" + response = db_api.get("/api/db/stats") + assert response.status_code == 200 + data = response.json() + assert "local_files" in data + assert "models" in data + + +# ============================================================================= +# Gallery Class Unit Tests +# ============================================================================= + + +class TestGalleryClass: + """Unit tests for the Gallery class.""" + + def test_save_image(self, temp_gallery) -> None: + """Test saving an image to gallery.""" + image_data = b"\x89PNG test data" + metadata = {"prompt": "test", "seed": 42} + + result = temp_gallery.save_image(image_data, metadata=metadata, seed=42) + + assert result.id is not None + assert result.path.exists() + assert result.meta_path.exists() + + def test_list_images_empty(self, temp_gallery) -> None: + """Test listing empty gallery.""" + images = temp_gallery.list_images() + assert images == [] + + def test_list_images_sorted(self, gallery_with_images) -> None: + """Test images are sorted by creation time.""" + images = gallery_with_images.list_images(newest_first=True) + assert len(images) == 3 + + times = [img.created_at for img in images] + assert times == sorted(times, reverse=True) + + def test_get_image(self, gallery_with_images) -> None: + """Test getting image by ID.""" + images = gallery_with_images.list_images() + image_id = images[0].id + + result = gallery_with_images.get_image(image_id) + assert result is not None + assert result.id == image_id + + def test_get_image_not_found(self, temp_gallery) -> None: + """Test getting non-existent image.""" + result = temp_gallery.get_image("nonexistent") + assert result is None + + def test_delete_image(self, gallery_with_images) -> None: + """Test deleting an image.""" + images = gallery_with_images.list_images() + image_id = images[0].id + image_path = images[0].path + + result = gallery_with_images.delete_image(image_id) + assert result is True + assert not image_path.exists() + + def test_delete_image_not_found(self, temp_gallery) -> None: + """Test deleting non-existent image.""" + result = temp_gallery.delete_image("nonexistent") + assert result is False + + def test_count(self, gallery_with_images) -> None: + """Test counting images.""" + assert gallery_with_images.count() == 3 + + def test_update_metadata(self, gallery_with_images) -> None: + """Test updating metadata.""" + images = gallery_with_images.list_images() + image_id = images[0].id + + result = gallery_with_images.update_metadata(image_id, {"custom_field": "value"}) + assert result is not None + assert result["custom_field"] == "value" + + def test_update_metadata_not_found(self, temp_gallery) -> None: + """Test updating metadata for non-existent image.""" + result = temp_gallery.update_metadata("nonexistent", {"field": "value"}) + assert result is None