diff --git a/.coverage b/.coverage index 9bb1222..1d9f032 100644 Binary files a/.coverage and b/.coverage differ diff --git a/pyproject.toml b/pyproject.toml index 7801ae6..4e9aff6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ dependencies = [ "safetensors>=0.4.0", "httpx>=0.27.0", "rich>=13.0.0", + "typer>=0.15.0", ] [project.scripts] diff --git a/tensors.py b/tensors.py index 5644669..9c148c8 100644 --- a/tensors.py +++ b/tensors.py @@ -12,6 +12,7 @@ import os import re import struct import sys +import tomllib from pathlib import Path from typing import Any @@ -31,7 +32,12 @@ from rich.table import Table console = Console() -RC_FILE = Path.home() / ".sftrc" +# XDG Base Directory spec: ~/.config/tensors/config.toml +CONFIG_DIR = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config")) / "tensors" +CONFIG_FILE = CONFIG_DIR / "config.toml" + +# Legacy config for migration +LEGACY_RC_FILE = Path.home() / ".sftrc" # Default download paths by model type DEFAULT_PATHS: dict[str, Path] = { @@ -41,16 +47,54 @@ DEFAULT_PATHS: dict[str, Path] = { } +def load_config() -> dict[str, Any]: + """Load configuration from TOML config file.""" + if CONFIG_FILE.exists(): + with CONFIG_FILE.open("rb") as f: + return tomllib.load(f) + return {} + + +def save_config(config: dict[str, Any]) -> None: + """Save configuration to TOML config file.""" + CONFIG_DIR.mkdir(parents=True, exist_ok=True) + + lines: list[str] = [] + for key, value in config.items(): + if isinstance(value, dict): + lines.append(f"[{key}]") + for k, v in value.items(): + if isinstance(v, str): + lines.append(f'{k} = "{v}"') + else: + lines.append(f"{k} = {v}") + lines.append("") + elif isinstance(value, str): + lines.append(f'{key} = "{value}"') + else: + lines.append(f"{key} = {value}") + + CONFIG_FILE.write_text("\n".join(lines) + "\n") + + def load_api_key() -> str | None: - """Load API key from ~/.sftrc or CIVITAI_API_KEY env var.""" + """Load API key from config file or CIVITAI_API_KEY env var.""" # Check environment variable first env_key = os.environ.get("CIVITAI_API_KEY") if env_key: return env_key - # Fall back to RC file - if RC_FILE.exists(): - content = RC_FILE.read_text().strip() + # Check TOML config file + config = load_config() + api_section = config.get("api", {}) + if isinstance(api_section, dict): + key = api_section.get("civitai_key") + if key: + return str(key) + + # Fall back to legacy RC file for migration + if LEGACY_RC_FILE.exists(): + content = LEGACY_RC_FILE.read_text().strip() if content: return content return None diff --git a/tests/test_tensors.py b/tests/test_tensors.py index 2ab988d..9672d6c 100644 --- a/tests/test_tensors.py +++ b/tests/test_tensors.py @@ -110,6 +110,29 @@ class TestLoadApiKey: def test_returns_none_if_no_key(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: """Test that None is returned when no key is available.""" monkeypatch.delenv("CIVITAI_API_KEY", raising=False) - # Temporarily point RC_FILE to nonexistent file - monkeypatch.setattr(tensors, "RC_FILE", tmp_path / "nonexistent") + # Point config and legacy files to nonexistent paths + monkeypatch.setattr(tensors, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml") + monkeypatch.setattr(tensors, "LEGACY_RC_FILE", tmp_path / "nonexistent") assert load_api_key() is None + + def test_returns_key_from_config_file( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """Test that key is loaded from TOML config file.""" + monkeypatch.delenv("CIVITAI_API_KEY", raising=False) + config_file = tmp_path / "config.toml" + config_file.write_text('[api]\ncivitai_key = "key-from-config"\n') + monkeypatch.setattr(tensors, "CONFIG_FILE", config_file) + monkeypatch.setattr(tensors, "LEGACY_RC_FILE", tmp_path / "nonexistent") + assert load_api_key() == "key-from-config" + + def test_returns_key_from_legacy_file( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """Test that key is loaded from legacy RC file when no config exists.""" + monkeypatch.delenv("CIVITAI_API_KEY", raising=False) + legacy_file = tmp_path / ".sftrc" + legacy_file.write_text("legacy-key") + monkeypatch.setattr(tensors, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml") + monkeypatch.setattr(tensors, "LEGACY_RC_FILE", legacy_file) + assert load_api_key() == "legacy-key" diff --git a/uv.lock b/uv.lock index 05f12cd..dcfac2c 100644 --- a/uv.lock +++ b/uv.lock @@ -33,6 +33,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/db/3c/33bac158f8ab7f89b2e59426d5fe2e4f63f7ed25df84c036890172b412b5/cfgv-3.5.0-py2.py3-none-any.whl", hash = "sha256:a8dc6b26ad22ff227d2634a65cb388215ce6cc96bbcc5cfde7641ae87e8dacc0", size = 7445, upload-time = "2025-11-19T20:55:50.744Z" }, ] +[[package]] +name = "click" +version = "8.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3d/fa/656b739db8587d7b5dfa22e22ed02566950fbfbcdc20311993483657a5c0/click-8.3.1.tar.gz", hash = "sha256:12ff4785d337a1bb490bb7e9c2b1ee5da3112e94a8622f26a6c77f5d2fc6842a", size = 295065, upload-time = "2025-11-15T20:45:42.706Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl", hash = "sha256:981153a64e25f12d547d3426c367a4857371575ee7ad18df2a6183ab0545b2a6", size = 108274, upload-time = "2025-11-15T20:45:41.139Z" }, +] + [[package]] name = "colorama" version = "0.4.6" @@ -539,6 +551,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5d/e6/ec8471c8072382cb91233ba7267fd931219753bb43814cbc71757bfd4dab/safetensors-0.7.0-cp38-abi3-win_amd64.whl", hash = "sha256:d1239932053f56f3456f32eb9625590cc7582e905021f94636202a864d470755", size = 341380, upload-time = "2025-11-19T15:18:44.427Z" }, ] +[[package]] +name = "shellingham" +version = "1.5.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310, upload-time = "2023-10-24T04:13:40.426Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" }, +] + [[package]] name = "tensors" version = "0.1.0" @@ -547,6 +568,7 @@ dependencies = [ { name = "httpx" }, { name = "rich" }, { name = "safetensors" }, + { name = "typer" }, ] [package.dev-dependencies] @@ -564,6 +586,7 @@ requires-dist = [ { name = "httpx", specifier = ">=0.27.0" }, { name = "rich", specifier = ">=13.0.0" }, { name = "safetensors", specifier = ">=0.4.0" }, + { name = "typer", specifier = ">=0.15.0" }, ] [package.metadata.requires-dev] @@ -576,6 +599,21 @@ dev = [ { name = "ruff", specifier = ">=0.9.0" }, ] +[[package]] +name = "typer" +version = "0.21.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "rich" }, + { name = "shellingham" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/36/bf/8825b5929afd84d0dabd606c67cd57b8388cb3ec385f7ef19c5cc2202069/typer-0.21.1.tar.gz", hash = "sha256:ea835607cd752343b6b2b7ce676893e5a0324082268b48f27aa058bdb7d2145d", size = 110371, upload-time = "2026-01-06T11:21:10.989Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/1d/d9257dd49ff2ca23ea5f132edf1281a0c4f9de8a762b9ae399b670a59235/typer-0.21.1-py3-none-any.whl", hash = "sha256:7985e89081c636b88d172c2ee0cfe33c253160994d47bdfdc302defd7d1f1d01", size = 47381, upload-time = "2026-01-06T11:21:09.824Z" }, +] + [[package]] name = "typing-extensions" version = "4.15.0"