diff --git a/tensors.py b/tensors.py index e2e0268..f8de03f 100644 --- a/tensors.py +++ b/tensors.py @@ -32,9 +32,24 @@ console = Console() RC_FILE = Path.home() / ".sftrc" +# Default download paths by model type +DEFAULT_PATHS: dict[str, Path] = { + "Checkpoint": Path.home() / ".xm" / "models" / "checkpoints", + "LORA": Path.home() / ".xm" / "models" / "loras", + "LoCon": Path.home() / ".xm" / "models" / "loras", +} + def load_api_key() -> str | None: - """Load API key from ~/.sftrc if it exists.""" + """Load API key from ~/.sftrc or CIVITAI_API_KEY env var.""" + import os + + # Check environment variable first + env_key = os.environ.get("CIVITAI_API_KEY") + if env_key: + return env_key + + # Fall back to RC file if RC_FILE.exists(): content = RC_FILE.read_text().strip() if content: @@ -42,6 +57,13 @@ def load_api_key() -> str | None: return None +def get_default_output_path(model_type: str | None) -> Path | None: + """Get default output path based on model type.""" + if model_type and model_type in DEFAULT_PATHS: + return DEFAULT_PATHS[model_type] + return None + + CIVITAI_API_BASE = "https://civitai.com/api/v1" CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models" @@ -611,14 +633,6 @@ def _resolve_version_id( def cmd_download(args: argparse.Namespace) -> int: """Handle the download subcommand.""" api_key: str | None = args.api_key or load_api_key() - output_dir: Path = args.output.resolve() - - if not output_dir.exists(): - console.print(f"[red]Error: Output directory not found: {output_dir}[/red]") - return 1 - if not output_dir.is_dir(): - console.print(f"[red]Error: Not a directory: {output_dir}[/red]") - return 1 # Resolve version ID from hash or model ID if needed version_id = _resolve_version_id( @@ -629,7 +643,7 @@ def cmd_download(args: argparse.Namespace) -> int: console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]") return 1 - # Fetch version info to get filename + # Fetch version info to get filename and model type console.print(f"[cyan]Fetching model info for version {version_id}...[/cyan]") version_info = fetch_civitai_model_version(version_id, api_key) @@ -637,6 +651,28 @@ def cmd_download(args: argparse.Namespace) -> int: console.print("[red]Error: Could not fetch model version info.[/red]") return 1 + # Determine model type for default path + model_type: str | None = version_info.get("model", {}).get("type") + + # Determine output directory + if args.output is None: + # Use model type-based default + output_dir = get_default_output_path(model_type) + if output_dir is None: + console.print(f"[red]Error: No default path for model type '{model_type}'. Use --output to specify.[/red]") + return 1 + console.print(f"[dim]Using default path for {model_type}: {output_dir}[/dim]") + else: + output_dir = args.output.resolve() + + # Create directory if it doesn't exist + if not output_dir.exists(): + console.print(f"[cyan]Creating directory: {output_dir}[/cyan]") + output_dir.mkdir(parents=True, exist_ok=True) + elif not output_dir.is_dir(): + console.print(f"[red]Error: Not a directory: {output_dir}[/red]") + return 1 + # Find primary file or first file files: list[dict[str, Any]] = version_info.get("files", []) primary_file = next((f for f in files if f.get("primary")), files[0] if files else None) @@ -762,8 +798,8 @@ def main() -> int: "--output", "-o", type=Path, - default=Path(), - help="Output directory (default: current directory)", + default=None, + help="Output directory (default: type-based, e.g. ~/.xm/models/checkpoints for Checkpoint)", ) dl_parser.add_argument( "--no-resume",