b0b5bca5f8
Master had been failing CI lint since before this PR. Knock out the backlog so the parallel-queue PR can ship green and future PRs don't inherit the red baseline. Changes by category: - UP042 (9): Migrate `class Foo(str, Enum)` to `class Foo(StrEnum)` in tensors/config.py (7 enums) and tensors/server/search_routes.py (2). Requires Python 3.11+, already enforced by `requires-python = ">=3.12"`. - PLR2004 (3): Extract magic comparison values in comfyui_api_routes.py to module-level constants (_DEFAULT_STEPS, _DEFAULT_CFG, _PROMPT_LOG_TRUNCATE). - PLW0108 (2): Inline `lambda: StubDB()` -> `StubDB` in test_server.py. - PLR0915 (3): Add explicit `# noqa: PLR0915` to typer command bodies that are intentionally long (template, templates_extract, _wait_for_completion_ws). - PLR1714 (1): `file_path.name == model or file_path.stem == model` -> `model in (file_path.name, file_path.stem)` in cli.py:3047. - SIM113 (1): Use `enumerate(as_completed(futures), start=1)` for the completion counter in style-sweep parallel loop. - RUF059 (1): Prefix unused tuple-unpacked vars with `_` in _run_one. - SIM105 (1): `try: ws.close() except Exception: pass` -> `contextlib.suppress(Exception)` in comfyui.py. - PLC0415 (1): Add missing `# noqa: PLC0415` to the second of two function-scoped tensors.config imports. No behavior changes. All 374 tests still pass.
1064 lines
34 KiB
Python
1064 lines
34 KiB
Python
"""Configuration, constants, and enums for tsr CLI."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import tomllib
|
|
from enum import StrEnum
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
# ============================================================================
|
|
# XDG Base Directory Configuration
|
|
# ============================================================================
|
|
|
|
# Config: ~/.config/tensors/config.toml
|
|
# Data: ~/.local/share/tensors/models/, ~/.local/share/tensors/metadata/
|
|
CONFIG_DIR = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config")) / "tensors"
|
|
CONFIG_FILE = CONFIG_DIR / "config.toml"
|
|
|
|
DATA_DIR = Path(os.environ.get("XDG_DATA_HOME", Path.home() / ".local" / "share")) / "tensors"
|
|
MODELS_DIR = DATA_DIR / "models"
|
|
METADATA_DIR = DATA_DIR / "metadata"
|
|
GALLERY_DIR = DATA_DIR / "gallery"
|
|
|
|
# Legacy config for migration
|
|
LEGACY_RC_FILE = Path.home() / ".sftrc"
|
|
|
|
# Default download paths by model type (can be overridden in config.toml [paths])
|
|
#
|
|
# Note: "DiffusionModel" is not an official CivitAI model_type — CivitAI lumps
|
|
# UNet-only files (e.g. Flux UNet released separately from CLIP+VAE) under
|
|
# "Checkpoint". The DiffusionModel entry here exists so users can register a
|
|
# ComfyUI `diffusion_models/` path and target it manually via `tsr dl -o`.
|
|
DEFAULT_PATHS: dict[str, Path] = {
|
|
"Checkpoint": MODELS_DIR / "checkpoints",
|
|
"LORA": MODELS_DIR / "loras",
|
|
"LoCon": MODELS_DIR / "loras",
|
|
"TextualInversion": MODELS_DIR / "embeddings",
|
|
"VAE": MODELS_DIR / "vae",
|
|
"Controlnet": MODELS_DIR / "controlnet",
|
|
"Upscaler": MODELS_DIR / "upscalers",
|
|
"DiffusionModel": MODELS_DIR / "diffusion_models",
|
|
"Other": MODELS_DIR / "other",
|
|
}
|
|
|
|
# Config-file keys accepted by `tsr config --set-path KEY=PATH`. Single source
|
|
# of truth shared between the CLI validator and the display-marker logic.
|
|
VALID_PATH_TYPES: list[str] = [
|
|
"checkpoints",
|
|
"loras",
|
|
"embeddings",
|
|
"vae",
|
|
"controlnet",
|
|
"upscalers",
|
|
"diffusion_models",
|
|
"other",
|
|
]
|
|
|
|
CIVITAI_API_BASE = "https://civitai.com/api/v1"
|
|
CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models"
|
|
|
|
|
|
# ============================================================================
|
|
# Enums for CLI
|
|
# ============================================================================
|
|
|
|
|
|
class Provider(StrEnum):
|
|
"""Model search providers."""
|
|
|
|
civitai = "civitai"
|
|
hf = "hf"
|
|
all = "all"
|
|
|
|
|
|
class ModelType(StrEnum):
|
|
"""CivitAI model types."""
|
|
|
|
checkpoint = "checkpoint"
|
|
lora = "lora"
|
|
embedding = "embedding"
|
|
vae = "vae"
|
|
controlnet = "controlnet"
|
|
locon = "locon"
|
|
hypernetwork = "hypernetwork"
|
|
poses = "poses"
|
|
upscaler = "upscaler"
|
|
motionmodule = "motionmodule"
|
|
wildcards = "wildcards"
|
|
workflows = "workflows"
|
|
other = "other"
|
|
|
|
def to_api(self) -> str:
|
|
"""Convert to CivitAI API value."""
|
|
mapping = {
|
|
"checkpoint": "Checkpoint",
|
|
"lora": "LORA",
|
|
"embedding": "TextualInversion",
|
|
"vae": "VAE",
|
|
"controlnet": "Controlnet",
|
|
"locon": "LoCon",
|
|
"hypernetwork": "Hypernetwork",
|
|
"poses": "Poses",
|
|
"upscaler": "Upscaler",
|
|
"motionmodule": "MotionModule",
|
|
"wildcards": "Wildcards",
|
|
"workflows": "Workflows",
|
|
"other": "Other",
|
|
}
|
|
return mapping[self.value]
|
|
|
|
|
|
class BaseModel(StrEnum):
|
|
"""Common base models."""
|
|
|
|
# Stable Diffusion 1.x
|
|
sd14 = "sd14"
|
|
sd15 = "sd15"
|
|
sd15_lcm = "sd15_lcm"
|
|
sd15_hyper = "sd15_hyper"
|
|
# Stable Diffusion 2.x
|
|
sd20 = "sd20"
|
|
sd21 = "sd21"
|
|
# SDXL variants
|
|
sdxl = "sdxl"
|
|
sdxl_turbo = "sdxl_turbo"
|
|
sdxl_lightning = "sdxl_lightning"
|
|
sdxl_hyper = "sdxl_hyper"
|
|
# Pony / Illustrious
|
|
pony = "pony"
|
|
illustrious = "illustrious"
|
|
# Flux variants
|
|
flux_dev = "flux_dev"
|
|
flux_schnell = "flux_schnell"
|
|
# SD 3.x
|
|
sd35_large = "sd35_large"
|
|
sd35_medium = "sd35_medium"
|
|
# Other
|
|
cascade = "cascade"
|
|
svd = "svd"
|
|
other = "other"
|
|
|
|
def to_api(self) -> str:
|
|
"""Convert to CivitAI API value."""
|
|
mapping = {
|
|
"sd14": "SD 1.4",
|
|
"sd15": "SD 1.5",
|
|
"sd15_lcm": "SD 1.5 LCM",
|
|
"sd15_hyper": "SD 1.5 Hyper",
|
|
"sd20": "SD 2.0",
|
|
"sd21": "SD 2.1",
|
|
"sdxl": "SDXL 1.0",
|
|
"sdxl_turbo": "SDXL Turbo",
|
|
"sdxl_lightning": "SDXL Lightning",
|
|
"sdxl_hyper": "SDXL Hyper",
|
|
"pony": "Pony",
|
|
"illustrious": "Illustrious",
|
|
"flux_dev": "Flux.1 D",
|
|
"flux_schnell": "Flux.1 S",
|
|
"sd35_large": "SD 3.5 Large",
|
|
"sd35_medium": "SD 3.5 Medium",
|
|
"cascade": "Stable Cascade",
|
|
"svd": "SVD",
|
|
"other": "Other",
|
|
}
|
|
return mapping[self.value]
|
|
|
|
|
|
class SortOrder(StrEnum):
|
|
"""Sort options for search."""
|
|
|
|
downloads = "downloads"
|
|
rating = "rating"
|
|
newest = "newest"
|
|
|
|
def to_api(self) -> str:
|
|
"""Convert to CivitAI API value."""
|
|
mapping = {
|
|
"downloads": "Most Downloaded",
|
|
"rating": "Highest Rated",
|
|
"newest": "Newest",
|
|
}
|
|
return mapping[self.value]
|
|
|
|
|
|
class Period(StrEnum):
|
|
"""Time period for sorting/filtering."""
|
|
|
|
all = "all"
|
|
year = "year"
|
|
month = "month"
|
|
week = "week"
|
|
day = "day"
|
|
|
|
def to_api(self) -> str:
|
|
"""Convert to CivitAI API value."""
|
|
mapping = {
|
|
"all": "AllTime",
|
|
"year": "Year",
|
|
"month": "Month",
|
|
"week": "Week",
|
|
"day": "Day",
|
|
}
|
|
return mapping[self.value]
|
|
|
|
|
|
class NsfwLevel(StrEnum):
|
|
"""NSFW content filter level."""
|
|
|
|
none = "none"
|
|
soft = "soft"
|
|
mature = "mature"
|
|
x = "x"
|
|
|
|
def to_api(self) -> str:
|
|
"""Convert to CivitAI API value."""
|
|
# For models endpoint, this maps to the nsfw param
|
|
# none = exclude NSFW, others = specific levels
|
|
return self.value.capitalize() if self.value != "none" else "None"
|
|
|
|
|
|
class CommercialUse(StrEnum):
|
|
"""Commercial use permissions."""
|
|
|
|
none = "none"
|
|
image = "image"
|
|
rent = "rent"
|
|
sell = "sell"
|
|
|
|
def to_api(self) -> str:
|
|
"""Convert to CivitAI API value."""
|
|
return self.value.capitalize()
|
|
|
|
|
|
# ============================================================================
|
|
# Config Functions
|
|
# ============================================================================
|
|
|
|
|
|
def load_config() -> dict[str, Any]:
|
|
"""Load configuration from TOML config file."""
|
|
if CONFIG_FILE.exists():
|
|
with CONFIG_FILE.open("rb") as f:
|
|
return tomllib.load(f)
|
|
return {}
|
|
|
|
|
|
def save_config(config: dict[str, Any]) -> None:
|
|
"""Save configuration to TOML config file."""
|
|
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
lines: list[str] = []
|
|
# Write scalar values first (before any sections)
|
|
for key, value in config.items():
|
|
if not isinstance(value, dict):
|
|
if isinstance(value, str):
|
|
lines.append(f'{key} = "{value}"')
|
|
else:
|
|
lines.append(f"{key} = {value}")
|
|
|
|
if lines:
|
|
lines.append("")
|
|
|
|
# Then write sections (dicts)
|
|
for key, value in config.items():
|
|
if isinstance(value, dict):
|
|
lines.append(f"[{key}]")
|
|
for k, v in value.items():
|
|
if isinstance(v, str):
|
|
lines.append(f'{k} = "{v}"')
|
|
else:
|
|
lines.append(f"{k} = {v}")
|
|
lines.append("")
|
|
|
|
CONFIG_FILE.write_text("\n".join(lines) + "\n")
|
|
|
|
|
|
def load_api_key() -> str | None:
|
|
"""Load API key from config file or CIVITAI_API_KEY env var."""
|
|
# Check environment variable first
|
|
env_key = os.environ.get("CIVITAI_API_KEY")
|
|
if env_key:
|
|
return env_key
|
|
|
|
# Check TOML config file
|
|
config = load_config()
|
|
api_section = config.get("api", {})
|
|
if isinstance(api_section, dict):
|
|
key = api_section.get("civitai_key")
|
|
if key:
|
|
return str(key)
|
|
|
|
# Fall back to legacy RC file for migration
|
|
if LEGACY_RC_FILE.exists():
|
|
content = LEGACY_RC_FILE.read_text().strip()
|
|
if content:
|
|
return content
|
|
return None
|
|
|
|
|
|
def get_model_paths() -> dict[str, Path]:
|
|
"""Get model paths from config, with defaults.
|
|
|
|
Config format in config.toml:
|
|
[paths]
|
|
checkpoints = "/opt/comfyui/models/checkpoints"
|
|
loras = "/opt/comfyui/models/loras"
|
|
embeddings = "/opt/comfyui/models/embeddings"
|
|
vae = "/opt/comfyui/models/vae"
|
|
controlnet = "/opt/comfyui/models/controlnet"
|
|
upscalers = "/opt/comfyui/models/upscale_models"
|
|
other = "/opt/comfyui/models/other"
|
|
|
|
Returns dict mapping CivitAI model types to paths.
|
|
"""
|
|
config = load_config()
|
|
paths_config = config.get("paths", {})
|
|
|
|
# Map config keys to CivitAI model types. "diffusion_models" maps to the
|
|
# synthetic "DiffusionModel" bucket (see DEFAULT_PATHS for rationale).
|
|
key_to_types = {
|
|
"checkpoints": ["Checkpoint"],
|
|
"loras": ["LORA", "LoCon"],
|
|
"embeddings": ["TextualInversion"],
|
|
"vae": ["VAE"],
|
|
"controlnet": ["Controlnet"],
|
|
"upscalers": ["Upscaler"],
|
|
"diffusion_models": ["DiffusionModel"],
|
|
"other": ["Other"],
|
|
}
|
|
|
|
# Start with defaults
|
|
result = dict(DEFAULT_PATHS)
|
|
|
|
# Override with config values
|
|
if isinstance(paths_config, dict):
|
|
for key, types in key_to_types.items():
|
|
if key in paths_config:
|
|
path = Path(paths_config[key])
|
|
for model_type in types:
|
|
result[model_type] = path
|
|
|
|
return result
|
|
|
|
|
|
def get_default_output_path(model_type: str | None) -> Path | None:
|
|
"""Get default output path based on model type.
|
|
|
|
Checks config.toml [paths] section first, falls back to defaults.
|
|
"""
|
|
if not model_type:
|
|
return None
|
|
|
|
paths = get_model_paths()
|
|
return paths.get(model_type)
|
|
|
|
|
|
# ============================================================================
|
|
# Remote Server Configuration
|
|
# ============================================================================
|
|
|
|
|
|
def get_remotes() -> dict[str, str]:
|
|
"""Get configured remote servers.
|
|
|
|
Returns a dict mapping names to URLs, e.g., {"junkpile": "http://junkpile:8080"}
|
|
"""
|
|
config = load_config()
|
|
remotes = config.get("remotes", {})
|
|
return dict(remotes) if isinstance(remotes, dict) else {}
|
|
|
|
|
|
def get_default_remote() -> str | None:
|
|
"""Get the default remote name or URL."""
|
|
config = load_config()
|
|
return config.get("default_remote")
|
|
|
|
|
|
def resolve_remote(remote: str | None) -> str | None:
|
|
"""Resolve a remote name or URL to a full URL.
|
|
|
|
Args:
|
|
remote: Remote name (from config), URL, or None
|
|
|
|
Returns:
|
|
Full URL or None if no remote specified and no default
|
|
"""
|
|
if remote is None:
|
|
# Check for default remote
|
|
default = get_default_remote()
|
|
if default:
|
|
remote = default
|
|
else:
|
|
return None
|
|
|
|
# Check if it's a URL (starts with http:// or https://)
|
|
if remote.startswith(("http://", "https://")):
|
|
return remote
|
|
|
|
# Look up in configured remotes
|
|
remotes = get_remotes()
|
|
if remote in remotes:
|
|
return remotes[remote]
|
|
|
|
# Treat as hostname with default port
|
|
return f"http://{remote}:8080"
|
|
|
|
|
|
def save_remote(name: str, url: str) -> None:
|
|
"""Save a remote server configuration."""
|
|
config = load_config()
|
|
if "remotes" not in config:
|
|
config["remotes"] = {}
|
|
config["remotes"][name] = url
|
|
save_config(config)
|
|
|
|
|
|
def set_default_remote(name: str | None) -> None:
|
|
"""Set the default remote."""
|
|
config = load_config()
|
|
if name is None:
|
|
config.pop("default_remote", None)
|
|
else:
|
|
config["default_remote"] = name
|
|
save_config(config)
|
|
|
|
|
|
# ============================================================================
|
|
# SD Server Configuration
|
|
# ============================================================================
|
|
|
|
SD_SERVER_DEFAULT_URL = "http://localhost:1234"
|
|
|
|
|
|
def get_sd_server_url() -> str:
|
|
"""Get the sd-server URL.
|
|
|
|
Resolution order:
|
|
1. SD_SERVER_URL environment variable
|
|
2. config.toml [server].sd_server_url
|
|
3. Default: http://localhost:1234
|
|
"""
|
|
# Check environment variable first
|
|
env_url = os.environ.get("SD_SERVER_URL")
|
|
if env_url:
|
|
return env_url
|
|
|
|
# Check config file
|
|
config = load_config()
|
|
server_config = config.get("server", {})
|
|
if isinstance(server_config, dict):
|
|
url = server_config.get("sd_server_url")
|
|
if url:
|
|
return str(url)
|
|
|
|
return SD_SERVER_DEFAULT_URL
|
|
|
|
|
|
def get_sd_server_api_key() -> str | None:
|
|
"""Get the sd-server API key.
|
|
|
|
Resolution order:
|
|
1. SD_SERVER_API_KEY environment variable
|
|
2. config.toml [server].sd_server_api_key
|
|
3. None (no authentication)
|
|
"""
|
|
# Check environment variable first
|
|
env_key = os.environ.get("SD_SERVER_API_KEY")
|
|
if env_key:
|
|
return env_key
|
|
|
|
# Check config file
|
|
config = load_config()
|
|
server_config = config.get("server", {})
|
|
if isinstance(server_config, dict):
|
|
key = server_config.get("sd_server_api_key")
|
|
if key:
|
|
return str(key)
|
|
|
|
return None
|
|
|
|
|
|
# ============================================================================
|
|
# Tensors Server API Key
|
|
# ============================================================================
|
|
|
|
|
|
def get_server_api_key() -> str | None:
|
|
"""Get the tensors server API key for authentication.
|
|
|
|
Resolution order:
|
|
1. TENSORS_API_KEY environment variable
|
|
2. config.toml [server].api_key
|
|
3. None (no authentication required)
|
|
"""
|
|
# Check environment variable first
|
|
env_key = os.environ.get("TENSORS_API_KEY")
|
|
if env_key:
|
|
return env_key
|
|
|
|
# Check config file
|
|
config = load_config()
|
|
server_config = config.get("server", {})
|
|
if isinstance(server_config, dict):
|
|
key = server_config.get("api_key")
|
|
if key:
|
|
return str(key)
|
|
|
|
return None
|
|
|
|
|
|
# ============================================================================
|
|
# ComfyUI Configuration
|
|
# ============================================================================
|
|
|
|
COMFYUI_DEFAULT_URL = "http://127.0.0.1:8188"
|
|
|
|
# Default generation parameters
|
|
COMFYUI_DEFAULT_WIDTH = 1024
|
|
COMFYUI_DEFAULT_HEIGHT = 1024
|
|
COMFYUI_DEFAULT_STEPS = 20
|
|
COMFYUI_DEFAULT_CFG = 7.0
|
|
COMFYUI_DEFAULT_SAMPLER = "euler"
|
|
COMFYUI_DEFAULT_SCHEDULER = "normal"
|
|
|
|
# ============================================================================
|
|
# Model Family Defaults (Quality Tags, Negative Prompts, etc.)
|
|
# ============================================================================
|
|
|
|
# Rating tags per model family — maps (family, rating) to the tag to inject
|
|
# Families not listed here have no rating tag system (prompt-driven only)
|
|
RATING_TAGS: dict[str, dict[str, str]] = {
|
|
"pony": {
|
|
"safe": "rating_safe",
|
|
"questionable": "rating_questionable",
|
|
"explicit": "rating_explicit",
|
|
},
|
|
"illustrious": {
|
|
"safe": "rating:safe",
|
|
"questionable": "rating:questionable",
|
|
"explicit": "rating:explicit",
|
|
},
|
|
}
|
|
# NoobAI uses same tags as Illustrious
|
|
RATING_TAGS["noobai"] = RATING_TAGS["illustrious"]
|
|
|
|
|
|
def get_rating_tag(family: str | None, rating: str) -> str | None:
|
|
"""Get the rating tag for a model family and rating level.
|
|
|
|
Args:
|
|
family: Model family key (e.g. "pony", "illustrious") or None
|
|
rating: One of "safe", "questionable", "explicit"
|
|
|
|
Returns:
|
|
Rating tag string to inject into prompt, or None if family has no rating system
|
|
"""
|
|
if not family:
|
|
return None
|
|
tags = RATING_TAGS.get(family)
|
|
if not tags:
|
|
return None
|
|
return tags.get(rating)
|
|
|
|
|
|
MODEL_FAMILY_DEFAULTS: dict[str, dict[str, Any]] = {
|
|
"pony": {
|
|
"quality_prefix": "score_9, score_8_up, score_7_up",
|
|
"negative_prompt": "score_5, score_4, ugly, deformed, blurry, bad anatomy, bad hands, missing fingers",
|
|
"width": 1024,
|
|
"height": 1024,
|
|
"portrait": (832, 1216),
|
|
"landscape": (1216, 832),
|
|
"cfg": 6.5,
|
|
"clip_skip": 2,
|
|
"sampler": "euler_ancestral",
|
|
"scheduler": "normal",
|
|
"steps": 25,
|
|
"vae": "ponyStandardVAE_v10.safetensors",
|
|
},
|
|
"illustrious": {
|
|
"quality_prefix": "masterpiece, best quality, highres",
|
|
"negative_prompt": "worst quality, bad quality, low quality, lowres, bad anatomy, bad hands, jpeg artifacts, watermark",
|
|
"width": 1024,
|
|
"height": 1024,
|
|
"portrait": (832, 1216),
|
|
"landscape": (1216, 832),
|
|
"cfg": 6.0,
|
|
"sampler": "euler_ancestral",
|
|
"scheduler": "normal",
|
|
"steps": 25,
|
|
"vae": "illustriousXLV20_v10.safetensors",
|
|
},
|
|
"sdxl": {
|
|
"quality_prefix": "",
|
|
"negative_prompt": "ugly, deformed, bad anatomy, bad hands, extra fingers, missing fingers, blurry, watermark",
|
|
"width": 1024,
|
|
"height": 1024,
|
|
"portrait": (832, 1216),
|
|
"landscape": (1216, 832),
|
|
"cfg": 7.0,
|
|
"sampler": "dpmpp_2m",
|
|
"scheduler": "karras",
|
|
"steps": 25,
|
|
"vae": "sdxl_vae.safetensors",
|
|
},
|
|
"sdxl_lightning": {
|
|
"quality_prefix": "",
|
|
"negative_prompt": "ugly, deformed, bad anatomy, bad hands, extra fingers, missing fingers, blurry, watermark",
|
|
"width": 1024,
|
|
"height": 1024,
|
|
"portrait": (832, 1216),
|
|
"landscape": (1216, 832),
|
|
"cfg": 2.0,
|
|
"sampler": "euler",
|
|
"scheduler": "sgm_uniform",
|
|
"steps": 8,
|
|
"vae": "sdxl_vae.safetensors",
|
|
},
|
|
"sdxl_turbo": {
|
|
"quality_prefix": "",
|
|
"negative_prompt": "",
|
|
"width": 1024,
|
|
"height": 1024,
|
|
"portrait": (832, 1216),
|
|
"landscape": (1216, 832),
|
|
"cfg": 1.0,
|
|
"sampler": "euler_ancestral",
|
|
"scheduler": "normal",
|
|
"steps": 4,
|
|
"vae": "sdxl_vae.safetensors",
|
|
},
|
|
"sd15": {
|
|
"quality_prefix": "masterpiece, best quality",
|
|
"negative_prompt": (
|
|
"(worst quality:2), (low quality:2), bad anatomy, bad hands, extra fingers, "
|
|
"missing fingers, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, watermark"
|
|
),
|
|
"width": 512,
|
|
"height": 512,
|
|
"portrait": (512, 768),
|
|
"landscape": (768, 512),
|
|
"cfg": 7.0,
|
|
"sampler": "euler_ancestral",
|
|
"scheduler": "normal",
|
|
"steps": 25,
|
|
"vae": "vae-ft-mse-840000-ema-pruned.safetensors",
|
|
},
|
|
"sd15_lcm": {
|
|
"quality_prefix": "masterpiece, best quality",
|
|
"negative_prompt": "",
|
|
"width": 512,
|
|
"height": 512,
|
|
"portrait": (512, 768),
|
|
"landscape": (768, 512),
|
|
"cfg": 1.5,
|
|
"sampler": "lcm",
|
|
"scheduler": "normal",
|
|
"steps": 6,
|
|
"vae": "vae-ft-mse-840000-ema-pruned.safetensors",
|
|
},
|
|
"flux": {
|
|
"quality_prefix": "",
|
|
"negative_prompt": "",
|
|
"width": 1024,
|
|
"height": 1024,
|
|
"portrait": (832, 1216),
|
|
"landscape": (1216, 832),
|
|
# Flux Dev is guidance-distilled: KSampler.cfg MUST be 1.0.
|
|
# Real prompt-adherence dial lives on the FluxGuidance node (see "guidance" below).
|
|
# Source: https://comfyanonymous.github.io/ComfyUI_examples/flux/
|
|
"cfg": 1.0,
|
|
"guidance": 3.5,
|
|
"sampler": "euler",
|
|
"scheduler": "simple",
|
|
"steps": 20,
|
|
"vae": "ae.safetensors",
|
|
},
|
|
"flux_schnell": {
|
|
"quality_prefix": "",
|
|
"negative_prompt": "",
|
|
"width": 1024,
|
|
"height": 1024,
|
|
"portrait": (832, 1216),
|
|
"landscape": (1216, 832),
|
|
# Schnell is also distilled; FluxGuidance is typically left at 3.5 but
|
|
# has minimal effect since the model is trained for 4 steps regardless.
|
|
"cfg": 1.0,
|
|
"guidance": 3.5,
|
|
"sampler": "euler",
|
|
"scheduler": "simple",
|
|
"steps": 4,
|
|
"vae": "ae.safetensors",
|
|
},
|
|
# UNet-only Flux checkpoints — same architecture as "flux" but the file
|
|
# ships without CLIP/T5/VAE baked in. Workflow must load them externally
|
|
# via UNETLoader + DualCLIPLoader + VAELoader instead of CheckpointLoaderSimple.
|
|
# external_clip=True signals this to the workflow builder.
|
|
"flux_unet": {
|
|
"quality_prefix": "",
|
|
"negative_prompt": "",
|
|
"width": 1024,
|
|
"height": 1024,
|
|
"portrait": (832, 1216),
|
|
"landscape": (1216, 832),
|
|
"cfg": 1.0,
|
|
"guidance": 3.5,
|
|
"sampler": "euler",
|
|
"scheduler": "simple",
|
|
"steps": 20,
|
|
"vae": "ae.safetensors",
|
|
"external_clip": True,
|
|
"clip_l": "clip_l.safetensors",
|
|
"clip_t5": "t5xxl_fp16.safetensors",
|
|
},
|
|
# Flux.2 Klein 9B — newer Black Forest Labs release. Different architecture
|
|
# from Flux.1: single Qwen3-8B text encoder (12288-dim conditioning, 3 stacked
|
|
# hidden layers), Flux2 latent format, custom Flux2Scheduler, dedicated VAE
|
|
# (flux2-vae.safetensors). Workflow uses CLIPLoader (type=flux2) instead of
|
|
# DualCLIPLoader, and the custom-sampling pipeline (SamplerCustomAdvanced +
|
|
# BasicGuider + Flux2Scheduler) instead of plain KSampler.
|
|
"flux2_klein": {
|
|
"quality_prefix": "",
|
|
"negative_prompt": "",
|
|
"width": 1024,
|
|
"height": 1024,
|
|
"portrait": (832, 1216),
|
|
"landscape": (1216, 832),
|
|
"cfg": 1.0,
|
|
"guidance": 3.5,
|
|
"sampler": "euler",
|
|
"scheduler": "simple", # unused — Flux2Scheduler provides sigmas
|
|
"steps": 20,
|
|
"vae": "flux2-vae.safetensors",
|
|
"external_clip": True,
|
|
"clip_encoder": "qwen_3_8b_fp8mixed.safetensors",
|
|
"clip_type": "flux2",
|
|
},
|
|
"zimage": {
|
|
"quality_prefix": "",
|
|
"negative_prompt": "",
|
|
"width": 1024,
|
|
"height": 1024,
|
|
"portrait": (832, 1216),
|
|
"landscape": (1216, 832),
|
|
"cfg": 1.0,
|
|
"sampler": "euler",
|
|
"scheduler": "simple",
|
|
"steps": 4,
|
|
"vae": "ae.safetensors",
|
|
},
|
|
}
|
|
|
|
|
|
# UNet-only Flux checkpoint filename substrings (case-insensitive). These ship
|
|
# without baked-in CLIP/T5/VAE, so they require UNETLoader + DualCLIPLoader +
|
|
# VAELoader instead of CheckpointLoaderSimple. Matched on the lowercased
|
|
# filename via simple substring containment.
|
|
#
|
|
# Add new patterns here as we encounter them — order doesn't matter, first
|
|
# match wins.
|
|
FLUX_UNET_ONLY_PATTERNS: tuple[str, ...] = (
|
|
"lust_", # lust_v10.safetensors (Flux.2 Klein 9B-base)
|
|
# Note: bare "lust" would falsely match "illustrious" — keep the underscore.
|
|
"cyberrealisticflux", # cyberrealisticFlux_v25.safetensors
|
|
"getphatflux", # getphatFLUXReality_v11Softcore.safetensors
|
|
"moodydesire", # moodyDesireMix_v20PRO.safetensors
|
|
"fcfluxpony", # fcFluxPonyPerfectBase_fcFluxPerfectBase.safetensors
|
|
"prototype_", # prototype_v10.safetensors (Flux unet-only, no "flux" in name)
|
|
)
|
|
|
|
|
|
def _is_flux_unet_only(name_lower: str) -> bool:
|
|
"""True if the lowercased filename matches a known UNet-only Flux pattern."""
|
|
return any(p in name_lower for p in FLUX_UNET_ONLY_PATTERNS)
|
|
|
|
|
|
# All-in-one Flux checkpoint filename substrings (case-insensitive). These
|
|
# checkpoints bundle UNet + CLIP-L + T5 + VAE in a single file (loadable via
|
|
# CheckpointLoaderSimple), but their filename does NOT contain "flux" so the
|
|
# generic substring check misses them and they fall through to SDXL defaults
|
|
# (which fails when the SDXL VAE isn't installed on the target backend).
|
|
#
|
|
# Detect via header inspection: keys like `model.diffusion_model.double_blocks.*`
|
|
# plus bundled `text_encoders.*` and `vae.*` prefixes indicate FLUX all-in-one.
|
|
# Add new patterns here as we encounter them.
|
|
FLUX_ALL_IN_ONE_PATTERNS: tuple[str, ...] = (
|
|
"ultrasense", # ultrasenseInfinity_v10.safetensors
|
|
"bodyslider", # bodySliderFitness_v10.safetensors
|
|
)
|
|
|
|
|
|
def _is_flux_all_in_one(name_lower: str) -> bool:
|
|
"""True if the lowercased filename matches a known all-in-one Flux pattern."""
|
|
return any(p in name_lower for p in FLUX_ALL_IN_ONE_PATTERNS)
|
|
|
|
|
|
# Flux.2 Klein 9B filename substrings (case-insensitive). These checkpoints are
|
|
# UNet-only AND require the Flux.2 architecture (Qwen3-8B encoder, Flux2
|
|
# scheduler). Detection is primarily via base_model field ("Flux.2 Klein"); the
|
|
# filename patterns are a fallback for checkpoints with missing/wrong DB
|
|
# metadata. Filename match wins over base_model.
|
|
FLUX2_KLEIN_PATTERNS: tuple[str, ...] = (
|
|
"lust_", # lust_v10.safetensors
|
|
"moodydesire", # moodyDesireMix_v20PRO.safetensors
|
|
)
|
|
|
|
|
|
def _is_flux2_klein(name_lower: str, base_lower: str) -> bool:
|
|
"""True if the model is a Flux.2 Klein 9B checkpoint.
|
|
|
|
Detects via base_model field ("flux.2 klein", "flux2 klein") first,
|
|
then filename pattern fallback.
|
|
"""
|
|
if "flux.2 klein" in base_lower or "flux2 klein" in base_lower:
|
|
return True
|
|
return any(p in name_lower for p in FLUX2_KLEIN_PATTERNS)
|
|
|
|
|
|
def detect_model_family(model_name: str, base_model: str | None = None) -> str | None: # noqa: PLR0911
|
|
"""Detect model family from filename or CivitAI base_model field.
|
|
|
|
Args:
|
|
model_name: Filename of the model (e.g., "ponyDiffusionV6XL.safetensors")
|
|
base_model: Optional CivitAI base_model field (e.g., "Pony", "SDXL 1.0")
|
|
|
|
Returns:
|
|
Model family key (pony, illustrious, sdxl, sdxl_lightning, sdxl_turbo,
|
|
sd15, sd15_lcm, flux, flux_schnell, flux_unet, flux2_klein, zimage)
|
|
or None if unknown
|
|
"""
|
|
name_lower = model_name.lower()
|
|
base_lower = (base_model or "").lower()
|
|
|
|
# Flux.2 Klein 9B override: must run BEFORE flux_unet (Klein patterns like
|
|
# "lust_" and "moodydesire" also appear in FLUX_UNET_ONLY_PATTERNS) AND
|
|
# before the generic flux check. Detection prefers base_model field but
|
|
# falls back to filename pattern for checkpoints with missing metadata.
|
|
if _is_flux2_klein(name_lower, base_lower):
|
|
return "flux2_klein"
|
|
|
|
# UNet-only Flux override: must run BEFORE the generic flux check below,
|
|
# since some patterns ("cyberrealisticflux", "getphatflux", "fcfluxpony")
|
|
# also contain the substring "flux". Filename wins over base_model
|
|
# field — these checkpoints are often mis-tagged on CivitAI.
|
|
if _is_flux_unet_only(name_lower):
|
|
return "flux_unet"
|
|
|
|
# All-in-one Flux override for checkpoints whose filename omits "flux"
|
|
# (e.g. "ultrasenseInfinity_v10.safetensors"). Without this they fall
|
|
# through to the SDXL default at the bottom of this function and the
|
|
# generated workflow asks ComfyUI for sdxl_vae.safetensors — which fails
|
|
# on Flux-only backends like sin.
|
|
if _is_flux_all_in_one(name_lower):
|
|
return "flux"
|
|
|
|
# Architecture override: filename containing "flux" wins over any base_model
|
|
# field (handles hybrid models like "FluxPony" that CivitAI tags as "Pony"
|
|
# but are architecturally Flux and need the Flux workflow).
|
|
if "flux" in name_lower:
|
|
if "schnell" in name_lower:
|
|
return "flux_schnell"
|
|
return "flux"
|
|
|
|
# Check base_model field first (most reliable from CivitAI)
|
|
if base_lower:
|
|
if "pony" in base_lower:
|
|
return "pony"
|
|
if "illustrious" in base_lower:
|
|
return "illustrious"
|
|
# Flux variants (check specific variants before generic flux)
|
|
if "flux" in base_lower and "schnell" in base_lower:
|
|
return "flux_schnell"
|
|
if "flux" in base_lower:
|
|
return "flux"
|
|
# ZImageTurbo
|
|
if "zimage" in base_lower:
|
|
return "zimage"
|
|
# SD 1.5 variants
|
|
if "lcm" in base_lower and ("sd 1.5" in base_lower or "sd 1.4" in base_lower):
|
|
return "sd15_lcm"
|
|
if "sd 1.5" in base_lower or "sd 1.4" in base_lower:
|
|
return "sd15"
|
|
# SDXL variants (check specific variants before generic sdxl)
|
|
if "sdxl" in base_lower and "lightning" in base_lower:
|
|
return "sdxl_lightning"
|
|
if "sdxl" in base_lower and "turbo" in base_lower:
|
|
return "sdxl_turbo"
|
|
if "sdxl" in base_lower:
|
|
return "sdxl"
|
|
|
|
# Fall back to filename heuristics (check specific variants first)
|
|
# Flux variants take precedence — architecture wins over training dataset
|
|
# (e.g. "FluxPony" hybrids are Flux models trained on Pony data, not SDXL/Pony)
|
|
if "flux" in name_lower and "schnell" in name_lower:
|
|
return "flux_schnell"
|
|
if "flux" in name_lower:
|
|
return "flux"
|
|
if "pony" in name_lower:
|
|
return "pony"
|
|
if "illustrious" in name_lower or "noob" in name_lower:
|
|
return "illustrious"
|
|
# ZImageTurbo
|
|
if "zimage" in name_lower:
|
|
return "zimage"
|
|
# SDXL variants
|
|
if "lightning" in name_lower and any(x in name_lower for x in ["sdxl", "xl"]):
|
|
return "sdxl_lightning"
|
|
if "turbo" in name_lower and any(x in name_lower for x in ["sdxl", "xl"]):
|
|
return "sdxl_turbo"
|
|
# SD 1.5 variants
|
|
if "lcm" in name_lower and any(x in name_lower for x in ["sd15", "sd1.5", "sd_1.5"]):
|
|
return "sd15_lcm"
|
|
if any(x in name_lower for x in ["sd15", "sd1.5", "sd_1.5", "dreamshaper", "realistic", "deliberate", "anything"]):
|
|
return "sd15"
|
|
if any(x in name_lower for x in ["sdxl", "xl_"]):
|
|
return "sdxl"
|
|
|
|
return None
|
|
|
|
|
|
def get_model_generation_defaults(model_name: str, base_model: str | None = None) -> dict[str, Any]:
|
|
"""Get generation defaults for a model based on its family.
|
|
|
|
Detects the model family and returns appropriate default settings for:
|
|
- sampler, scheduler, steps, cfg, width, height
|
|
- quality_prefix, negative_prompt
|
|
|
|
Args:
|
|
model_name: Filename of the model
|
|
base_model: Optional CivitAI base_model field
|
|
|
|
Returns:
|
|
Dict with generation defaults. Falls back to global SDXL defaults if family unknown.
|
|
"""
|
|
family = detect_model_family(model_name, base_model)
|
|
|
|
# Get family-specific defaults or fall back to SDXL defaults
|
|
if family and family in MODEL_FAMILY_DEFAULTS:
|
|
defaults = dict(MODEL_FAMILY_DEFAULTS[family])
|
|
else:
|
|
# Default to SDXL settings for unknown models
|
|
defaults = dict(MODEL_FAMILY_DEFAULTS.get("sdxl", {}))
|
|
|
|
# Ensure all expected keys are present with fallbacks
|
|
defaults.setdefault("sampler", COMFYUI_DEFAULT_SAMPLER)
|
|
defaults.setdefault("scheduler", COMFYUI_DEFAULT_SCHEDULER)
|
|
defaults.setdefault("steps", COMFYUI_DEFAULT_STEPS)
|
|
defaults.setdefault("cfg", COMFYUI_DEFAULT_CFG)
|
|
defaults.setdefault("width", COMFYUI_DEFAULT_WIDTH)
|
|
defaults.setdefault("height", COMFYUI_DEFAULT_HEIGHT)
|
|
defaults.setdefault("quality_prefix", "")
|
|
defaults.setdefault("negative_prompt", "")
|
|
|
|
# Include the detected family for reference
|
|
defaults["family"] = family
|
|
|
|
return defaults
|
|
|
|
|
|
def resolve_orientation(family: str | None, orientation: str = "square") -> tuple[int, int]:
|
|
"""Get width/height for a model family and orientation.
|
|
|
|
Args:
|
|
family: Model family key (e.g. "pony", "sd15", "sdxl") or None for default
|
|
orientation: One of "square", "portrait", "landscape"
|
|
|
|
Returns:
|
|
(width, height) tuple
|
|
"""
|
|
defaults = MODEL_FAMILY_DEFAULTS.get(family or "sdxl", MODEL_FAMILY_DEFAULTS["sdxl"])
|
|
w: int = defaults["width"]
|
|
h: int = defaults["height"]
|
|
fallback = (w, h)
|
|
|
|
if orientation == "portrait":
|
|
pair: tuple[int, int] = defaults.get("portrait", fallback)
|
|
return pair
|
|
if orientation == "landscape":
|
|
pair = defaults.get("landscape", fallback)
|
|
return pair
|
|
return fallback
|
|
|
|
|
|
def get_comfyui_url() -> str:
|
|
"""Get the ComfyUI server URL.
|
|
|
|
Resolution order:
|
|
1. COMFYUI_URL environment variable
|
|
2. config.toml [comfyui].url
|
|
3. Default: http://127.0.0.1:8188
|
|
|
|
Config example:
|
|
[comfyui]
|
|
url = "http://192.168.1.100:8188"
|
|
"""
|
|
# Check environment variable first
|
|
env_url = os.environ.get("COMFYUI_URL")
|
|
if env_url:
|
|
return env_url
|
|
|
|
# Check config file
|
|
config = load_config()
|
|
comfyui_config = config.get("comfyui", {})
|
|
if isinstance(comfyui_config, dict):
|
|
url = comfyui_config.get("url")
|
|
if url:
|
|
return str(url)
|
|
|
|
return COMFYUI_DEFAULT_URL
|
|
|
|
|
|
def get_comfyui_defaults() -> dict[str, Any]:
|
|
"""Get default ComfyUI generation parameters.
|
|
|
|
Resolution order (per parameter):
|
|
1. config.toml [comfyui] section values
|
|
2. Built-in defaults
|
|
|
|
Config example:
|
|
[comfyui]
|
|
url = "http://127.0.0.1:8188"
|
|
default_model = "flux1-dev-fp8.safetensors"
|
|
width = 1024
|
|
height = 1024
|
|
steps = 20
|
|
cfg = 7.0
|
|
sampler = "euler"
|
|
scheduler = "normal"
|
|
|
|
Returns dict with keys: model, width, height, steps, cfg, sampler, scheduler
|
|
"""
|
|
config = load_config()
|
|
comfyui_config = config.get("comfyui", {})
|
|
|
|
defaults: dict[str, Any] = {
|
|
"model": None,
|
|
"width": COMFYUI_DEFAULT_WIDTH,
|
|
"height": COMFYUI_DEFAULT_HEIGHT,
|
|
"steps": COMFYUI_DEFAULT_STEPS,
|
|
"cfg": COMFYUI_DEFAULT_CFG,
|
|
"sampler": COMFYUI_DEFAULT_SAMPLER,
|
|
"scheduler": COMFYUI_DEFAULT_SCHEDULER,
|
|
}
|
|
|
|
if isinstance(comfyui_config, dict):
|
|
# Override with config values if present
|
|
if "default_model" in comfyui_config:
|
|
defaults["model"] = str(comfyui_config["default_model"])
|
|
if "width" in comfyui_config:
|
|
defaults["width"] = int(comfyui_config["width"])
|
|
if "height" in comfyui_config:
|
|
defaults["height"] = int(comfyui_config["height"])
|
|
if "steps" in comfyui_config:
|
|
defaults["steps"] = int(comfyui_config["steps"])
|
|
if "cfg" in comfyui_config:
|
|
defaults["cfg"] = float(comfyui_config["cfg"])
|
|
if "sampler" in comfyui_config:
|
|
defaults["sampler"] = str(comfyui_config["sampler"])
|
|
if "scheduler" in comfyui_config:
|
|
defaults["scheduler"] = str(comfyui_config["scheduler"])
|
|
|
|
return defaults
|