From 9f9f137f68be9a1278607365428ede390ff3d71a Mon Sep 17 00:00:00 2001 From: Adam Ladachowski Date: Sun, 15 Feb 2026 19:35:38 +0100 Subject: [PATCH] Unify CivitAI and HuggingFace search into single endpoint - Add --provider flag to search command (civitai, hf, or all) - Default to searching both providers simultaneously - Add /api/search unified REST endpoint - Map common parameters (sort, author, tag) across providers - Remove separate `hf search` subcommand (use `search -P hf`) Co-Authored-By: Claude Opus 4.5 --- tensors/cli.py | 152 ++++++++++++++++-------------- tensors/config.py | 8 ++ tensors/server/__init__.py | 2 + tensors/server/search_routes.py | 160 ++++++++++++++++++++++++++++++++ 4 files changed, 254 insertions(+), 68 deletions(-) create mode 100644 tensors/server/search_routes.py diff --git a/tensors/cli.py b/tensors/cli.py index b099d76..af0c9cf 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -26,6 +26,7 @@ from tensors.config import ( ModelType, NsfwLevel, Period, + Provider, SortOrder, get_default_output_path, load_api_key, @@ -208,62 +209,103 @@ def _save_metadata( @app.command() def search( query: Annotated[str | None, typer.Argument(help="Search query (optional)")] = None, - model_type: Annotated[ModelType | None, typer.Option("-t", "--type", help="Model type filter")] = None, - base: Annotated[BaseModel | None, typer.Option("-b", "--base", help="Base model filter")] = None, + provider: Annotated[Provider, typer.Option("--provider", "-P", help="Search provider")] = Provider.all, + model_type: Annotated[ModelType | None, typer.Option("-t", "--type", help="Model type filter (CivitAI)")] = None, + base: Annotated[BaseModel | None, typer.Option("-b", "--base", help="Base model filter (CivitAI)")] = None, sort: Annotated[SortOrder, typer.Option("-s", "--sort", help="Sort order")] = SortOrder.downloads, - limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 20, - period: Annotated[Period | None, typer.Option("-p", "--period", help="Time period")] = None, + limit: Annotated[int, typer.Option("-n", "--limit", help="Max results per provider")] = 20, + period: Annotated[Period | None, typer.Option("-p", "--period", help="Time period (CivitAI)")] = None, tag: Annotated[str | None, typer.Option("--tag", help="Filter by tag")] = None, - username: Annotated[str | None, typer.Option("-u", "--user", help="Filter by creator")] = None, - page: Annotated[int | None, typer.Option("--page", help="Page number")] = None, - nsfw: Annotated[NsfwLevel | None, typer.Option("--nsfw", help="NSFW filter level")] = None, - sfw: Annotated[bool, typer.Option("--sfw", help="Exclude NSFW content")] = False, - commercial: Annotated[CommercialUse | None, typer.Option("--commercial", help="Commercial use filter")] = None, + username: Annotated[str | None, typer.Option("-u", "--user", "-a", "--author", help="Filter by creator/author")] = None, + page: Annotated[int | None, typer.Option("--page", help="Page number (CivitAI)")] = None, + nsfw: Annotated[NsfwLevel | None, typer.Option("--nsfw", help="NSFW filter level (CivitAI)")] = None, + sfw: Annotated[bool, typer.Option("--sfw", help="Exclude NSFW content (CivitAI)")] = False, + commercial: Annotated[CommercialUse | None, typer.Option("--commercial", help="Commercial use filter (CivitAI)")] = None, + pipeline: Annotated[str | None, typer.Option("--pipeline", help="Pipeline tag (HuggingFace)")] = None, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, ) -> None: - """Search CivitAI models. + """Search models on CivitAI and/or Hugging Face. Examples: - tsr search "anime" # Search by name - tsr search -t lora -b pony # LoRAs for Pony - tsr search --tag anime -b illustrious # Tag + base model - tsr search -u "username" # By creator - tsr search -p week -s newest # New this week - tsr search --sfw # Exclude NSFW + tsr search "flux" # Search both providers + tsr search "anime" -P civitai # CivitAI only + tsr search "llama" -P hf # Hugging Face only + tsr search -t lora -b pony # CivitAI LoRAs for Pony + tsr search -a stabilityai -P hf # HF by author + tsr search --sfw -P civitai # CivitAI SFW only """ key = api_key or load_api_key() + civitai_results: dict[str, Any] | None = None + hf_results: list[dict[str, Any]] | None = None - # Handle SFW flag - nsfw_filter: NsfwLevel | bool | None = NsfwLevel.none if sfw else nsfw + # Search CivitAI + if provider in (Provider.civitai, Provider.all): + nsfw_filter: NsfwLevel | bool | None = NsfwLevel.none if sfw else nsfw + civitai_results = search_civitai( + query=query, + model_type=model_type, + base_model=base, + sort=sort, + limit=limit, + api_key=key, + console=console if provider == Provider.civitai else None, + period=period, + nsfw=nsfw_filter, + tag=tag, + username=username, + page=page, + commercial_use=commercial, + ) + if civitai_results: + _cache_models_quietly(civitai_results.get("items", [])) - results = search_civitai( - query=query, - model_type=model_type, - base_model=base, - sort=sort, - limit=limit, - api_key=key, - console=console, - period=period, - nsfw=nsfw_filter, - tag=tag, - username=username, - page=page, - commercial_use=commercial, - ) - - if not results: - console.print("[red]Search failed.[/red]") - raise typer.Exit(1) - - # Auto-cache search results - _cache_models_quietly(results.get("items", [])) + # Search Hugging Face + if provider in (Provider.hf, Provider.all): + tags = [tag] if tag else None + hf_results = search_hf_models( + query=query, + author=username, + tags=tags, + pipeline_tag=pipeline, + sort="downloads" if sort == SortOrder.downloads else "likes" if sort == SortOrder.rating else "created_at", + limit=limit, + console=console if provider == Provider.hf else None, + ) + # Output results if json_output: - console.print_json(data=results) + output: dict[str, Any] = {} + if civitai_results: + output["civitai"] = civitai_results + if hf_results: + output["huggingface"] = hf_results + console.print_json(data=output) + return + + # Display based on provider + if provider == Provider.civitai: + if not civitai_results: + console.print("[red]CivitAI search failed.[/red]") + raise typer.Exit(1) + display_search_results(civitai_results, console) + elif provider == Provider.hf: + if hf_results is None: + console.print("[red]Hugging Face search failed.[/red]") + raise typer.Exit(1) + display_hf_search_results(hf_results, console) else: - display_search_results(results, console) + # Both providers + if civitai_results and civitai_results.get("items"): + console.print("\n[bold cyan]═══ CivitAI Results ═══[/bold cyan]") + display_search_results(civitai_results, console) + + if hf_results: + console.print("\n[bold cyan]═══ Hugging Face Results ═══[/bold cyan]") + display_hf_search_results(hf_results, console) + + if not (civitai_results and civitai_results.get("items")) and not hf_results: + console.print("[yellow]No results found on either provider.[/yellow]") @app.command() @@ -737,32 +779,6 @@ hf_app = typer.Typer(name="hf", help="Hugging Face Hub commands for safetensor f app.add_typer(hf_app) -@hf_app.command("search") -def hf_search( - query: Annotated[str | None, typer.Argument(help="Search query")] = None, - author: Annotated[str | None, typer.Option("-a", "--author", help="Filter by author/org")] = None, - pipeline: Annotated[str | None, typer.Option("-p", "--pipeline", help="Pipeline tag (text-to-image, etc.)")] = None, - sort: Annotated[str | None, typer.Option("-s", "--sort", help="Sort by (downloads, likes, created_at)")] = None, - limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 25, - json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, -) -> None: - """Search Hugging Face for models with safetensor files.""" - results = search_hf_models( - query=query, - author=author, - pipeline_tag=pipeline, - sort=sort, - limit=limit, - console=console, - ) - - if json_output: - console.print_json(data=results) - return - - display_hf_search_results(results, console) - - @hf_app.command("get") def hf_get( model_id: Annotated[str, typer.Argument(help="Model ID (e.g., stabilityai/stable-diffusion-xl-base-1.0)")], diff --git a/tensors/config.py b/tensors/config.py index 9d56cd5..ea35c8f 100644 --- a/tensors/config.py +++ b/tensors/config.py @@ -41,6 +41,14 @@ CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models" # ============================================================================ +class Provider(str, Enum): + """Model search providers.""" + + civitai = "civitai" + hf = "hf" + all = "all" + + class ModelType(str, Enum): """CivitAI model types.""" diff --git a/tensors/server/__init__.py b/tensors/server/__init__.py index 1a7a802..6278852 100644 --- a/tensors/server/__init__.py +++ b/tensors/server/__init__.py @@ -14,6 +14,7 @@ from tensors.server.civitai_routes import create_civitai_router from tensors.server.db_routes import create_db_router from tensors.server.download_routes import create_download_router from tensors.server.gallery_routes import create_gallery_router +from tensors.server.search_routes import create_search_router if TYPE_CHECKING: from collections.abc import AsyncIterator @@ -61,6 +62,7 @@ def create_app() -> FastAPI: # Protected routers (auth required if configured) from tensors.server.auth import verify_api_key # noqa: PLC0415 + app.include_router(create_search_router(), dependencies=[Depends(verify_api_key)]) app.include_router(create_civitai_router(), dependencies=[Depends(verify_api_key)]) app.include_router(create_db_router(), dependencies=[Depends(verify_api_key)]) app.include_router(create_gallery_router(), dependencies=[Depends(verify_api_key)]) diff --git a/tensors/server/search_routes.py b/tensors/server/search_routes.py new file mode 100644 index 0000000..9337545 --- /dev/null +++ b/tensors/server/search_routes.py @@ -0,0 +1,160 @@ +"""FastAPI route handlers for unified model search across providers.""" + +from __future__ import annotations + +import contextlib +import logging +from enum import Enum +from typing import Annotated, Any + +from fastapi import APIRouter, Query + +from tensors.api import search_civitai +from tensors.config import ( + BaseModel as BaseModelEnum, +) +from tensors.config import ( + CommercialUse as CommercialUseEnum, +) +from tensors.config import ( + ModelType, + load_api_key, +) +from tensors.config import ( + NsfwLevel as NsfwLevelEnum, +) +from tensors.config import ( + Period as PeriodEnum, +) +from tensors.config import ( + SortOrder as SortOrderEnum, +) +from tensors.hf import search_hf_models + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/search", tags=["Search"]) + + +class Provider(str, Enum): + """Search provider options.""" + + civitai = "civitai" + hf = "hf" + all = "all" + + +class SortOrder(str, Enum): + """Sort order options.""" + + downloads = "downloads" + rating = "rating" + newest = "newest" + + +@router.get("") +async def search_models( + query: Annotated[str | None, Query(description="Search query")] = None, + provider: Annotated[Provider, Query(description="Search provider (civitai, hf, or all)")] = Provider.all, + # CivitAI-specific + types: Annotated[str | None, Query(description="Model type - CivitAI (Checkpoint, LORA, etc.)")] = None, + base_models: Annotated[str | None, Query(alias="baseModels", description="Base model - CivitAI")] = None, + period: Annotated[str | None, Query(description="Time period - CivitAI (AllTime, Year, Month, Week, Day)")] = None, + nsfw: Annotated[str | None, Query(description="NSFW level - CivitAI (None, Soft, Mature, X)")] = None, + sfw: Annotated[bool, Query(description="Exclude NSFW - CivitAI")] = False, + commercial: Annotated[str | None, Query(description="Commercial use - CivitAI (None, Image, Rent, Sell)")] = None, + page: Annotated[int | None, Query(ge=1, description="Page number - CivitAI")] = None, + # HuggingFace-specific + pipeline: Annotated[str | None, Query(description="Pipeline tag - HuggingFace (text-to-image, etc.)")] = None, + # Common + sort: Annotated[SortOrder, Query(description="Sort order")] = SortOrder.downloads, + limit: Annotated[int, Query(le=100, description="Max results per provider")] = 25, + tag: Annotated[str | None, Query(description="Filter by tag")] = None, + author: Annotated[str | None, Query(description="Filter by author/creator")] = None, +) -> dict[str, Any]: + """Search models across CivitAI and/or Hugging Face. + + Returns results from selected provider(s). When provider=all, returns + results from both CivitAI and Hugging Face in separate keys. + """ + api_key = load_api_key() + results: dict[str, Any] = {} + + # Search CivitAI + if provider in (Provider.civitai, Provider.all): + # Map sort order to CivitAI enum + civitai_sort = SortOrderEnum.downloads + if sort == SortOrder.rating: + civitai_sort = SortOrderEnum.rating + elif sort == SortOrder.newest: + civitai_sort = SortOrderEnum.newest + + # Map other enums + model_type = ModelType(types.lower()) if types else None + base_model = None + if base_models: + with contextlib.suppress(ValueError): + base_model = BaseModelEnum(base_models.lower()) + + period_enum = None + if period: + with contextlib.suppress(ValueError): + period_enum = PeriodEnum(period.lower()) + + nsfw_filter: NsfwLevelEnum | bool | None = None + if sfw: + nsfw_filter = NsfwLevelEnum.none + elif nsfw: + with contextlib.suppress(ValueError): + nsfw_filter = NsfwLevelEnum(nsfw.lower()) + + commercial_enum = None + if commercial: + with contextlib.suppress(ValueError): + commercial_enum = CommercialUseEnum(commercial.lower()) + + civitai_results = search_civitai( + query=query, + model_type=model_type, + base_model=base_model, + sort=civitai_sort, + limit=limit, + api_key=api_key, + period=period_enum, + nsfw=nsfw_filter, + tag=tag, + username=author, + page=page, + commercial_use=commercial_enum, + ) + + if civitai_results: + results["civitai"] = civitai_results + + # Search Hugging Face + if provider in (Provider.hf, Provider.all): + hf_sort = "downloads" + if sort == SortOrder.rating: + hf_sort = "likes" + elif sort == SortOrder.newest: + hf_sort = "created_at" + + tags = [tag] if tag else None + hf_results = search_hf_models( + query=query, + author=author, + tags=tags, + pipeline_tag=pipeline, + sort=hf_sort, + limit=limit, + ) + + if hf_results: + results["huggingface"] = hf_results + + return results + + +def create_search_router() -> APIRouter: + """Return the unified search API router.""" + return router