Add configurable model paths
- Add [paths] section to config.toml for custom model directories - Add get_model_paths() function that merges config with defaults - Update get_default_output_path() to check config first - Add --set-path option to tsr config command - Update download_routes.py to use centralized path function - Add tests for path configuration Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
+48
-1
@@ -29,6 +29,7 @@ from tensors.config import (
|
||||
Provider,
|
||||
SortOrder,
|
||||
get_default_output_path,
|
||||
get_model_paths,
|
||||
load_api_key,
|
||||
load_config,
|
||||
save_config,
|
||||
@@ -490,6 +491,7 @@ def _display_download_info(
|
||||
def config(
|
||||
show: Annotated[bool, typer.Option("--show", help="Show current config")] = False,
|
||||
set_key: Annotated[str | None, typer.Option("--set-key", help="Set CivitAI API key")] = None,
|
||||
set_path: Annotated[str | None, typer.Option("--set-path", help="Set model path (TYPE=PATH)")] = None,
|
||||
) -> None:
|
||||
"""Manage configuration."""
|
||||
if set_key:
|
||||
@@ -501,7 +503,29 @@ def config(
|
||||
console.print(f"[green]API key saved to {CONFIG_FILE}[/green]")
|
||||
return
|
||||
|
||||
if show or (not set_key):
|
||||
if set_path:
|
||||
# Parse TYPE=PATH format
|
||||
if "=" not in set_path:
|
||||
console.print("[red]Error: Use format TYPE=PATH (e.g., checkpoints=/opt/models/checkpoints)[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
path_type, path_value = set_path.split("=", 1)
|
||||
path_type = path_type.lower().strip()
|
||||
valid_types = ["checkpoints", "loras", "embeddings", "vae", "controlnet", "upscalers", "other"]
|
||||
|
||||
if path_type not in valid_types:
|
||||
console.print(f"[red]Error: Invalid type '{path_type}'. Valid: {', '.join(valid_types)}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
cfg = load_config()
|
||||
if "paths" not in cfg:
|
||||
cfg["paths"] = {}
|
||||
cfg["paths"][path_type] = path_value.strip()
|
||||
save_config(cfg)
|
||||
console.print(f"[green]Path for {path_type} set to: {path_value}[/green]")
|
||||
return
|
||||
|
||||
if show or (not set_key and not set_path):
|
||||
console.print(f"[bold]Config file:[/bold] {CONFIG_FILE}")
|
||||
console.print(f"[bold]Config exists:[/bold] {CONFIG_FILE.exists()}")
|
||||
|
||||
@@ -512,8 +536,31 @@ def config(
|
||||
else:
|
||||
console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]")
|
||||
|
||||
console.print()
|
||||
console.print("[bold]Model paths:[/bold]")
|
||||
paths = get_model_paths()
|
||||
# Group by unique paths to show cleanly
|
||||
shown_paths: dict[str, list[str]] = {}
|
||||
for model_type, path in paths.items():
|
||||
path_str = str(path)
|
||||
if path_str not in shown_paths:
|
||||
shown_paths[path_str] = []
|
||||
shown_paths[path_str].append(model_type)
|
||||
|
||||
cfg = load_config()
|
||||
configured_paths = cfg.get("paths", {})
|
||||
|
||||
for path_str, types in sorted(shown_paths.items(), key=lambda x: x[0]):
|
||||
is_custom = any(
|
||||
path_str == configured_paths.get(k)
|
||||
for k in ["checkpoints", "loras", "embeddings", "vae", "controlnet", "upscalers", "other"]
|
||||
)
|
||||
marker = " [green](custom)[/green]" if is_custom else " [dim](default)[/dim]"
|
||||
console.print(f" {', '.join(sorted(types))}: {path_str}{marker}")
|
||||
|
||||
console.print()
|
||||
console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]")
|
||||
console.print("[dim]Set paths with: tsr config --set-path checkpoints=/path/to/models[/dim]")
|
||||
|
||||
|
||||
@app.command()
|
||||
|
||||
+58
-5
@@ -25,11 +25,16 @@ GALLERY_DIR = DATA_DIR / "gallery"
|
||||
# Legacy config for migration
|
||||
LEGACY_RC_FILE = Path.home() / ".sftrc"
|
||||
|
||||
# Default download paths by model type
|
||||
# Default download paths by model type (can be overridden in config.toml [paths])
|
||||
DEFAULT_PATHS: dict[str, Path] = {
|
||||
"Checkpoint": MODELS_DIR / "checkpoints",
|
||||
"LORA": MODELS_DIR / "loras",
|
||||
"LoCon": MODELS_DIR / "loras",
|
||||
"TextualInversion": MODELS_DIR / "embeddings",
|
||||
"VAE": MODELS_DIR / "vae",
|
||||
"Controlnet": MODELS_DIR / "controlnet",
|
||||
"Upscaler": MODELS_DIR / "upscalers",
|
||||
"Other": MODELS_DIR / "other",
|
||||
}
|
||||
|
||||
CIVITAI_API_BASE = "https://civitai.com/api/v1"
|
||||
@@ -274,11 +279,59 @@ def load_api_key() -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def get_model_paths() -> dict[str, Path]:
|
||||
"""Get model paths from config, with defaults.
|
||||
|
||||
Config format in config.toml:
|
||||
[paths]
|
||||
checkpoints = "/opt/comfyui/models/checkpoints"
|
||||
loras = "/opt/comfyui/models/loras"
|
||||
embeddings = "/opt/comfyui/models/embeddings"
|
||||
vae = "/opt/comfyui/models/vae"
|
||||
controlnet = "/opt/comfyui/models/controlnet"
|
||||
upscalers = "/opt/comfyui/models/upscale_models"
|
||||
other = "/opt/comfyui/models/other"
|
||||
|
||||
Returns dict mapping CivitAI model types to paths.
|
||||
"""
|
||||
config = load_config()
|
||||
paths_config = config.get("paths", {})
|
||||
|
||||
# Map config keys to CivitAI model types
|
||||
key_to_types = {
|
||||
"checkpoints": ["Checkpoint"],
|
||||
"loras": ["LORA", "LoCon"],
|
||||
"embeddings": ["TextualInversion"],
|
||||
"vae": ["VAE"],
|
||||
"controlnet": ["Controlnet"],
|
||||
"upscalers": ["Upscaler"],
|
||||
"other": ["Other"],
|
||||
}
|
||||
|
||||
# Start with defaults
|
||||
result = dict(DEFAULT_PATHS)
|
||||
|
||||
# Override with config values
|
||||
if isinstance(paths_config, dict):
|
||||
for key, types in key_to_types.items():
|
||||
if key in paths_config:
|
||||
path = Path(paths_config[key])
|
||||
for model_type in types:
|
||||
result[model_type] = path
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_default_output_path(model_type: str | None) -> Path | None:
|
||||
"""Get default output path based on model type."""
|
||||
if model_type and model_type in DEFAULT_PATHS:
|
||||
return DEFAULT_PATHS[model_type]
|
||||
return None
|
||||
"""Get default output path based on model type.
|
||||
|
||||
Checks config.toml [paths] section first, falls back to defaults.
|
||||
"""
|
||||
if not model_type:
|
||||
return None
|
||||
|
||||
paths = get_model_paths()
|
||||
return paths.get(model_type)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
||||
@@ -10,7 +10,7 @@ from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
|
||||
from tensors.api import download_model_with_progress, fetch_civitai_by_hash, fetch_civitai_model, fetch_civitai_model_version
|
||||
from tensors.config import MODELS_DIR, load_api_key
|
||||
from tensors.config import get_default_output_path, get_model_paths, load_api_key
|
||||
from tensors.db import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -75,18 +75,14 @@ def _get_output_dir(version_info: dict[str, Any], override: str | None) -> Path:
|
||||
return Path(override)
|
||||
|
||||
model_type = version_info.get("model", {}).get("type", "Checkpoint")
|
||||
path = get_default_output_path(model_type)
|
||||
|
||||
# Map type to directory
|
||||
type_dirs = {
|
||||
"Checkpoint": MODELS_DIR / "checkpoints",
|
||||
"LORA": MODELS_DIR / "loras",
|
||||
"LoCon": MODELS_DIR / "loras",
|
||||
"TextualInversion": MODELS_DIR / "embeddings",
|
||||
"VAE": MODELS_DIR / "vae",
|
||||
"Controlnet": MODELS_DIR / "controlnet",
|
||||
}
|
||||
if path:
|
||||
return path
|
||||
|
||||
return type_dirs.get(model_type, MODELS_DIR / "other")
|
||||
# Fallback for unknown types
|
||||
paths = get_model_paths()
|
||||
return paths.get("Other", Path.home() / ".local" / "share" / "tensors" / "models" / "other")
|
||||
|
||||
|
||||
_KB = 1024
|
||||
|
||||
@@ -26,6 +26,7 @@ from tensors.config import (
|
||||
ModelType,
|
||||
SortOrder,
|
||||
get_default_output_path,
|
||||
get_model_paths,
|
||||
load_api_key,
|
||||
load_config,
|
||||
save_config,
|
||||
@@ -127,6 +128,46 @@ class TestGetDefaultOutputPath:
|
||||
assert get_default_output_path(None) is None
|
||||
|
||||
|
||||
class TestGetModelPaths:
|
||||
"""Tests for get_model_paths function."""
|
||||
|
||||
def test_returns_dict_with_all_types(self) -> None:
|
||||
"""Test that all model types are included."""
|
||||
paths = get_model_paths()
|
||||
assert isinstance(paths, dict)
|
||||
assert "Checkpoint" in paths
|
||||
assert "LORA" in paths
|
||||
assert "LoCon" in paths
|
||||
assert "TextualInversion" in paths
|
||||
assert "VAE" in paths
|
||||
assert "Controlnet" in paths
|
||||
|
||||
def test_config_override(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test that config.toml paths override defaults."""
|
||||
# Create a config file with custom path
|
||||
config_file = tmp_path / "config.toml"
|
||||
config_file.write_text('[paths]\ncheckpoints = "/custom/checkpoints"\n')
|
||||
monkeypatch.setattr(config, "CONFIG_FILE", config_file)
|
||||
|
||||
paths = get_model_paths()
|
||||
assert paths["Checkpoint"] == Path("/custom/checkpoints")
|
||||
# Other types should still be defaults
|
||||
assert "loras" in str(paths["LORA"])
|
||||
|
||||
def test_get_default_output_path_uses_config(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test that get_default_output_path respects config overrides."""
|
||||
config_file = tmp_path / "config.toml"
|
||||
config_file.write_text('[paths]\nloras = "/custom/loras"\n')
|
||||
monkeypatch.setattr(config, "CONFIG_FILE", config_file)
|
||||
|
||||
result = get_default_output_path("LORA")
|
||||
assert result == Path("/custom/loras")
|
||||
|
||||
# LoCon should also use the loras path
|
||||
result = get_default_output_path("LoCon")
|
||||
assert result == Path("/custom/loras")
|
||||
|
||||
|
||||
class TestLoadApiKey:
|
||||
"""Tests for load_api_key function."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user