diff --git a/pyproject.toml b/pyproject.toml index 28d4684..7ac9a9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ select = [ "RUF", # ruff-specific ] ignore = [ + "PLR0911", # too many return statements "PLR0913", # too many arguments "PLR2004", # magic value comparison ] diff --git a/tensors.py b/tensors.py index d2e1105..8af4c1d 100644 --- a/tensors.py +++ b/tensors.py @@ -554,9 +554,12 @@ def cmd_info(args: argparse.Namespace) -> int: def _resolve_version_id( - version_id: int | None, sha256_hash: str | None, api_key: str | None + version_id: int | None, + sha256_hash: str | None, + model_id: int | None, + api_key: str | None, ) -> int | None: - """Resolve version ID from hash if needed.""" + """Resolve version ID from hash or model ID if needed.""" if version_id: return version_id if sha256_hash: @@ -571,6 +574,24 @@ def _resolve_version_id( else: console.print("[red]Error: Could not determine version ID from CivitAI response.[/red]") return vid + if model_id: + console.print(f"[cyan]Looking up model {model_id}...[/cyan]") + model_data = fetch_civitai_model(model_id, api_key) + if not model_data: + console.print(f"[red]Error: Model {model_id} not found on CivitAI.[/red]") + return None + versions: list[dict[str, Any]] = model_data.get("modelVersions", []) + if not versions: + console.print("[red]Error: Model has no versions.[/red]") + return None + # First version is the latest + latest = versions[0] + vid = latest.get("id") + if vid: + console.print( + f"[green]Found latest version:[/green] {latest.get('name', 'N/A')} (ID: {vid})" + ) + return vid return None @@ -586,11 +607,13 @@ def cmd_download(args: argparse.Namespace) -> int: console.print(f"[red]Error: Not a directory: {output_dir}[/red]") return 1 - # Resolve version ID from hash if needed - version_id = _resolve_version_id(args.version_id, args.hash, api_key) + # Resolve version ID from hash or model ID if needed + version_id = _resolve_version_id( + args.version_id, args.hash, getattr(args, "model_id", None), api_key + ) if not version_id: - if not args.version_id and not args.hash: - console.print("[red]Error: Must specify --version-id or --hash[/red]") + if not args.version_id and not args.hash and not getattr(args, "model_id", None): + console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]") return 1 # Fetch version info to get filename @@ -704,6 +727,12 @@ def main() -> int: type=int, help="CivitAI model version ID to download", ) + dl_parser.add_argument( + "--model-id", + "-m", + type=int, + help="CivitAI model ID (downloads latest version)", + ) dl_parser.add_argument( "--hash", "-H",