Unify CivitAI and HuggingFace search into single endpoint
- Add --provider flag to search command (civitai, hf, or all) - Default to searching both providers simultaneously - Add /api/search unified REST endpoint - Map common parameters (sort, author, tag) across providers - Remove separate `hf search` subcommand (use `search -P hf`) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
+70
-54
@@ -26,6 +26,7 @@ from tensors.config import (
|
||||
ModelType,
|
||||
NsfwLevel,
|
||||
Period,
|
||||
Provider,
|
||||
SortOrder,
|
||||
get_default_output_path,
|
||||
load_api_key,
|
||||
@@ -208,43 +209,47 @@ def _save_metadata(
|
||||
@app.command()
|
||||
def search(
|
||||
query: Annotated[str | None, typer.Argument(help="Search query (optional)")] = None,
|
||||
model_type: Annotated[ModelType | None, typer.Option("-t", "--type", help="Model type filter")] = None,
|
||||
base: Annotated[BaseModel | None, typer.Option("-b", "--base", help="Base model filter")] = None,
|
||||
provider: Annotated[Provider, typer.Option("--provider", "-P", help="Search provider")] = Provider.all,
|
||||
model_type: Annotated[ModelType | None, typer.Option("-t", "--type", help="Model type filter (CivitAI)")] = None,
|
||||
base: Annotated[BaseModel | None, typer.Option("-b", "--base", help="Base model filter (CivitAI)")] = None,
|
||||
sort: Annotated[SortOrder, typer.Option("-s", "--sort", help="Sort order")] = SortOrder.downloads,
|
||||
limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 20,
|
||||
period: Annotated[Period | None, typer.Option("-p", "--period", help="Time period")] = None,
|
||||
limit: Annotated[int, typer.Option("-n", "--limit", help="Max results per provider")] = 20,
|
||||
period: Annotated[Period | None, typer.Option("-p", "--period", help="Time period (CivitAI)")] = None,
|
||||
tag: Annotated[str | None, typer.Option("--tag", help="Filter by tag")] = None,
|
||||
username: Annotated[str | None, typer.Option("-u", "--user", help="Filter by creator")] = None,
|
||||
page: Annotated[int | None, typer.Option("--page", help="Page number")] = None,
|
||||
nsfw: Annotated[NsfwLevel | None, typer.Option("--nsfw", help="NSFW filter level")] = None,
|
||||
sfw: Annotated[bool, typer.Option("--sfw", help="Exclude NSFW content")] = False,
|
||||
commercial: Annotated[CommercialUse | None, typer.Option("--commercial", help="Commercial use filter")] = None,
|
||||
username: Annotated[str | None, typer.Option("-u", "--user", "-a", "--author", help="Filter by creator/author")] = None,
|
||||
page: Annotated[int | None, typer.Option("--page", help="Page number (CivitAI)")] = None,
|
||||
nsfw: Annotated[NsfwLevel | None, typer.Option("--nsfw", help="NSFW filter level (CivitAI)")] = None,
|
||||
sfw: Annotated[bool, typer.Option("--sfw", help="Exclude NSFW content (CivitAI)")] = False,
|
||||
commercial: Annotated[CommercialUse | None, typer.Option("--commercial", help="Commercial use filter (CivitAI)")] = None,
|
||||
pipeline: Annotated[str | None, typer.Option("--pipeline", help="Pipeline tag (HuggingFace)")] = None,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
||||
) -> None:
|
||||
"""Search CivitAI models.
|
||||
"""Search models on CivitAI and/or Hugging Face.
|
||||
|
||||
Examples:
|
||||
tsr search "anime" # Search by name
|
||||
tsr search -t lora -b pony # LoRAs for Pony
|
||||
tsr search --tag anime -b illustrious # Tag + base model
|
||||
tsr search -u "username" # By creator
|
||||
tsr search -p week -s newest # New this week
|
||||
tsr search --sfw # Exclude NSFW
|
||||
tsr search "flux" # Search both providers
|
||||
tsr search "anime" -P civitai # CivitAI only
|
||||
tsr search "llama" -P hf # Hugging Face only
|
||||
tsr search -t lora -b pony # CivitAI LoRAs for Pony
|
||||
tsr search -a stabilityai -P hf # HF by author
|
||||
tsr search --sfw -P civitai # CivitAI SFW only
|
||||
"""
|
||||
key = api_key or load_api_key()
|
||||
civitai_results: dict[str, Any] | None = None
|
||||
hf_results: list[dict[str, Any]] | None = None
|
||||
|
||||
# Handle SFW flag
|
||||
# Search CivitAI
|
||||
if provider in (Provider.civitai, Provider.all):
|
||||
nsfw_filter: NsfwLevel | bool | None = NsfwLevel.none if sfw else nsfw
|
||||
|
||||
results = search_civitai(
|
||||
civitai_results = search_civitai(
|
||||
query=query,
|
||||
model_type=model_type,
|
||||
base_model=base,
|
||||
sort=sort,
|
||||
limit=limit,
|
||||
api_key=key,
|
||||
console=console,
|
||||
console=console if provider == Provider.civitai else None,
|
||||
period=period,
|
||||
nsfw=nsfw_filter,
|
||||
tag=tag,
|
||||
@@ -252,18 +257,55 @@ def search(
|
||||
page=page,
|
||||
commercial_use=commercial,
|
||||
)
|
||||
if civitai_results:
|
||||
_cache_models_quietly(civitai_results.get("items", []))
|
||||
|
||||
if not results:
|
||||
console.print("[red]Search failed.[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Auto-cache search results
|
||||
_cache_models_quietly(results.get("items", []))
|
||||
# Search Hugging Face
|
||||
if provider in (Provider.hf, Provider.all):
|
||||
tags = [tag] if tag else None
|
||||
hf_results = search_hf_models(
|
||||
query=query,
|
||||
author=username,
|
||||
tags=tags,
|
||||
pipeline_tag=pipeline,
|
||||
sort="downloads" if sort == SortOrder.downloads else "likes" if sort == SortOrder.rating else "created_at",
|
||||
limit=limit,
|
||||
console=console if provider == Provider.hf else None,
|
||||
)
|
||||
|
||||
# Output results
|
||||
if json_output:
|
||||
console.print_json(data=results)
|
||||
output: dict[str, Any] = {}
|
||||
if civitai_results:
|
||||
output["civitai"] = civitai_results
|
||||
if hf_results:
|
||||
output["huggingface"] = hf_results
|
||||
console.print_json(data=output)
|
||||
return
|
||||
|
||||
# Display based on provider
|
||||
if provider == Provider.civitai:
|
||||
if not civitai_results:
|
||||
console.print("[red]CivitAI search failed.[/red]")
|
||||
raise typer.Exit(1)
|
||||
display_search_results(civitai_results, console)
|
||||
elif provider == Provider.hf:
|
||||
if hf_results is None:
|
||||
console.print("[red]Hugging Face search failed.[/red]")
|
||||
raise typer.Exit(1)
|
||||
display_hf_search_results(hf_results, console)
|
||||
else:
|
||||
display_search_results(results, console)
|
||||
# Both providers
|
||||
if civitai_results and civitai_results.get("items"):
|
||||
console.print("\n[bold cyan]═══ CivitAI Results ═══[/bold cyan]")
|
||||
display_search_results(civitai_results, console)
|
||||
|
||||
if hf_results:
|
||||
console.print("\n[bold cyan]═══ Hugging Face Results ═══[/bold cyan]")
|
||||
display_hf_search_results(hf_results, console)
|
||||
|
||||
if not (civitai_results and civitai_results.get("items")) and not hf_results:
|
||||
console.print("[yellow]No results found on either provider.[/yellow]")
|
||||
|
||||
|
||||
@app.command()
|
||||
@@ -737,32 +779,6 @@ hf_app = typer.Typer(name="hf", help="Hugging Face Hub commands for safetensor f
|
||||
app.add_typer(hf_app)
|
||||
|
||||
|
||||
@hf_app.command("search")
|
||||
def hf_search(
|
||||
query: Annotated[str | None, typer.Argument(help="Search query")] = None,
|
||||
author: Annotated[str | None, typer.Option("-a", "--author", help="Filter by author/org")] = None,
|
||||
pipeline: Annotated[str | None, typer.Option("-p", "--pipeline", help="Pipeline tag (text-to-image, etc.)")] = None,
|
||||
sort: Annotated[str | None, typer.Option("-s", "--sort", help="Sort by (downloads, likes, created_at)")] = None,
|
||||
limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 25,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
) -> None:
|
||||
"""Search Hugging Face for models with safetensor files."""
|
||||
results = search_hf_models(
|
||||
query=query,
|
||||
author=author,
|
||||
pipeline_tag=pipeline,
|
||||
sort=sort,
|
||||
limit=limit,
|
||||
console=console,
|
||||
)
|
||||
|
||||
if json_output:
|
||||
console.print_json(data=results)
|
||||
return
|
||||
|
||||
display_hf_search_results(results, console)
|
||||
|
||||
|
||||
@hf_app.command("get")
|
||||
def hf_get(
|
||||
model_id: Annotated[str, typer.Argument(help="Model ID (e.g., stabilityai/stable-diffusion-xl-base-1.0)")],
|
||||
|
||||
@@ -41,6 +41,14 @@ CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models"
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class Provider(str, Enum):
|
||||
"""Model search providers."""
|
||||
|
||||
civitai = "civitai"
|
||||
hf = "hf"
|
||||
all = "all"
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
"""CivitAI model types."""
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from tensors.server.civitai_routes import create_civitai_router
|
||||
from tensors.server.db_routes import create_db_router
|
||||
from tensors.server.download_routes import create_download_router
|
||||
from tensors.server.gallery_routes import create_gallery_router
|
||||
from tensors.server.search_routes import create_search_router
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator
|
||||
@@ -61,6 +62,7 @@ def create_app() -> FastAPI:
|
||||
# Protected routers (auth required if configured)
|
||||
from tensors.server.auth import verify_api_key # noqa: PLC0415
|
||||
|
||||
app.include_router(create_search_router(), dependencies=[Depends(verify_api_key)])
|
||||
app.include_router(create_civitai_router(), dependencies=[Depends(verify_api_key)])
|
||||
app.include_router(create_db_router(), dependencies=[Depends(verify_api_key)])
|
||||
app.include_router(create_gallery_router(), dependencies=[Depends(verify_api_key)])
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
"""FastAPI route handlers for unified model search across providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Query
|
||||
|
||||
from tensors.api import search_civitai
|
||||
from tensors.config import (
|
||||
BaseModel as BaseModelEnum,
|
||||
)
|
||||
from tensors.config import (
|
||||
CommercialUse as CommercialUseEnum,
|
||||
)
|
||||
from tensors.config import (
|
||||
ModelType,
|
||||
load_api_key,
|
||||
)
|
||||
from tensors.config import (
|
||||
NsfwLevel as NsfwLevelEnum,
|
||||
)
|
||||
from tensors.config import (
|
||||
Period as PeriodEnum,
|
||||
)
|
||||
from tensors.config import (
|
||||
SortOrder as SortOrderEnum,
|
||||
)
|
||||
from tensors.hf import search_hf_models
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/search", tags=["Search"])
|
||||
|
||||
|
||||
class Provider(str, Enum):
|
||||
"""Search provider options."""
|
||||
|
||||
civitai = "civitai"
|
||||
hf = "hf"
|
||||
all = "all"
|
||||
|
||||
|
||||
class SortOrder(str, Enum):
|
||||
"""Sort order options."""
|
||||
|
||||
downloads = "downloads"
|
||||
rating = "rating"
|
||||
newest = "newest"
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def search_models(
|
||||
query: Annotated[str | None, Query(description="Search query")] = None,
|
||||
provider: Annotated[Provider, Query(description="Search provider (civitai, hf, or all)")] = Provider.all,
|
||||
# CivitAI-specific
|
||||
types: Annotated[str | None, Query(description="Model type - CivitAI (Checkpoint, LORA, etc.)")] = None,
|
||||
base_models: Annotated[str | None, Query(alias="baseModels", description="Base model - CivitAI")] = None,
|
||||
period: Annotated[str | None, Query(description="Time period - CivitAI (AllTime, Year, Month, Week, Day)")] = None,
|
||||
nsfw: Annotated[str | None, Query(description="NSFW level - CivitAI (None, Soft, Mature, X)")] = None,
|
||||
sfw: Annotated[bool, Query(description="Exclude NSFW - CivitAI")] = False,
|
||||
commercial: Annotated[str | None, Query(description="Commercial use - CivitAI (None, Image, Rent, Sell)")] = None,
|
||||
page: Annotated[int | None, Query(ge=1, description="Page number - CivitAI")] = None,
|
||||
# HuggingFace-specific
|
||||
pipeline: Annotated[str | None, Query(description="Pipeline tag - HuggingFace (text-to-image, etc.)")] = None,
|
||||
# Common
|
||||
sort: Annotated[SortOrder, Query(description="Sort order")] = SortOrder.downloads,
|
||||
limit: Annotated[int, Query(le=100, description="Max results per provider")] = 25,
|
||||
tag: Annotated[str | None, Query(description="Filter by tag")] = None,
|
||||
author: Annotated[str | None, Query(description="Filter by author/creator")] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Search models across CivitAI and/or Hugging Face.
|
||||
|
||||
Returns results from selected provider(s). When provider=all, returns
|
||||
results from both CivitAI and Hugging Face in separate keys.
|
||||
"""
|
||||
api_key = load_api_key()
|
||||
results: dict[str, Any] = {}
|
||||
|
||||
# Search CivitAI
|
||||
if provider in (Provider.civitai, Provider.all):
|
||||
# Map sort order to CivitAI enum
|
||||
civitai_sort = SortOrderEnum.downloads
|
||||
if sort == SortOrder.rating:
|
||||
civitai_sort = SortOrderEnum.rating
|
||||
elif sort == SortOrder.newest:
|
||||
civitai_sort = SortOrderEnum.newest
|
||||
|
||||
# Map other enums
|
||||
model_type = ModelType(types.lower()) if types else None
|
||||
base_model = None
|
||||
if base_models:
|
||||
with contextlib.suppress(ValueError):
|
||||
base_model = BaseModelEnum(base_models.lower())
|
||||
|
||||
period_enum = None
|
||||
if period:
|
||||
with contextlib.suppress(ValueError):
|
||||
period_enum = PeriodEnum(period.lower())
|
||||
|
||||
nsfw_filter: NsfwLevelEnum | bool | None = None
|
||||
if sfw:
|
||||
nsfw_filter = NsfwLevelEnum.none
|
||||
elif nsfw:
|
||||
with contextlib.suppress(ValueError):
|
||||
nsfw_filter = NsfwLevelEnum(nsfw.lower())
|
||||
|
||||
commercial_enum = None
|
||||
if commercial:
|
||||
with contextlib.suppress(ValueError):
|
||||
commercial_enum = CommercialUseEnum(commercial.lower())
|
||||
|
||||
civitai_results = search_civitai(
|
||||
query=query,
|
||||
model_type=model_type,
|
||||
base_model=base_model,
|
||||
sort=civitai_sort,
|
||||
limit=limit,
|
||||
api_key=api_key,
|
||||
period=period_enum,
|
||||
nsfw=nsfw_filter,
|
||||
tag=tag,
|
||||
username=author,
|
||||
page=page,
|
||||
commercial_use=commercial_enum,
|
||||
)
|
||||
|
||||
if civitai_results:
|
||||
results["civitai"] = civitai_results
|
||||
|
||||
# Search Hugging Face
|
||||
if provider in (Provider.hf, Provider.all):
|
||||
hf_sort = "downloads"
|
||||
if sort == SortOrder.rating:
|
||||
hf_sort = "likes"
|
||||
elif sort == SortOrder.newest:
|
||||
hf_sort = "created_at"
|
||||
|
||||
tags = [tag] if tag else None
|
||||
hf_results = search_hf_models(
|
||||
query=query,
|
||||
author=author,
|
||||
tags=tags,
|
||||
pipeline_tag=pipeline,
|
||||
sort=hf_sort,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
if hf_results:
|
||||
results["huggingface"] = hf_results
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def create_search_router() -> APIRouter:
|
||||
"""Return the unified search API router."""
|
||||
return router
|
||||
Reference in New Issue
Block a user