e257a029da
- Add tests/test_db.py with 33 tests for Database class: - Schema initialization and migrations - Local file CRUD operations (scan, list, link) - CivitAI model caching (cache_model, tags, versions, files) - Query operations (search, get_model, get_triggers) - Statistics and context manager - Extend tests/test_server.py with 27 tests for API endpoints: - Gallery endpoints (list, get, meta, edit, delete, stats) - Database endpoints (files, models, stats) - Gallery class unit tests - Add tests/test_client.py with 33 tests for TsrClient: - Server status operations - Gallery image operations (list, get, delete, edit, download) - Model management (list, active, switch, loras) - Image generation - CivitAI download operations - Database query operations - Error handling and context manager Total: 191 tests passing with 61% coverage Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
547 lines
19 KiB
Python
547 lines
19 KiB
Python
"""Tests for tensors.server package (FastAPI sd-server proxy wrapper)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import httpx
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
from tensors.server import create_app
|
|
from tensors.server.models import ServerConfig
|
|
from tensors.server.process import ProcessManager
|
|
|
|
|
|
@pytest.fixture()
|
|
def pm() -> ProcessManager:
|
|
return ProcessManager()
|
|
|
|
|
|
@pytest.fixture()
|
|
def api() -> TestClient:
|
|
return TestClient(create_app())
|
|
|
|
|
|
def _get_pm(api: TestClient) -> ProcessManager:
|
|
return api.app.state.pm # type: ignore[no-any-return, attr-defined]
|
|
|
|
|
|
class TestStatus:
|
|
def test_not_running(self, api: TestClient) -> None:
|
|
r = api.get("/status")
|
|
assert r.status_code == 200
|
|
assert r.json()["running"] is False
|
|
|
|
def test_running(self, api: TestClient) -> None:
|
|
pm = _get_pm(api)
|
|
mock_proc = MagicMock()
|
|
mock_proc.poll.return_value = None
|
|
mock_proc.pid = 999
|
|
pm.proc = mock_proc
|
|
pm.config = ServerConfig(model="/m.safetensors")
|
|
r = api.get("/status")
|
|
data = r.json()
|
|
assert data["running"] is True
|
|
assert data["pid"] == 999
|
|
|
|
def test_exited(self, api: TestClient) -> None:
|
|
pm = _get_pm(api)
|
|
mock_proc = MagicMock()
|
|
mock_proc.poll.return_value = 1
|
|
pm.proc = mock_proc
|
|
r = api.get("/status")
|
|
data = r.json()
|
|
assert data["running"] is False
|
|
assert data["exit_code"] == 1
|
|
|
|
|
|
class TestReload:
|
|
@patch.object(ProcessManager, "wait_ready", new_callable=AsyncMock, return_value=True)
|
|
@patch("tensors.server.process.subprocess.Popen")
|
|
def test_reload_swaps_model(self, mock_popen: MagicMock, mock_ready: AsyncMock, api: TestClient) -> None:
|
|
pm = _get_pm(api)
|
|
pm.config = ServerConfig(model="/old.gguf", port=5555, args=["--fa"])
|
|
mock_popen.return_value.pid = 42
|
|
mock_popen.return_value.poll.return_value = None
|
|
r = api.post("/reload", json={"model": "/new.gguf"})
|
|
assert r.status_code == 200
|
|
data = r.json()
|
|
assert data["ok"] is True
|
|
assert data["model"] == "/new.gguf"
|
|
assert data["pid"] == 42
|
|
# Verify new config preserved port and args from previous config
|
|
assert pm.config is not None
|
|
assert pm.config.port == 5555
|
|
assert pm.config.args == ["--fa"]
|
|
assert pm.config.model == "/new.gguf"
|
|
|
|
@patch.object(ProcessManager, "wait_ready", new_callable=AsyncMock, return_value=False)
|
|
@patch("tensors.server.process.subprocess.Popen")
|
|
def test_reload_fails_when_not_ready(self, mock_popen: MagicMock, mock_ready: AsyncMock, api: TestClient) -> None:
|
|
pm = _get_pm(api)
|
|
pm.config = ServerConfig(model="/old.gguf")
|
|
mock_popen.return_value.pid = 43
|
|
mock_popen.return_value.poll.return_value = None
|
|
r = api.post("/reload", json={"model": "/bad.gguf"})
|
|
assert r.status_code == 503
|
|
assert "failed" in r.json()["error"]
|
|
|
|
def test_reload_requires_model(self, api: TestClient) -> None:
|
|
r = api.post("/reload", json={})
|
|
assert r.status_code == 422
|
|
|
|
|
|
class TestProxy:
|
|
def test_proxy_503_when_not_running(self, api: TestClient) -> None:
|
|
r = api.get("/v1/models")
|
|
assert r.status_code == 503
|
|
assert "not running" in r.json()["error"]
|
|
|
|
def test_proxy_forwards_request(self, api: TestClient) -> None:
|
|
pm = _get_pm(api)
|
|
mock_proc = MagicMock()
|
|
mock_proc.poll.return_value = None
|
|
mock_proc.pid = 100
|
|
pm.proc = mock_proc
|
|
pm.config = ServerConfig(model="/m.gguf", port=1234)
|
|
|
|
upstream_response = httpx.Response(
|
|
200,
|
|
json={"data": [{"id": "model-1"}]},
|
|
headers={"content-type": "application/json"},
|
|
)
|
|
mock_client = AsyncMock()
|
|
mock_client.request.return_value = upstream_response
|
|
api.app.state.client = mock_client # type: ignore[attr-defined]
|
|
|
|
r = api.get("/v1/models")
|
|
assert r.status_code == 200
|
|
assert r.json() == {"data": [{"id": "model-1"}]}
|
|
mock_client.request.assert_called_once()
|
|
|
|
def test_proxy_forwards_post_with_body(self, api: TestClient) -> None:
|
|
pm = _get_pm(api)
|
|
mock_proc = MagicMock()
|
|
mock_proc.poll.return_value = None
|
|
mock_proc.pid = 100
|
|
pm.proc = mock_proc
|
|
pm.config = ServerConfig(model="/m.gguf", port=1234)
|
|
|
|
upstream_response = httpx.Response(200, json={"ok": True})
|
|
mock_client = AsyncMock()
|
|
mock_client.request.return_value = upstream_response
|
|
api.app.state.client = mock_client # type: ignore[attr-defined]
|
|
|
|
r = api.post("/v1/chat/completions", json={"prompt": "hello"})
|
|
assert r.status_code == 200
|
|
mock_client.request.assert_called_once()
|
|
|
|
|
|
class TestProcessManager:
|
|
def test_status_not_running(self, pm: ProcessManager) -> None:
|
|
assert pm.status() == {"running": False}
|
|
|
|
def test_build_cmd(self, pm: ProcessManager) -> None:
|
|
pm.config = ServerConfig(model="/m.gguf", port=1234, args=["--fa"])
|
|
cmd = pm.build_cmd()
|
|
assert "/m.gguf" in cmd
|
|
assert "--fa" in cmd
|
|
assert "1234" in cmd
|
|
|
|
def test_build_cmd_no_config(self, pm: ProcessManager) -> None:
|
|
with pytest.raises(RuntimeError, match="No config"):
|
|
pm.build_cmd()
|
|
|
|
@patch("tensors.server.process.subprocess.Popen")
|
|
def test_start_and_stop(self, mock_popen: MagicMock, pm: ProcessManager) -> None:
|
|
mock_popen.return_value.pid = 77
|
|
mock_popen.return_value.poll.return_value = None
|
|
mock_popen.return_value.wait.return_value = 0
|
|
pm.start(ServerConfig(model="/m.gguf"))
|
|
assert pm.proc is not None
|
|
assert pm.stop() is True
|
|
assert pm.proc is None
|
|
|
|
def test_server_config_defaults(self) -> None:
|
|
cfg = ServerConfig(model="/m.gguf")
|
|
assert cfg.port == 1234
|
|
assert cfg.args == []
|
|
|
|
|
|
# =============================================================================
|
|
# Gallery Endpoint Tests
|
|
# =============================================================================
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_gallery(tmp_path):
|
|
"""Create a temporary gallery for testing."""
|
|
from tensors.server.gallery import Gallery # noqa: PLC0415
|
|
|
|
gallery_dir = tmp_path / "gallery"
|
|
gallery_dir.mkdir()
|
|
return Gallery(gallery_dir=gallery_dir)
|
|
|
|
|
|
@pytest.fixture
|
|
def gallery_with_images(temp_gallery):
|
|
"""Gallery with some test images."""
|
|
# Create test images
|
|
for i in range(3):
|
|
# Create a minimal PNG (1x1 pixel)
|
|
image_data = (
|
|
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01"
|
|
b"\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00"
|
|
b"\x00\x00\x0cIDATx\x9cc\xf8\x0f\x00\x00\x01\x01\x00"
|
|
b"\x05\x18\xd8N\x00\x00\x00\x00IEND\xaeB`\x82"
|
|
)
|
|
metadata = {"prompt": f"test prompt {i}", "seed": i, "width": 512, "height": 512}
|
|
temp_gallery.save_image(image_data, metadata=metadata, seed=i)
|
|
|
|
return temp_gallery
|
|
|
|
|
|
@pytest.fixture
|
|
def gallery_api(temp_gallery) -> TestClient:
|
|
"""Test client for gallery API with temp gallery."""
|
|
from fastapi import FastAPI # noqa: PLC0415
|
|
|
|
# Override the gallery singleton
|
|
from tensors.server import gallery_routes # noqa: PLC0415
|
|
from tensors.server.gallery_routes import create_gallery_router # noqa: PLC0415
|
|
|
|
gallery_routes._gallery = temp_gallery
|
|
|
|
app = FastAPI()
|
|
app.include_router(create_gallery_router())
|
|
return TestClient(app)
|
|
|
|
|
|
class TestGalleryList:
|
|
"""Tests for gallery list endpoint."""
|
|
|
|
def test_list_images_empty(self, gallery_api: TestClient) -> None:
|
|
"""Test listing empty gallery."""
|
|
response = gallery_api.get("/api/images")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["images"] == []
|
|
assert data["total"] == 0
|
|
|
|
def test_list_images_with_data(self, gallery_api: TestClient, gallery_with_images) -> None:
|
|
"""Test listing gallery with images."""
|
|
from tensors.server import gallery_routes # noqa: PLC0415
|
|
|
|
gallery_routes._gallery = gallery_with_images
|
|
|
|
response = gallery_api.get("/api/images")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data["images"]) == 3
|
|
assert data["total"] == 3
|
|
|
|
def test_list_images_pagination(self, gallery_api: TestClient, gallery_with_images) -> None:
|
|
"""Test pagination parameters."""
|
|
from tensors.server import gallery_routes # noqa: PLC0415
|
|
|
|
gallery_routes._gallery = gallery_with_images
|
|
|
|
response = gallery_api.get("/api/images?limit=2&offset=1")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data["images"]) == 2
|
|
|
|
|
|
class TestGalleryGetImage:
|
|
"""Tests for getting individual images."""
|
|
|
|
def test_get_image_not_found(self, gallery_api: TestClient) -> None:
|
|
"""Test getting non-existent image returns 404."""
|
|
response = gallery_api.get("/api/images/nonexistent")
|
|
assert response.status_code == 404
|
|
|
|
def test_get_image_success(self, gallery_api: TestClient, gallery_with_images) -> None:
|
|
"""Test getting an image file."""
|
|
from tensors.server import gallery_routes # noqa: PLC0415
|
|
|
|
gallery_routes._gallery = gallery_with_images
|
|
|
|
list_response = gallery_api.get("/api/images")
|
|
images = list_response.json()["images"]
|
|
image_id = images[0]["id"]
|
|
|
|
response = gallery_api.get(f"/api/images/{image_id}")
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "image/png"
|
|
|
|
|
|
class TestGalleryMetadata:
|
|
"""Tests for image metadata endpoints."""
|
|
|
|
def test_get_metadata_not_found(self, gallery_api: TestClient) -> None:
|
|
"""Test getting metadata for non-existent image."""
|
|
response = gallery_api.get("/api/images/nonexistent/meta")
|
|
assert response.status_code == 404
|
|
|
|
def test_get_metadata_success(self, gallery_api: TestClient, gallery_with_images) -> None:
|
|
"""Test getting image metadata."""
|
|
from tensors.server import gallery_routes # noqa: PLC0415
|
|
|
|
gallery_routes._gallery = gallery_with_images
|
|
|
|
list_response = gallery_api.get("/api/images")
|
|
images = list_response.json()["images"]
|
|
image_id = images[0]["id"]
|
|
|
|
response = gallery_api.get(f"/api/images/{image_id}/meta")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["id"] == image_id
|
|
assert "metadata" in data
|
|
|
|
def test_edit_metadata(self, gallery_api: TestClient, gallery_with_images) -> None:
|
|
"""Test updating image metadata."""
|
|
from tensors.server import gallery_routes # noqa: PLC0415
|
|
|
|
gallery_routes._gallery = gallery_with_images
|
|
|
|
list_response = gallery_api.get("/api/images")
|
|
images = list_response.json()["images"]
|
|
image_id = images[0]["id"]
|
|
|
|
response = gallery_api.post(
|
|
f"/api/images/{image_id}/edit",
|
|
json={"tags": ["test", "favorite"], "rating": 5},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["metadata"]["tags"] == ["test", "favorite"]
|
|
assert data["metadata"]["rating"] == 5
|
|
|
|
def test_edit_metadata_not_found(self, gallery_api: TestClient) -> None:
|
|
"""Test editing non-existent image metadata."""
|
|
response = gallery_api.post(
|
|
"/api/images/nonexistent/edit",
|
|
json={"tags": ["test"]},
|
|
)
|
|
assert response.status_code == 404
|
|
|
|
|
|
class TestGalleryDelete:
|
|
"""Tests for deleting images."""
|
|
|
|
def test_delete_image_not_found(self, gallery_api: TestClient) -> None:
|
|
"""Test deleting non-existent image."""
|
|
response = gallery_api.delete("/api/images/nonexistent")
|
|
assert response.status_code == 404
|
|
|
|
def test_delete_image_success(self, gallery_api: TestClient, gallery_with_images) -> None:
|
|
"""Test deleting an image."""
|
|
from tensors.server import gallery_routes # noqa: PLC0415
|
|
|
|
gallery_routes._gallery = gallery_with_images
|
|
|
|
list_response = gallery_api.get("/api/images")
|
|
initial_count = list_response.json()["total"]
|
|
image_id = list_response.json()["images"][0]["id"]
|
|
|
|
response = gallery_api.delete(f"/api/images/{image_id}")
|
|
assert response.status_code == 200
|
|
assert response.json()["deleted"] is True
|
|
|
|
list_response = gallery_api.get("/api/images")
|
|
assert list_response.json()["total"] == initial_count - 1
|
|
|
|
|
|
class TestGalleryStats:
|
|
"""Tests for gallery statistics endpoint."""
|
|
|
|
def test_stats_empty(self, gallery_api: TestClient) -> None:
|
|
"""Test stats on empty gallery."""
|
|
response = gallery_api.get("/api/images/stats/summary")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["total_images"] == 0
|
|
|
|
|
|
# =============================================================================
|
|
# Database Endpoint Tests
|
|
# =============================================================================
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_db(tmp_path):
|
|
"""Create a temporary database for testing."""
|
|
from tensors.db import Database # noqa: PLC0415
|
|
|
|
db_path = tmp_path / "test_models.db"
|
|
db = Database(db_path=db_path)
|
|
db.init_schema()
|
|
return db
|
|
|
|
|
|
@pytest.fixture
|
|
def db_api(temp_db, monkeypatch) -> TestClient:
|
|
"""Test client for db API with temp database."""
|
|
from fastapi import FastAPI # noqa: PLC0415
|
|
|
|
# Monkeypatch Database to use temp_db path
|
|
from tensors import db as db_module # noqa: PLC0415
|
|
from tensors.server.db_routes import create_db_router # noqa: PLC0415
|
|
|
|
monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path)
|
|
|
|
app = FastAPI()
|
|
app.include_router(create_db_router())
|
|
return TestClient(app)
|
|
|
|
|
|
class TestDbEndpoints:
|
|
"""Tests for database API endpoints."""
|
|
|
|
def test_list_files_empty(self, db_api: TestClient) -> None:
|
|
"""Test listing files from empty database."""
|
|
response = db_api.get("/api/db/files")
|
|
assert response.status_code == 200
|
|
assert response.json() == []
|
|
|
|
def test_search_models_empty(self, db_api: TestClient) -> None:
|
|
"""Test searching models in empty database."""
|
|
response = db_api.get("/api/db/models")
|
|
assert response.status_code == 200
|
|
assert response.json() == []
|
|
|
|
def test_search_models_with_query(self, db_api: TestClient, temp_db, monkeypatch) -> None:
|
|
"""Test searching models with query parameters."""
|
|
from tensors import db as db_module # noqa: PLC0415
|
|
|
|
monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path)
|
|
|
|
model_data = {
|
|
"id": 12345,
|
|
"name": "Test Model",
|
|
"type": "LORA",
|
|
"tags": [],
|
|
"modelVersions": [],
|
|
}
|
|
temp_db.cache_model(model_data)
|
|
|
|
response = db_api.get("/api/db/models?query=Test")
|
|
assert response.status_code == 200
|
|
results = response.json()
|
|
assert len(results) >= 1
|
|
|
|
def test_get_model_not_found(self, db_api: TestClient) -> None:
|
|
"""Test getting non-existent model."""
|
|
response = db_api.get("/api/db/models/999999")
|
|
assert response.status_code == 404
|
|
|
|
def test_get_model_success(self, db_api: TestClient, temp_db, monkeypatch) -> None:
|
|
"""Test getting cached model."""
|
|
from tensors import db as db_module # noqa: PLC0415
|
|
|
|
monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path)
|
|
|
|
model_data = {
|
|
"id": 12345,
|
|
"name": "Test Model",
|
|
"type": "Checkpoint",
|
|
"tags": [],
|
|
"modelVersions": [],
|
|
}
|
|
temp_db.cache_model(model_data)
|
|
|
|
response = db_api.get("/api/db/models/12345")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["name"] == "Test Model"
|
|
|
|
def test_get_stats(self, db_api: TestClient) -> None:
|
|
"""Test getting database stats."""
|
|
response = db_api.get("/api/db/stats")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "local_files" in data
|
|
assert "models" in data
|
|
|
|
|
|
# =============================================================================
|
|
# Gallery Class Unit Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestGalleryClass:
|
|
"""Unit tests for the Gallery class."""
|
|
|
|
def test_save_image(self, temp_gallery) -> None:
|
|
"""Test saving an image to gallery."""
|
|
image_data = b"\x89PNG test data"
|
|
metadata = {"prompt": "test", "seed": 42}
|
|
|
|
result = temp_gallery.save_image(image_data, metadata=metadata, seed=42)
|
|
|
|
assert result.id is not None
|
|
assert result.path.exists()
|
|
assert result.meta_path.exists()
|
|
|
|
def test_list_images_empty(self, temp_gallery) -> None:
|
|
"""Test listing empty gallery."""
|
|
images = temp_gallery.list_images()
|
|
assert images == []
|
|
|
|
def test_list_images_sorted(self, gallery_with_images) -> None:
|
|
"""Test images are sorted by creation time."""
|
|
images = gallery_with_images.list_images(newest_first=True)
|
|
assert len(images) == 3
|
|
|
|
times = [img.created_at for img in images]
|
|
assert times == sorted(times, reverse=True)
|
|
|
|
def test_get_image(self, gallery_with_images) -> None:
|
|
"""Test getting image by ID."""
|
|
images = gallery_with_images.list_images()
|
|
image_id = images[0].id
|
|
|
|
result = gallery_with_images.get_image(image_id)
|
|
assert result is not None
|
|
assert result.id == image_id
|
|
|
|
def test_get_image_not_found(self, temp_gallery) -> None:
|
|
"""Test getting non-existent image."""
|
|
result = temp_gallery.get_image("nonexistent")
|
|
assert result is None
|
|
|
|
def test_delete_image(self, gallery_with_images) -> None:
|
|
"""Test deleting an image."""
|
|
images = gallery_with_images.list_images()
|
|
image_id = images[0].id
|
|
image_path = images[0].path
|
|
|
|
result = gallery_with_images.delete_image(image_id)
|
|
assert result is True
|
|
assert not image_path.exists()
|
|
|
|
def test_delete_image_not_found(self, temp_gallery) -> None:
|
|
"""Test deleting non-existent image."""
|
|
result = temp_gallery.delete_image("nonexistent")
|
|
assert result is False
|
|
|
|
def test_count(self, gallery_with_images) -> None:
|
|
"""Test counting images."""
|
|
assert gallery_with_images.count() == 3
|
|
|
|
def test_update_metadata(self, gallery_with_images) -> None:
|
|
"""Test updating metadata."""
|
|
images = gallery_with_images.list_images()
|
|
image_id = images[0].id
|
|
|
|
result = gallery_with_images.update_metadata(image_id, {"custom_field": "value"})
|
|
assert result is not None
|
|
assert result["custom_field"] == "value"
|
|
|
|
def test_update_metadata_not_found(self, temp_gallery) -> None:
|
|
"""Test updating metadata for non-existent image."""
|
|
result = temp_gallery.update_metadata("nonexistent", {"field": "value"})
|
|
assert result is None
|