diff --git a/.coverage b/.coverage index 78fc8c4..e98daa6 100644 Binary files a/.coverage and b/.coverage differ diff --git a/.gitignore b/.gitignore index 505a3b1..550e762 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,8 @@ wheels/ # Virtual environments .venv + +# Coverage +.coverage +htmlcov/ +coverage.xml diff --git a/tensors.py b/tensors.py index 29402a6..5fa4df8 100644 --- a/tensors.py +++ b/tensors.py @@ -46,18 +46,24 @@ console = Console() # Configuration # ============================================================================ -# XDG Base Directory spec: ~/.config/tensors/config.toml +# XDG Base Directory spec +# Config: ~/.config/tensors/config.toml +# Data: ~/.local/share/tensors/models/, ~/.local/share/tensors/metadata/ CONFIG_DIR = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config")) / "tensors" CONFIG_FILE = CONFIG_DIR / "config.toml" +DATA_DIR = Path(os.environ.get("XDG_DATA_HOME", Path.home() / ".local" / "share")) / "tensors" +MODELS_DIR = DATA_DIR / "models" +METADATA_DIR = DATA_DIR / "metadata" + # Legacy config for migration LEGACY_RC_FILE = Path.home() / ".sftrc" # Default download paths by model type DEFAULT_PATHS: dict[str, Path] = { - "Checkpoint": Path.home() / ".xm" / "models" / "checkpoints", - "LORA": Path.home() / ".xm" / "models" / "loras", - "LoCon": Path.home() / ".xm" / "models" / "loras", + "Checkpoint": MODELS_DIR / "checkpoints", + "LORA": MODELS_DIR / "loras", + "LoCon": MODELS_DIR / "loras", } CIVITAI_API_BASE = "https://civitai.com/api/v1" @@ -72,30 +78,62 @@ CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models" class ModelType(str, Enum): """CivitAI model types.""" - checkpoint = "Checkpoint" - lora = "LORA" - embedding = "TextualInversion" - vae = "VAE" - controlnet = "Controlnet" - locon = "LoCon" + checkpoint = "checkpoint" + lora = "lora" + embedding = "embedding" + vae = "vae" + controlnet = "controlnet" + locon = "locon" + + def to_api(self) -> str: + """Convert to CivitAI API value.""" + mapping = { + "checkpoint": "Checkpoint", + "lora": "LORA", + "embedding": "TextualInversion", + "vae": "VAE", + "controlnet": "Controlnet", + "locon": "LoCon", + } + return mapping[self.value] class BaseModel(str, Enum): """Common base models.""" - sd15 = "SD 1.5" - sdxl = "SDXL 1.0" - pony = "Pony" - flux = "Flux.1 D" - illustrious = "Illustrious" + sd15 = "sd15" + sdxl = "sdxl" + pony = "pony" + flux = "flux" + illustrious = "illustrious" + + def to_api(self) -> str: + """Convert to CivitAI API value.""" + mapping = { + "sd15": "SD 1.5", + "sdxl": "SDXL 1.0", + "pony": "Pony", + "flux": "Flux.1 D", + "illustrious": "Illustrious", + } + return mapping[self.value] class SortOrder(str, Enum): """Sort options for search.""" - downloads = "Most Downloaded" - rating = "Highest Rated" - newest = "Newest" + downloads = "downloads" + rating = "rating" + newest = "newest" + + def to_api(self) -> str: + """Convert to CivitAI API value.""" + mapping = { + "downloads": "Most Downloaded", + "rating": "Highest Rated", + "newest": "Newest", + } + return mapping[self.value] # ============================================================================ @@ -345,12 +383,12 @@ def search_civitai( params["query"] = query if model_type: - params["types"] = model_type.value + params["types"] = model_type.to_api() if base_model: - params["baseModels"] = base_model.value + params["baseModels"] = base_model.to_api() - params["sort"] = sort.value + params["sort"] = sort.to_api() # Request more if we need client-side filtering if query and has_filters: @@ -760,8 +798,8 @@ def info( raise typer.Exit(1) base_name = get_base_name(file_path) - json_path = output_dir / f"{base_name}-xm.json" - sha_path = output_dir / f"{base_name}-xm.sha256" + json_path = output_dir / f"{base_name}.json" + sha_path = output_dir / f"{base_name}.sha256" output = { "file": str(file_path),