diff --git a/.coverage b/.coverage index 1d6cb60..54a3d72 100644 Binary files a/.coverage and b/.coverage differ diff --git a/pyproject.toml b/pyproject.toml index a96bf2c..e831b43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,9 +52,7 @@ select = [ "RUF", # ruff-specific ] ignore = [ - "PLR2004", # Magic values - acceptable in CLI apps "PLR0913", # Too many arguments - CLI commands need many options - "PLR0911", # Too many return statements ] [tool.ruff.lint.isort] diff --git a/tensors/api.py b/tensors/api.py index 06d72a0..3d49e56 100644 --- a/tensors/api.py +++ b/tensors/api.py @@ -3,6 +3,7 @@ from __future__ import annotations import re +from http import HTTPStatus from typing import TYPE_CHECKING, Any if TYPE_CHECKING: @@ -40,7 +41,7 @@ def fetch_civitai_model_version(version_id: int, api_key: str | None, console: C try: response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) - if response.status_code == 404: + if response.status_code == HTTPStatus.NOT_FOUND: return None response.raise_for_status() result: dict[str, Any] = response.json() @@ -67,7 +68,7 @@ def fetch_civitai_model(model_id: int, api_key: str | None, console: Console) -> try: response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) - if response.status_code == 404: + if response.status_code == HTTPStatus.NOT_FOUND: return None response.raise_for_status() result: dict[str, Any] = response.json() @@ -94,7 +95,7 @@ def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None, console: Consol try: response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) - if response.status_code == 404: + if response.status_code == HTTPStatus.NOT_FOUND: return None response.raise_for_status() result: dict[str, Any] = response.json() @@ -269,7 +270,7 @@ def download_model( follow_redirects=True, timeout=httpx.Timeout(30.0, read=None), ) as response: - if response.status_code == 416: + if response.status_code == HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE: console.print("[green]File already fully downloaded.[/green]") return True @@ -279,7 +280,7 @@ def download_model( except httpx.HTTPStatusError as e: console.print(f"[red]Download error: HTTP {e.response.status_code}[/red]") - if e.response.status_code == 401: + if e.response.status_code == HTTPStatus.UNAUTHORIZED: console.print("[yellow]Hint: This model may require an API key.[/yellow]") return False except httpx.RequestError as e: diff --git a/tensors/cli.py b/tensors/cli.py index d0f58f9..4a59daa 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -38,6 +38,9 @@ from tensors.display import ( ) from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_metadata +# Key masking threshold +MIN_KEY_LENGTH_FOR_MASKING = 8 + app = typer.Typer( name="tsr", help="Read safetensor metadata, search and download CivitAI models.", @@ -211,44 +214,50 @@ def get( display_model_info(model_data, console) +def _resolve_by_hash(hash_val: str, api_key: str | None) -> int | None: + """Resolve version ID from SHA256 hash.""" + console.print(f"[cyan]Looking up model by hash: {hash_val[:16]}...[/cyan]") + civitai_data = fetch_civitai_by_hash(hash_val.upper(), api_key, console) + if not civitai_data: + console.print("[red]Error: Model not found on CivitAI for this hash.[/red]") + return None + vid: int | None = civitai_data.get("id") + if vid: + console.print(f"[green]Found:[/green] {civitai_data.get('name', 'N/A')}") + return vid + + +def _resolve_by_model_id(model_id: int, api_key: str | None) -> int | None: + """Resolve latest version ID from model ID.""" + console.print(f"[cyan]Looking up model {model_id}...[/cyan]") + model_data = fetch_civitai_model(model_id, api_key, console) + if not model_data: + console.print(f"[red]Error: Model {model_id} not found.[/red]") + return None + versions = model_data.get("modelVersions", []) + if not versions: + console.print("[red]Error: Model has no versions.[/red]") + return None + latest = versions[0] + latest_vid: int | None = latest.get("id") + if latest_vid: + console.print(f"[green]Found latest:[/green] {latest.get('name', 'N/A')} (ID: {latest_vid})") + return latest_vid + + def _resolve_version_id( version_id: int | None, hash_val: str | None, model_id: int | None, api_key: str | None, ) -> int | None: - """Resolve version ID from hash or model ID.""" + """Resolve version ID from direct ID, hash, or model ID.""" if version_id: return version_id - if hash_val: - console.print(f"[cyan]Looking up model by hash: {hash_val[:16]}...[/cyan]") - civitai_data = fetch_civitai_by_hash(hash_val.upper(), api_key, console) - if not civitai_data: - console.print("[red]Error: Model not found on CivitAI for this hash.[/red]") - return None - vid: int | None = civitai_data.get("id") - if vid: - console.print(f"[green]Found:[/green] {civitai_data.get('name', 'N/A')}") - return vid - + return _resolve_by_hash(hash_val, api_key) if model_id: - console.print(f"[cyan]Looking up model {model_id}...[/cyan]") - model_data = fetch_civitai_model(model_id, api_key, console) - if not model_data: - console.print(f"[red]Error: Model {model_id} not found.[/red]") - return None - versions = model_data.get("modelVersions", []) - if not versions: - console.print("[red]Error: Model has no versions.[/red]") - return None - latest = versions[0] - latest_vid: int | None = latest.get("id") - if latest_vid: - name = latest.get("name", "N/A") - console.print(f"[green]Found latest:[/green] {name} (ID: {latest_vid})") - return latest_vid - + return _resolve_by_model_id(model_id, api_key) return None @@ -359,7 +368,7 @@ def config( key = load_api_key() if key: - masked = key[:4] + "..." + key[-4:] if len(key) > 8 else "***" + masked = key[:4] + "..." + key[-4:] if len(key) > MIN_KEY_LENGTH_FOR_MASKING else "***" console.print(f"[bold]API key:[/bold] {masked}") else: console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]") diff --git a/tensors/display.py b/tensors/display.py index 9f5d64d..0466767 100644 --- a/tensors/display.py +++ b/tensors/display.py @@ -12,23 +12,30 @@ from rich.table import Table if TYPE_CHECKING: from rich.console import Console +# Size formatting constants +KB = 1024 +MB_IN_KB = KB * KB +THOUSAND = 1000 +MILLION = 1_000_000 +MAX_TAGS_DISPLAY = 10 + def _format_size(size_kb: float) -> str: """Format size in KB to human-readable string.""" - if size_kb < 1024: + if size_kb < KB: return f"{size_kb:.0f} KB" - if size_kb < 1024 * 1024: - return f"{size_kb / 1024:.1f} MB" - return f"{size_kb / 1024 / 1024:.2f} GB" + if size_kb < MB_IN_KB: + return f"{size_kb / KB:.1f} MB" + return f"{size_kb / KB / KB:.2f} GB" def _format_count(count: int) -> str: """Format large numbers with K/M suffix.""" - if count < 1000: + if count < THOUSAND: return str(count) - if count < 1_000_000: - return f"{count / 1000:.1f}K" - return f"{count / 1_000_000:.1f}M" + if count < MILLION: + return f"{count / THOUSAND:.1f}K" + return f"{count / MILLION:.1f}M" def display_file_info(file_path: Path, local_metadata: dict[str, Any], sha256_hash: str, console: Console) -> None: @@ -174,7 +181,7 @@ def _add_model_basic_info(table: Table, model_data: dict[str, Any]) -> None: tags: list[str] = model_data.get("tags", []) if tags: - table.add_row("Tags", ", ".join(tags[:10]) + ("..." if len(tags) > 10 else "")) + table.add_row("Tags", ", ".join(tags[:MAX_TAGS_DISPLAY]) + ("..." if len(tags) > MAX_TAGS_DISPLAY else "")) stats: dict[str, Any] = model_data.get("stats", {}) if stats: diff --git a/tensors/safetensor.py b/tensors/safetensor.py index 899ca38..710d57b 100644 --- a/tensors/safetensor.py +++ b/tensors/safetensor.py @@ -24,18 +24,21 @@ from rich.progress import ( if TYPE_CHECKING: from rich.console import Console +# Safetensor format constants +HEADER_SIZE_BYTES = 8 # u64 little-endian +MAX_HEADER_SIZE = 100_000_000 # 100MB sanity check + def read_safetensor_metadata(file_path: Path) -> dict[str, Any]: """Read metadata from a safetensor file header.""" with file_path.open("rb") as f: - # First 8 bytes are the header size (little-endian u64) - header_size_bytes = f.read(8) - if len(header_size_bytes) < 8: + header_size_bytes = f.read(HEADER_SIZE_BYTES) + if len(header_size_bytes) < HEADER_SIZE_BYTES: raise ValueError("Invalid safetensor file: too short") header_size = struct.unpack(" 100_000_000: # 100MB sanity check + if header_size > MAX_HEADER_SIZE: raise ValueError(f"Invalid header size: {header_size}") header_bytes = f.read(header_size)