💬 Commit message: Update 2026-01-26 22:23:24, 1 files, 156 lines

📁 Files changed: 1
📝 Lines changed: 156

  • tensors.py
This commit is contained in:
Adam Ladachowski
2026-01-26 22:23:24 +01:00
parent 0ee55a9db5
commit e305d776a2
+156
View File
@@ -116,6 +116,36 @@ def fetch_civitai_model_version(
return None return None
def fetch_civitai_model(model_id: int, api_key: str | None = None) -> dict[str, Any] | None:
"""Fetch model information from CivitAI by model ID."""
url = f"{CIVITAI_API_BASE}/models/{model_id}"
headers: dict[str, str] = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console,
transient=True,
) as progress:
progress.add_task("[cyan]Fetching model from CivitAI...", total=None)
try:
response = httpx.get(url, headers=headers, timeout=30.0)
if response.status_code == 404:
return None
response.raise_for_status()
result: dict[str, Any] = response.json()
return result
except httpx.HTTPStatusError as e:
console.print(f"[red]API error: {e.response.status_code}[/red]")
return None
except httpx.RequestError as e:
console.print(f"[red]Request error: {e}[/red]")
return None
def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None = None) -> dict[str, Any] | None: def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None = None) -> dict[str, Any] | None:
"""Fetch model information from CivitAI by SHA256 hash.""" """Fetch model information from CivitAI by SHA256 hash."""
url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}" url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}"
@@ -327,6 +357,89 @@ def _display_civitai_data(civitai_data: dict[str, Any] | None) -> None:
) )
def _display_model_info(model_data: dict[str, Any]) -> None:
"""Display full CivitAI model information."""
# Main model info table
model_table = Table(title="Model Information", show_header=True, header_style="bold magenta")
model_table.add_column("Property", style="cyan")
model_table.add_column("Value", style="green", max_width=80)
model_table.add_row("ID", str(model_data.get("id", "N/A")))
model_table.add_row("Name", str(model_data.get("name", "N/A")))
model_table.add_row("Type", str(model_data.get("type", "N/A")))
model_table.add_row("NSFW", str(model_data.get("nsfw", False)))
# Creator info
creator = model_data.get("creator", {})
if creator:
model_table.add_row("Creator", str(creator.get("username", "N/A")))
# Tags
tags: list[str] = model_data.get("tags", [])
if tags:
model_table.add_row("Tags", ", ".join(tags[:10]) + ("..." if len(tags) > 10 else ""))
# Stats
stats: dict[str, Any] = model_data.get("stats", {})
if stats:
model_table.add_row("Downloads", f"{stats.get('downloadCount', 0):,}")
model_table.add_row("Favorites", f"{stats.get('favoriteCount', 0):,}")
model_table.add_row(
"Rating", f"{stats.get('rating', 0):.1f} ({stats.get('ratingCount', 0)} ratings)"
)
# Mode (archived/taken down)
mode = model_data.get("mode")
if mode:
model_table.add_row("Status", str(mode))
console.print()
console.print(model_table)
# Versions table
versions: list[dict[str, Any]] = model_data.get("modelVersions", [])
if versions:
ver_table = Table(title="Model Versions", show_header=True, header_style="bold magenta")
ver_table.add_column("ID", style="cyan")
ver_table.add_column("Name", style="green")
ver_table.add_column("Base Model", style="yellow")
ver_table.add_column("Created", style="blue")
ver_table.add_column("Primary File", style="white")
for ver in versions:
files: list[dict[str, Any]] = ver.get("files", [])
primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
file_info = ""
if primary_file:
size_kb = primary_file.get("sizeKB", 0)
size_str = (
f"{size_kb / 1024:.0f} MB"
if size_kb < 1024 * 1024
else f"{size_kb / 1024 / 1024:.1f} GB"
)
file_info = f"{primary_file.get('name', 'N/A')} ({size_str})"
created = str(ver.get("createdAt", "N/A"))[:10] # Just date portion
ver_table.add_row(
str(ver.get("id", "N/A")),
str(ver.get("name", "N/A")),
str(ver.get("baseModel", "N/A")),
created,
file_info,
)
console.print()
console.print(ver_table)
# Model page link
model_id = model_data.get("id")
if model_id:
console.print()
console.print(
f"[bold blue]View on CivitAI:[/bold blue] https://civitai.com/models/{model_id}"
)
def display_results( def display_results(
file_path: Path, file_path: Path,
local_metadata: dict[str, Any], local_metadata: dict[str, Any],
@@ -517,6 +630,25 @@ def cmd_download(args: argparse.Namespace) -> int:
return 0 if success else 1 return 0 if success else 1
def cmd_get(args: argparse.Namespace) -> int:
"""Handle the get subcommand - fetch model info by ID."""
model_id: int = args.model_id
api_key: str | None = args.api_key
model_data = fetch_civitai_model(model_id, api_key)
if not model_data:
console.print(f"[red]Error: Model {model_id} not found on CivitAI.[/red]")
return 1
if args.json_output:
console.print_json(data=model_data)
else:
_display_model_info(model_data)
return 0
def main() -> int: def main() -> int:
"""Main entry point.""" """Main entry point."""
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@@ -597,6 +729,30 @@ def main() -> int:
) )
dl_parser.set_defaults(func=cmd_download) dl_parser.set_defaults(func=cmd_download)
# Get command
get_parser = subparsers.add_parser(
"get",
help="Fetch model information from CivitAI by model ID",
)
get_parser.add_argument(
"model_id",
type=int,
help="CivitAI model ID",
)
get_parser.add_argument(
"--api-key",
type=str,
default=None,
help="CivitAI API key for authenticated requests",
)
get_parser.add_argument(
"--json",
action="store_true",
dest="json_output",
help="Output results as JSON",
)
get_parser.set_defaults(func=cmd_get)
# Parse and handle default command # Parse and handle default command
args = parser.parse_args() args = parser.parse_args()