"""Tests for tensors.server package (gallery and CivitAI management).""" from __future__ import annotations import pytest from fastapi.testclient import TestClient from tensors.server import create_app @pytest.fixture() def api() -> TestClient: """Create test client.""" return TestClient(create_app()) class TestStatus: def test_status_ok(self, api: TestClient) -> None: """Test status endpoint returns ok.""" r = api.get("/status") assert r.status_code == 200 data = r.json() assert data["status"] == "ok" # ============================================================================= # Gallery Endpoint Tests # ============================================================================= @pytest.fixture def temp_gallery(tmp_path): """Create a temporary gallery for testing.""" from tensors.server.gallery import Gallery # noqa: PLC0415 gallery_dir = tmp_path / "gallery" gallery_dir.mkdir() return Gallery(gallery_dir=gallery_dir) @pytest.fixture def gallery_with_images(temp_gallery): """Gallery with some test images.""" # Create test images for i in range(3): # Create a minimal PNG (1x1 pixel) image_data = ( b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01" b"\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00" b"\x00\x00\x0cIDATx\x9cc\xf8\x0f\x00\x00\x01\x01\x00" b"\x05\x18\xd8N\x00\x00\x00\x00IEND\xaeB`\x82" ) metadata = {"prompt": f"test prompt {i}", "seed": i, "width": 512, "height": 512} temp_gallery.save_image(image_data, metadata=metadata, seed=i) return temp_gallery @pytest.fixture def gallery_api(temp_gallery) -> TestClient: """Test client for gallery API with temp gallery.""" from fastapi import FastAPI # noqa: PLC0415 # Override the gallery singleton from tensors.server import gallery_routes # noqa: PLC0415 from tensors.server.gallery_routes import create_gallery_router # noqa: PLC0415 gallery_routes._gallery = temp_gallery app = FastAPI() app.include_router(create_gallery_router()) return TestClient(app) class TestGalleryList: """Tests for gallery list endpoint.""" def test_list_images_empty(self, gallery_api: TestClient) -> None: """Test listing empty gallery.""" response = gallery_api.get("/api/images") assert response.status_code == 200 data = response.json() assert data["images"] == [] assert data["total"] == 0 def test_list_images_with_data(self, gallery_api: TestClient, gallery_with_images) -> None: """Test listing gallery with images.""" from tensors.server import gallery_routes # noqa: PLC0415 gallery_routes._gallery = gallery_with_images response = gallery_api.get("/api/images") assert response.status_code == 200 data = response.json() assert len(data["images"]) == 3 assert data["total"] == 3 def test_list_images_pagination(self, gallery_api: TestClient, gallery_with_images) -> None: """Test pagination parameters.""" from tensors.server import gallery_routes # noqa: PLC0415 gallery_routes._gallery = gallery_with_images response = gallery_api.get("/api/images?limit=2&offset=1") assert response.status_code == 200 data = response.json() assert len(data["images"]) == 2 class TestGalleryGetImage: """Tests for getting individual images.""" def test_get_image_not_found(self, gallery_api: TestClient) -> None: """Test getting non-existent image returns 404.""" response = gallery_api.get("/api/images/nonexistent") assert response.status_code == 404 def test_get_image_success(self, gallery_api: TestClient, gallery_with_images) -> None: """Test getting an image file.""" from tensors.server import gallery_routes # noqa: PLC0415 gallery_routes._gallery = gallery_with_images list_response = gallery_api.get("/api/images") images = list_response.json()["images"] image_id = images[0]["id"] response = gallery_api.get(f"/api/images/{image_id}") assert response.status_code == 200 assert response.headers["content-type"] == "image/png" class TestGalleryMetadata: """Tests for image metadata endpoints.""" def test_get_metadata_not_found(self, gallery_api: TestClient) -> None: """Test getting metadata for non-existent image.""" response = gallery_api.get("/api/images/nonexistent/meta") assert response.status_code == 404 def test_get_metadata_success(self, gallery_api: TestClient, gallery_with_images) -> None: """Test getting image metadata.""" from tensors.server import gallery_routes # noqa: PLC0415 gallery_routes._gallery = gallery_with_images list_response = gallery_api.get("/api/images") images = list_response.json()["images"] image_id = images[0]["id"] response = gallery_api.get(f"/api/images/{image_id}/meta") assert response.status_code == 200 data = response.json() assert data["id"] == image_id assert "metadata" in data def test_edit_metadata(self, gallery_api: TestClient, gallery_with_images) -> None: """Test updating image metadata.""" from tensors.server import gallery_routes # noqa: PLC0415 gallery_routes._gallery = gallery_with_images list_response = gallery_api.get("/api/images") images = list_response.json()["images"] image_id = images[0]["id"] response = gallery_api.post( f"/api/images/{image_id}/edit", json={"tags": ["test", "favorite"], "rating": 5}, ) assert response.status_code == 200 data = response.json() assert data["metadata"]["tags"] == ["test", "favorite"] assert data["metadata"]["rating"] == 5 def test_edit_metadata_not_found(self, gallery_api: TestClient) -> None: """Test editing non-existent image metadata.""" response = gallery_api.post( "/api/images/nonexistent/edit", json={"tags": ["test"]}, ) assert response.status_code == 404 class TestGalleryDelete: """Tests for deleting images.""" def test_delete_image_not_found(self, gallery_api: TestClient) -> None: """Test deleting non-existent image.""" response = gallery_api.delete("/api/images/nonexistent") assert response.status_code == 404 def test_delete_image_success(self, gallery_api: TestClient, gallery_with_images) -> None: """Test deleting an image.""" from tensors.server import gallery_routes # noqa: PLC0415 gallery_routes._gallery = gallery_with_images list_response = gallery_api.get("/api/images") initial_count = list_response.json()["total"] image_id = list_response.json()["images"][0]["id"] response = gallery_api.delete(f"/api/images/{image_id}") assert response.status_code == 200 assert response.json()["deleted"] is True list_response = gallery_api.get("/api/images") assert list_response.json()["total"] == initial_count - 1 class TestGalleryStats: """Tests for gallery statistics endpoint.""" def test_stats_empty(self, gallery_api: TestClient) -> None: """Test stats on empty gallery.""" response = gallery_api.get("/api/images/stats/summary") assert response.status_code == 200 data = response.json() assert data["total_images"] == 0 # ============================================================================= # Database Endpoint Tests # ============================================================================= @pytest.fixture def temp_db(tmp_path): """Create a temporary database for testing.""" from tensors.db import Database # noqa: PLC0415 db_path = tmp_path / "test_models.db" db = Database(db_path=db_path) db.init_schema() return db @pytest.fixture def db_api(temp_db, monkeypatch) -> TestClient: """Test client for db API with temp database.""" from fastapi import FastAPI # noqa: PLC0415 # Monkeypatch Database to use temp_db path from tensors import db as db_module # noqa: PLC0415 from tensors.server.db_routes import create_db_router # noqa: PLC0415 monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path) app = FastAPI() app.include_router(create_db_router()) return TestClient(app) class TestDbEndpoints: """Tests for database API endpoints.""" def test_list_files_empty(self, db_api: TestClient) -> None: """Test listing files from empty database.""" response = db_api.get("/api/db/files") assert response.status_code == 200 assert response.json() == [] def test_search_models_empty(self, db_api: TestClient) -> None: """Test searching models in empty database.""" response = db_api.get("/api/db/models") assert response.status_code == 200 assert response.json() == [] def test_search_models_with_query(self, db_api: TestClient, temp_db, monkeypatch) -> None: """Test searching models with query parameters.""" from tensors import db as db_module # noqa: PLC0415 monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path) model_data = { "id": 12345, "name": "Test Model", "type": "LORA", "tags": [], "modelVersions": [], } temp_db.cache_model(model_data) response = db_api.get("/api/db/models?query=Test") assert response.status_code == 200 results = response.json() assert len(results) >= 1 def test_get_model_not_found(self, db_api: TestClient) -> None: """Test getting non-existent model.""" response = db_api.get("/api/db/models/999999") assert response.status_code == 404 def test_get_model_success(self, db_api: TestClient, temp_db, monkeypatch) -> None: """Test getting cached model.""" from tensors import db as db_module # noqa: PLC0415 monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path) model_data = { "id": 12345, "name": "Test Model", "type": "Checkpoint", "tags": [], "modelVersions": [], } temp_db.cache_model(model_data) response = db_api.get("/api/db/models/12345") assert response.status_code == 200 data = response.json() assert data["name"] == "Test Model" def test_get_stats(self, db_api: TestClient) -> None: """Test getting database stats.""" response = db_api.get("/api/db/stats") assert response.status_code == 200 data = response.json() assert "local_files" in data assert "models" in data # ============================================================================= # Gallery Class Unit Tests # ============================================================================= class TestGalleryClass: """Unit tests for the Gallery class.""" def test_save_image(self, temp_gallery) -> None: """Test saving an image to gallery.""" image_data = b"\x89PNG test data" metadata = {"prompt": "test", "seed": 42} result = temp_gallery.save_image(image_data, metadata=metadata, seed=42) assert result.id is not None assert result.path.exists() assert result.meta_path.exists() def test_list_images_empty(self, temp_gallery) -> None: """Test listing empty gallery.""" images = temp_gallery.list_images() assert images == [] def test_list_images_sorted(self, gallery_with_images) -> None: """Test images are sorted by creation time.""" images = gallery_with_images.list_images(newest_first=True) assert len(images) == 3 times = [img.created_at for img in images] assert times == sorted(times, reverse=True) def test_get_image(self, gallery_with_images) -> None: """Test getting image by ID.""" images = gallery_with_images.list_images() image_id = images[0].id result = gallery_with_images.get_image(image_id) assert result is not None assert result.id == image_id def test_get_image_not_found(self, temp_gallery) -> None: """Test getting non-existent image.""" result = temp_gallery.get_image("nonexistent") assert result is None def test_delete_image(self, gallery_with_images) -> None: """Test deleting an image.""" images = gallery_with_images.list_images() image_id = images[0].id image_path = images[0].path result = gallery_with_images.delete_image(image_id) assert result is True assert not image_path.exists() def test_delete_image_not_found(self, temp_gallery) -> None: """Test deleting non-existent image.""" result = temp_gallery.delete_image("nonexistent") assert result is False def test_count(self, gallery_with_images) -> None: """Test counting images.""" assert gallery_with_images.count() == 3 def test_update_metadata(self, gallery_with_images) -> None: """Test updating metadata.""" images = gallery_with_images.list_images() image_id = images[0].id result = gallery_with_images.update_metadata(image_id, {"custom_field": "value"}) assert result is not None assert result["custom_field"] == "value" def test_update_metadata_not_found(self, temp_gallery) -> None: """Test updating metadata for non-existent image.""" result = temp_gallery.update_metadata("nonexistent", {"field": "value"}) assert result is None