Add configurable model paths

- Add [paths] section to config.toml for custom model directories
- Add get_model_paths() function that merges config with defaults
- Update get_default_output_path() to check config first
- Add --set-path option to tsr config command
- Update download_routes.py to use centralized path function
- Add tests for path configuration

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Adam Ladachowski
2026-02-16 13:40:45 +01:00
parent 79657a7b1f
commit 574cdb6abd
4 changed files with 154 additions and 17 deletions
+48 -1
View File
@@ -29,6 +29,7 @@ from tensors.config import (
Provider,
SortOrder,
get_default_output_path,
get_model_paths,
load_api_key,
load_config,
save_config,
@@ -490,6 +491,7 @@ def _display_download_info(
def config(
show: Annotated[bool, typer.Option("--show", help="Show current config")] = False,
set_key: Annotated[str | None, typer.Option("--set-key", help="Set CivitAI API key")] = None,
set_path: Annotated[str | None, typer.Option("--set-path", help="Set model path (TYPE=PATH)")] = None,
) -> None:
"""Manage configuration."""
if set_key:
@@ -501,7 +503,29 @@ def config(
console.print(f"[green]API key saved to {CONFIG_FILE}[/green]")
return
if show or (not set_key):
if set_path:
# Parse TYPE=PATH format
if "=" not in set_path:
console.print("[red]Error: Use format TYPE=PATH (e.g., checkpoints=/opt/models/checkpoints)[/red]")
raise typer.Exit(1)
path_type, path_value = set_path.split("=", 1)
path_type = path_type.lower().strip()
valid_types = ["checkpoints", "loras", "embeddings", "vae", "controlnet", "upscalers", "other"]
if path_type not in valid_types:
console.print(f"[red]Error: Invalid type '{path_type}'. Valid: {', '.join(valid_types)}[/red]")
raise typer.Exit(1)
cfg = load_config()
if "paths" not in cfg:
cfg["paths"] = {}
cfg["paths"][path_type] = path_value.strip()
save_config(cfg)
console.print(f"[green]Path for {path_type} set to: {path_value}[/green]")
return
if show or (not set_key and not set_path):
console.print(f"[bold]Config file:[/bold] {CONFIG_FILE}")
console.print(f"[bold]Config exists:[/bold] {CONFIG_FILE.exists()}")
@@ -512,8 +536,31 @@ def config(
else:
console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]")
console.print()
console.print("[bold]Model paths:[/bold]")
paths = get_model_paths()
# Group by unique paths to show cleanly
shown_paths: dict[str, list[str]] = {}
for model_type, path in paths.items():
path_str = str(path)
if path_str not in shown_paths:
shown_paths[path_str] = []
shown_paths[path_str].append(model_type)
cfg = load_config()
configured_paths = cfg.get("paths", {})
for path_str, types in sorted(shown_paths.items(), key=lambda x: x[0]):
is_custom = any(
path_str == configured_paths.get(k)
for k in ["checkpoints", "loras", "embeddings", "vae", "controlnet", "upscalers", "other"]
)
marker = " [green](custom)[/green]" if is_custom else " [dim](default)[/dim]"
console.print(f" {', '.join(sorted(types))}: {path_str}{marker}")
console.print()
console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]")
console.print("[dim]Set paths with: tsr config --set-path checkpoints=/path/to/models[/dim]")
@app.command()
+58 -5
View File
@@ -25,11 +25,16 @@ GALLERY_DIR = DATA_DIR / "gallery"
# Legacy config for migration
LEGACY_RC_FILE = Path.home() / ".sftrc"
# Default download paths by model type
# Default download paths by model type (can be overridden in config.toml [paths])
DEFAULT_PATHS: dict[str, Path] = {
"Checkpoint": MODELS_DIR / "checkpoints",
"LORA": MODELS_DIR / "loras",
"LoCon": MODELS_DIR / "loras",
"TextualInversion": MODELS_DIR / "embeddings",
"VAE": MODELS_DIR / "vae",
"Controlnet": MODELS_DIR / "controlnet",
"Upscaler": MODELS_DIR / "upscalers",
"Other": MODELS_DIR / "other",
}
CIVITAI_API_BASE = "https://civitai.com/api/v1"
@@ -274,11 +279,59 @@ def load_api_key() -> str | None:
return None
def get_model_paths() -> dict[str, Path]:
"""Get model paths from config, with defaults.
Config format in config.toml:
[paths]
checkpoints = "/opt/comfyui/models/checkpoints"
loras = "/opt/comfyui/models/loras"
embeddings = "/opt/comfyui/models/embeddings"
vae = "/opt/comfyui/models/vae"
controlnet = "/opt/comfyui/models/controlnet"
upscalers = "/opt/comfyui/models/upscale_models"
other = "/opt/comfyui/models/other"
Returns dict mapping CivitAI model types to paths.
"""
config = load_config()
paths_config = config.get("paths", {})
# Map config keys to CivitAI model types
key_to_types = {
"checkpoints": ["Checkpoint"],
"loras": ["LORA", "LoCon"],
"embeddings": ["TextualInversion"],
"vae": ["VAE"],
"controlnet": ["Controlnet"],
"upscalers": ["Upscaler"],
"other": ["Other"],
}
# Start with defaults
result = dict(DEFAULT_PATHS)
# Override with config values
if isinstance(paths_config, dict):
for key, types in key_to_types.items():
if key in paths_config:
path = Path(paths_config[key])
for model_type in types:
result[model_type] = path
return result
def get_default_output_path(model_type: str | None) -> Path | None:
"""Get default output path based on model type."""
if model_type and model_type in DEFAULT_PATHS:
return DEFAULT_PATHS[model_type]
return None
"""Get default output path based on model type.
Checks config.toml [paths] section first, falls back to defaults.
"""
if not model_type:
return None
paths = get_model_paths()
return paths.get(model_type)
# ============================================================================
+7 -11
View File
@@ -10,7 +10,7 @@ from fastapi import APIRouter, BackgroundTasks, HTTPException
from pydantic import BaseModel as PydanticBaseModel
from tensors.api import download_model_with_progress, fetch_civitai_by_hash, fetch_civitai_model, fetch_civitai_model_version
from tensors.config import MODELS_DIR, load_api_key
from tensors.config import get_default_output_path, get_model_paths, load_api_key
from tensors.db import Database
logger = logging.getLogger(__name__)
@@ -75,18 +75,14 @@ def _get_output_dir(version_info: dict[str, Any], override: str | None) -> Path:
return Path(override)
model_type = version_info.get("model", {}).get("type", "Checkpoint")
path = get_default_output_path(model_type)
# Map type to directory
type_dirs = {
"Checkpoint": MODELS_DIR / "checkpoints",
"LORA": MODELS_DIR / "loras",
"LoCon": MODELS_DIR / "loras",
"TextualInversion": MODELS_DIR / "embeddings",
"VAE": MODELS_DIR / "vae",
"Controlnet": MODELS_DIR / "controlnet",
}
if path:
return path
return type_dirs.get(model_type, MODELS_DIR / "other")
# Fallback for unknown types
paths = get_model_paths()
return paths.get("Other", Path.home() / ".local" / "share" / "tensors" / "models" / "other")
_KB = 1024