💬 Commit message: Update 2026-02-15 18:00:36, 4 files, 228 lines
📁 Files changed: 4 📝 Lines changed: 228 • pyproject.toml • api.py • cli.py • config.py
This commit is contained in:
+89
-6
@@ -23,7 +23,16 @@ from rich.progress import (
|
||||
TransferSpeedColumn,
|
||||
)
|
||||
|
||||
from tensors.config import CIVITAI_API_BASE, CIVITAI_DOWNLOAD_BASE, BaseModel, ModelType, SortOrder
|
||||
from tensors.config import (
|
||||
CIVITAI_API_BASE,
|
||||
CIVITAI_DOWNLOAD_BASE,
|
||||
BaseModel,
|
||||
CommercialUse,
|
||||
ModelType,
|
||||
NsfwLevel,
|
||||
Period,
|
||||
SortOrder,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from rich.console import Console
|
||||
@@ -135,15 +144,39 @@ def _build_search_params(
|
||||
base_model: BaseModel | None,
|
||||
sort: SortOrder,
|
||||
limit: int,
|
||||
*,
|
||||
period: Period | None = None,
|
||||
nsfw: NsfwLevel | bool | None = None,
|
||||
tag: str | None = None,
|
||||
username: str | None = None,
|
||||
page: int | None = None,
|
||||
commercial_use: CommercialUse | None = None,
|
||||
allow_derivatives: bool | None = None,
|
||||
primary_file_only: bool = False,
|
||||
) -> tuple[dict[str, Any], bool]:
|
||||
"""Build search parameters and return (params, has_filters)."""
|
||||
"""Build search parameters and return (params, has_filters).
|
||||
|
||||
API Quirks / Workarounds:
|
||||
- query + filters don't work reliably together → we fetch more and filter client-side
|
||||
- nsfw=true is required to include NSFW content (default excludes it)
|
||||
- baseModels is undocumented but works
|
||||
"""
|
||||
params: dict[str, Any] = {
|
||||
"limit": min(limit, 100),
|
||||
"nsfw": "true",
|
||||
}
|
||||
|
||||
# NSFW handling - default to including all content
|
||||
if nsfw is None:
|
||||
params["nsfw"] = "true" # Include NSFW by default (like website)
|
||||
elif isinstance(nsfw, bool):
|
||||
params["nsfw"] = str(nsfw).lower()
|
||||
elif nsfw == NsfwLevel.none:
|
||||
params["nsfw"] = "false" # Exclude NSFW
|
||||
else:
|
||||
params["nsfw"] = "true" # Include for specific levels (API filters server-side)
|
||||
|
||||
# API quirk: query + filters don't work reliably together
|
||||
has_filters = model_type is not None or base_model is not None
|
||||
has_filters = model_type is not None or base_model is not None or tag is not None
|
||||
|
||||
if query and not has_filters:
|
||||
params["query"] = query
|
||||
@@ -156,6 +189,28 @@ def _build_search_params(
|
||||
|
||||
params["sort"] = sort.to_api()
|
||||
|
||||
# Additional filters
|
||||
if period:
|
||||
params["period"] = period.to_api()
|
||||
|
||||
if tag:
|
||||
params["tag"] = tag
|
||||
|
||||
if username:
|
||||
params["username"] = username
|
||||
|
||||
if page and page > 1:
|
||||
params["page"] = page
|
||||
|
||||
if commercial_use:
|
||||
params["allowCommercialUse"] = commercial_use.to_api()
|
||||
|
||||
if allow_derivatives is not None:
|
||||
params["allowDerivatives"] = str(allow_derivatives).lower()
|
||||
|
||||
if primary_file_only:
|
||||
params["primaryFileOnly"] = "true"
|
||||
|
||||
# Request more if we need client-side filtering
|
||||
if query and has_filters:
|
||||
params["limit"] = 100
|
||||
@@ -179,9 +234,37 @@ def search_civitai(
|
||||
limit: int,
|
||||
api_key: str | None,
|
||||
console: Console,
|
||||
*,
|
||||
period: Period | None = None,
|
||||
nsfw: NsfwLevel | bool | None = None,
|
||||
tag: str | None = None,
|
||||
username: str | None = None,
|
||||
page: int | None = None,
|
||||
commercial_use: CommercialUse | None = None,
|
||||
allow_derivatives: bool | None = None,
|
||||
primary_file_only: bool = False,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Search CivitAI models."""
|
||||
params, has_filters = _build_search_params(query, model_type, base_model, sort, limit)
|
||||
"""Search CivitAI models.
|
||||
|
||||
Implements workarounds for API limitations:
|
||||
- Query + filters: fetches more results and filters client-side
|
||||
- NSFW: defaults to including all content (like website behavior)
|
||||
"""
|
||||
params, has_filters = _build_search_params(
|
||||
query,
|
||||
model_type,
|
||||
base_model,
|
||||
sort,
|
||||
limit,
|
||||
period=period,
|
||||
nsfw=nsfw,
|
||||
tag=tag,
|
||||
username=username,
|
||||
page=page,
|
||||
commercial_use=commercial_use,
|
||||
allow_derivatives=allow_derivatives,
|
||||
primary_file_only=primary_file_only,
|
||||
)
|
||||
url = f"{CIVITAI_API_BASE}/models"
|
||||
|
||||
with Progress(
|
||||
|
||||
+29
-1
@@ -22,7 +22,10 @@ from tensors.api import (
|
||||
from tensors.config import (
|
||||
CONFIG_FILE,
|
||||
BaseModel,
|
||||
CommercialUse,
|
||||
ModelType,
|
||||
NsfwLevel,
|
||||
Period,
|
||||
SortOrder,
|
||||
get_default_output_path,
|
||||
load_api_key,
|
||||
@@ -177,12 +180,31 @@ def search(
|
||||
base: Annotated[BaseModel | None, typer.Option("-b", "--base", help="Base model filter")] = None,
|
||||
sort: Annotated[SortOrder, typer.Option("-s", "--sort", help="Sort order")] = SortOrder.downloads,
|
||||
limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 20,
|
||||
period: Annotated[Period | None, typer.Option("-p", "--period", help="Time period")] = None,
|
||||
tag: Annotated[str | None, typer.Option("--tag", help="Filter by tag")] = None,
|
||||
username: Annotated[str | None, typer.Option("-u", "--user", help="Filter by creator")] = None,
|
||||
page: Annotated[int | None, typer.Option("--page", help="Page number")] = None,
|
||||
nsfw: Annotated[NsfwLevel | None, typer.Option("--nsfw", help="NSFW filter level")] = None,
|
||||
sfw: Annotated[bool, typer.Option("--sfw", help="Exclude NSFW content")] = False,
|
||||
commercial: Annotated[CommercialUse | None, typer.Option("--commercial", help="Commercial use filter")] = None,
|
||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
||||
) -> None:
|
||||
"""Search CivitAI models."""
|
||||
"""Search CivitAI models.
|
||||
|
||||
Examples:
|
||||
tsr search "anime" # Search by name
|
||||
tsr search -t lora -b pony # LoRAs for Pony
|
||||
tsr search --tag anime -b illustrious # Tag + base model
|
||||
tsr search -u "username" # By creator
|
||||
tsr search -p week -s newest # New this week
|
||||
tsr search --sfw # Exclude NSFW
|
||||
"""
|
||||
key = api_key or load_api_key()
|
||||
|
||||
# Handle SFW flag
|
||||
nsfw_filter: NsfwLevel | bool | None = NsfwLevel.none if sfw else nsfw
|
||||
|
||||
results = search_civitai(
|
||||
query=query,
|
||||
model_type=model_type,
|
||||
@@ -191,6 +213,12 @@ def search(
|
||||
limit=limit,
|
||||
api_key=key,
|
||||
console=console,
|
||||
period=period,
|
||||
nsfw=nsfw_filter,
|
||||
tag=tag,
|
||||
username=username,
|
||||
page=page,
|
||||
commercial_use=commercial,
|
||||
)
|
||||
|
||||
if not results:
|
||||
|
||||
+100
-2
@@ -50,6 +50,13 @@ class ModelType(str, Enum):
|
||||
vae = "vae"
|
||||
controlnet = "controlnet"
|
||||
locon = "locon"
|
||||
hypernetwork = "hypernetwork"
|
||||
poses = "poses"
|
||||
upscaler = "upscaler"
|
||||
motionmodule = "motionmodule"
|
||||
wildcards = "wildcards"
|
||||
workflows = "workflows"
|
||||
other = "other"
|
||||
|
||||
def to_api(self) -> str:
|
||||
"""Convert to CivitAI API value."""
|
||||
@@ -60,6 +67,13 @@ class ModelType(str, Enum):
|
||||
"vae": "VAE",
|
||||
"controlnet": "Controlnet",
|
||||
"locon": "LoCon",
|
||||
"hypernetwork": "Hypernetwork",
|
||||
"poses": "Poses",
|
||||
"upscaler": "Upscaler",
|
||||
"motionmodule": "MotionModule",
|
||||
"wildcards": "Wildcards",
|
||||
"workflows": "Workflows",
|
||||
"other": "Other",
|
||||
}
|
||||
return mapping[self.value]
|
||||
|
||||
@@ -67,20 +81,55 @@ class ModelType(str, Enum):
|
||||
class BaseModel(str, Enum):
|
||||
"""Common base models."""
|
||||
|
||||
# Stable Diffusion 1.x
|
||||
sd14 = "sd14"
|
||||
sd15 = "sd15"
|
||||
sd15_lcm = "sd15_lcm"
|
||||
sd15_hyper = "sd15_hyper"
|
||||
# Stable Diffusion 2.x
|
||||
sd20 = "sd20"
|
||||
sd21 = "sd21"
|
||||
# SDXL variants
|
||||
sdxl = "sdxl"
|
||||
sdxl_turbo = "sdxl_turbo"
|
||||
sdxl_lightning = "sdxl_lightning"
|
||||
sdxl_hyper = "sdxl_hyper"
|
||||
# Pony / Illustrious
|
||||
pony = "pony"
|
||||
flux = "flux"
|
||||
illustrious = "illustrious"
|
||||
# Flux variants
|
||||
flux_dev = "flux_dev"
|
||||
flux_schnell = "flux_schnell"
|
||||
# SD 3.x
|
||||
sd35_large = "sd35_large"
|
||||
sd35_medium = "sd35_medium"
|
||||
# Other
|
||||
cascade = "cascade"
|
||||
svd = "svd"
|
||||
other = "other"
|
||||
|
||||
def to_api(self) -> str:
|
||||
"""Convert to CivitAI API value."""
|
||||
mapping = {
|
||||
"sd14": "SD 1.4",
|
||||
"sd15": "SD 1.5",
|
||||
"sd15_lcm": "SD 1.5 LCM",
|
||||
"sd15_hyper": "SD 1.5 Hyper",
|
||||
"sd20": "SD 2.0",
|
||||
"sd21": "SD 2.1",
|
||||
"sdxl": "SDXL 1.0",
|
||||
"sdxl_turbo": "SDXL Turbo",
|
||||
"sdxl_lightning": "SDXL Lightning",
|
||||
"sdxl_hyper": "SDXL Hyper",
|
||||
"pony": "Pony",
|
||||
"flux": "Flux.1 D",
|
||||
"illustrious": "Illustrious",
|
||||
"flux_dev": "Flux.1 D",
|
||||
"flux_schnell": "Flux.1 S",
|
||||
"sd35_large": "SD 3.5 Large",
|
||||
"sd35_medium": "SD 3.5 Medium",
|
||||
"cascade": "Stable Cascade",
|
||||
"svd": "SVD",
|
||||
"other": "Other",
|
||||
}
|
||||
return mapping[self.value]
|
||||
|
||||
@@ -102,6 +151,55 @@ class SortOrder(str, Enum):
|
||||
return mapping[self.value]
|
||||
|
||||
|
||||
class Period(str, Enum):
|
||||
"""Time period for sorting/filtering."""
|
||||
|
||||
all = "all"
|
||||
year = "year"
|
||||
month = "month"
|
||||
week = "week"
|
||||
day = "day"
|
||||
|
||||
def to_api(self) -> str:
|
||||
"""Convert to CivitAI API value."""
|
||||
mapping = {
|
||||
"all": "AllTime",
|
||||
"year": "Year",
|
||||
"month": "Month",
|
||||
"week": "Week",
|
||||
"day": "Day",
|
||||
}
|
||||
return mapping[self.value]
|
||||
|
||||
|
||||
class NsfwLevel(str, Enum):
|
||||
"""NSFW content filter level."""
|
||||
|
||||
none = "none"
|
||||
soft = "soft"
|
||||
mature = "mature"
|
||||
x = "x"
|
||||
|
||||
def to_api(self) -> str:
|
||||
"""Convert to CivitAI API value."""
|
||||
# For models endpoint, this maps to the nsfw param
|
||||
# none = exclude NSFW, others = specific levels
|
||||
return self.value.capitalize() if self.value != "none" else "None"
|
||||
|
||||
|
||||
class CommercialUse(str, Enum):
|
||||
"""Commercial use permissions."""
|
||||
|
||||
none = "none"
|
||||
image = "image"
|
||||
rent = "rent"
|
||||
sell = "sell"
|
||||
|
||||
def to_api(self) -> str:
|
||||
"""Convert to CivitAI API value."""
|
||||
return self.value.capitalize()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Config Functions
|
||||
# ============================================================================
|
||||
|
||||
Reference in New Issue
Block a user