diff --git a/pyproject.toml b/pyproject.toml index 9b3d8d2..2cd2cee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ select = [ "RUF", # ruff-specific ] ignore = [ + "PLR0912", # Too many branches - search param building needs many conditionals "PLR0913", # Too many arguments - CLI commands need many options ] diff --git a/tensors/api.py b/tensors/api.py index d3ad199..4016f2e 100644 --- a/tensors/api.py +++ b/tensors/api.py @@ -23,7 +23,16 @@ from rich.progress import ( TransferSpeedColumn, ) -from tensors.config import CIVITAI_API_BASE, CIVITAI_DOWNLOAD_BASE, BaseModel, ModelType, SortOrder +from tensors.config import ( + CIVITAI_API_BASE, + CIVITAI_DOWNLOAD_BASE, + BaseModel, + CommercialUse, + ModelType, + NsfwLevel, + Period, + SortOrder, +) if TYPE_CHECKING: from rich.console import Console @@ -135,15 +144,39 @@ def _build_search_params( base_model: BaseModel | None, sort: SortOrder, limit: int, + *, + period: Period | None = None, + nsfw: NsfwLevel | bool | None = None, + tag: str | None = None, + username: str | None = None, + page: int | None = None, + commercial_use: CommercialUse | None = None, + allow_derivatives: bool | None = None, + primary_file_only: bool = False, ) -> tuple[dict[str, Any], bool]: - """Build search parameters and return (params, has_filters).""" + """Build search parameters and return (params, has_filters). + + API Quirks / Workarounds: + - query + filters don't work reliably together → we fetch more and filter client-side + - nsfw=true is required to include NSFW content (default excludes it) + - baseModels is undocumented but works + """ params: dict[str, Any] = { "limit": min(limit, 100), - "nsfw": "true", } + # NSFW handling - default to including all content + if nsfw is None: + params["nsfw"] = "true" # Include NSFW by default (like website) + elif isinstance(nsfw, bool): + params["nsfw"] = str(nsfw).lower() + elif nsfw == NsfwLevel.none: + params["nsfw"] = "false" # Exclude NSFW + else: + params["nsfw"] = "true" # Include for specific levels (API filters server-side) + # API quirk: query + filters don't work reliably together - has_filters = model_type is not None or base_model is not None + has_filters = model_type is not None or base_model is not None or tag is not None if query and not has_filters: params["query"] = query @@ -156,6 +189,28 @@ def _build_search_params( params["sort"] = sort.to_api() + # Additional filters + if period: + params["period"] = period.to_api() + + if tag: + params["tag"] = tag + + if username: + params["username"] = username + + if page and page > 1: + params["page"] = page + + if commercial_use: + params["allowCommercialUse"] = commercial_use.to_api() + + if allow_derivatives is not None: + params["allowDerivatives"] = str(allow_derivatives).lower() + + if primary_file_only: + params["primaryFileOnly"] = "true" + # Request more if we need client-side filtering if query and has_filters: params["limit"] = 100 @@ -179,9 +234,37 @@ def search_civitai( limit: int, api_key: str | None, console: Console, + *, + period: Period | None = None, + nsfw: NsfwLevel | bool | None = None, + tag: str | None = None, + username: str | None = None, + page: int | None = None, + commercial_use: CommercialUse | None = None, + allow_derivatives: bool | None = None, + primary_file_only: bool = False, ) -> dict[str, Any] | None: - """Search CivitAI models.""" - params, has_filters = _build_search_params(query, model_type, base_model, sort, limit) + """Search CivitAI models. + + Implements workarounds for API limitations: + - Query + filters: fetches more results and filters client-side + - NSFW: defaults to including all content (like website behavior) + """ + params, has_filters = _build_search_params( + query, + model_type, + base_model, + sort, + limit, + period=period, + nsfw=nsfw, + tag=tag, + username=username, + page=page, + commercial_use=commercial_use, + allow_derivatives=allow_derivatives, + primary_file_only=primary_file_only, + ) url = f"{CIVITAI_API_BASE}/models" with Progress( diff --git a/tensors/cli.py b/tensors/cli.py index 5f52f6d..9ce9155 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -22,7 +22,10 @@ from tensors.api import ( from tensors.config import ( CONFIG_FILE, BaseModel, + CommercialUse, ModelType, + NsfwLevel, + Period, SortOrder, get_default_output_path, load_api_key, @@ -177,12 +180,31 @@ def search( base: Annotated[BaseModel | None, typer.Option("-b", "--base", help="Base model filter")] = None, sort: Annotated[SortOrder, typer.Option("-s", "--sort", help="Sort order")] = SortOrder.downloads, limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 20, + period: Annotated[Period | None, typer.Option("-p", "--period", help="Time period")] = None, + tag: Annotated[str | None, typer.Option("--tag", help="Filter by tag")] = None, + username: Annotated[str | None, typer.Option("-u", "--user", help="Filter by creator")] = None, + page: Annotated[int | None, typer.Option("--page", help="Page number")] = None, + nsfw: Annotated[NsfwLevel | None, typer.Option("--nsfw", help="NSFW filter level")] = None, + sfw: Annotated[bool, typer.Option("--sfw", help="Exclude NSFW content")] = False, + commercial: Annotated[CommercialUse | None, typer.Option("--commercial", help="Commercial use filter")] = None, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, ) -> None: - """Search CivitAI models.""" + """Search CivitAI models. + + Examples: + tsr search "anime" # Search by name + tsr search -t lora -b pony # LoRAs for Pony + tsr search --tag anime -b illustrious # Tag + base model + tsr search -u "username" # By creator + tsr search -p week -s newest # New this week + tsr search --sfw # Exclude NSFW + """ key = api_key or load_api_key() + # Handle SFW flag + nsfw_filter: NsfwLevel | bool | None = NsfwLevel.none if sfw else nsfw + results = search_civitai( query=query, model_type=model_type, @@ -191,6 +213,12 @@ def search( limit=limit, api_key=key, console=console, + period=period, + nsfw=nsfw_filter, + tag=tag, + username=username, + page=page, + commercial_use=commercial, ) if not results: diff --git a/tensors/config.py b/tensors/config.py index 430f84e..9321f4a 100644 --- a/tensors/config.py +++ b/tensors/config.py @@ -50,6 +50,13 @@ class ModelType(str, Enum): vae = "vae" controlnet = "controlnet" locon = "locon" + hypernetwork = "hypernetwork" + poses = "poses" + upscaler = "upscaler" + motionmodule = "motionmodule" + wildcards = "wildcards" + workflows = "workflows" + other = "other" def to_api(self) -> str: """Convert to CivitAI API value.""" @@ -60,6 +67,13 @@ class ModelType(str, Enum): "vae": "VAE", "controlnet": "Controlnet", "locon": "LoCon", + "hypernetwork": "Hypernetwork", + "poses": "Poses", + "upscaler": "Upscaler", + "motionmodule": "MotionModule", + "wildcards": "Wildcards", + "workflows": "Workflows", + "other": "Other", } return mapping[self.value] @@ -67,20 +81,55 @@ class ModelType(str, Enum): class BaseModel(str, Enum): """Common base models.""" + # Stable Diffusion 1.x + sd14 = "sd14" sd15 = "sd15" + sd15_lcm = "sd15_lcm" + sd15_hyper = "sd15_hyper" + # Stable Diffusion 2.x + sd20 = "sd20" + sd21 = "sd21" + # SDXL variants sdxl = "sdxl" + sdxl_turbo = "sdxl_turbo" + sdxl_lightning = "sdxl_lightning" + sdxl_hyper = "sdxl_hyper" + # Pony / Illustrious pony = "pony" - flux = "flux" illustrious = "illustrious" + # Flux variants + flux_dev = "flux_dev" + flux_schnell = "flux_schnell" + # SD 3.x + sd35_large = "sd35_large" + sd35_medium = "sd35_medium" + # Other + cascade = "cascade" + svd = "svd" + other = "other" def to_api(self) -> str: """Convert to CivitAI API value.""" mapping = { + "sd14": "SD 1.4", "sd15": "SD 1.5", + "sd15_lcm": "SD 1.5 LCM", + "sd15_hyper": "SD 1.5 Hyper", + "sd20": "SD 2.0", + "sd21": "SD 2.1", "sdxl": "SDXL 1.0", + "sdxl_turbo": "SDXL Turbo", + "sdxl_lightning": "SDXL Lightning", + "sdxl_hyper": "SDXL Hyper", "pony": "Pony", - "flux": "Flux.1 D", "illustrious": "Illustrious", + "flux_dev": "Flux.1 D", + "flux_schnell": "Flux.1 S", + "sd35_large": "SD 3.5 Large", + "sd35_medium": "SD 3.5 Medium", + "cascade": "Stable Cascade", + "svd": "SVD", + "other": "Other", } return mapping[self.value] @@ -102,6 +151,55 @@ class SortOrder(str, Enum): return mapping[self.value] +class Period(str, Enum): + """Time period for sorting/filtering.""" + + all = "all" + year = "year" + month = "month" + week = "week" + day = "day" + + def to_api(self) -> str: + """Convert to CivitAI API value.""" + mapping = { + "all": "AllTime", + "year": "Year", + "month": "Month", + "week": "Week", + "day": "Day", + } + return mapping[self.value] + + +class NsfwLevel(str, Enum): + """NSFW content filter level.""" + + none = "none" + soft = "soft" + mature = "mature" + x = "x" + + def to_api(self) -> str: + """Convert to CivitAI API value.""" + # For models endpoint, this maps to the nsfw param + # none = exclude NSFW, others = specific levels + return self.value.capitalize() if self.value != "none" else "None" + + +class CommercialUse(str, Enum): + """Commercial use permissions.""" + + none = "none" + image = "image" + rent = "rent" + sell = "sell" + + def to_api(self) -> str: + """Convert to CivitAI API value.""" + return self.value.capitalize() + + # ============================================================================ # Config Functions # ============================================================================