diff --git a/tensors/cli.py b/tensors/cli.py index af0c9cf..b013366 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -29,6 +29,7 @@ from tensors.config import ( Provider, SortOrder, get_default_output_path, + get_model_paths, load_api_key, load_config, save_config, @@ -490,6 +491,7 @@ def _display_download_info( def config( show: Annotated[bool, typer.Option("--show", help="Show current config")] = False, set_key: Annotated[str | None, typer.Option("--set-key", help="Set CivitAI API key")] = None, + set_path: Annotated[str | None, typer.Option("--set-path", help="Set model path (TYPE=PATH)")] = None, ) -> None: """Manage configuration.""" if set_key: @@ -501,7 +503,29 @@ def config( console.print(f"[green]API key saved to {CONFIG_FILE}[/green]") return - if show or (not set_key): + if set_path: + # Parse TYPE=PATH format + if "=" not in set_path: + console.print("[red]Error: Use format TYPE=PATH (e.g., checkpoints=/opt/models/checkpoints)[/red]") + raise typer.Exit(1) + + path_type, path_value = set_path.split("=", 1) + path_type = path_type.lower().strip() + valid_types = ["checkpoints", "loras", "embeddings", "vae", "controlnet", "upscalers", "other"] + + if path_type not in valid_types: + console.print(f"[red]Error: Invalid type '{path_type}'. Valid: {', '.join(valid_types)}[/red]") + raise typer.Exit(1) + + cfg = load_config() + if "paths" not in cfg: + cfg["paths"] = {} + cfg["paths"][path_type] = path_value.strip() + save_config(cfg) + console.print(f"[green]Path for {path_type} set to: {path_value}[/green]") + return + + if show or (not set_key and not set_path): console.print(f"[bold]Config file:[/bold] {CONFIG_FILE}") console.print(f"[bold]Config exists:[/bold] {CONFIG_FILE.exists()}") @@ -512,8 +536,31 @@ def config( else: console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]") + console.print() + console.print("[bold]Model paths:[/bold]") + paths = get_model_paths() + # Group by unique paths to show cleanly + shown_paths: dict[str, list[str]] = {} + for model_type, path in paths.items(): + path_str = str(path) + if path_str not in shown_paths: + shown_paths[path_str] = [] + shown_paths[path_str].append(model_type) + + cfg = load_config() + configured_paths = cfg.get("paths", {}) + + for path_str, types in sorted(shown_paths.items(), key=lambda x: x[0]): + is_custom = any( + path_str == configured_paths.get(k) + for k in ["checkpoints", "loras", "embeddings", "vae", "controlnet", "upscalers", "other"] + ) + marker = " [green](custom)[/green]" if is_custom else " [dim](default)[/dim]" + console.print(f" {', '.join(sorted(types))}: {path_str}{marker}") + console.print() console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]") + console.print("[dim]Set paths with: tsr config --set-path checkpoints=/path/to/models[/dim]") @app.command() diff --git a/tensors/config.py b/tensors/config.py index ea35c8f..8f0b8c9 100644 --- a/tensors/config.py +++ b/tensors/config.py @@ -25,11 +25,16 @@ GALLERY_DIR = DATA_DIR / "gallery" # Legacy config for migration LEGACY_RC_FILE = Path.home() / ".sftrc" -# Default download paths by model type +# Default download paths by model type (can be overridden in config.toml [paths]) DEFAULT_PATHS: dict[str, Path] = { "Checkpoint": MODELS_DIR / "checkpoints", "LORA": MODELS_DIR / "loras", "LoCon": MODELS_DIR / "loras", + "TextualInversion": MODELS_DIR / "embeddings", + "VAE": MODELS_DIR / "vae", + "Controlnet": MODELS_DIR / "controlnet", + "Upscaler": MODELS_DIR / "upscalers", + "Other": MODELS_DIR / "other", } CIVITAI_API_BASE = "https://civitai.com/api/v1" @@ -274,11 +279,59 @@ def load_api_key() -> str | None: return None +def get_model_paths() -> dict[str, Path]: + """Get model paths from config, with defaults. + + Config format in config.toml: + [paths] + checkpoints = "/opt/comfyui/models/checkpoints" + loras = "/opt/comfyui/models/loras" + embeddings = "/opt/comfyui/models/embeddings" + vae = "/opt/comfyui/models/vae" + controlnet = "/opt/comfyui/models/controlnet" + upscalers = "/opt/comfyui/models/upscale_models" + other = "/opt/comfyui/models/other" + + Returns dict mapping CivitAI model types to paths. + """ + config = load_config() + paths_config = config.get("paths", {}) + + # Map config keys to CivitAI model types + key_to_types = { + "checkpoints": ["Checkpoint"], + "loras": ["LORA", "LoCon"], + "embeddings": ["TextualInversion"], + "vae": ["VAE"], + "controlnet": ["Controlnet"], + "upscalers": ["Upscaler"], + "other": ["Other"], + } + + # Start with defaults + result = dict(DEFAULT_PATHS) + + # Override with config values + if isinstance(paths_config, dict): + for key, types in key_to_types.items(): + if key in paths_config: + path = Path(paths_config[key]) + for model_type in types: + result[model_type] = path + + return result + + def get_default_output_path(model_type: str | None) -> Path | None: - """Get default output path based on model type.""" - if model_type and model_type in DEFAULT_PATHS: - return DEFAULT_PATHS[model_type] - return None + """Get default output path based on model type. + + Checks config.toml [paths] section first, falls back to defaults. + """ + if not model_type: + return None + + paths = get_model_paths() + return paths.get(model_type) # ============================================================================ diff --git a/tensors/server/download_routes.py b/tensors/server/download_routes.py index aba46f9..bea0f3a 100644 --- a/tensors/server/download_routes.py +++ b/tensors/server/download_routes.py @@ -10,7 +10,7 @@ from fastapi import APIRouter, BackgroundTasks, HTTPException from pydantic import BaseModel as PydanticBaseModel from tensors.api import download_model_with_progress, fetch_civitai_by_hash, fetch_civitai_model, fetch_civitai_model_version -from tensors.config import MODELS_DIR, load_api_key +from tensors.config import get_default_output_path, get_model_paths, load_api_key from tensors.db import Database logger = logging.getLogger(__name__) @@ -75,18 +75,14 @@ def _get_output_dir(version_info: dict[str, Any], override: str | None) -> Path: return Path(override) model_type = version_info.get("model", {}).get("type", "Checkpoint") + path = get_default_output_path(model_type) - # Map type to directory - type_dirs = { - "Checkpoint": MODELS_DIR / "checkpoints", - "LORA": MODELS_DIR / "loras", - "LoCon": MODELS_DIR / "loras", - "TextualInversion": MODELS_DIR / "embeddings", - "VAE": MODELS_DIR / "vae", - "Controlnet": MODELS_DIR / "controlnet", - } + if path: + return path - return type_dirs.get(model_type, MODELS_DIR / "other") + # Fallback for unknown types + paths = get_model_paths() + return paths.get("Other", Path.home() / ".local" / "share" / "tensors" / "models" / "other") _KB = 1024 diff --git a/tests/test_tensors.py b/tests/test_tensors.py index 7ffd60d..621647c 100644 --- a/tests/test_tensors.py +++ b/tests/test_tensors.py @@ -26,6 +26,7 @@ from tensors.config import ( ModelType, SortOrder, get_default_output_path, + get_model_paths, load_api_key, load_config, save_config, @@ -127,6 +128,46 @@ class TestGetDefaultOutputPath: assert get_default_output_path(None) is None +class TestGetModelPaths: + """Tests for get_model_paths function.""" + + def test_returns_dict_with_all_types(self) -> None: + """Test that all model types are included.""" + paths = get_model_paths() + assert isinstance(paths, dict) + assert "Checkpoint" in paths + assert "LORA" in paths + assert "LoCon" in paths + assert "TextualInversion" in paths + assert "VAE" in paths + assert "Controlnet" in paths + + def test_config_override(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that config.toml paths override defaults.""" + # Create a config file with custom path + config_file = tmp_path / "config.toml" + config_file.write_text('[paths]\ncheckpoints = "/custom/checkpoints"\n') + monkeypatch.setattr(config, "CONFIG_FILE", config_file) + + paths = get_model_paths() + assert paths["Checkpoint"] == Path("/custom/checkpoints") + # Other types should still be defaults + assert "loras" in str(paths["LORA"]) + + def test_get_default_output_path_uses_config(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that get_default_output_path respects config overrides.""" + config_file = tmp_path / "config.toml" + config_file.write_text('[paths]\nloras = "/custom/loras"\n') + monkeypatch.setattr(config, "CONFIG_FILE", config_file) + + result = get_default_output_path("LORA") + assert result == Path("/custom/loras") + + # LoCon should also use the loras path + result = get_default_output_path("LoCon") + assert result == Path("/custom/loras") + + class TestLoadApiKey: """Tests for load_api_key function."""