From e305d776a2e1323a7b2e112da3ff6a7ad6acd07e Mon Sep 17 00:00:00 2001 From: Adam Ladachowski Date: Mon, 26 Jan 2026 22:23:24 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=AC=20Commit=20message:=20Update=20202?= =?UTF-8?q?6-01-26=2022:23:24,=201=20files,=20156=20lines?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 📁 Files changed: 1 📝 Lines changed: 156 • tensors.py --- tensors.py | 156 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 156 insertions(+) diff --git a/tensors.py b/tensors.py index 3e233ad..d2e1105 100644 --- a/tensors.py +++ b/tensors.py @@ -116,6 +116,36 @@ def fetch_civitai_model_version( return None +def fetch_civitai_model(model_id: int, api_key: str | None = None) -> dict[str, Any] | None: + """Fetch model information from CivitAI by model ID.""" + url = f"{CIVITAI_API_BASE}/models/{model_id}" + headers: dict[str, str] = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + transient=True, + ) as progress: + progress.add_task("[cyan]Fetching model from CivitAI...", total=None) + + try: + response = httpx.get(url, headers=headers, timeout=30.0) + if response.status_code == 404: + return None + response.raise_for_status() + result: dict[str, Any] = response.json() + return result + except httpx.HTTPStatusError as e: + console.print(f"[red]API error: {e.response.status_code}[/red]") + return None + except httpx.RequestError as e: + console.print(f"[red]Request error: {e}[/red]") + return None + + def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None = None) -> dict[str, Any] | None: """Fetch model information from CivitAI by SHA256 hash.""" url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}" @@ -327,6 +357,89 @@ def _display_civitai_data(civitai_data: dict[str, Any] | None) -> None: ) +def _display_model_info(model_data: dict[str, Any]) -> None: + """Display full CivitAI model information.""" + # Main model info table + model_table = Table(title="Model Information", show_header=True, header_style="bold magenta") + model_table.add_column("Property", style="cyan") + model_table.add_column("Value", style="green", max_width=80) + + model_table.add_row("ID", str(model_data.get("id", "N/A"))) + model_table.add_row("Name", str(model_data.get("name", "N/A"))) + model_table.add_row("Type", str(model_data.get("type", "N/A"))) + model_table.add_row("NSFW", str(model_data.get("nsfw", False))) + + # Creator info + creator = model_data.get("creator", {}) + if creator: + model_table.add_row("Creator", str(creator.get("username", "N/A"))) + + # Tags + tags: list[str] = model_data.get("tags", []) + if tags: + model_table.add_row("Tags", ", ".join(tags[:10]) + ("..." if len(tags) > 10 else "")) + + # Stats + stats: dict[str, Any] = model_data.get("stats", {}) + if stats: + model_table.add_row("Downloads", f"{stats.get('downloadCount', 0):,}") + model_table.add_row("Favorites", f"{stats.get('favoriteCount', 0):,}") + model_table.add_row( + "Rating", f"{stats.get('rating', 0):.1f} ({stats.get('ratingCount', 0)} ratings)" + ) + + # Mode (archived/taken down) + mode = model_data.get("mode") + if mode: + model_table.add_row("Status", str(mode)) + + console.print() + console.print(model_table) + + # Versions table + versions: list[dict[str, Any]] = model_data.get("modelVersions", []) + if versions: + ver_table = Table(title="Model Versions", show_header=True, header_style="bold magenta") + ver_table.add_column("ID", style="cyan") + ver_table.add_column("Name", style="green") + ver_table.add_column("Base Model", style="yellow") + ver_table.add_column("Created", style="blue") + ver_table.add_column("Primary File", style="white") + + for ver in versions: + files: list[dict[str, Any]] = ver.get("files", []) + primary_file = next((f for f in files if f.get("primary")), files[0] if files else None) + file_info = "" + if primary_file: + size_kb = primary_file.get("sizeKB", 0) + size_str = ( + f"{size_kb / 1024:.0f} MB" + if size_kb < 1024 * 1024 + else f"{size_kb / 1024 / 1024:.1f} GB" + ) + file_info = f"{primary_file.get('name', 'N/A')} ({size_str})" + + created = str(ver.get("createdAt", "N/A"))[:10] # Just date portion + ver_table.add_row( + str(ver.get("id", "N/A")), + str(ver.get("name", "N/A")), + str(ver.get("baseModel", "N/A")), + created, + file_info, + ) + + console.print() + console.print(ver_table) + + # Model page link + model_id = model_data.get("id") + if model_id: + console.print() + console.print( + f"[bold blue]View on CivitAI:[/bold blue] https://civitai.com/models/{model_id}" + ) + + def display_results( file_path: Path, local_metadata: dict[str, Any], @@ -517,6 +630,25 @@ def cmd_download(args: argparse.Namespace) -> int: return 0 if success else 1 +def cmd_get(args: argparse.Namespace) -> int: + """Handle the get subcommand - fetch model info by ID.""" + model_id: int = args.model_id + api_key: str | None = args.api_key + + model_data = fetch_civitai_model(model_id, api_key) + + if not model_data: + console.print(f"[red]Error: Model {model_id} not found on CivitAI.[/red]") + return 1 + + if args.json_output: + console.print_json(data=model_data) + else: + _display_model_info(model_data) + + return 0 + + def main() -> int: """Main entry point.""" parser = argparse.ArgumentParser( @@ -597,6 +729,30 @@ def main() -> int: ) dl_parser.set_defaults(func=cmd_download) + # Get command + get_parser = subparsers.add_parser( + "get", + help="Fetch model information from CivitAI by model ID", + ) + get_parser.add_argument( + "model_id", + type=int, + help="CivitAI model ID", + ) + get_parser.add_argument( + "--api-key", + type=str, + default=None, + help="CivitAI API key for authenticated requests", + ) + get_parser.add_argument( + "--json", + action="store_true", + dest="json_output", + help="Output results as JSON", + ) + get_parser.set_defaults(func=cmd_get) + # Parse and handle default command args = parser.parse_args()