💬 Commit message: Update 2026-02-15 18:00:36, 4 files, 228 lines

📁 Files changed: 4
📝 Lines changed: 228

  • pyproject.toml
  • api.py
  • cli.py
  • config.py
This commit is contained in:
Adam Ladachowski
2026-02-15 18:00:36 +01:00
parent 1d41f924bf
commit 5f71710298
4 changed files with 219 additions and 9 deletions
+1
View File
@@ -60,6 +60,7 @@ select = [
"RUF", # ruff-specific "RUF", # ruff-specific
] ]
ignore = [ ignore = [
"PLR0912", # Too many branches - search param building needs many conditionals
"PLR0913", # Too many arguments - CLI commands need many options "PLR0913", # Too many arguments - CLI commands need many options
] ]
+89 -6
View File
@@ -23,7 +23,16 @@ from rich.progress import (
TransferSpeedColumn, TransferSpeedColumn,
) )
from tensors.config import CIVITAI_API_BASE, CIVITAI_DOWNLOAD_BASE, BaseModel, ModelType, SortOrder from tensors.config import (
CIVITAI_API_BASE,
CIVITAI_DOWNLOAD_BASE,
BaseModel,
CommercialUse,
ModelType,
NsfwLevel,
Period,
SortOrder,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from rich.console import Console from rich.console import Console
@@ -135,15 +144,39 @@ def _build_search_params(
base_model: BaseModel | None, base_model: BaseModel | None,
sort: SortOrder, sort: SortOrder,
limit: int, limit: int,
*,
period: Period | None = None,
nsfw: NsfwLevel | bool | None = None,
tag: str | None = None,
username: str | None = None,
page: int | None = None,
commercial_use: CommercialUse | None = None,
allow_derivatives: bool | None = None,
primary_file_only: bool = False,
) -> tuple[dict[str, Any], bool]: ) -> tuple[dict[str, Any], bool]:
"""Build search parameters and return (params, has_filters).""" """Build search parameters and return (params, has_filters).
API Quirks / Workarounds:
- query + filters don't work reliably together → we fetch more and filter client-side
- nsfw=true is required to include NSFW content (default excludes it)
- baseModels is undocumented but works
"""
params: dict[str, Any] = { params: dict[str, Any] = {
"limit": min(limit, 100), "limit": min(limit, 100),
"nsfw": "true",
} }
# NSFW handling - default to including all content
if nsfw is None:
params["nsfw"] = "true" # Include NSFW by default (like website)
elif isinstance(nsfw, bool):
params["nsfw"] = str(nsfw).lower()
elif nsfw == NsfwLevel.none:
params["nsfw"] = "false" # Exclude NSFW
else:
params["nsfw"] = "true" # Include for specific levels (API filters server-side)
# API quirk: query + filters don't work reliably together # API quirk: query + filters don't work reliably together
has_filters = model_type is not None or base_model is not None has_filters = model_type is not None or base_model is not None or tag is not None
if query and not has_filters: if query and not has_filters:
params["query"] = query params["query"] = query
@@ -156,6 +189,28 @@ def _build_search_params(
params["sort"] = sort.to_api() params["sort"] = sort.to_api()
# Additional filters
if period:
params["period"] = period.to_api()
if tag:
params["tag"] = tag
if username:
params["username"] = username
if page and page > 1:
params["page"] = page
if commercial_use:
params["allowCommercialUse"] = commercial_use.to_api()
if allow_derivatives is not None:
params["allowDerivatives"] = str(allow_derivatives).lower()
if primary_file_only:
params["primaryFileOnly"] = "true"
# Request more if we need client-side filtering # Request more if we need client-side filtering
if query and has_filters: if query and has_filters:
params["limit"] = 100 params["limit"] = 100
@@ -179,9 +234,37 @@ def search_civitai(
limit: int, limit: int,
api_key: str | None, api_key: str | None,
console: Console, console: Console,
*,
period: Period | None = None,
nsfw: NsfwLevel | bool | None = None,
tag: str | None = None,
username: str | None = None,
page: int | None = None,
commercial_use: CommercialUse | None = None,
allow_derivatives: bool | None = None,
primary_file_only: bool = False,
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Search CivitAI models.""" """Search CivitAI models.
params, has_filters = _build_search_params(query, model_type, base_model, sort, limit)
Implements workarounds for API limitations:
- Query + filters: fetches more results and filters client-side
- NSFW: defaults to including all content (like website behavior)
"""
params, has_filters = _build_search_params(
query,
model_type,
base_model,
sort,
limit,
period=period,
nsfw=nsfw,
tag=tag,
username=username,
page=page,
commercial_use=commercial_use,
allow_derivatives=allow_derivatives,
primary_file_only=primary_file_only,
)
url = f"{CIVITAI_API_BASE}/models" url = f"{CIVITAI_API_BASE}/models"
with Progress( with Progress(
+29 -1
View File
@@ -22,7 +22,10 @@ from tensors.api import (
from tensors.config import ( from tensors.config import (
CONFIG_FILE, CONFIG_FILE,
BaseModel, BaseModel,
CommercialUse,
ModelType, ModelType,
NsfwLevel,
Period,
SortOrder, SortOrder,
get_default_output_path, get_default_output_path,
load_api_key, load_api_key,
@@ -177,12 +180,31 @@ def search(
base: Annotated[BaseModel | None, typer.Option("-b", "--base", help="Base model filter")] = None, base: Annotated[BaseModel | None, typer.Option("-b", "--base", help="Base model filter")] = None,
sort: Annotated[SortOrder, typer.Option("-s", "--sort", help="Sort order")] = SortOrder.downloads, sort: Annotated[SortOrder, typer.Option("-s", "--sort", help="Sort order")] = SortOrder.downloads,
limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 20, limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 20,
period: Annotated[Period | None, typer.Option("-p", "--period", help="Time period")] = None,
tag: Annotated[str | None, typer.Option("--tag", help="Filter by tag")] = None,
username: Annotated[str | None, typer.Option("-u", "--user", help="Filter by creator")] = None,
page: Annotated[int | None, typer.Option("--page", help="Page number")] = None,
nsfw: Annotated[NsfwLevel | None, typer.Option("--nsfw", help="NSFW filter level")] = None,
sfw: Annotated[bool, typer.Option("--sfw", help="Exclude NSFW content")] = False,
commercial: Annotated[CommercialUse | None, typer.Option("--commercial", help="Commercial use filter")] = None,
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False, json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None, api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
) -> None: ) -> None:
"""Search CivitAI models.""" """Search CivitAI models.
Examples:
tsr search "anime" # Search by name
tsr search -t lora -b pony # LoRAs for Pony
tsr search --tag anime -b illustrious # Tag + base model
tsr search -u "username" # By creator
tsr search -p week -s newest # New this week
tsr search --sfw # Exclude NSFW
"""
key = api_key or load_api_key() key = api_key or load_api_key()
# Handle SFW flag
nsfw_filter: NsfwLevel | bool | None = NsfwLevel.none if sfw else nsfw
results = search_civitai( results = search_civitai(
query=query, query=query,
model_type=model_type, model_type=model_type,
@@ -191,6 +213,12 @@ def search(
limit=limit, limit=limit,
api_key=key, api_key=key,
console=console, console=console,
period=period,
nsfw=nsfw_filter,
tag=tag,
username=username,
page=page,
commercial_use=commercial,
) )
if not results: if not results:
+100 -2
View File
@@ -50,6 +50,13 @@ class ModelType(str, Enum):
vae = "vae" vae = "vae"
controlnet = "controlnet" controlnet = "controlnet"
locon = "locon" locon = "locon"
hypernetwork = "hypernetwork"
poses = "poses"
upscaler = "upscaler"
motionmodule = "motionmodule"
wildcards = "wildcards"
workflows = "workflows"
other = "other"
def to_api(self) -> str: def to_api(self) -> str:
"""Convert to CivitAI API value.""" """Convert to CivitAI API value."""
@@ -60,6 +67,13 @@ class ModelType(str, Enum):
"vae": "VAE", "vae": "VAE",
"controlnet": "Controlnet", "controlnet": "Controlnet",
"locon": "LoCon", "locon": "LoCon",
"hypernetwork": "Hypernetwork",
"poses": "Poses",
"upscaler": "Upscaler",
"motionmodule": "MotionModule",
"wildcards": "Wildcards",
"workflows": "Workflows",
"other": "Other",
} }
return mapping[self.value] return mapping[self.value]
@@ -67,20 +81,55 @@ class ModelType(str, Enum):
class BaseModel(str, Enum): class BaseModel(str, Enum):
"""Common base models.""" """Common base models."""
# Stable Diffusion 1.x
sd14 = "sd14"
sd15 = "sd15" sd15 = "sd15"
sd15_lcm = "sd15_lcm"
sd15_hyper = "sd15_hyper"
# Stable Diffusion 2.x
sd20 = "sd20"
sd21 = "sd21"
# SDXL variants
sdxl = "sdxl" sdxl = "sdxl"
sdxl_turbo = "sdxl_turbo"
sdxl_lightning = "sdxl_lightning"
sdxl_hyper = "sdxl_hyper"
# Pony / Illustrious
pony = "pony" pony = "pony"
flux = "flux"
illustrious = "illustrious" illustrious = "illustrious"
# Flux variants
flux_dev = "flux_dev"
flux_schnell = "flux_schnell"
# SD 3.x
sd35_large = "sd35_large"
sd35_medium = "sd35_medium"
# Other
cascade = "cascade"
svd = "svd"
other = "other"
def to_api(self) -> str: def to_api(self) -> str:
"""Convert to CivitAI API value.""" """Convert to CivitAI API value."""
mapping = { mapping = {
"sd14": "SD 1.4",
"sd15": "SD 1.5", "sd15": "SD 1.5",
"sd15_lcm": "SD 1.5 LCM",
"sd15_hyper": "SD 1.5 Hyper",
"sd20": "SD 2.0",
"sd21": "SD 2.1",
"sdxl": "SDXL 1.0", "sdxl": "SDXL 1.0",
"sdxl_turbo": "SDXL Turbo",
"sdxl_lightning": "SDXL Lightning",
"sdxl_hyper": "SDXL Hyper",
"pony": "Pony", "pony": "Pony",
"flux": "Flux.1 D",
"illustrious": "Illustrious", "illustrious": "Illustrious",
"flux_dev": "Flux.1 D",
"flux_schnell": "Flux.1 S",
"sd35_large": "SD 3.5 Large",
"sd35_medium": "SD 3.5 Medium",
"cascade": "Stable Cascade",
"svd": "SVD",
"other": "Other",
} }
return mapping[self.value] return mapping[self.value]
@@ -102,6 +151,55 @@ class SortOrder(str, Enum):
return mapping[self.value] return mapping[self.value]
class Period(str, Enum):
"""Time period for sorting/filtering."""
all = "all"
year = "year"
month = "month"
week = "week"
day = "day"
def to_api(self) -> str:
"""Convert to CivitAI API value."""
mapping = {
"all": "AllTime",
"year": "Year",
"month": "Month",
"week": "Week",
"day": "Day",
}
return mapping[self.value]
class NsfwLevel(str, Enum):
"""NSFW content filter level."""
none = "none"
soft = "soft"
mature = "mature"
x = "x"
def to_api(self) -> str:
"""Convert to CivitAI API value."""
# For models endpoint, this maps to the nsfw param
# none = exclude NSFW, others = specific levels
return self.value.capitalize() if self.value != "none" else "None"
class CommercialUse(str, Enum):
"""Commercial use permissions."""
none = "none"
image = "image"
rent = "rent"
sell = "sell"
def to_api(self) -> str:
"""Convert to CivitAI API value."""
return self.value.capitalize()
# ============================================================================ # ============================================================================
# Config Functions # Config Functions
# ============================================================================ # ============================================================================