80faead7eb
📁 Files changed: 18 📝 Lines changed: 494 • deploy.md • TASK.md • justfile • deploy.sh • config.py • __init__.py • generate_routes.py • models_routes.py • routes.py • sd_client.py • index-CcuP2dTH.css • index-DmOZ-7Sw.js • index-J_qzb7Jl.js • index-QncGJEyk.css • index.html • client.ts • GenerateView.vue • app.ts
301 lines
8.2 KiB
Python
301 lines
8.2 KiB
Python
"""Configuration, constants, and enums for tsr CLI."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import tomllib
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
# ============================================================================
|
|
# XDG Base Directory Configuration
|
|
# ============================================================================
|
|
|
|
# Config: ~/.config/tensors/config.toml
|
|
# Data: ~/.local/share/tensors/models/, ~/.local/share/tensors/metadata/
|
|
CONFIG_DIR = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config")) / "tensors"
|
|
CONFIG_FILE = CONFIG_DIR / "config.toml"
|
|
|
|
DATA_DIR = Path(os.environ.get("XDG_DATA_HOME", Path.home() / ".local" / "share")) / "tensors"
|
|
MODELS_DIR = DATA_DIR / "models"
|
|
METADATA_DIR = DATA_DIR / "metadata"
|
|
GALLERY_DIR = DATA_DIR / "gallery"
|
|
|
|
# Legacy config for migration
|
|
LEGACY_RC_FILE = Path.home() / ".sftrc"
|
|
|
|
# Default download paths by model type
|
|
DEFAULT_PATHS: dict[str, Path] = {
|
|
"Checkpoint": MODELS_DIR / "checkpoints",
|
|
"LORA": MODELS_DIR / "loras",
|
|
"LoCon": MODELS_DIR / "loras",
|
|
}
|
|
|
|
CIVITAI_API_BASE = "https://civitai.com/api/v1"
|
|
CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models"
|
|
|
|
|
|
# ============================================================================
|
|
# Enums for CLI
|
|
# ============================================================================
|
|
|
|
|
|
class ModelType(str, Enum):
|
|
"""CivitAI model types."""
|
|
|
|
checkpoint = "checkpoint"
|
|
lora = "lora"
|
|
embedding = "embedding"
|
|
vae = "vae"
|
|
controlnet = "controlnet"
|
|
locon = "locon"
|
|
|
|
def to_api(self) -> str:
|
|
"""Convert to CivitAI API value."""
|
|
mapping = {
|
|
"checkpoint": "Checkpoint",
|
|
"lora": "LORA",
|
|
"embedding": "TextualInversion",
|
|
"vae": "VAE",
|
|
"controlnet": "Controlnet",
|
|
"locon": "LoCon",
|
|
}
|
|
return mapping[self.value]
|
|
|
|
|
|
class BaseModel(str, Enum):
|
|
"""Common base models."""
|
|
|
|
sd15 = "sd15"
|
|
sdxl = "sdxl"
|
|
pony = "pony"
|
|
flux = "flux"
|
|
illustrious = "illustrious"
|
|
|
|
def to_api(self) -> str:
|
|
"""Convert to CivitAI API value."""
|
|
mapping = {
|
|
"sd15": "SD 1.5",
|
|
"sdxl": "SDXL 1.0",
|
|
"pony": "Pony",
|
|
"flux": "Flux.1 D",
|
|
"illustrious": "Illustrious",
|
|
}
|
|
return mapping[self.value]
|
|
|
|
|
|
class SortOrder(str, Enum):
|
|
"""Sort options for search."""
|
|
|
|
downloads = "downloads"
|
|
rating = "rating"
|
|
newest = "newest"
|
|
|
|
def to_api(self) -> str:
|
|
"""Convert to CivitAI API value."""
|
|
mapping = {
|
|
"downloads": "Most Downloaded",
|
|
"rating": "Highest Rated",
|
|
"newest": "Newest",
|
|
}
|
|
return mapping[self.value]
|
|
|
|
|
|
# ============================================================================
|
|
# Config Functions
|
|
# ============================================================================
|
|
|
|
|
|
def load_config() -> dict[str, Any]:
|
|
"""Load configuration from TOML config file."""
|
|
if CONFIG_FILE.exists():
|
|
with CONFIG_FILE.open("rb") as f:
|
|
return tomllib.load(f)
|
|
return {}
|
|
|
|
|
|
def save_config(config: dict[str, Any]) -> None:
|
|
"""Save configuration to TOML config file."""
|
|
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
lines: list[str] = []
|
|
# Write scalar values first (before any sections)
|
|
for key, value in config.items():
|
|
if not isinstance(value, dict):
|
|
if isinstance(value, str):
|
|
lines.append(f'{key} = "{value}"')
|
|
else:
|
|
lines.append(f"{key} = {value}")
|
|
|
|
if lines:
|
|
lines.append("")
|
|
|
|
# Then write sections (dicts)
|
|
for key, value in config.items():
|
|
if isinstance(value, dict):
|
|
lines.append(f"[{key}]")
|
|
for k, v in value.items():
|
|
if isinstance(v, str):
|
|
lines.append(f'{k} = "{v}"')
|
|
else:
|
|
lines.append(f"{k} = {v}")
|
|
lines.append("")
|
|
|
|
CONFIG_FILE.write_text("\n".join(lines) + "\n")
|
|
|
|
|
|
def load_api_key() -> str | None:
|
|
"""Load API key from config file or CIVITAI_API_KEY env var."""
|
|
# Check environment variable first
|
|
env_key = os.environ.get("CIVITAI_API_KEY")
|
|
if env_key:
|
|
return env_key
|
|
|
|
# Check TOML config file
|
|
config = load_config()
|
|
api_section = config.get("api", {})
|
|
if isinstance(api_section, dict):
|
|
key = api_section.get("civitai_key")
|
|
if key:
|
|
return str(key)
|
|
|
|
# Fall back to legacy RC file for migration
|
|
if LEGACY_RC_FILE.exists():
|
|
content = LEGACY_RC_FILE.read_text().strip()
|
|
if content:
|
|
return content
|
|
return None
|
|
|
|
|
|
def get_default_output_path(model_type: str | None) -> Path | None:
|
|
"""Get default output path based on model type."""
|
|
if model_type and model_type in DEFAULT_PATHS:
|
|
return DEFAULT_PATHS[model_type]
|
|
return None
|
|
|
|
|
|
# ============================================================================
|
|
# Remote Server Configuration
|
|
# ============================================================================
|
|
|
|
|
|
def get_remotes() -> dict[str, str]:
|
|
"""Get configured remote servers.
|
|
|
|
Returns a dict mapping names to URLs, e.g., {"junkpile": "http://junkpile:8080"}
|
|
"""
|
|
config = load_config()
|
|
remotes = config.get("remotes", {})
|
|
return dict(remotes) if isinstance(remotes, dict) else {}
|
|
|
|
|
|
def get_default_remote() -> str | None:
|
|
"""Get the default remote name or URL."""
|
|
config = load_config()
|
|
return config.get("default_remote")
|
|
|
|
|
|
def resolve_remote(remote: str | None) -> str | None:
|
|
"""Resolve a remote name or URL to a full URL.
|
|
|
|
Args:
|
|
remote: Remote name (from config), URL, or None
|
|
|
|
Returns:
|
|
Full URL or None if no remote specified and no default
|
|
"""
|
|
if remote is None:
|
|
# Check for default remote
|
|
default = get_default_remote()
|
|
if default:
|
|
remote = default
|
|
else:
|
|
return None
|
|
|
|
# Check if it's a URL (starts with http:// or https://)
|
|
if remote.startswith(("http://", "https://")):
|
|
return remote
|
|
|
|
# Look up in configured remotes
|
|
remotes = get_remotes()
|
|
if remote in remotes:
|
|
return remotes[remote]
|
|
|
|
# Treat as hostname with default port
|
|
return f"http://{remote}:8080"
|
|
|
|
|
|
def save_remote(name: str, url: str) -> None:
|
|
"""Save a remote server configuration."""
|
|
config = load_config()
|
|
if "remotes" not in config:
|
|
config["remotes"] = {}
|
|
config["remotes"][name] = url
|
|
save_config(config)
|
|
|
|
|
|
def set_default_remote(name: str | None) -> None:
|
|
"""Set the default remote."""
|
|
config = load_config()
|
|
if name is None:
|
|
config.pop("default_remote", None)
|
|
else:
|
|
config["default_remote"] = name
|
|
save_config(config)
|
|
|
|
|
|
# ============================================================================
|
|
# SD Server Configuration
|
|
# ============================================================================
|
|
|
|
SD_SERVER_DEFAULT_URL = "http://localhost:1234"
|
|
|
|
|
|
def get_sd_server_url() -> str:
|
|
"""Get the sd-server URL.
|
|
|
|
Resolution order:
|
|
1. SD_SERVER_URL environment variable
|
|
2. config.toml [server].sd_server_url
|
|
3. Default: http://localhost:1234
|
|
"""
|
|
# Check environment variable first
|
|
env_url = os.environ.get("SD_SERVER_URL")
|
|
if env_url:
|
|
return env_url
|
|
|
|
# Check config file
|
|
config = load_config()
|
|
server_config = config.get("server", {})
|
|
if isinstance(server_config, dict):
|
|
url = server_config.get("sd_server_url")
|
|
if url:
|
|
return str(url)
|
|
|
|
return SD_SERVER_DEFAULT_URL
|
|
|
|
|
|
def get_sd_server_api_key() -> str | None:
|
|
"""Get the sd-server API key.
|
|
|
|
Resolution order:
|
|
1. SD_SERVER_API_KEY environment variable
|
|
2. config.toml [server].sd_server_api_key
|
|
3. None (no authentication)
|
|
"""
|
|
# Check environment variable first
|
|
env_key = os.environ.get("SD_SERVER_API_KEY")
|
|
if env_key:
|
|
return env_key
|
|
|
|
# Check config file
|
|
config = load_config()
|
|
server_config = config.get("server", {})
|
|
if isinstance(server_config, dict):
|
|
key = server_config.get("sd_server_api_key")
|
|
if key:
|
|
return str(key)
|
|
|
|
return None
|