diff --git a/tensors/server/civitai_routes.py b/tensors/server/civitai_routes.py index 9b90ba3..3707d07 100644 --- a/tensors/server/civitai_routes.py +++ b/tensors/server/civitai_routes.py @@ -3,11 +3,10 @@ from __future__ import annotations import logging -from enum import Enum -from typing import Annotated, Any +from typing import Any import httpx -from fastapi import APIRouter, Query, Response +from fastapi import APIRouter, Response from fastapi.responses import JSONResponse from tensors.config import CIVITAI_API_BASE, load_api_key @@ -18,42 +17,6 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/civitai", tags=["CivitAI"]) -class SortOrder(str, Enum): - """Sort order options for CivitAI search.""" - - most_downloaded = "Most Downloaded" - highest_rated = "Highest Rated" - newest = "Newest" - - -class Period(str, Enum): - """Time period filter options.""" - - all = "AllTime" - year = "Year" - month = "Month" - week = "Week" - day = "Day" - - -class NsfwLevel(str, Enum): - """NSFW content filter levels.""" - - none = "None" - soft = "Soft" - mature = "Mature" - x = "X" - - -class CommercialUse(str, Enum): - """Commercial use filter options.""" - - none = "None" - image = "Image" - rent = "Rent" - sell = "Sell" - - def _get_headers(api_key: str | None) -> dict[str, str]: """Get headers for CivitAI API requests.""" headers: dict[str, str] = {} @@ -62,87 +25,6 @@ def _get_headers(api_key: str | None) -> dict[str, str]: return headers -@router.get("/search", response_model=None) -async def search_models( - query: Annotated[str | None, Query(description="Search query")] = None, - types: Annotated[str | None, Query(description="Model type (Checkpoint, LORA, etc.)")] = None, - base_models: Annotated[str | None, Query(alias="baseModels", description="Base model")] = None, - sort: Annotated[SortOrder, Query(description="Sort order")] = SortOrder.most_downloaded, - limit: Annotated[int | None, Query(le=100, description="Max results (default: 25)", example=5)] = None, - period: Annotated[Period | None, Query(description="Time period filter")] = None, - tag: Annotated[str | None, Query(description="Filter by tag")] = None, - username: Annotated[str | None, Query(description="Filter by creator username")] = None, - page: Annotated[int | None, Query(ge=1, description="Page number")] = None, - nsfw: Annotated[NsfwLevel | None, Query(description="NSFW filter level")] = None, - sfw: Annotated[bool, Query(description="Exclude NSFW content")] = False, - commercial: Annotated[CommercialUse | None, Query(description="Commercial use filter")] = None, -) -> dict[str, Any] | Response: - """Search CivitAI models. - - Supports all CivitAI search parameters including filters for type, base model, - time period, tags, creator, NSFW level, and commercial use. - """ - api_key = load_api_key() - actual_limit = limit if limit is not None else 25 - - params: dict[str, Any] = { - "limit": min(actual_limit, 100), - "sort": sort.value, - } - - # Handle NSFW filtering - if sfw: - params["nsfw"] = "false" - elif nsfw: - params["nsfwLevel"] = nsfw.value - else: - params["nsfw"] = "true" # Default: include all - - if query: - params["query"] = query - if types: - params["types"] = types - if base_models: - params["baseModels"] = base_models - if period: - params["period"] = period.value - if tag: - params["tag"] = tag - if username: - params["username"] = username - if page: - params["page"] = page - if commercial: - params["allowCommercialUse"] = commercial.value - - url = f"{CIVITAI_API_BASE}/models" - - try: - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.get(url, params=params, headers=_get_headers(api_key)) - response.raise_for_status() - result: dict[str, Any] = response.json() - - # Cache all models from search results - items = result.get("items", []) - if items: - try: - with Database() as db: - db.init_schema() - for model_data in items: - db.cache_model(model_data) - except Exception as e: - logger.warning("Failed to cache search results: %s", e) - - return result - except httpx.HTTPStatusError as e: - logger.error("CivitAI API error: %s", e.response.status_code) - return JSONResponse({"error": f"API error: {e.response.status_code}"}, status_code=e.response.status_code) - except httpx.RequestError as e: - logger.error("CivitAI request error: %s", e) - return JSONResponse({"error": f"Request error: {e}"}, status_code=500) - - @router.get("/model/{model_id}", response_model=None) async def get_model(model_id: int) -> dict[str, Any] | Response: """Get model details from CivitAI and cache to database.""" diff --git a/tests/test_server.py b/tests/test_server.py index 0ae8f77..b3930a1 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -481,72 +481,6 @@ def civitai_api(monkeypatch) -> TestClient: return TestClient(app) -class TestCivitAISearch: - """Tests for CivitAI search endpoint.""" - - def test_search_basic(self, civitai_api: TestClient, respx_mock) -> None: - """Test basic search request.""" - import respx - - respx_mock.get("https://civitai.com/api/v1/models").mock( - return_value=respx.MockResponse( - 200, - json={"items": [{"id": 1, "name": "Test Model"}], "metadata": {}}, - ) - ) - - response = civitai_api.get("/api/civitai/search") - assert response.status_code == 200 - data = response.json() - assert "items" in data - - def test_search_with_params(self, civitai_api: TestClient, respx_mock) -> None: - """Test search with query parameters.""" - import respx - - respx_mock.get("https://civitai.com/api/v1/models").mock( - return_value=respx.MockResponse( - 200, - json={"items": [], "metadata": {}}, - ) - ) - - response = civitai_api.get( - "/api/civitai/search", - params={ - "query": "anime", - "types": "LORA", - "baseModels": "Illustrious", - "sort": "Newest", - "limit": 10, - "period": "Week", - "tag": "character", - "sfw": True, - }, - ) - assert response.status_code == 200 - - def test_search_api_error(self, civitai_api: TestClient, respx_mock) -> None: - """Test search handles API errors.""" - import respx - - respx_mock.get("https://civitai.com/api/v1/models").mock( - return_value=respx.MockResponse(500, json={"error": "Server error"}) - ) - - response = civitai_api.get("/api/civitai/search") - assert response.status_code == 500 - - def test_search_network_error(self, civitai_api: TestClient, respx_mock) -> None: - """Test search handles network errors.""" - import httpx - - respx_mock.get("https://civitai.com/api/v1/models").mock(side_effect=httpx.RequestError("Connection failed")) - - response = civitai_api.get("/api/civitai/search") - assert response.status_code == 500 - - class TestCivitAIGetModel: """Tests for CivitAI get model endpoint.""" @@ -1195,50 +1129,6 @@ class TestDownloadEndpoints: class TestCivitAIRoutesExtended: """Extended tests for CivitAI routes.""" - def test_search_with_nsfw_filter(self, civitai_api: TestClient, respx_mock) -> None: - """Test search with NSFW filter.""" - import respx - - respx_mock.get("https://civitai.com/api/v1/models").mock( - return_value=respx.MockResponse(200, json={"items": [], "metadata": {}}) - ) - - response = civitai_api.get("/api/civitai/search", params={"nsfw": "Soft"}) - assert response.status_code == 200 - - def test_search_with_commercial_filter(self, civitai_api: TestClient, respx_mock) -> None: - """Test search with commercial use filter.""" - import respx - - respx_mock.get("https://civitai.com/api/v1/models").mock( - return_value=respx.MockResponse(200, json={"items": [], "metadata": {}}) - ) - - response = civitai_api.get("/api/civitai/search", params={"commercial": "Rent"}) - assert response.status_code == 200 - - def test_search_with_username(self, civitai_api: TestClient, respx_mock) -> None: - """Test search with username filter.""" - import respx - - respx_mock.get("https://civitai.com/api/v1/models").mock( - return_value=respx.MockResponse(200, json={"items": [], "metadata": {}}) - ) - - response = civitai_api.get("/api/civitai/search", params={"username": "testuser"}) - assert response.status_code == 200 - - def test_search_with_page(self, civitai_api: TestClient, respx_mock) -> None: - """Test search with page parameter.""" - import respx - - respx_mock.get("https://civitai.com/api/v1/models").mock( - return_value=respx.MockResponse(200, json={"items": [], "metadata": {}}) - ) - - response = civitai_api.get("/api/civitai/search", params={"page": 2}) - assert response.status_code == 200 - def test_get_model_caches_result(self, civitai_api: TestClient, respx_mock, temp_db, monkeypatch) -> None: """Test that getting a model caches it in the database.""" import respx