Add configurable model paths
- Add [paths] section to config.toml for custom model directories - Add get_model_paths() function that merges config with defaults - Update get_default_output_path() to check config first - Add --set-path option to tsr config command - Update download_routes.py to use centralized path function - Add tests for path configuration Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
+48
-1
@@ -29,6 +29,7 @@ from tensors.config import (
|
|||||||
Provider,
|
Provider,
|
||||||
SortOrder,
|
SortOrder,
|
||||||
get_default_output_path,
|
get_default_output_path,
|
||||||
|
get_model_paths,
|
||||||
load_api_key,
|
load_api_key,
|
||||||
load_config,
|
load_config,
|
||||||
save_config,
|
save_config,
|
||||||
@@ -490,6 +491,7 @@ def _display_download_info(
|
|||||||
def config(
|
def config(
|
||||||
show: Annotated[bool, typer.Option("--show", help="Show current config")] = False,
|
show: Annotated[bool, typer.Option("--show", help="Show current config")] = False,
|
||||||
set_key: Annotated[str | None, typer.Option("--set-key", help="Set CivitAI API key")] = None,
|
set_key: Annotated[str | None, typer.Option("--set-key", help="Set CivitAI API key")] = None,
|
||||||
|
set_path: Annotated[str | None, typer.Option("--set-path", help="Set model path (TYPE=PATH)")] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Manage configuration."""
|
"""Manage configuration."""
|
||||||
if set_key:
|
if set_key:
|
||||||
@@ -501,7 +503,29 @@ def config(
|
|||||||
console.print(f"[green]API key saved to {CONFIG_FILE}[/green]")
|
console.print(f"[green]API key saved to {CONFIG_FILE}[/green]")
|
||||||
return
|
return
|
||||||
|
|
||||||
if show or (not set_key):
|
if set_path:
|
||||||
|
# Parse TYPE=PATH format
|
||||||
|
if "=" not in set_path:
|
||||||
|
console.print("[red]Error: Use format TYPE=PATH (e.g., checkpoints=/opt/models/checkpoints)[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
path_type, path_value = set_path.split("=", 1)
|
||||||
|
path_type = path_type.lower().strip()
|
||||||
|
valid_types = ["checkpoints", "loras", "embeddings", "vae", "controlnet", "upscalers", "other"]
|
||||||
|
|
||||||
|
if path_type not in valid_types:
|
||||||
|
console.print(f"[red]Error: Invalid type '{path_type}'. Valid: {', '.join(valid_types)}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
cfg = load_config()
|
||||||
|
if "paths" not in cfg:
|
||||||
|
cfg["paths"] = {}
|
||||||
|
cfg["paths"][path_type] = path_value.strip()
|
||||||
|
save_config(cfg)
|
||||||
|
console.print(f"[green]Path for {path_type} set to: {path_value}[/green]")
|
||||||
|
return
|
||||||
|
|
||||||
|
if show or (not set_key and not set_path):
|
||||||
console.print(f"[bold]Config file:[/bold] {CONFIG_FILE}")
|
console.print(f"[bold]Config file:[/bold] {CONFIG_FILE}")
|
||||||
console.print(f"[bold]Config exists:[/bold] {CONFIG_FILE.exists()}")
|
console.print(f"[bold]Config exists:[/bold] {CONFIG_FILE.exists()}")
|
||||||
|
|
||||||
@@ -512,8 +536,31 @@ def config(
|
|||||||
else:
|
else:
|
||||||
console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]")
|
console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]")
|
||||||
|
|
||||||
|
console.print()
|
||||||
|
console.print("[bold]Model paths:[/bold]")
|
||||||
|
paths = get_model_paths()
|
||||||
|
# Group by unique paths to show cleanly
|
||||||
|
shown_paths: dict[str, list[str]] = {}
|
||||||
|
for model_type, path in paths.items():
|
||||||
|
path_str = str(path)
|
||||||
|
if path_str not in shown_paths:
|
||||||
|
shown_paths[path_str] = []
|
||||||
|
shown_paths[path_str].append(model_type)
|
||||||
|
|
||||||
|
cfg = load_config()
|
||||||
|
configured_paths = cfg.get("paths", {})
|
||||||
|
|
||||||
|
for path_str, types in sorted(shown_paths.items(), key=lambda x: x[0]):
|
||||||
|
is_custom = any(
|
||||||
|
path_str == configured_paths.get(k)
|
||||||
|
for k in ["checkpoints", "loras", "embeddings", "vae", "controlnet", "upscalers", "other"]
|
||||||
|
)
|
||||||
|
marker = " [green](custom)[/green]" if is_custom else " [dim](default)[/dim]"
|
||||||
|
console.print(f" {', '.join(sorted(types))}: {path_str}{marker}")
|
||||||
|
|
||||||
console.print()
|
console.print()
|
||||||
console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]")
|
console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]")
|
||||||
|
console.print("[dim]Set paths with: tsr config --set-path checkpoints=/path/to/models[/dim]")
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
|
|||||||
+58
-5
@@ -25,11 +25,16 @@ GALLERY_DIR = DATA_DIR / "gallery"
|
|||||||
# Legacy config for migration
|
# Legacy config for migration
|
||||||
LEGACY_RC_FILE = Path.home() / ".sftrc"
|
LEGACY_RC_FILE = Path.home() / ".sftrc"
|
||||||
|
|
||||||
# Default download paths by model type
|
# Default download paths by model type (can be overridden in config.toml [paths])
|
||||||
DEFAULT_PATHS: dict[str, Path] = {
|
DEFAULT_PATHS: dict[str, Path] = {
|
||||||
"Checkpoint": MODELS_DIR / "checkpoints",
|
"Checkpoint": MODELS_DIR / "checkpoints",
|
||||||
"LORA": MODELS_DIR / "loras",
|
"LORA": MODELS_DIR / "loras",
|
||||||
"LoCon": MODELS_DIR / "loras",
|
"LoCon": MODELS_DIR / "loras",
|
||||||
|
"TextualInversion": MODELS_DIR / "embeddings",
|
||||||
|
"VAE": MODELS_DIR / "vae",
|
||||||
|
"Controlnet": MODELS_DIR / "controlnet",
|
||||||
|
"Upscaler": MODELS_DIR / "upscalers",
|
||||||
|
"Other": MODELS_DIR / "other",
|
||||||
}
|
}
|
||||||
|
|
||||||
CIVITAI_API_BASE = "https://civitai.com/api/v1"
|
CIVITAI_API_BASE = "https://civitai.com/api/v1"
|
||||||
@@ -274,11 +279,59 @@ def load_api_key() -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_paths() -> dict[str, Path]:
|
||||||
|
"""Get model paths from config, with defaults.
|
||||||
|
|
||||||
|
Config format in config.toml:
|
||||||
|
[paths]
|
||||||
|
checkpoints = "/opt/comfyui/models/checkpoints"
|
||||||
|
loras = "/opt/comfyui/models/loras"
|
||||||
|
embeddings = "/opt/comfyui/models/embeddings"
|
||||||
|
vae = "/opt/comfyui/models/vae"
|
||||||
|
controlnet = "/opt/comfyui/models/controlnet"
|
||||||
|
upscalers = "/opt/comfyui/models/upscale_models"
|
||||||
|
other = "/opt/comfyui/models/other"
|
||||||
|
|
||||||
|
Returns dict mapping CivitAI model types to paths.
|
||||||
|
"""
|
||||||
|
config = load_config()
|
||||||
|
paths_config = config.get("paths", {})
|
||||||
|
|
||||||
|
# Map config keys to CivitAI model types
|
||||||
|
key_to_types = {
|
||||||
|
"checkpoints": ["Checkpoint"],
|
||||||
|
"loras": ["LORA", "LoCon"],
|
||||||
|
"embeddings": ["TextualInversion"],
|
||||||
|
"vae": ["VAE"],
|
||||||
|
"controlnet": ["Controlnet"],
|
||||||
|
"upscalers": ["Upscaler"],
|
||||||
|
"other": ["Other"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Start with defaults
|
||||||
|
result = dict(DEFAULT_PATHS)
|
||||||
|
|
||||||
|
# Override with config values
|
||||||
|
if isinstance(paths_config, dict):
|
||||||
|
for key, types in key_to_types.items():
|
||||||
|
if key in paths_config:
|
||||||
|
path = Path(paths_config[key])
|
||||||
|
for model_type in types:
|
||||||
|
result[model_type] = path
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def get_default_output_path(model_type: str | None) -> Path | None:
|
def get_default_output_path(model_type: str | None) -> Path | None:
|
||||||
"""Get default output path based on model type."""
|
"""Get default output path based on model type.
|
||||||
if model_type and model_type in DEFAULT_PATHS:
|
|
||||||
return DEFAULT_PATHS[model_type]
|
Checks config.toml [paths] section first, falls back to defaults.
|
||||||
return None
|
"""
|
||||||
|
if not model_type:
|
||||||
|
return None
|
||||||
|
|
||||||
|
paths = get_model_paths()
|
||||||
|
return paths.get(model_type)
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from fastapi import APIRouter, BackgroundTasks, HTTPException
|
|||||||
from pydantic import BaseModel as PydanticBaseModel
|
from pydantic import BaseModel as PydanticBaseModel
|
||||||
|
|
||||||
from tensors.api import download_model_with_progress, fetch_civitai_by_hash, fetch_civitai_model, fetch_civitai_model_version
|
from tensors.api import download_model_with_progress, fetch_civitai_by_hash, fetch_civitai_model, fetch_civitai_model_version
|
||||||
from tensors.config import MODELS_DIR, load_api_key
|
from tensors.config import get_default_output_path, get_model_paths, load_api_key
|
||||||
from tensors.db import Database
|
from tensors.db import Database
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -75,18 +75,14 @@ def _get_output_dir(version_info: dict[str, Any], override: str | None) -> Path:
|
|||||||
return Path(override)
|
return Path(override)
|
||||||
|
|
||||||
model_type = version_info.get("model", {}).get("type", "Checkpoint")
|
model_type = version_info.get("model", {}).get("type", "Checkpoint")
|
||||||
|
path = get_default_output_path(model_type)
|
||||||
|
|
||||||
# Map type to directory
|
if path:
|
||||||
type_dirs = {
|
return path
|
||||||
"Checkpoint": MODELS_DIR / "checkpoints",
|
|
||||||
"LORA": MODELS_DIR / "loras",
|
|
||||||
"LoCon": MODELS_DIR / "loras",
|
|
||||||
"TextualInversion": MODELS_DIR / "embeddings",
|
|
||||||
"VAE": MODELS_DIR / "vae",
|
|
||||||
"Controlnet": MODELS_DIR / "controlnet",
|
|
||||||
}
|
|
||||||
|
|
||||||
return type_dirs.get(model_type, MODELS_DIR / "other")
|
# Fallback for unknown types
|
||||||
|
paths = get_model_paths()
|
||||||
|
return paths.get("Other", Path.home() / ".local" / "share" / "tensors" / "models" / "other")
|
||||||
|
|
||||||
|
|
||||||
_KB = 1024
|
_KB = 1024
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from tensors.config import (
|
|||||||
ModelType,
|
ModelType,
|
||||||
SortOrder,
|
SortOrder,
|
||||||
get_default_output_path,
|
get_default_output_path,
|
||||||
|
get_model_paths,
|
||||||
load_api_key,
|
load_api_key,
|
||||||
load_config,
|
load_config,
|
||||||
save_config,
|
save_config,
|
||||||
@@ -127,6 +128,46 @@ class TestGetDefaultOutputPath:
|
|||||||
assert get_default_output_path(None) is None
|
assert get_default_output_path(None) is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetModelPaths:
|
||||||
|
"""Tests for get_model_paths function."""
|
||||||
|
|
||||||
|
def test_returns_dict_with_all_types(self) -> None:
|
||||||
|
"""Test that all model types are included."""
|
||||||
|
paths = get_model_paths()
|
||||||
|
assert isinstance(paths, dict)
|
||||||
|
assert "Checkpoint" in paths
|
||||||
|
assert "LORA" in paths
|
||||||
|
assert "LoCon" in paths
|
||||||
|
assert "TextualInversion" in paths
|
||||||
|
assert "VAE" in paths
|
||||||
|
assert "Controlnet" in paths
|
||||||
|
|
||||||
|
def test_config_override(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
"""Test that config.toml paths override defaults."""
|
||||||
|
# Create a config file with custom path
|
||||||
|
config_file = tmp_path / "config.toml"
|
||||||
|
config_file.write_text('[paths]\ncheckpoints = "/custom/checkpoints"\n')
|
||||||
|
monkeypatch.setattr(config, "CONFIG_FILE", config_file)
|
||||||
|
|
||||||
|
paths = get_model_paths()
|
||||||
|
assert paths["Checkpoint"] == Path("/custom/checkpoints")
|
||||||
|
# Other types should still be defaults
|
||||||
|
assert "loras" in str(paths["LORA"])
|
||||||
|
|
||||||
|
def test_get_default_output_path_uses_config(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
"""Test that get_default_output_path respects config overrides."""
|
||||||
|
config_file = tmp_path / "config.toml"
|
||||||
|
config_file.write_text('[paths]\nloras = "/custom/loras"\n')
|
||||||
|
monkeypatch.setattr(config, "CONFIG_FILE", config_file)
|
||||||
|
|
||||||
|
result = get_default_output_path("LORA")
|
||||||
|
assert result == Path("/custom/loras")
|
||||||
|
|
||||||
|
# LoCon should also use the loras path
|
||||||
|
result = get_default_output_path("LoCon")
|
||||||
|
assert result == Path("/custom/loras")
|
||||||
|
|
||||||
|
|
||||||
class TestLoadApiKey:
|
class TestLoadApiKey:
|
||||||
"""Tests for load_api_key function."""
|
"""Tests for load_api_key function."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user