💬 Commit message: Update 2026-01-28 02:46:46, 1 files, 60 lines

📁 Files changed: 1
📝 Lines changed: 60

  • tensors.py
This commit is contained in:
Adam Ladachowski
2026-01-28 02:46:46 +01:00
parent 87547c641d
commit 2b971a54a4
+48 -12
View File
@@ -32,9 +32,24 @@ console = Console()
RC_FILE = Path.home() / ".sftrc" RC_FILE = Path.home() / ".sftrc"
# Default download paths by model type
DEFAULT_PATHS: dict[str, Path] = {
"Checkpoint": Path.home() / ".xm" / "models" / "checkpoints",
"LORA": Path.home() / ".xm" / "models" / "loras",
"LoCon": Path.home() / ".xm" / "models" / "loras",
}
def load_api_key() -> str | None: def load_api_key() -> str | None:
"""Load API key from ~/.sftrc if it exists.""" """Load API key from ~/.sftrc or CIVITAI_API_KEY env var."""
import os
# Check environment variable first
env_key = os.environ.get("CIVITAI_API_KEY")
if env_key:
return env_key
# Fall back to RC file
if RC_FILE.exists(): if RC_FILE.exists():
content = RC_FILE.read_text().strip() content = RC_FILE.read_text().strip()
if content: if content:
@@ -42,6 +57,13 @@ def load_api_key() -> str | None:
return None return None
def get_default_output_path(model_type: str | None) -> Path | None:
"""Get default output path based on model type."""
if model_type and model_type in DEFAULT_PATHS:
return DEFAULT_PATHS[model_type]
return None
CIVITAI_API_BASE = "https://civitai.com/api/v1" CIVITAI_API_BASE = "https://civitai.com/api/v1"
CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models" CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models"
@@ -611,14 +633,6 @@ def _resolve_version_id(
def cmd_download(args: argparse.Namespace) -> int: def cmd_download(args: argparse.Namespace) -> int:
"""Handle the download subcommand.""" """Handle the download subcommand."""
api_key: str | None = args.api_key or load_api_key() api_key: str | None = args.api_key or load_api_key()
output_dir: Path = args.output.resolve()
if not output_dir.exists():
console.print(f"[red]Error: Output directory not found: {output_dir}[/red]")
return 1
if not output_dir.is_dir():
console.print(f"[red]Error: Not a directory: {output_dir}[/red]")
return 1
# Resolve version ID from hash or model ID if needed # Resolve version ID from hash or model ID if needed
version_id = _resolve_version_id( version_id = _resolve_version_id(
@@ -629,7 +643,7 @@ def cmd_download(args: argparse.Namespace) -> int:
console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]") console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]")
return 1 return 1
# Fetch version info to get filename # Fetch version info to get filename and model type
console.print(f"[cyan]Fetching model info for version {version_id}...[/cyan]") console.print(f"[cyan]Fetching model info for version {version_id}...[/cyan]")
version_info = fetch_civitai_model_version(version_id, api_key) version_info = fetch_civitai_model_version(version_id, api_key)
@@ -637,6 +651,28 @@ def cmd_download(args: argparse.Namespace) -> int:
console.print("[red]Error: Could not fetch model version info.[/red]") console.print("[red]Error: Could not fetch model version info.[/red]")
return 1 return 1
# Determine model type for default path
model_type: str | None = version_info.get("model", {}).get("type")
# Determine output directory
if args.output is None:
# Use model type-based default
output_dir = get_default_output_path(model_type)
if output_dir is None:
console.print(f"[red]Error: No default path for model type '{model_type}'. Use --output to specify.[/red]")
return 1
console.print(f"[dim]Using default path for {model_type}: {output_dir}[/dim]")
else:
output_dir = args.output.resolve()
# Create directory if it doesn't exist
if not output_dir.exists():
console.print(f"[cyan]Creating directory: {output_dir}[/cyan]")
output_dir.mkdir(parents=True, exist_ok=True)
elif not output_dir.is_dir():
console.print(f"[red]Error: Not a directory: {output_dir}[/red]")
return 1
# Find primary file or first file # Find primary file or first file
files: list[dict[str, Any]] = version_info.get("files", []) files: list[dict[str, Any]] = version_info.get("files", [])
primary_file = next((f for f in files if f.get("primary")), files[0] if files else None) primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
@@ -762,8 +798,8 @@ def main() -> int:
"--output", "--output",
"-o", "-o",
type=Path, type=Path,
default=Path(), default=None,
help="Output directory (default: current directory)", help="Output directory (default: type-based, e.g. ~/.xm/models/checkpoints for Checkpoint)",
) )
dl_parser.add_argument( dl_parser.add_argument(
"--no-resume", "--no-resume",