Unify CivitAI and HuggingFace search into single endpoint

- Add --provider flag to search command (civitai, hf, or all)
- Default to searching both providers simultaneously
- Add /api/search unified REST endpoint
- Map common parameters (sort, author, tag) across providers
- Remove separate `hf search` subcommand (use `search -P hf`)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Adam Ladachowski
2026-02-15 19:35:38 +01:00
parent eb151dac8d
commit 9f9f137f68
4 changed files with 254 additions and 68 deletions
+84 -68
View File
@@ -26,6 +26,7 @@ from tensors.config import (
ModelType, ModelType,
NsfwLevel, NsfwLevel,
Period, Period,
Provider,
SortOrder, SortOrder,
get_default_output_path, get_default_output_path,
load_api_key, load_api_key,
@@ -208,62 +209,103 @@ def _save_metadata(
@app.command() @app.command()
def search( def search(
query: Annotated[str | None, typer.Argument(help="Search query (optional)")] = None, query: Annotated[str | None, typer.Argument(help="Search query (optional)")] = None,
model_type: Annotated[ModelType | None, typer.Option("-t", "--type", help="Model type filter")] = None, provider: Annotated[Provider, typer.Option("--provider", "-P", help="Search provider")] = Provider.all,
base: Annotated[BaseModel | None, typer.Option("-b", "--base", help="Base model filter")] = None, model_type: Annotated[ModelType | None, typer.Option("-t", "--type", help="Model type filter (CivitAI)")] = None,
base: Annotated[BaseModel | None, typer.Option("-b", "--base", help="Base model filter (CivitAI)")] = None,
sort: Annotated[SortOrder, typer.Option("-s", "--sort", help="Sort order")] = SortOrder.downloads, sort: Annotated[SortOrder, typer.Option("-s", "--sort", help="Sort order")] = SortOrder.downloads,
limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 20, limit: Annotated[int, typer.Option("-n", "--limit", help="Max results per provider")] = 20,
period: Annotated[Period | None, typer.Option("-p", "--period", help="Time period")] = None, period: Annotated[Period | None, typer.Option("-p", "--period", help="Time period (CivitAI)")] = None,
tag: Annotated[str | None, typer.Option("--tag", help="Filter by tag")] = None, tag: Annotated[str | None, typer.Option("--tag", help="Filter by tag")] = None,
username: Annotated[str | None, typer.Option("-u", "--user", help="Filter by creator")] = None, username: Annotated[str | None, typer.Option("-u", "--user", "-a", "--author", help="Filter by creator/author")] = None,
page: Annotated[int | None, typer.Option("--page", help="Page number")] = None, page: Annotated[int | None, typer.Option("--page", help="Page number (CivitAI)")] = None,
nsfw: Annotated[NsfwLevel | None, typer.Option("--nsfw", help="NSFW filter level")] = None, nsfw: Annotated[NsfwLevel | None, typer.Option("--nsfw", help="NSFW filter level (CivitAI)")] = None,
sfw: Annotated[bool, typer.Option("--sfw", help="Exclude NSFW content")] = False, sfw: Annotated[bool, typer.Option("--sfw", help="Exclude NSFW content (CivitAI)")] = False,
commercial: Annotated[CommercialUse | None, typer.Option("--commercial", help="Commercial use filter")] = None, commercial: Annotated[CommercialUse | None, typer.Option("--commercial", help="Commercial use filter (CivitAI)")] = None,
pipeline: Annotated[str | None, typer.Option("--pipeline", help="Pipeline tag (HuggingFace)")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
) -> None: ) -> None:
"""Search CivitAI models. """Search models on CivitAI and/or Hugging Face.
Examples: Examples:
tsr search "anime" # Search by name tsr search "flux" # Search both providers
tsr search -t lora -b pony # LoRAs for Pony tsr search "anime" -P civitai # CivitAI only
tsr search --tag anime -b illustrious # Tag + base model tsr search "llama" -P hf # Hugging Face only
tsr search -u "username" # By creator tsr search -t lora -b pony # CivitAI LoRAs for Pony
tsr search -p week -s newest # New this week tsr search -a stabilityai -P hf # HF by author
tsr search --sfw # Exclude NSFW tsr search --sfw -P civitai # CivitAI SFW only
""" """
key = api_key or load_api_key() key = api_key or load_api_key()
civitai_results: dict[str, Any] | None = None
hf_results: list[dict[str, Any]] | None = None
# Handle SFW flag # Search CivitAI
nsfw_filter: NsfwLevel | bool | None = NsfwLevel.none if sfw else nsfw if provider in (Provider.civitai, Provider.all):
nsfw_filter: NsfwLevel | bool | None = NsfwLevel.none if sfw else nsfw
civitai_results = search_civitai(
query=query,
model_type=model_type,
base_model=base,
sort=sort,
limit=limit,
api_key=key,
console=console if provider == Provider.civitai else None,
period=period,
nsfw=nsfw_filter,
tag=tag,
username=username,
page=page,
commercial_use=commercial,
)
if civitai_results:
_cache_models_quietly(civitai_results.get("items", []))
results = search_civitai( # Search Hugging Face
query=query, if provider in (Provider.hf, Provider.all):
model_type=model_type, tags = [tag] if tag else None
base_model=base, hf_results = search_hf_models(
sort=sort, query=query,
limit=limit, author=username,
api_key=key, tags=tags,
console=console, pipeline_tag=pipeline,
period=period, sort="downloads" if sort == SortOrder.downloads else "likes" if sort == SortOrder.rating else "created_at",
nsfw=nsfw_filter, limit=limit,
tag=tag, console=console if provider == Provider.hf else None,
username=username, )
page=page,
commercial_use=commercial,
)
if not results:
console.print("[red]Search failed.[/red]")
raise typer.Exit(1)
# Auto-cache search results
_cache_models_quietly(results.get("items", []))
# Output results
if json_output: if json_output:
console.print_json(data=results) output: dict[str, Any] = {}
if civitai_results:
output["civitai"] = civitai_results
if hf_results:
output["huggingface"] = hf_results
console.print_json(data=output)
return
# Display based on provider
if provider == Provider.civitai:
if not civitai_results:
console.print("[red]CivitAI search failed.[/red]")
raise typer.Exit(1)
display_search_results(civitai_results, console)
elif provider == Provider.hf:
if hf_results is None:
console.print("[red]Hugging Face search failed.[/red]")
raise typer.Exit(1)
display_hf_search_results(hf_results, console)
else: else:
display_search_results(results, console) # Both providers
if civitai_results and civitai_results.get("items"):
console.print("\n[bold cyan]═══ CivitAI Results ═══[/bold cyan]")
display_search_results(civitai_results, console)
if hf_results:
console.print("\n[bold cyan]═══ Hugging Face Results ═══[/bold cyan]")
display_hf_search_results(hf_results, console)
if not (civitai_results and civitai_results.get("items")) and not hf_results:
console.print("[yellow]No results found on either provider.[/yellow]")
@app.command() @app.command()
@@ -737,32 +779,6 @@ hf_app = typer.Typer(name="hf", help="Hugging Face Hub commands for safetensor f
app.add_typer(hf_app) app.add_typer(hf_app)
@hf_app.command("search")
def hf_search(
query: Annotated[str | None, typer.Argument(help="Search query")] = None,
author: Annotated[str | None, typer.Option("-a", "--author", help="Filter by author/org")] = None,
pipeline: Annotated[str | None, typer.Option("-p", "--pipeline", help="Pipeline tag (text-to-image, etc.)")] = None,
sort: Annotated[str | None, typer.Option("-s", "--sort", help="Sort by (downloads, likes, created_at)")] = None,
limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 25,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
) -> None:
"""Search Hugging Face for models with safetensor files."""
results = search_hf_models(
query=query,
author=author,
pipeline_tag=pipeline,
sort=sort,
limit=limit,
console=console,
)
if json_output:
console.print_json(data=results)
return
display_hf_search_results(results, console)
@hf_app.command("get") @hf_app.command("get")
def hf_get( def hf_get(
model_id: Annotated[str, typer.Argument(help="Model ID (e.g., stabilityai/stable-diffusion-xl-base-1.0)")], model_id: Annotated[str, typer.Argument(help="Model ID (e.g., stabilityai/stable-diffusion-xl-base-1.0)")],
+8
View File
@@ -41,6 +41,14 @@ CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models"
# ============================================================================ # ============================================================================
class Provider(str, Enum):
"""Model search providers."""
civitai = "civitai"
hf = "hf"
all = "all"
class ModelType(str, Enum): class ModelType(str, Enum):
"""CivitAI model types.""" """CivitAI model types."""
+2
View File
@@ -14,6 +14,7 @@ from tensors.server.civitai_routes import create_civitai_router
from tensors.server.db_routes import create_db_router from tensors.server.db_routes import create_db_router
from tensors.server.download_routes import create_download_router from tensors.server.download_routes import create_download_router
from tensors.server.gallery_routes import create_gallery_router from tensors.server.gallery_routes import create_gallery_router
from tensors.server.search_routes import create_search_router
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
@@ -61,6 +62,7 @@ def create_app() -> FastAPI:
# Protected routers (auth required if configured) # Protected routers (auth required if configured)
from tensors.server.auth import verify_api_key # noqa: PLC0415 from tensors.server.auth import verify_api_key # noqa: PLC0415
app.include_router(create_search_router(), dependencies=[Depends(verify_api_key)])
app.include_router(create_civitai_router(), dependencies=[Depends(verify_api_key)]) app.include_router(create_civitai_router(), dependencies=[Depends(verify_api_key)])
app.include_router(create_db_router(), dependencies=[Depends(verify_api_key)]) app.include_router(create_db_router(), dependencies=[Depends(verify_api_key)])
app.include_router(create_gallery_router(), dependencies=[Depends(verify_api_key)]) app.include_router(create_gallery_router(), dependencies=[Depends(verify_api_key)])
+160
View File
@@ -0,0 +1,160 @@
"""FastAPI route handlers for unified model search across providers."""
from __future__ import annotations
import contextlib
import logging
from enum import Enum
from typing import Annotated, Any
from fastapi import APIRouter, Query
from tensors.api import search_civitai
from tensors.config import (
BaseModel as BaseModelEnum,
)
from tensors.config import (
CommercialUse as CommercialUseEnum,
)
from tensors.config import (
ModelType,
load_api_key,
)
from tensors.config import (
NsfwLevel as NsfwLevelEnum,
)
from tensors.config import (
Period as PeriodEnum,
)
from tensors.config import (
SortOrder as SortOrderEnum,
)
from tensors.hf import search_hf_models
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/search", tags=["Search"])
class Provider(str, Enum):
"""Search provider options."""
civitai = "civitai"
hf = "hf"
all = "all"
class SortOrder(str, Enum):
"""Sort order options."""
downloads = "downloads"
rating = "rating"
newest = "newest"
@router.get("")
async def search_models(
query: Annotated[str | None, Query(description="Search query")] = None,
provider: Annotated[Provider, Query(description="Search provider (civitai, hf, or all)")] = Provider.all,
# CivitAI-specific
types: Annotated[str | None, Query(description="Model type - CivitAI (Checkpoint, LORA, etc.)")] = None,
base_models: Annotated[str | None, Query(alias="baseModels", description="Base model - CivitAI")] = None,
period: Annotated[str | None, Query(description="Time period - CivitAI (AllTime, Year, Month, Week, Day)")] = None,
nsfw: Annotated[str | None, Query(description="NSFW level - CivitAI (None, Soft, Mature, X)")] = None,
sfw: Annotated[bool, Query(description="Exclude NSFW - CivitAI")] = False,
commercial: Annotated[str | None, Query(description="Commercial use - CivitAI (None, Image, Rent, Sell)")] = None,
page: Annotated[int | None, Query(ge=1, description="Page number - CivitAI")] = None,
# HuggingFace-specific
pipeline: Annotated[str | None, Query(description="Pipeline tag - HuggingFace (text-to-image, etc.)")] = None,
# Common
sort: Annotated[SortOrder, Query(description="Sort order")] = SortOrder.downloads,
limit: Annotated[int, Query(le=100, description="Max results per provider")] = 25,
tag: Annotated[str | None, Query(description="Filter by tag")] = None,
author: Annotated[str | None, Query(description="Filter by author/creator")] = None,
) -> dict[str, Any]:
"""Search models across CivitAI and/or Hugging Face.
Returns results from selected provider(s). When provider=all, returns
results from both CivitAI and Hugging Face in separate keys.
"""
api_key = load_api_key()
results: dict[str, Any] = {}
# Search CivitAI
if provider in (Provider.civitai, Provider.all):
# Map sort order to CivitAI enum
civitai_sort = SortOrderEnum.downloads
if sort == SortOrder.rating:
civitai_sort = SortOrderEnum.rating
elif sort == SortOrder.newest:
civitai_sort = SortOrderEnum.newest
# Map other enums
model_type = ModelType(types.lower()) if types else None
base_model = None
if base_models:
with contextlib.suppress(ValueError):
base_model = BaseModelEnum(base_models.lower())
period_enum = None
if period:
with contextlib.suppress(ValueError):
period_enum = PeriodEnum(period.lower())
nsfw_filter: NsfwLevelEnum | bool | None = None
if sfw:
nsfw_filter = NsfwLevelEnum.none
elif nsfw:
with contextlib.suppress(ValueError):
nsfw_filter = NsfwLevelEnum(nsfw.lower())
commercial_enum = None
if commercial:
with contextlib.suppress(ValueError):
commercial_enum = CommercialUseEnum(commercial.lower())
civitai_results = search_civitai(
query=query,
model_type=model_type,
base_model=base_model,
sort=civitai_sort,
limit=limit,
api_key=api_key,
period=period_enum,
nsfw=nsfw_filter,
tag=tag,
username=author,
page=page,
commercial_use=commercial_enum,
)
if civitai_results:
results["civitai"] = civitai_results
# Search Hugging Face
if provider in (Provider.hf, Provider.all):
hf_sort = "downloads"
if sort == SortOrder.rating:
hf_sort = "likes"
elif sort == SortOrder.newest:
hf_sort = "created_at"
tags = [tag] if tag else None
hf_results = search_hf_models(
query=query,
author=author,
tags=tags,
pipeline_tag=pipeline,
sort=hf_sort,
limit=limit,
)
if hf_results:
results["huggingface"] = hf_results
return results
def create_search_router() -> APIRouter:
"""Return the unified search API router."""
return router