💬 Commit message: Update 2026-01-28 02:46:46, 1 files, 60 lines
📁 Files changed: 1 📝 Lines changed: 60 • tensors.py
This commit is contained in:
+48
-12
@@ -32,9 +32,24 @@ console = Console()
|
|||||||
|
|
||||||
RC_FILE = Path.home() / ".sftrc"
|
RC_FILE = Path.home() / ".sftrc"
|
||||||
|
|
||||||
|
# Default download paths by model type
|
||||||
|
DEFAULT_PATHS: dict[str, Path] = {
|
||||||
|
"Checkpoint": Path.home() / ".xm" / "models" / "checkpoints",
|
||||||
|
"LORA": Path.home() / ".xm" / "models" / "loras",
|
||||||
|
"LoCon": Path.home() / ".xm" / "models" / "loras",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_api_key() -> str | None:
|
def load_api_key() -> str | None:
|
||||||
"""Load API key from ~/.sftrc if it exists."""
|
"""Load API key from ~/.sftrc or CIVITAI_API_KEY env var."""
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Check environment variable first
|
||||||
|
env_key = os.environ.get("CIVITAI_API_KEY")
|
||||||
|
if env_key:
|
||||||
|
return env_key
|
||||||
|
|
||||||
|
# Fall back to RC file
|
||||||
if RC_FILE.exists():
|
if RC_FILE.exists():
|
||||||
content = RC_FILE.read_text().strip()
|
content = RC_FILE.read_text().strip()
|
||||||
if content:
|
if content:
|
||||||
@@ -42,6 +57,13 @@ def load_api_key() -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_output_path(model_type: str | None) -> Path | None:
|
||||||
|
"""Get default output path based on model type."""
|
||||||
|
if model_type and model_type in DEFAULT_PATHS:
|
||||||
|
return DEFAULT_PATHS[model_type]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
CIVITAI_API_BASE = "https://civitai.com/api/v1"
|
CIVITAI_API_BASE = "https://civitai.com/api/v1"
|
||||||
CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models"
|
CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models"
|
||||||
|
|
||||||
@@ -611,14 +633,6 @@ def _resolve_version_id(
|
|||||||
def cmd_download(args: argparse.Namespace) -> int:
|
def cmd_download(args: argparse.Namespace) -> int:
|
||||||
"""Handle the download subcommand."""
|
"""Handle the download subcommand."""
|
||||||
api_key: str | None = args.api_key or load_api_key()
|
api_key: str | None = args.api_key or load_api_key()
|
||||||
output_dir: Path = args.output.resolve()
|
|
||||||
|
|
||||||
if not output_dir.exists():
|
|
||||||
console.print(f"[red]Error: Output directory not found: {output_dir}[/red]")
|
|
||||||
return 1
|
|
||||||
if not output_dir.is_dir():
|
|
||||||
console.print(f"[red]Error: Not a directory: {output_dir}[/red]")
|
|
||||||
return 1
|
|
||||||
|
|
||||||
# Resolve version ID from hash or model ID if needed
|
# Resolve version ID from hash or model ID if needed
|
||||||
version_id = _resolve_version_id(
|
version_id = _resolve_version_id(
|
||||||
@@ -629,7 +643,7 @@ def cmd_download(args: argparse.Namespace) -> int:
|
|||||||
console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]")
|
console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]")
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
# Fetch version info to get filename
|
# Fetch version info to get filename and model type
|
||||||
console.print(f"[cyan]Fetching model info for version {version_id}...[/cyan]")
|
console.print(f"[cyan]Fetching model info for version {version_id}...[/cyan]")
|
||||||
version_info = fetch_civitai_model_version(version_id, api_key)
|
version_info = fetch_civitai_model_version(version_id, api_key)
|
||||||
|
|
||||||
@@ -637,6 +651,28 @@ def cmd_download(args: argparse.Namespace) -> int:
|
|||||||
console.print("[red]Error: Could not fetch model version info.[/red]")
|
console.print("[red]Error: Could not fetch model version info.[/red]")
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
# Determine model type for default path
|
||||||
|
model_type: str | None = version_info.get("model", {}).get("type")
|
||||||
|
|
||||||
|
# Determine output directory
|
||||||
|
if args.output is None:
|
||||||
|
# Use model type-based default
|
||||||
|
output_dir = get_default_output_path(model_type)
|
||||||
|
if output_dir is None:
|
||||||
|
console.print(f"[red]Error: No default path for model type '{model_type}'. Use --output to specify.[/red]")
|
||||||
|
return 1
|
||||||
|
console.print(f"[dim]Using default path for {model_type}: {output_dir}[/dim]")
|
||||||
|
else:
|
||||||
|
output_dir = args.output.resolve()
|
||||||
|
|
||||||
|
# Create directory if it doesn't exist
|
||||||
|
if not output_dir.exists():
|
||||||
|
console.print(f"[cyan]Creating directory: {output_dir}[/cyan]")
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
elif not output_dir.is_dir():
|
||||||
|
console.print(f"[red]Error: Not a directory: {output_dir}[/red]")
|
||||||
|
return 1
|
||||||
|
|
||||||
# Find primary file or first file
|
# Find primary file or first file
|
||||||
files: list[dict[str, Any]] = version_info.get("files", [])
|
files: list[dict[str, Any]] = version_info.get("files", [])
|
||||||
primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
|
primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
|
||||||
@@ -762,8 +798,8 @@ def main() -> int:
|
|||||||
"--output",
|
"--output",
|
||||||
"-o",
|
"-o",
|
||||||
type=Path,
|
type=Path,
|
||||||
default=Path(),
|
default=None,
|
||||||
help="Output directory (default: current directory)",
|
help="Output directory (default: type-based, e.g. ~/.xm/models/checkpoints for Checkpoint)",
|
||||||
)
|
)
|
||||||
dl_parser.add_argument(
|
dl_parser.add_argument(
|
||||||
"--no-resume",
|
"--no-resume",
|
||||||
|
|||||||
Reference in New Issue
Block a user