💬 Commit message: Update 2026-02-15 18:00:36, 4 files, 228 lines
📁 Files changed: 4 📝 Lines changed: 228 • pyproject.toml • api.py • cli.py • config.py
This commit is contained in:
@@ -60,6 +60,7 @@ select = [
|
|||||||
"RUF", # ruff-specific
|
"RUF", # ruff-specific
|
||||||
]
|
]
|
||||||
ignore = [
|
ignore = [
|
||||||
|
"PLR0912", # Too many branches - search param building needs many conditionals
|
||||||
"PLR0913", # Too many arguments - CLI commands need many options
|
"PLR0913", # Too many arguments - CLI commands need many options
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
+89
-6
@@ -23,7 +23,16 @@ from rich.progress import (
|
|||||||
TransferSpeedColumn,
|
TransferSpeedColumn,
|
||||||
)
|
)
|
||||||
|
|
||||||
from tensors.config import CIVITAI_API_BASE, CIVITAI_DOWNLOAD_BASE, BaseModel, ModelType, SortOrder
|
from tensors.config import (
|
||||||
|
CIVITAI_API_BASE,
|
||||||
|
CIVITAI_DOWNLOAD_BASE,
|
||||||
|
BaseModel,
|
||||||
|
CommercialUse,
|
||||||
|
ModelType,
|
||||||
|
NsfwLevel,
|
||||||
|
Period,
|
||||||
|
SortOrder,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
@@ -135,15 +144,39 @@ def _build_search_params(
|
|||||||
base_model: BaseModel | None,
|
base_model: BaseModel | None,
|
||||||
sort: SortOrder,
|
sort: SortOrder,
|
||||||
limit: int,
|
limit: int,
|
||||||
|
*,
|
||||||
|
period: Period | None = None,
|
||||||
|
nsfw: NsfwLevel | bool | None = None,
|
||||||
|
tag: str | None = None,
|
||||||
|
username: str | None = None,
|
||||||
|
page: int | None = None,
|
||||||
|
commercial_use: CommercialUse | None = None,
|
||||||
|
allow_derivatives: bool | None = None,
|
||||||
|
primary_file_only: bool = False,
|
||||||
) -> tuple[dict[str, Any], bool]:
|
) -> tuple[dict[str, Any], bool]:
|
||||||
"""Build search parameters and return (params, has_filters)."""
|
"""Build search parameters and return (params, has_filters).
|
||||||
|
|
||||||
|
API Quirks / Workarounds:
|
||||||
|
- query + filters don't work reliably together → we fetch more and filter client-side
|
||||||
|
- nsfw=true is required to include NSFW content (default excludes it)
|
||||||
|
- baseModels is undocumented but works
|
||||||
|
"""
|
||||||
params: dict[str, Any] = {
|
params: dict[str, Any] = {
|
||||||
"limit": min(limit, 100),
|
"limit": min(limit, 100),
|
||||||
"nsfw": "true",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# NSFW handling - default to including all content
|
||||||
|
if nsfw is None:
|
||||||
|
params["nsfw"] = "true" # Include NSFW by default (like website)
|
||||||
|
elif isinstance(nsfw, bool):
|
||||||
|
params["nsfw"] = str(nsfw).lower()
|
||||||
|
elif nsfw == NsfwLevel.none:
|
||||||
|
params["nsfw"] = "false" # Exclude NSFW
|
||||||
|
else:
|
||||||
|
params["nsfw"] = "true" # Include for specific levels (API filters server-side)
|
||||||
|
|
||||||
# API quirk: query + filters don't work reliably together
|
# API quirk: query + filters don't work reliably together
|
||||||
has_filters = model_type is not None or base_model is not None
|
has_filters = model_type is not None or base_model is not None or tag is not None
|
||||||
|
|
||||||
if query and not has_filters:
|
if query and not has_filters:
|
||||||
params["query"] = query
|
params["query"] = query
|
||||||
@@ -156,6 +189,28 @@ def _build_search_params(
|
|||||||
|
|
||||||
params["sort"] = sort.to_api()
|
params["sort"] = sort.to_api()
|
||||||
|
|
||||||
|
# Additional filters
|
||||||
|
if period:
|
||||||
|
params["period"] = period.to_api()
|
||||||
|
|
||||||
|
if tag:
|
||||||
|
params["tag"] = tag
|
||||||
|
|
||||||
|
if username:
|
||||||
|
params["username"] = username
|
||||||
|
|
||||||
|
if page and page > 1:
|
||||||
|
params["page"] = page
|
||||||
|
|
||||||
|
if commercial_use:
|
||||||
|
params["allowCommercialUse"] = commercial_use.to_api()
|
||||||
|
|
||||||
|
if allow_derivatives is not None:
|
||||||
|
params["allowDerivatives"] = str(allow_derivatives).lower()
|
||||||
|
|
||||||
|
if primary_file_only:
|
||||||
|
params["primaryFileOnly"] = "true"
|
||||||
|
|
||||||
# Request more if we need client-side filtering
|
# Request more if we need client-side filtering
|
||||||
if query and has_filters:
|
if query and has_filters:
|
||||||
params["limit"] = 100
|
params["limit"] = 100
|
||||||
@@ -179,9 +234,37 @@ def search_civitai(
|
|||||||
limit: int,
|
limit: int,
|
||||||
api_key: str | None,
|
api_key: str | None,
|
||||||
console: Console,
|
console: Console,
|
||||||
|
*,
|
||||||
|
period: Period | None = None,
|
||||||
|
nsfw: NsfwLevel | bool | None = None,
|
||||||
|
tag: str | None = None,
|
||||||
|
username: str | None = None,
|
||||||
|
page: int | None = None,
|
||||||
|
commercial_use: CommercialUse | None = None,
|
||||||
|
allow_derivatives: bool | None = None,
|
||||||
|
primary_file_only: bool = False,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Search CivitAI models."""
|
"""Search CivitAI models.
|
||||||
params, has_filters = _build_search_params(query, model_type, base_model, sort, limit)
|
|
||||||
|
Implements workarounds for API limitations:
|
||||||
|
- Query + filters: fetches more results and filters client-side
|
||||||
|
- NSFW: defaults to including all content (like website behavior)
|
||||||
|
"""
|
||||||
|
params, has_filters = _build_search_params(
|
||||||
|
query,
|
||||||
|
model_type,
|
||||||
|
base_model,
|
||||||
|
sort,
|
||||||
|
limit,
|
||||||
|
period=period,
|
||||||
|
nsfw=nsfw,
|
||||||
|
tag=tag,
|
||||||
|
username=username,
|
||||||
|
page=page,
|
||||||
|
commercial_use=commercial_use,
|
||||||
|
allow_derivatives=allow_derivatives,
|
||||||
|
primary_file_only=primary_file_only,
|
||||||
|
)
|
||||||
url = f"{CIVITAI_API_BASE}/models"
|
url = f"{CIVITAI_API_BASE}/models"
|
||||||
|
|
||||||
with Progress(
|
with Progress(
|
||||||
|
|||||||
+29
-1
@@ -22,7 +22,10 @@ from tensors.api import (
|
|||||||
from tensors.config import (
|
from tensors.config import (
|
||||||
CONFIG_FILE,
|
CONFIG_FILE,
|
||||||
BaseModel,
|
BaseModel,
|
||||||
|
CommercialUse,
|
||||||
ModelType,
|
ModelType,
|
||||||
|
NsfwLevel,
|
||||||
|
Period,
|
||||||
SortOrder,
|
SortOrder,
|
||||||
get_default_output_path,
|
get_default_output_path,
|
||||||
load_api_key,
|
load_api_key,
|
||||||
@@ -177,12 +180,31 @@ def search(
|
|||||||
base: Annotated[BaseModel | None, typer.Option("-b", "--base", help="Base model filter")] = None,
|
base: Annotated[BaseModel | None, typer.Option("-b", "--base", help="Base model filter")] = None,
|
||||||
sort: Annotated[SortOrder, typer.Option("-s", "--sort", help="Sort order")] = SortOrder.downloads,
|
sort: Annotated[SortOrder, typer.Option("-s", "--sort", help="Sort order")] = SortOrder.downloads,
|
||||||
limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 20,
|
limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 20,
|
||||||
|
period: Annotated[Period | None, typer.Option("-p", "--period", help="Time period")] = None,
|
||||||
|
tag: Annotated[str | None, typer.Option("--tag", help="Filter by tag")] = None,
|
||||||
|
username: Annotated[str | None, typer.Option("-u", "--user", help="Filter by creator")] = None,
|
||||||
|
page: Annotated[int | None, typer.Option("--page", help="Page number")] = None,
|
||||||
|
nsfw: Annotated[NsfwLevel | None, typer.Option("--nsfw", help="NSFW filter level")] = None,
|
||||||
|
sfw: Annotated[bool, typer.Option("--sfw", help="Exclude NSFW content")] = False,
|
||||||
|
commercial: Annotated[CommercialUse | None, typer.Option("--commercial", help="Commercial use filter")] = None,
|
||||||
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Search CivitAI models."""
|
"""Search CivitAI models.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
tsr search "anime" # Search by name
|
||||||
|
tsr search -t lora -b pony # LoRAs for Pony
|
||||||
|
tsr search --tag anime -b illustrious # Tag + base model
|
||||||
|
tsr search -u "username" # By creator
|
||||||
|
tsr search -p week -s newest # New this week
|
||||||
|
tsr search --sfw # Exclude NSFW
|
||||||
|
"""
|
||||||
key = api_key or load_api_key()
|
key = api_key or load_api_key()
|
||||||
|
|
||||||
|
# Handle SFW flag
|
||||||
|
nsfw_filter: NsfwLevel | bool | None = NsfwLevel.none if sfw else nsfw
|
||||||
|
|
||||||
results = search_civitai(
|
results = search_civitai(
|
||||||
query=query,
|
query=query,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
@@ -191,6 +213,12 @@ def search(
|
|||||||
limit=limit,
|
limit=limit,
|
||||||
api_key=key,
|
api_key=key,
|
||||||
console=console,
|
console=console,
|
||||||
|
period=period,
|
||||||
|
nsfw=nsfw_filter,
|
||||||
|
tag=tag,
|
||||||
|
username=username,
|
||||||
|
page=page,
|
||||||
|
commercial_use=commercial,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
|
|||||||
+100
-2
@@ -50,6 +50,13 @@ class ModelType(str, Enum):
|
|||||||
vae = "vae"
|
vae = "vae"
|
||||||
controlnet = "controlnet"
|
controlnet = "controlnet"
|
||||||
locon = "locon"
|
locon = "locon"
|
||||||
|
hypernetwork = "hypernetwork"
|
||||||
|
poses = "poses"
|
||||||
|
upscaler = "upscaler"
|
||||||
|
motionmodule = "motionmodule"
|
||||||
|
wildcards = "wildcards"
|
||||||
|
workflows = "workflows"
|
||||||
|
other = "other"
|
||||||
|
|
||||||
def to_api(self) -> str:
|
def to_api(self) -> str:
|
||||||
"""Convert to CivitAI API value."""
|
"""Convert to CivitAI API value."""
|
||||||
@@ -60,6 +67,13 @@ class ModelType(str, Enum):
|
|||||||
"vae": "VAE",
|
"vae": "VAE",
|
||||||
"controlnet": "Controlnet",
|
"controlnet": "Controlnet",
|
||||||
"locon": "LoCon",
|
"locon": "LoCon",
|
||||||
|
"hypernetwork": "Hypernetwork",
|
||||||
|
"poses": "Poses",
|
||||||
|
"upscaler": "Upscaler",
|
||||||
|
"motionmodule": "MotionModule",
|
||||||
|
"wildcards": "Wildcards",
|
||||||
|
"workflows": "Workflows",
|
||||||
|
"other": "Other",
|
||||||
}
|
}
|
||||||
return mapping[self.value]
|
return mapping[self.value]
|
||||||
|
|
||||||
@@ -67,20 +81,55 @@ class ModelType(str, Enum):
|
|||||||
class BaseModel(str, Enum):
|
class BaseModel(str, Enum):
|
||||||
"""Common base models."""
|
"""Common base models."""
|
||||||
|
|
||||||
|
# Stable Diffusion 1.x
|
||||||
|
sd14 = "sd14"
|
||||||
sd15 = "sd15"
|
sd15 = "sd15"
|
||||||
|
sd15_lcm = "sd15_lcm"
|
||||||
|
sd15_hyper = "sd15_hyper"
|
||||||
|
# Stable Diffusion 2.x
|
||||||
|
sd20 = "sd20"
|
||||||
|
sd21 = "sd21"
|
||||||
|
# SDXL variants
|
||||||
sdxl = "sdxl"
|
sdxl = "sdxl"
|
||||||
|
sdxl_turbo = "sdxl_turbo"
|
||||||
|
sdxl_lightning = "sdxl_lightning"
|
||||||
|
sdxl_hyper = "sdxl_hyper"
|
||||||
|
# Pony / Illustrious
|
||||||
pony = "pony"
|
pony = "pony"
|
||||||
flux = "flux"
|
|
||||||
illustrious = "illustrious"
|
illustrious = "illustrious"
|
||||||
|
# Flux variants
|
||||||
|
flux_dev = "flux_dev"
|
||||||
|
flux_schnell = "flux_schnell"
|
||||||
|
# SD 3.x
|
||||||
|
sd35_large = "sd35_large"
|
||||||
|
sd35_medium = "sd35_medium"
|
||||||
|
# Other
|
||||||
|
cascade = "cascade"
|
||||||
|
svd = "svd"
|
||||||
|
other = "other"
|
||||||
|
|
||||||
def to_api(self) -> str:
|
def to_api(self) -> str:
|
||||||
"""Convert to CivitAI API value."""
|
"""Convert to CivitAI API value."""
|
||||||
mapping = {
|
mapping = {
|
||||||
|
"sd14": "SD 1.4",
|
||||||
"sd15": "SD 1.5",
|
"sd15": "SD 1.5",
|
||||||
|
"sd15_lcm": "SD 1.5 LCM",
|
||||||
|
"sd15_hyper": "SD 1.5 Hyper",
|
||||||
|
"sd20": "SD 2.0",
|
||||||
|
"sd21": "SD 2.1",
|
||||||
"sdxl": "SDXL 1.0",
|
"sdxl": "SDXL 1.0",
|
||||||
|
"sdxl_turbo": "SDXL Turbo",
|
||||||
|
"sdxl_lightning": "SDXL Lightning",
|
||||||
|
"sdxl_hyper": "SDXL Hyper",
|
||||||
"pony": "Pony",
|
"pony": "Pony",
|
||||||
"flux": "Flux.1 D",
|
|
||||||
"illustrious": "Illustrious",
|
"illustrious": "Illustrious",
|
||||||
|
"flux_dev": "Flux.1 D",
|
||||||
|
"flux_schnell": "Flux.1 S",
|
||||||
|
"sd35_large": "SD 3.5 Large",
|
||||||
|
"sd35_medium": "SD 3.5 Medium",
|
||||||
|
"cascade": "Stable Cascade",
|
||||||
|
"svd": "SVD",
|
||||||
|
"other": "Other",
|
||||||
}
|
}
|
||||||
return mapping[self.value]
|
return mapping[self.value]
|
||||||
|
|
||||||
@@ -102,6 +151,55 @@ class SortOrder(str, Enum):
|
|||||||
return mapping[self.value]
|
return mapping[self.value]
|
||||||
|
|
||||||
|
|
||||||
|
class Period(str, Enum):
|
||||||
|
"""Time period for sorting/filtering."""
|
||||||
|
|
||||||
|
all = "all"
|
||||||
|
year = "year"
|
||||||
|
month = "month"
|
||||||
|
week = "week"
|
||||||
|
day = "day"
|
||||||
|
|
||||||
|
def to_api(self) -> str:
|
||||||
|
"""Convert to CivitAI API value."""
|
||||||
|
mapping = {
|
||||||
|
"all": "AllTime",
|
||||||
|
"year": "Year",
|
||||||
|
"month": "Month",
|
||||||
|
"week": "Week",
|
||||||
|
"day": "Day",
|
||||||
|
}
|
||||||
|
return mapping[self.value]
|
||||||
|
|
||||||
|
|
||||||
|
class NsfwLevel(str, Enum):
|
||||||
|
"""NSFW content filter level."""
|
||||||
|
|
||||||
|
none = "none"
|
||||||
|
soft = "soft"
|
||||||
|
mature = "mature"
|
||||||
|
x = "x"
|
||||||
|
|
||||||
|
def to_api(self) -> str:
|
||||||
|
"""Convert to CivitAI API value."""
|
||||||
|
# For models endpoint, this maps to the nsfw param
|
||||||
|
# none = exclude NSFW, others = specific levels
|
||||||
|
return self.value.capitalize() if self.value != "none" else "None"
|
||||||
|
|
||||||
|
|
||||||
|
class CommercialUse(str, Enum):
|
||||||
|
"""Commercial use permissions."""
|
||||||
|
|
||||||
|
none = "none"
|
||||||
|
image = "image"
|
||||||
|
rent = "rent"
|
||||||
|
sell = "sell"
|
||||||
|
|
||||||
|
def to_api(self) -> str:
|
||||||
|
"""Convert to CivitAI API value."""
|
||||||
|
return self.value.capitalize()
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Config Functions
|
# Config Functions
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|||||||
Reference in New Issue
Block a user