"""Tests for tensors.server package (gallery and CivitAI management).""" from __future__ import annotations import pytest from fastapi.testclient import TestClient from tensors.server import create_app @pytest.fixture() def api() -> TestClient: """Create test client.""" return TestClient(create_app()) class TestStatus: def test_status_ok(self, api: TestClient) -> None: """Test status endpoint returns ok.""" r = api.get("/status") assert r.status_code == 200 data = r.json() assert data["status"] == "ok" # ============================================================================= # Gallery Endpoint Tests # ============================================================================= @pytest.fixture def temp_gallery(tmp_path): """Create a temporary gallery for testing.""" from tensors.server.gallery import Gallery # noqa: PLC0415 gallery_dir = tmp_path / "gallery" gallery_dir.mkdir() return Gallery(gallery_dir=gallery_dir) @pytest.fixture def gallery_with_images(temp_gallery): """Gallery with some test images.""" # Create test images for i in range(3): # Create a minimal PNG (1x1 pixel) image_data = ( b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01" b"\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00" b"\x00\x00\x0cIDATx\x9cc\xf8\x0f\x00\x00\x01\x01\x00" b"\x05\x18\xd8N\x00\x00\x00\x00IEND\xaeB`\x82" ) metadata = {"prompt": f"test prompt {i}", "seed": i, "width": 512, "height": 512} temp_gallery.save_image(image_data, metadata=metadata, seed=i) return temp_gallery @pytest.fixture def gallery_api(temp_gallery) -> TestClient: """Test client for gallery API with temp gallery.""" from fastapi import FastAPI # noqa: PLC0415 # Override the gallery singleton from tensors.server import gallery_routes # noqa: PLC0415 from tensors.server.gallery_routes import create_gallery_router # noqa: PLC0415 gallery_routes._gallery = temp_gallery app = FastAPI() app.include_router(create_gallery_router()) return TestClient(app) class TestGalleryList: """Tests for gallery list endpoint.""" def test_list_images_empty(self, gallery_api: TestClient) -> None: """Test listing empty gallery.""" response = gallery_api.get("/api/images") assert response.status_code == 200 data = response.json() assert data["images"] == [] assert data["total"] == 0 def test_list_images_with_data(self, gallery_api: TestClient, gallery_with_images) -> None: """Test listing gallery with images.""" from tensors.server import gallery_routes # noqa: PLC0415 gallery_routes._gallery = gallery_with_images response = gallery_api.get("/api/images") assert response.status_code == 200 data = response.json() assert len(data["images"]) == 3 assert data["total"] == 3 def test_list_images_pagination(self, gallery_api: TestClient, gallery_with_images) -> None: """Test pagination parameters.""" from tensors.server import gallery_routes # noqa: PLC0415 gallery_routes._gallery = gallery_with_images response = gallery_api.get("/api/images?limit=2&offset=1") assert response.status_code == 200 data = response.json() assert len(data["images"]) == 2 class TestGalleryGetImage: """Tests for getting individual images.""" def test_get_image_not_found(self, gallery_api: TestClient) -> None: """Test getting non-existent image returns 404.""" response = gallery_api.get("/api/images/nonexistent") assert response.status_code == 404 def test_get_image_success(self, gallery_api: TestClient, gallery_with_images) -> None: """Test getting an image file.""" from tensors.server import gallery_routes # noqa: PLC0415 gallery_routes._gallery = gallery_with_images list_response = gallery_api.get("/api/images") images = list_response.json()["images"] image_id = images[0]["id"] response = gallery_api.get(f"/api/images/{image_id}") assert response.status_code == 200 assert response.headers["content-type"] == "image/png" class TestGalleryMetadata: """Tests for image metadata endpoints.""" def test_get_metadata_not_found(self, gallery_api: TestClient) -> None: """Test getting metadata for non-existent image.""" response = gallery_api.get("/api/images/nonexistent/meta") assert response.status_code == 404 def test_get_metadata_success(self, gallery_api: TestClient, gallery_with_images) -> None: """Test getting image metadata.""" from tensors.server import gallery_routes # noqa: PLC0415 gallery_routes._gallery = gallery_with_images list_response = gallery_api.get("/api/images") images = list_response.json()["images"] image_id = images[0]["id"] response = gallery_api.get(f"/api/images/{image_id}/meta") assert response.status_code == 200 data = response.json() assert data["id"] == image_id assert "metadata" in data def test_edit_metadata(self, gallery_api: TestClient, gallery_with_images) -> None: """Test updating image metadata.""" from tensors.server import gallery_routes # noqa: PLC0415 gallery_routes._gallery = gallery_with_images list_response = gallery_api.get("/api/images") images = list_response.json()["images"] image_id = images[0]["id"] response = gallery_api.post( f"/api/images/{image_id}/edit", json={"tags": ["test", "favorite"], "rating": 5}, ) assert response.status_code == 200 data = response.json() assert data["metadata"]["tags"] == ["test", "favorite"] assert data["metadata"]["rating"] == 5 def test_edit_metadata_not_found(self, gallery_api: TestClient) -> None: """Test editing non-existent image metadata.""" response = gallery_api.post( "/api/images/nonexistent/edit", json={"tags": ["test"]}, ) assert response.status_code == 404 class TestGalleryDelete: """Tests for deleting images.""" def test_delete_image_not_found(self, gallery_api: TestClient) -> None: """Test deleting non-existent image.""" response = gallery_api.delete("/api/images/nonexistent") assert response.status_code == 404 def test_delete_image_success(self, gallery_api: TestClient, gallery_with_images) -> None: """Test deleting an image.""" from tensors.server import gallery_routes # noqa: PLC0415 gallery_routes._gallery = gallery_with_images list_response = gallery_api.get("/api/images") initial_count = list_response.json()["total"] image_id = list_response.json()["images"][0]["id"] response = gallery_api.delete(f"/api/images/{image_id}") assert response.status_code == 200 assert response.json()["deleted"] is True list_response = gallery_api.get("/api/images") assert list_response.json()["total"] == initial_count - 1 class TestGalleryStats: """Tests for gallery statistics endpoint.""" def test_stats_empty(self, gallery_api: TestClient) -> None: """Test stats on empty gallery.""" response = gallery_api.get("/api/images/stats/summary") assert response.status_code == 200 data = response.json() assert data["total_images"] == 0 # ============================================================================= # Database Endpoint Tests # ============================================================================= @pytest.fixture def temp_db(tmp_path): """Create a temporary database for testing.""" from tensors.db import Database # noqa: PLC0415 db_path = tmp_path / "test_models.db" db = Database(db_path=db_path) db.init_schema() return db @pytest.fixture def db_api(temp_db, monkeypatch) -> TestClient: """Test client for db API with temp database.""" from fastapi import FastAPI # noqa: PLC0415 # Monkeypatch Database to use temp_db path from tensors import db as db_module # noqa: PLC0415 from tensors.server.db_routes import create_db_router # noqa: PLC0415 monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path) app = FastAPI() app.include_router(create_db_router()) return TestClient(app) class TestDbEndpoints: """Tests for database API endpoints.""" def test_list_files_empty(self, db_api: TestClient) -> None: """Test listing files from empty database.""" response = db_api.get("/api/db/files") assert response.status_code == 200 assert response.json() == [] def test_search_models_empty(self, db_api: TestClient) -> None: """Test searching models in empty database.""" response = db_api.get("/api/db/models") assert response.status_code == 200 assert response.json() == [] def test_search_models_with_query(self, db_api: TestClient, temp_db, monkeypatch) -> None: """Test searching models with query parameters.""" from tensors import db as db_module # noqa: PLC0415 monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path) model_data = { "id": 12345, "name": "Test Model", "type": "LORA", "tags": [], "modelVersions": [], } temp_db.cache_model(model_data) response = db_api.get("/api/db/models?query=Test") assert response.status_code == 200 results = response.json() assert len(results) >= 1 def test_get_model_not_found(self, db_api: TestClient) -> None: """Test getting non-existent model.""" response = db_api.get("/api/db/models/999999") assert response.status_code == 404 def test_get_model_success(self, db_api: TestClient, temp_db, monkeypatch) -> None: """Test getting cached model.""" from tensors import db as db_module # noqa: PLC0415 monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path) model_data = { "id": 12345, "name": "Test Model", "type": "Checkpoint", "tags": [], "modelVersions": [], } temp_db.cache_model(model_data) response = db_api.get("/api/db/models/12345") assert response.status_code == 200 data = response.json() assert data["name"] == "Test Model" def test_get_stats(self, db_api: TestClient) -> None: """Test getting database stats.""" response = db_api.get("/api/db/stats") assert response.status_code == 200 data = response.json() assert "local_files" in data assert "models" in data # ============================================================================= # Gallery Class Unit Tests # ============================================================================= class TestGalleryClass: """Unit tests for the Gallery class.""" def test_save_image(self, temp_gallery) -> None: """Test saving an image to gallery.""" image_data = b"\x89PNG test data" metadata = {"prompt": "test", "seed": 42} result = temp_gallery.save_image(image_data, metadata=metadata, seed=42) assert result.id is not None assert result.path.exists() assert result.meta_path.exists() def test_list_images_empty(self, temp_gallery) -> None: """Test listing empty gallery.""" images = temp_gallery.list_images() assert images == [] def test_list_images_sorted(self, gallery_with_images) -> None: """Test images are sorted by creation time.""" images = gallery_with_images.list_images(newest_first=True) assert len(images) == 3 times = [img.created_at for img in images] assert times == sorted(times, reverse=True) def test_get_image(self, gallery_with_images) -> None: """Test getting image by ID.""" images = gallery_with_images.list_images() image_id = images[0].id result = gallery_with_images.get_image(image_id) assert result is not None assert result.id == image_id def test_get_image_not_found(self, temp_gallery) -> None: """Test getting non-existent image.""" result = temp_gallery.get_image("nonexistent") assert result is None def test_delete_image(self, gallery_with_images) -> None: """Test deleting an image.""" images = gallery_with_images.list_images() image_id = images[0].id image_path = images[0].path result = gallery_with_images.delete_image(image_id) assert result is True assert not image_path.exists() def test_delete_image_not_found(self, temp_gallery) -> None: """Test deleting non-existent image.""" result = temp_gallery.delete_image("nonexistent") assert result is False def test_count(self, gallery_with_images) -> None: """Test counting images.""" assert gallery_with_images.count() == 3 def test_update_metadata(self, gallery_with_images) -> None: """Test updating metadata.""" images = gallery_with_images.list_images() image_id = images[0].id result = gallery_with_images.update_metadata(image_id, {"custom_field": "value"}) assert result is not None assert result["custom_field"] == "value" def test_update_metadata_not_found(self, temp_gallery) -> None: """Test updating metadata for non-existent image.""" result = temp_gallery.update_metadata("nonexistent", {"field": "value"}) assert result is None # ============================================================================= # Auth Tests # ============================================================================= class TestAuth: """Tests for API key authentication.""" def test_no_auth_when_no_key_configured(self, monkeypatch) -> None: """Test auth is disabled when no API key is configured.""" from tensors.server.auth import verify_api_key monkeypatch.setattr("tensors.server.auth.get_server_api_key", lambda: None) result = verify_api_key(header_key=None, query_key=None) assert result is None def test_auth_required_when_key_configured(self, monkeypatch) -> None: """Test auth is required when API key is configured.""" from tensors.server.auth import verify_api_key monkeypatch.setattr("tensors.server.auth.get_server_api_key", lambda: "secret-key") with pytest.raises(Exception) as exc_info: verify_api_key(header_key=None, query_key=None) assert exc_info.value.status_code == 401 def test_valid_header_key(self, monkeypatch) -> None: """Test valid API key via header.""" from tensors.server.auth import verify_api_key monkeypatch.setattr("tensors.server.auth.get_server_api_key", lambda: "secret-key") result = verify_api_key(header_key="secret-key", query_key=None) assert result == "secret-key" def test_valid_query_key(self, monkeypatch) -> None: """Test valid API key via query param.""" from tensors.server.auth import verify_api_key monkeypatch.setattr("tensors.server.auth.get_server_api_key", lambda: "secret-key") result = verify_api_key(header_key=None, query_key="secret-key") assert result == "secret-key" def test_invalid_key(self, monkeypatch) -> None: """Test invalid API key returns 403.""" from tensors.server.auth import verify_api_key monkeypatch.setattr("tensors.server.auth.get_server_api_key", lambda: "secret-key") with pytest.raises(Exception) as exc_info: verify_api_key(header_key="wrong-key", query_key=None) assert exc_info.value.status_code == 403 def test_header_takes_precedence(self, monkeypatch) -> None: """Test header key takes precedence over query key.""" from tensors.server.auth import verify_api_key monkeypatch.setattr("tensors.server.auth.get_server_api_key", lambda: "secret-key") result = verify_api_key(header_key="secret-key", query_key="wrong-key") assert result == "secret-key" # ============================================================================= # CivitAI Routes Tests # ============================================================================= @pytest.fixture def civitai_api(monkeypatch) -> TestClient: """Test client for CivitAI API.""" from fastapi import FastAPI from tensors.server.civitai_routes import create_civitai_router # Disable auth for testing monkeypatch.setattr("tensors.config.get_server_api_key", lambda: None) app = FastAPI() app.include_router(create_civitai_router()) return TestClient(app) class TestCivitAISearch: """Tests for CivitAI search endpoint.""" def test_search_basic(self, civitai_api: TestClient, respx_mock) -> None: """Test basic search request.""" import respx respx_mock.get("https://civitai.com/api/v1/models").mock( return_value=respx.MockResponse( 200, json={"items": [{"id": 1, "name": "Test Model"}], "metadata": {}}, ) ) response = civitai_api.get("/api/civitai/search") assert response.status_code == 200 data = response.json() assert "items" in data def test_search_with_params(self, civitai_api: TestClient, respx_mock) -> None: """Test search with query parameters.""" import respx respx_mock.get("https://civitai.com/api/v1/models").mock( return_value=respx.MockResponse( 200, json={"items": [], "metadata": {}}, ) ) response = civitai_api.get( "/api/civitai/search", params={ "query": "anime", "types": "LORA", "baseModels": "Illustrious", "sort": "Newest", "limit": 10, "period": "Week", "tag": "character", "sfw": True, }, ) assert response.status_code == 200 def test_search_api_error(self, civitai_api: TestClient, respx_mock) -> None: """Test search handles API errors.""" import respx respx_mock.get("https://civitai.com/api/v1/models").mock( return_value=respx.MockResponse(500, json={"error": "Server error"}) ) response = civitai_api.get("/api/civitai/search") assert response.status_code == 500 def test_search_network_error(self, civitai_api: TestClient, respx_mock) -> None: """Test search handles network errors.""" import httpx respx_mock.get("https://civitai.com/api/v1/models").mock( side_effect=httpx.RequestError("Connection failed") ) response = civitai_api.get("/api/civitai/search") assert response.status_code == 500 class TestCivitAIGetModel: """Tests for CivitAI get model endpoint.""" def test_get_model_success(self, civitai_api: TestClient, respx_mock, temp_db, monkeypatch) -> None: """Test getting a model by ID.""" import respx from tensors import db as db_module monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path) respx_mock.get("https://civitai.com/api/v1/models/12345").mock( return_value=respx.MockResponse( 200, json={ "id": 12345, "name": "Test Model", "type": "LORA", "tags": [], "modelVersions": [], }, ) ) response = civitai_api.get("/api/civitai/model/12345") assert response.status_code == 200 data = response.json() assert data["id"] == 12345 assert data["name"] == "Test Model" def test_get_model_not_found(self, civitai_api: TestClient, respx_mock) -> None: """Test getting non-existent model.""" import respx respx_mock.get("https://civitai.com/api/v1/models/99999").mock( return_value=respx.MockResponse(404, json={"error": "Not found"}) ) response = civitai_api.get("/api/civitai/model/99999") assert response.status_code == 404 def test_get_model_network_error(self, civitai_api: TestClient, respx_mock) -> None: """Test get model handles network errors.""" import httpx respx_mock.get("https://civitai.com/api/v1/models/12345").mock( side_effect=httpx.RequestError("Connection failed") ) response = civitai_api.get("/api/civitai/model/12345") assert response.status_code == 500 # ============================================================================= # Download Routes Tests # ============================================================================= @pytest.fixture def download_api(monkeypatch) -> TestClient: """Test client for Download API.""" from fastapi import FastAPI from tensors.server.download_routes import create_download_router # Disable auth for testing monkeypatch.setattr("tensors.config.get_server_api_key", lambda: None) app = FastAPI() app.include_router(create_download_router()) return TestClient(app) class TestDownloadRoutes: """Tests for download endpoints.""" def test_list_active_downloads_empty(self, download_api: TestClient) -> None: """Test listing active downloads when none exist.""" response = download_api.get("/api/download/active") assert response.status_code == 200 data = response.json() assert data["downloads"] == [] assert data["total"] == 0 def test_get_download_status_not_found(self, download_api: TestClient) -> None: """Test getting status of non-existent download.""" response = download_api.get("/api/download/status/nonexistent-id") assert response.status_code == 404 def test_start_download_no_identifier(self, download_api: TestClient) -> None: """Test starting download without any identifier returns 404.""" response = download_api.post("/api/download", json={}) # No version_id, model_id, or hash provided - can't find model assert response.status_code == 404 assert "not found" in response.json()["detail"].lower() # ============================================================================= # DB Routes Additional Tests # ============================================================================= class TestDbRoutesExtended: """Extended tests for database routes.""" def test_get_file_not_found(self, db_api: TestClient) -> None: """Test getting non-existent file.""" response = db_api.get("/api/db/files/99999") assert response.status_code == 404 def test_get_triggers_by_path(self, db_api: TestClient) -> None: """Test getting triggers by path (returns empty for non-existent).""" response = db_api.get("/api/db/triggers", params={"file_path": "/nonexistent/path.safetensors"}) assert response.status_code == 200 assert response.json() == [] def test_get_triggers_by_version(self, db_api: TestClient) -> None: """Test getting triggers by version (returns empty for non-existent).""" response = db_api.get("/api/db/triggers/99999") assert response.status_code == 200 assert response.json() == [] def test_cache_model(self, db_api: TestClient, temp_db, monkeypatch, respx_mock) -> None: """Test caching a model.""" import respx from tensors import db as db_module monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path) respx_mock.get("https://civitai.com/api/v1/models/12345").mock( return_value=respx.MockResponse( 200, json={ "id": 12345, "name": "Cached Model", "type": "LORA", "tags": [], "modelVersions": [], }, ) ) response = db_api.post("/api/db/cache", json={"model_id": 12345}) assert response.status_code == 200 data = response.json() assert "model_id" in data def test_scan_directory_not_found(self, db_api: TestClient) -> None: """Test scanning non-existent directory returns 400.""" response = db_api.post("/api/db/scan", json={"directory": "/nonexistent/directory"}) assert response.status_code == 400 def test_link_files(self, db_api: TestClient, temp_db, monkeypatch) -> None: """Test linking files endpoint.""" from tensors import db as db_module monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path) response = db_api.post("/api/db/link") assert response.status_code == 200 data = response.json() assert "linked" in data def test_search_models_with_type_filter(self, db_api: TestClient, temp_db, monkeypatch) -> None: """Test searching models with type filter.""" from tensors import db as db_module monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path) # Cache some models temp_db.cache_model({"id": 1, "name": "LORA Model", "type": "LORA", "tags": [], "modelVersions": []}) temp_db.cache_model({"id": 2, "name": "Checkpoint", "type": "Checkpoint", "tags": [], "modelVersions": []}) response = db_api.get("/api/db/models?type=LORA") assert response.status_code == 200 results = response.json() assert all(r.get("type") == "LORA" for r in results if r.get("type")) def test_search_models_with_base_filter(self, db_api: TestClient, temp_db, monkeypatch) -> None: """Test searching models with base model filter.""" from tensors import db as db_module monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path) response = db_api.get("/api/db/models?base=SD 1.5") assert response.status_code == 200 def test_search_models_with_limit(self, db_api: TestClient, temp_db, monkeypatch) -> None: """Test searching models with limit.""" from tensors import db as db_module monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path) for i in range(10): temp_db.cache_model({"id": 100 + i, "name": f"Model {i}", "type": "LORA", "tags": [], "modelVersions": []}) response = db_api.get("/api/db/models?limit=5") assert response.status_code == 200 results = response.json() assert len(results) <= 5 def test_cache_model_not_found(self, db_api: TestClient, respx_mock) -> None: """Test caching a model that doesn't exist on CivitAI.""" import respx respx_mock.get("https://civitai.com/api/v1/models/99999").mock( return_value=respx.MockResponse(404, json={"error": "Not found"}) ) response = db_api.post("/api/db/cache", json={"model_id": 99999}) assert response.status_code == 404 def test_scan_directory_success(self, db_api: TestClient, temp_db, monkeypatch, tmp_path) -> None: """Test scanning a valid directory.""" from tensors import db as db_module monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path) # Create a temporary directory (empty, no safetensors) scan_dir = tmp_path / "models" scan_dir.mkdir() response = db_api.post("/api/db/scan", json={"directory": str(scan_dir)}) assert response.status_code == 200 data = response.json() assert "scanned" in data assert data["scanned"] == 0 # Empty directory # ============================================================================= # Download Routes Helper Function Tests # ============================================================================= class TestDownloadHelpers: """Tests for download route helper functions.""" def test_format_size_bytes(self) -> None: """Test formatting bytes.""" from tensors.server.download_routes import _format_size assert _format_size(500) == "500 B" assert _format_size(0) == "0 B" def test_format_size_kb(self) -> None: """Test formatting kilobytes.""" from tensors.server.download_routes import _format_size assert _format_size(1024) == "1.0 KB" assert _format_size(2048) == "2.0 KB" assert _format_size(1536) == "1.5 KB" def test_format_size_mb(self) -> None: """Test formatting megabytes.""" from tensors.server.download_routes import _format_size assert _format_size(1024 * 1024) == "1.0 MB" assert _format_size(50 * 1024 * 1024) == "50.0 MB" def test_format_size_gb(self) -> None: """Test formatting gigabytes.""" from tensors.server.download_routes import _format_size assert _format_size(1024 * 1024 * 1024) == "1.0 GB" assert _format_size(2 * 1024 * 1024 * 1024) == "2.0 GB" def test_get_output_dir_with_override(self) -> None: """Test output dir with override.""" from tensors.server.download_routes import _get_output_dir result = _get_output_dir({}, "/custom/path") assert str(result) == "/custom/path" def test_get_output_dir_checkpoint(self) -> None: """Test output dir for checkpoint.""" from tensors.server.download_routes import _get_output_dir version_info = {"model": {"type": "Checkpoint"}} result = _get_output_dir(version_info, None) assert "checkpoints" in str(result) def test_get_output_dir_lora(self) -> None: """Test output dir for LORA.""" from tensors.server.download_routes import _get_output_dir version_info = {"model": {"type": "LORA"}} result = _get_output_dir(version_info, None) assert "loras" in str(result) def test_get_output_dir_locon(self) -> None: """Test output dir for LoCon.""" from tensors.server.download_routes import _get_output_dir version_info = {"model": {"type": "LoCon"}} result = _get_output_dir(version_info, None) assert "loras" in str(result) def test_get_output_dir_textual_inversion(self) -> None: """Test output dir for TextualInversion.""" from tensors.server.download_routes import _get_output_dir version_info = {"model": {"type": "TextualInversion"}} result = _get_output_dir(version_info, None) assert "embeddings" in str(result) def test_get_output_dir_vae(self) -> None: """Test output dir for VAE.""" from tensors.server.download_routes import _get_output_dir version_info = {"model": {"type": "VAE"}} result = _get_output_dir(version_info, None) assert "vae" in str(result) def test_get_output_dir_controlnet(self) -> None: """Test output dir for Controlnet.""" from tensors.server.download_routes import _get_output_dir version_info = {"model": {"type": "Controlnet"}} result = _get_output_dir(version_info, None) assert "controlnet" in str(result) def test_get_output_dir_unknown(self) -> None: """Test output dir for unknown type.""" from tensors.server.download_routes import _get_output_dir version_info = {"model": {"type": "UnknownType"}} result = _get_output_dir(version_info, None) assert "other" in str(result) def test_get_output_dir_no_type(self) -> None: """Test output dir with missing type (defaults to Checkpoint).""" from tensors.server.download_routes import _get_output_dir version_info = {"model": {}} result = _get_output_dir(version_info, None) assert "checkpoints" in str(result) class TestResolveVersionId: """Tests for _resolve_version_id helper.""" def test_resolve_with_version_id(self, respx_mock) -> None: """Test resolving by version ID.""" import respx from tensors.server.download_routes import _resolve_version_id respx_mock.get("https://civitai.com/api/v1/model-versions/12345").mock( return_value=respx.MockResponse(200, json={"id": 12345, "name": "v1.0"}) ) version_id, info = _resolve_version_id(12345, None, None, None) assert version_id == 12345 assert info is not None assert info["id"] == 12345 def test_resolve_with_hash(self, respx_mock) -> None: """Test resolving by hash.""" import respx from tensors.server.download_routes import _resolve_version_id respx_mock.get("https://civitai.com/api/v1/model-versions/by-hash/ABC123").mock( return_value=respx.MockResponse(200, json={"id": 555, "modelId": 100}) ) version_id, info = _resolve_version_id(None, None, "abc123", None) assert version_id == 555 assert info is not None def test_resolve_with_hash_not_found(self, respx_mock) -> None: """Test resolving by hash when not found.""" import respx from tensors.server.download_routes import _resolve_version_id respx_mock.get("https://civitai.com/api/v1/model-versions/by-hash/NOTFOUND").mock( return_value=respx.MockResponse(404, json={"error": "Not found"}) ) version_id, info = _resolve_version_id(None, None, "notfound", None) assert version_id is None assert info is None def test_resolve_with_model_id(self, respx_mock) -> None: """Test resolving by model ID (uses latest version).""" import respx from tensors.server.download_routes import _resolve_version_id respx_mock.get("https://civitai.com/api/v1/models/999").mock( return_value=respx.MockResponse( 200, json={ "id": 999, "modelVersions": [{"id": 1001, "name": "Latest"}, {"id": 1000, "name": "Old"}], }, ) ) version_id, info = _resolve_version_id(None, 999, None, None) assert version_id == 1001 assert info is not None assert info["name"] == "Latest" def test_resolve_with_model_id_no_versions(self, respx_mock) -> None: """Test resolving by model ID with no versions.""" import respx from tensors.server.download_routes import _resolve_version_id respx_mock.get("https://civitai.com/api/v1/models/888").mock( return_value=respx.MockResponse(200, json={"id": 888, "modelVersions": []}) ) version_id, info = _resolve_version_id(None, 888, None, None) assert version_id is None assert info is None def test_resolve_with_model_id_not_found(self, respx_mock) -> None: """Test resolving by model ID when model not found.""" import respx from tensors.server.download_routes import _resolve_version_id respx_mock.get("https://civitai.com/api/v1/models/777").mock( return_value=respx.MockResponse(404, json={"error": "Not found"}) ) version_id, info = _resolve_version_id(None, 777, None, None) assert version_id is None assert info is None def test_resolve_with_nothing(self) -> None: """Test resolving with no identifiers.""" from tensors.server.download_routes import _resolve_version_id version_id, info = _resolve_version_id(None, None, None, None) assert version_id is None assert info is None class TestDownloadEndpoints: """Tests for download API endpoints.""" def test_start_download_success(self, download_api: TestClient, respx_mock, tmp_path, monkeypatch) -> None: """Test starting a download successfully.""" import respx from tensors import config as config_module monkeypatch.setattr(config_module, "MODELS_DIR", tmp_path) respx_mock.get("https://civitai.com/api/v1/model-versions/12345").mock( return_value=respx.MockResponse( 200, json={ "id": 12345, "name": "v1.0", "model": {"name": "Test Model", "type": "LORA"}, "files": [{"name": "test-model.safetensors", "primary": True}], }, ) ) response = download_api.post("/api/download", json={"version_id": 12345}) assert response.status_code == 200 data = response.json() assert "download_id" in data assert data["status"] == "queued" assert data["version_id"] == 12345 def test_start_download_no_files(self, download_api: TestClient, respx_mock, tmp_path, monkeypatch) -> None: """Test starting download with no files returns 400.""" import respx from tensors import config as config_module monkeypatch.setattr(config_module, "MODELS_DIR", tmp_path) respx_mock.get("https://civitai.com/api/v1/model-versions/99999").mock( return_value=respx.MockResponse( 200, json={ "id": 99999, "name": "v1.0", "model": {"name": "Empty Model", "type": "LORA"}, "files": [], }, ) ) response = download_api.post("/api/download", json={"version_id": 99999}) assert response.status_code == 400 assert "No files found" in response.json()["detail"] def test_start_download_with_hash(self, download_api: TestClient, respx_mock, tmp_path, monkeypatch) -> None: """Test starting download using hash.""" import respx from tensors import config as config_module monkeypatch.setattr(config_module, "MODELS_DIR", tmp_path) respx_mock.get("https://civitai.com/api/v1/model-versions/by-hash/ABCD1234").mock( return_value=respx.MockResponse( 200, json={ "id": 555, "modelId": 100, "name": "v1.0", "model": {"name": "Hash Model", "type": "Checkpoint"}, "files": [{"name": "model.safetensors", "primary": True}], }, ) ) response = download_api.post("/api/download", json={"hash": "abcd1234"}) assert response.status_code == 200 data = response.json() assert data["version_id"] == 555 def test_start_download_with_model_id(self, download_api: TestClient, respx_mock, tmp_path, monkeypatch) -> None: """Test starting download using model ID (picks latest version).""" import respx from tensors import config as config_module monkeypatch.setattr(config_module, "MODELS_DIR", tmp_path) respx_mock.get("https://civitai.com/api/v1/models/200").mock( return_value=respx.MockResponse( 200, json={ "id": 200, "name": "Model With Versions", "modelVersions": [ { "id": 2001, "name": "Latest", "model": {"name": "Model", "type": "LORA"}, "files": [{"name": "latest.safetensors", "primary": True}], } ], }, ) ) response = download_api.post("/api/download", json={"model_id": 200}) assert response.status_code == 200 data = response.json() assert data["version_id"] == 2001 def test_start_download_with_output_dir(self, download_api: TestClient, respx_mock, tmp_path, monkeypatch) -> None: """Test starting download with custom output directory.""" import respx from tensors import config as config_module custom_dir = tmp_path / "custom" custom_dir.mkdir() monkeypatch.setattr(config_module, "MODELS_DIR", tmp_path) respx_mock.get("https://civitai.com/api/v1/model-versions/333").mock( return_value=respx.MockResponse( 200, json={ "id": 333, "name": "v1.0", "model": {"name": "Custom Dir Model", "type": "LORA"}, "files": [{"name": "custom.safetensors", "primary": True}], }, ) ) response = download_api.post( "/api/download", json={"version_id": 333, "output_dir": str(custom_dir)} ) assert response.status_code == 200 assert str(custom_dir) in response.json()["destination"] def test_get_download_status_success(self, download_api: TestClient, respx_mock, tmp_path, monkeypatch) -> None: """Test getting status of an existing download.""" import respx from tensors import config as config_module monkeypatch.setattr(config_module, "MODELS_DIR", tmp_path) # First create a download respx_mock.get("https://civitai.com/api/v1/model-versions/444").mock( return_value=respx.MockResponse( 200, json={ "id": 444, "name": "v1.0", "model": {"name": "Status Test", "type": "LORA"}, "files": [{"name": "status.safetensors", "primary": True}], }, ) ) create_response = download_api.post("/api/download", json={"version_id": 444}) download_id = create_response.json()["download_id"] # Now get its status status_response = download_api.get(f"/api/download/status/{download_id}") assert status_response.status_code == 200 data = status_response.json() assert data["id"] == download_id def test_list_active_downloads_with_data(self, download_api: TestClient, respx_mock, tmp_path, monkeypatch) -> None: """Test listing active downloads after creating some.""" import respx from tensors import config as config_module from tensors.server import download_routes # Clear any existing downloads download_routes._active_downloads.clear() monkeypatch.setattr(config_module, "MODELS_DIR", tmp_path) respx_mock.get("https://civitai.com/api/v1/model-versions/555").mock( return_value=respx.MockResponse( 200, json={ "id": 555, "name": "v1.0", "model": {"name": "Active Test", "type": "LORA"}, "files": [{"name": "active.safetensors", "primary": True}], }, ) ) download_api.post("/api/download", json={"version_id": 555}) response = download_api.get("/api/download/active") assert response.status_code == 200 data = response.json() assert data["total"] >= 1 assert len(data["downloads"]) >= 1 # ============================================================================= # CivitAI Routes Extended Tests # ============================================================================= class TestCivitAIRoutesExtended: """Extended tests for CivitAI routes.""" def test_search_with_nsfw_filter(self, civitai_api: TestClient, respx_mock) -> None: """Test search with NSFW filter.""" import respx respx_mock.get("https://civitai.com/api/v1/models").mock( return_value=respx.MockResponse(200, json={"items": [], "metadata": {}}) ) response = civitai_api.get("/api/civitai/search", params={"nsfw": "Soft"}) assert response.status_code == 200 def test_search_with_commercial_filter(self, civitai_api: TestClient, respx_mock) -> None: """Test search with commercial use filter.""" import respx respx_mock.get("https://civitai.com/api/v1/models").mock( return_value=respx.MockResponse(200, json={"items": [], "metadata": {}}) ) response = civitai_api.get("/api/civitai/search", params={"commercial": "Rent"}) assert response.status_code == 200 def test_search_with_username(self, civitai_api: TestClient, respx_mock) -> None: """Test search with username filter.""" import respx respx_mock.get("https://civitai.com/api/v1/models").mock( return_value=respx.MockResponse(200, json={"items": [], "metadata": {}}) ) response = civitai_api.get("/api/civitai/search", params={"username": "testuser"}) assert response.status_code == 200 def test_search_with_page(self, civitai_api: TestClient, respx_mock) -> None: """Test search with page parameter.""" import respx respx_mock.get("https://civitai.com/api/v1/models").mock( return_value=respx.MockResponse(200, json={"items": [], "metadata": {}}) ) response = civitai_api.get("/api/civitai/search", params={"page": 2}) assert response.status_code == 200 def test_get_model_caches_result(self, civitai_api: TestClient, respx_mock, temp_db, monkeypatch) -> None: """Test that getting a model caches it in the database.""" import respx from tensors import db as db_module monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path) respx_mock.get("https://civitai.com/api/v1/models/77777").mock( return_value=respx.MockResponse( 200, json={ "id": 77777, "name": "Cacheable Model", "type": "Checkpoint", "tags": ["anime"], "modelVersions": [{"id": 77778, "name": "v1"}], }, ) ) response = civitai_api.get("/api/civitai/model/77777") assert response.status_code == 200 assert response.json()["name"] == "Cacheable Model" # Verify it was cached cached = temp_db.get_model(77777) assert cached is not None assert cached["name"] == "Cacheable Model" # ============================================================================= # Gallery Routes Extended Tests # ============================================================================= class TestGalleryRoutesExtended: """Extended tests for gallery routes.""" def test_list_images_oldest_first(self, gallery_api: TestClient, gallery_with_images) -> None: """Test listing images sorted oldest first.""" from tensors.server import gallery_routes gallery_routes._gallery = gallery_with_images response = gallery_api.get("/api/images?newest_first=false") assert response.status_code == 200 data = response.json() assert len(data["images"]) == 3 def test_edit_metadata_partial_update(self, gallery_api: TestClient, gallery_with_images) -> None: """Test partial metadata update (only some fields).""" from tensors.server import gallery_routes gallery_routes._gallery = gallery_with_images list_response = gallery_api.get("/api/images") image_id = list_response.json()["images"][0]["id"] # Only update notes, not tags response = gallery_api.post(f"/api/images/{image_id}/edit", json={"notes": "Test note"}) assert response.status_code == 200 data = response.json() assert data["metadata"]["notes"] == "Test note" def test_edit_metadata_favorite(self, gallery_api: TestClient, gallery_with_images) -> None: """Test setting favorite flag.""" from tensors.server import gallery_routes gallery_routes._gallery = gallery_with_images list_response = gallery_api.get("/api/images") image_id = list_response.json()["images"][0]["id"] response = gallery_api.post(f"/api/images/{image_id}/edit", json={"favorite": True}) assert response.status_code == 200 assert response.json()["metadata"]["favorite"] is True # ============================================================================= # Download Background Task Tests # ============================================================================= class TestDownloadBackgroundTasks: """Tests for download background task functions.""" def test_do_download_success(self, monkeypatch, tmp_path) -> None: """Test successful download task execution.""" from tensors.server import download_routes from tensors.server.download_routes import _do_download # Set up tracking entry download_id = "test_123" download_routes._active_downloads[download_id] = {"id": download_id, "status": "queued"} # Mock the download function def mock_download(version_id, dest_path, api_key, on_progress, resume): # Simulate progress callback on_progress(1024, 2048, 100.0) on_progress(2048, 2048, 200.0) return True monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download) monkeypatch.setattr(download_routes, "_auto_link_file", lambda *args: None) dest_path = tmp_path / "model.safetensors" _do_download(12345, dest_path, None, download_id) assert download_routes._active_downloads[download_id]["status"] == "completed" assert download_routes._active_downloads[download_id]["progress"] == 100 # Cleanup del download_routes._active_downloads[download_id] def test_do_download_failure(self, monkeypatch, tmp_path) -> None: """Test failed download task execution.""" from tensors.server import download_routes from tensors.server.download_routes import _do_download download_id = "test_fail_123" download_routes._active_downloads[download_id] = {"id": download_id, "status": "queued"} # Mock download to return failure monkeypatch.setattr( download_routes, "download_model_with_progress", lambda *args, **kwargs: False ) dest_path = tmp_path / "model.safetensors" _do_download(12345, dest_path, None, download_id) assert download_routes._active_downloads[download_id]["status"] == "failed" assert "error" in download_routes._active_downloads[download_id] del download_routes._active_downloads[download_id] def test_do_download_exception(self, monkeypatch, tmp_path) -> None: """Test download task with exception.""" from tensors.server import download_routes from tensors.server.download_routes import _do_download download_id = "test_exc_123" download_routes._active_downloads[download_id] = {"id": download_id, "status": "queued"} # Mock download to raise exception def mock_download(*args, **kwargs): raise RuntimeError("Network error") monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download) dest_path = tmp_path / "model.safetensors" _do_download(12345, dest_path, None, download_id) assert download_routes._active_downloads[download_id]["status"] == "failed" assert "Network error" in download_routes._active_downloads[download_id]["error"] del download_routes._active_downloads[download_id] def test_on_progress_callback(self, monkeypatch, tmp_path) -> None: """Test progress callback updates correctly.""" from tensors.server import download_routes from tensors.server.download_routes import _do_download download_id = "test_progress_123" download_routes._active_downloads[download_id] = {"id": download_id, "status": "queued"} progress_calls = [] def mock_download(version_id, dest_path, api_key, on_progress, resume): # Test with different sizes on_progress(512, 1024, 50.0) # 512 B of 1 KB progress_calls.append(dict(download_routes._active_downloads[download_id])) on_progress(1024 * 500, 1024 * 1024, 1000.0) # 500 KB of 1 MB progress_calls.append(dict(download_routes._active_downloads[download_id])) on_progress(1024 * 1024 * 500, 1024 * 1024 * 1024, 10000.0) # 500 MB of 1 GB progress_calls.append(dict(download_routes._active_downloads[download_id])) return True monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download) monkeypatch.setattr(download_routes, "_auto_link_file", lambda *args: None) dest_path = tmp_path / "model.safetensors" _do_download(12345, dest_path, None, download_id) # Check progress formatting was called assert len(progress_calls) == 3 assert progress_calls[0]["downloaded_str"] == "512 B" assert progress_calls[0]["total_str"] == "1.0 KB" assert progress_calls[1]["downloaded_str"] == "500.0 KB" assert progress_calls[2]["downloaded_str"] == "500.0 MB" assert progress_calls[2]["total_str"] == "1.0 GB" del download_routes._active_downloads[download_id] def test_on_progress_zero_total(self, monkeypatch, tmp_path) -> None: """Test progress callback with zero total (unknown size).""" from tensors.server import download_routes from tensors.server.download_routes import _do_download download_id = "test_zero_total" download_routes._active_downloads[download_id] = {"id": download_id, "status": "queued"} def mock_download(version_id, dest_path, api_key, on_progress, resume): on_progress(1024, 0, 100.0) # Unknown total return True monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download) monkeypatch.setattr(download_routes, "_auto_link_file", lambda *args: None) dest_path = tmp_path / "model.safetensors" _do_download(12345, dest_path, None, download_id) assert download_routes._active_downloads[download_id]["total_str"] == "Unknown" del download_routes._active_downloads[download_id] class TestAutoLinkFile: """Tests for _auto_link_file function.""" def test_auto_link_success(self, monkeypatch, tmp_path, temp_db) -> None: """Test auto-linking a downloaded file.""" from tensors.server import download_routes from tensors.server.download_routes import _auto_link_file # Create a fake safetensor file file_path = tmp_path / "test.safetensors" file_path.write_bytes(b"fake safetensor data") from tensors import db as db_module monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path) # Mock scan results scanned_files = [] def mock_scan(directory): return [{"id": 1, "file_path": str(file_path), "sha256": "abc123"}] monkeypatch.setattr(temp_db, "scan_directory", mock_scan) # Mock CivitAI lookup def mock_fetch_by_hash(sha256, api_key): return {"id": 999, "modelId": 888} monkeypatch.setattr(download_routes, "fetch_civitai_by_hash", mock_fetch_by_hash) # Mock Database context manager to use temp_db class MockDB: def __enter__(self): return temp_db def __exit__(self, *args): pass monkeypatch.setattr(download_routes, "Database", MockDB) # This should not raise _auto_link_file(file_path, None) def test_auto_link_exception_handled(self, monkeypatch, tmp_path) -> None: """Test auto-link handles exceptions gracefully.""" from tensors.server.download_routes import _auto_link_file # Mock Database to raise exception def mock_db(*args, **kwargs): raise RuntimeError("DB error") from tensors.server import download_routes monkeypatch.setattr(download_routes, "Database", mock_db) file_path = tmp_path / "test.safetensors" # Should not raise _auto_link_file(file_path, None) # ============================================================================= # DB Routes - File Lookup Tests # ============================================================================= class TestDbFileLookup: """Tests for database file lookup.""" def test_get_file_success(self, db_api: TestClient, temp_db, monkeypatch, tmp_path) -> None: """Test getting an existing file.""" from tensors import db as db_module monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path) # Create and scan a file to add it to the database test_file = tmp_path / "test.safetensors" # Create minimal valid safetensor header header_data = b'{"__metadata__": {}}' header_size = len(header_data) test_file.write_bytes(header_size.to_bytes(8, "little") + header_data) # Scan to add to database results = temp_db.scan_directory(tmp_path) assert len(results) > 0 file_id = results[0]["id"] response = db_api.get(f"/api/db/files/{file_id}") assert response.status_code == 200 data = response.json() assert data["id"] == file_id def test_link_files_with_matches(self, db_api: TestClient, temp_db, monkeypatch, tmp_path, respx_mock) -> None: """Test linking files when CivitAI matches exist.""" import respx from tensors import db as db_module monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path) # Create and scan a file test_file = tmp_path / "linkable.safetensors" header_data = b'{"__metadata__": {}}' header_size = len(header_data) test_file.write_bytes(header_size.to_bytes(8, "little") + header_data) temp_db.scan_directory(tmp_path) files = temp_db.list_local_files() assert len(files) > 0 sha256 = files[0]["sha256"] # Mock CivitAI hash lookup to return a match respx_mock.get(f"https://civitai.com/api/v1/model-versions/by-hash/{sha256.upper()}").mock( return_value=respx.MockResponse( 200, json={"id": 12345, "modelId": 67890, "name": "Found Model"} ) ) response = db_api.post("/api/db/link") assert response.status_code == 200 data = response.json() assert data["linked"] >= 1 assert len(data["results"]) >= 1 # ============================================================================= # Server Init Tests # ============================================================================= class TestServerInit: """Tests for server initialization.""" def test_docs_endpoint(self, api: TestClient) -> None: """Test /docs endpoint returns HTML.""" response = api.get("/docs") assert response.status_code == 200 assert "text/html" in response.headers["content-type"] def test_openapi_schema(self, api: TestClient) -> None: """Test OpenAPI schema is available.""" response = api.get("/openapi.json") assert response.status_code == 200 data = response.json() assert data["info"]["title"] == "tensors" assert "paths" in data def test_app_startup_with_auth(self, monkeypatch) -> None: """Test app startup logging with auth enabled.""" monkeypatch.setattr("tensors.config.get_server_api_key", lambda: "test-key") from tensors.server import create_app app = create_app() # App should be created successfully assert app.title == "tensors" def test_app_startup_without_auth(self, monkeypatch) -> None: """Test app startup logging without auth.""" monkeypatch.setattr("tensors.config.get_server_api_key", lambda: None) from tensors.server import create_app app = create_app() assert app.title == "tensors" class TestCivitAICacheFailure: """Test CivitAI model caching failure handling.""" def test_get_model_cache_failure_continues(self, civitai_api: TestClient, respx_mock, monkeypatch) -> None: """Test that cache failure doesn't prevent model retrieval.""" import respx respx_mock.get("https://civitai.com/api/v1/models/88888").mock( return_value=respx.MockResponse( 200, json={ "id": 88888, "name": "Cache Fail Model", "type": "LORA", "tags": [], "modelVersions": [], }, ) ) # Make Database raise an exception class FailingDB: def __enter__(self): raise RuntimeError("Database error") def __exit__(self, *args): pass from tensors.server import civitai_routes monkeypatch.setattr(civitai_routes, "Database", FailingDB) # Should still return the model even though caching failed response = civitai_api.get("/api/civitai/model/88888") assert response.status_code == 200 assert response.json()["name"] == "Cache Fail Model" # ============================================================================= # Gallery Edge Cases # ============================================================================= class TestGalleryEdgeCases: """Edge case tests for gallery functionality.""" def test_gallery_get_metadata_for_image(self, gallery_api: TestClient, gallery_with_images) -> None: """Test getting metadata returns full image info.""" from tensors.server import gallery_routes gallery_routes._gallery = gallery_with_images list_response = gallery_api.get("/api/images") image_id = list_response.json()["images"][0]["id"] response = gallery_api.get(f"/api/images/{image_id}/meta") assert response.status_code == 200 data = response.json() assert "path" in data assert "created_at" in data assert "metadata" in data def test_gallery_stats_with_images(self, gallery_api: TestClient, gallery_with_images) -> None: """Test stats with actual images.""" from tensors.server import gallery_routes gallery_routes._gallery = gallery_with_images response = gallery_api.get("/api/images/stats/summary") assert response.status_code == 200 data = response.json() assert data["total_images"] == 3 assert "gallery_dir" in data class TestGalleryClassExtended: """Extended unit tests for Gallery class.""" def test_save_image_with_seed(self, temp_gallery) -> None: """Test saving image with seed creates proper filename.""" image_data = b"\x89PNG test data" result = temp_gallery.save_image(image_data, seed=12345) assert "12345" in result.path.name assert result.path.exists() def test_save_image_without_metadata(self, temp_gallery) -> None: """Test saving image without metadata.""" image_data = b"\x89PNG test data" result = temp_gallery.save_image(image_data) assert result.path.exists() # No metadata file should exist assert not result.meta_path.exists() def test_list_images_with_offset(self, gallery_with_images) -> None: """Test list images with offset.""" images = gallery_with_images.list_images(offset=1, limit=10) assert len(images) == 2 # 3 total - 1 offset = 2 def test_get_metadata_returns_dict(self, gallery_with_images) -> None: """Test get_metadata returns metadata dict.""" images = gallery_with_images.list_images() metadata = gallery_with_images.get_metadata(images[0].id) assert isinstance(metadata, dict) assert "prompt" in metadata def test_get_metadata_nonexistent(self, temp_gallery) -> None: """Test get_metadata for non-existent image returns None.""" result = temp_gallery.get_metadata("nonexistent") assert result is None