Add configurable model paths

- Add [paths] section to config.toml for custom model directories
- Add get_model_paths() function that merges config with defaults
- Update get_default_output_path() to check config first
- Add --set-path option to tsr config command
- Update download_routes.py to use centralized path function
- Add tests for path configuration

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Adam Ladachowski
2026-02-16 13:40:45 +01:00
parent 79657a7b1f
commit 574cdb6abd
4 changed files with 154 additions and 17 deletions
+48 -1
View File
@@ -29,6 +29,7 @@ from tensors.config import (
Provider, Provider,
SortOrder, SortOrder,
get_default_output_path, get_default_output_path,
get_model_paths,
load_api_key, load_api_key,
load_config, load_config,
save_config, save_config,
@@ -490,6 +491,7 @@ def _display_download_info(
def config( def config(
show: Annotated[bool, typer.Option("--show", help="Show current config")] = False, show: Annotated[bool, typer.Option("--show", help="Show current config")] = False,
set_key: Annotated[str | None, typer.Option("--set-key", help="Set CivitAI API key")] = None, set_key: Annotated[str | None, typer.Option("--set-key", help="Set CivitAI API key")] = None,
set_path: Annotated[str | None, typer.Option("--set-path", help="Set model path (TYPE=PATH)")] = None,
) -> None: ) -> None:
"""Manage configuration.""" """Manage configuration."""
if set_key: if set_key:
@@ -501,7 +503,29 @@ def config(
console.print(f"[green]API key saved to {CONFIG_FILE}[/green]") console.print(f"[green]API key saved to {CONFIG_FILE}[/green]")
return return
if show or (not set_key): if set_path:
# Parse TYPE=PATH format
if "=" not in set_path:
console.print("[red]Error: Use format TYPE=PATH (e.g., checkpoints=/opt/models/checkpoints)[/red]")
raise typer.Exit(1)
path_type, path_value = set_path.split("=", 1)
path_type = path_type.lower().strip()
valid_types = ["checkpoints", "loras", "embeddings", "vae", "controlnet", "upscalers", "other"]
if path_type not in valid_types:
console.print(f"[red]Error: Invalid type '{path_type}'. Valid: {', '.join(valid_types)}[/red]")
raise typer.Exit(1)
cfg = load_config()
if "paths" not in cfg:
cfg["paths"] = {}
cfg["paths"][path_type] = path_value.strip()
save_config(cfg)
console.print(f"[green]Path for {path_type} set to: {path_value}[/green]")
return
if show or (not set_key and not set_path):
console.print(f"[bold]Config file:[/bold] {CONFIG_FILE}") console.print(f"[bold]Config file:[/bold] {CONFIG_FILE}")
console.print(f"[bold]Config exists:[/bold] {CONFIG_FILE.exists()}") console.print(f"[bold]Config exists:[/bold] {CONFIG_FILE.exists()}")
@@ -512,8 +536,31 @@ def config(
else: else:
console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]") console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]")
console.print()
console.print("[bold]Model paths:[/bold]")
paths = get_model_paths()
# Group by unique paths to show cleanly
shown_paths: dict[str, list[str]] = {}
for model_type, path in paths.items():
path_str = str(path)
if path_str not in shown_paths:
shown_paths[path_str] = []
shown_paths[path_str].append(model_type)
cfg = load_config()
configured_paths = cfg.get("paths", {})
for path_str, types in sorted(shown_paths.items(), key=lambda x: x[0]):
is_custom = any(
path_str == configured_paths.get(k)
for k in ["checkpoints", "loras", "embeddings", "vae", "controlnet", "upscalers", "other"]
)
marker = " [green](custom)[/green]" if is_custom else " [dim](default)[/dim]"
console.print(f" {', '.join(sorted(types))}: {path_str}{marker}")
console.print() console.print()
console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]") console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]")
console.print("[dim]Set paths with: tsr config --set-path checkpoints=/path/to/models[/dim]")
@app.command() @app.command()
+57 -4
View File
@@ -25,11 +25,16 @@ GALLERY_DIR = DATA_DIR / "gallery"
# Legacy config for migration # Legacy config for migration
LEGACY_RC_FILE = Path.home() / ".sftrc" LEGACY_RC_FILE = Path.home() / ".sftrc"
# Default download paths by model type # Default download paths by model type (can be overridden in config.toml [paths])
DEFAULT_PATHS: dict[str, Path] = { DEFAULT_PATHS: dict[str, Path] = {
"Checkpoint": MODELS_DIR / "checkpoints", "Checkpoint": MODELS_DIR / "checkpoints",
"LORA": MODELS_DIR / "loras", "LORA": MODELS_DIR / "loras",
"LoCon": MODELS_DIR / "loras", "LoCon": MODELS_DIR / "loras",
"TextualInversion": MODELS_DIR / "embeddings",
"VAE": MODELS_DIR / "vae",
"Controlnet": MODELS_DIR / "controlnet",
"Upscaler": MODELS_DIR / "upscalers",
"Other": MODELS_DIR / "other",
} }
CIVITAI_API_BASE = "https://civitai.com/api/v1" CIVITAI_API_BASE = "https://civitai.com/api/v1"
@@ -274,12 +279,60 @@ def load_api_key() -> str | None:
return None return None
def get_model_paths() -> dict[str, Path]:
"""Get model paths from config, with defaults.
Config format in config.toml:
[paths]
checkpoints = "/opt/comfyui/models/checkpoints"
loras = "/opt/comfyui/models/loras"
embeddings = "/opt/comfyui/models/embeddings"
vae = "/opt/comfyui/models/vae"
controlnet = "/opt/comfyui/models/controlnet"
upscalers = "/opt/comfyui/models/upscale_models"
other = "/opt/comfyui/models/other"
Returns dict mapping CivitAI model types to paths.
"""
config = load_config()
paths_config = config.get("paths", {})
# Map config keys to CivitAI model types
key_to_types = {
"checkpoints": ["Checkpoint"],
"loras": ["LORA", "LoCon"],
"embeddings": ["TextualInversion"],
"vae": ["VAE"],
"controlnet": ["Controlnet"],
"upscalers": ["Upscaler"],
"other": ["Other"],
}
# Start with defaults
result = dict(DEFAULT_PATHS)
# Override with config values
if isinstance(paths_config, dict):
for key, types in key_to_types.items():
if key in paths_config:
path = Path(paths_config[key])
for model_type in types:
result[model_type] = path
return result
def get_default_output_path(model_type: str | None) -> Path | None: def get_default_output_path(model_type: str | None) -> Path | None:
"""Get default output path based on model type.""" """Get default output path based on model type.
if model_type and model_type in DEFAULT_PATHS:
return DEFAULT_PATHS[model_type] Checks config.toml [paths] section first, falls back to defaults.
"""
if not model_type:
return None return None
paths = get_model_paths()
return paths.get(model_type)
# ============================================================================ # ============================================================================
# Remote Server Configuration # Remote Server Configuration
+7 -11
View File
@@ -10,7 +10,7 @@ from fastapi import APIRouter, BackgroundTasks, HTTPException
from pydantic import BaseModel as PydanticBaseModel from pydantic import BaseModel as PydanticBaseModel
from tensors.api import download_model_with_progress, fetch_civitai_by_hash, fetch_civitai_model, fetch_civitai_model_version from tensors.api import download_model_with_progress, fetch_civitai_by_hash, fetch_civitai_model, fetch_civitai_model_version
from tensors.config import MODELS_DIR, load_api_key from tensors.config import get_default_output_path, get_model_paths, load_api_key
from tensors.db import Database from tensors.db import Database
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -75,18 +75,14 @@ def _get_output_dir(version_info: dict[str, Any], override: str | None) -> Path:
return Path(override) return Path(override)
model_type = version_info.get("model", {}).get("type", "Checkpoint") model_type = version_info.get("model", {}).get("type", "Checkpoint")
path = get_default_output_path(model_type)
# Map type to directory if path:
type_dirs = { return path
"Checkpoint": MODELS_DIR / "checkpoints",
"LORA": MODELS_DIR / "loras",
"LoCon": MODELS_DIR / "loras",
"TextualInversion": MODELS_DIR / "embeddings",
"VAE": MODELS_DIR / "vae",
"Controlnet": MODELS_DIR / "controlnet",
}
return type_dirs.get(model_type, MODELS_DIR / "other") # Fallback for unknown types
paths = get_model_paths()
return paths.get("Other", Path.home() / ".local" / "share" / "tensors" / "models" / "other")
_KB = 1024 _KB = 1024
+41
View File
@@ -26,6 +26,7 @@ from tensors.config import (
ModelType, ModelType,
SortOrder, SortOrder,
get_default_output_path, get_default_output_path,
get_model_paths,
load_api_key, load_api_key,
load_config, load_config,
save_config, save_config,
@@ -127,6 +128,46 @@ class TestGetDefaultOutputPath:
assert get_default_output_path(None) is None assert get_default_output_path(None) is None
class TestGetModelPaths:
"""Tests for get_model_paths function."""
def test_returns_dict_with_all_types(self) -> None:
"""Test that all model types are included."""
paths = get_model_paths()
assert isinstance(paths, dict)
assert "Checkpoint" in paths
assert "LORA" in paths
assert "LoCon" in paths
assert "TextualInversion" in paths
assert "VAE" in paths
assert "Controlnet" in paths
def test_config_override(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that config.toml paths override defaults."""
# Create a config file with custom path
config_file = tmp_path / "config.toml"
config_file.write_text('[paths]\ncheckpoints = "/custom/checkpoints"\n')
monkeypatch.setattr(config, "CONFIG_FILE", config_file)
paths = get_model_paths()
assert paths["Checkpoint"] == Path("/custom/checkpoints")
# Other types should still be defaults
assert "loras" in str(paths["LORA"])
def test_get_default_output_path_uses_config(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that get_default_output_path respects config overrides."""
config_file = tmp_path / "config.toml"
config_file.write_text('[paths]\nloras = "/custom/loras"\n')
monkeypatch.setattr(config, "CONFIG_FILE", config_file)
result = get_default_output_path("LORA")
assert result == Path("/custom/loras")
# LoCon should also use the loras path
result = get_default_output_path("LoCon")
assert result == Path("/custom/loras")
class TestLoadApiKey: class TestLoadApiKey:
"""Tests for load_api_key function.""" """Tests for load_api_key function."""