Fix PLR2004/PLR0911: replace magic values with constants, refactor _resolve_version_id

This commit is contained in:
Adam Ladachowski
2026-02-03 23:01:33 +01:00
parent 4bebae00f6
commit 75eccecfba
6 changed files with 67 additions and 49 deletions
BIN
View File
Binary file not shown.
-2
View File
@@ -52,9 +52,7 @@ select = [
"RUF", # ruff-specific "RUF", # ruff-specific
] ]
ignore = [ ignore = [
"PLR2004", # Magic values - acceptable in CLI apps
"PLR0913", # Too many arguments - CLI commands need many options "PLR0913", # Too many arguments - CLI commands need many options
"PLR0911", # Too many return statements
] ]
[tool.ruff.lint.isort] [tool.ruff.lint.isort]
+6 -5
View File
@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import re import re
from http import HTTPStatus
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -40,7 +41,7 @@ def fetch_civitai_model_version(version_id: int, api_key: str | None, console: C
try: try:
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
if response.status_code == 404: if response.status_code == HTTPStatus.NOT_FOUND:
return None return None
response.raise_for_status() response.raise_for_status()
result: dict[str, Any] = response.json() result: dict[str, Any] = response.json()
@@ -67,7 +68,7 @@ def fetch_civitai_model(model_id: int, api_key: str | None, console: Console) ->
try: try:
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
if response.status_code == 404: if response.status_code == HTTPStatus.NOT_FOUND:
return None return None
response.raise_for_status() response.raise_for_status()
result: dict[str, Any] = response.json() result: dict[str, Any] = response.json()
@@ -94,7 +95,7 @@ def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None, console: Consol
try: try:
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0) response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
if response.status_code == 404: if response.status_code == HTTPStatus.NOT_FOUND:
return None return None
response.raise_for_status() response.raise_for_status()
result: dict[str, Any] = response.json() result: dict[str, Any] = response.json()
@@ -269,7 +270,7 @@ def download_model(
follow_redirects=True, follow_redirects=True,
timeout=httpx.Timeout(30.0, read=None), timeout=httpx.Timeout(30.0, read=None),
) as response: ) as response:
if response.status_code == 416: if response.status_code == HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE:
console.print("[green]File already fully downloaded.[/green]") console.print("[green]File already fully downloaded.[/green]")
return True return True
@@ -279,7 +280,7 @@ def download_model(
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
console.print(f"[red]Download error: HTTP {e.response.status_code}[/red]") console.print(f"[red]Download error: HTTP {e.response.status_code}[/red]")
if e.response.status_code == 401: if e.response.status_code == HTTPStatus.UNAUTHORIZED:
console.print("[yellow]Hint: This model may require an API key.[/yellow]") console.print("[yellow]Hint: This model may require an API key.[/yellow]")
return False return False
except httpx.RequestError as e: except httpx.RequestError as e:
+38 -29
View File
@@ -38,6 +38,9 @@ from tensors.display import (
) )
from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_metadata from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_metadata
# Key masking threshold
MIN_KEY_LENGTH_FOR_MASKING = 8
app = typer.Typer( app = typer.Typer(
name="tsr", name="tsr",
help="Read safetensor metadata, search and download CivitAI models.", help="Read safetensor metadata, search and download CivitAI models.",
@@ -211,44 +214,50 @@ def get(
display_model_info(model_data, console) display_model_info(model_data, console)
def _resolve_by_hash(hash_val: str, api_key: str | None) -> int | None:
"""Resolve version ID from SHA256 hash."""
console.print(f"[cyan]Looking up model by hash: {hash_val[:16]}...[/cyan]")
civitai_data = fetch_civitai_by_hash(hash_val.upper(), api_key, console)
if not civitai_data:
console.print("[red]Error: Model not found on CivitAI for this hash.[/red]")
return None
vid: int | None = civitai_data.get("id")
if vid:
console.print(f"[green]Found:[/green] {civitai_data.get('name', 'N/A')}")
return vid
def _resolve_by_model_id(model_id: int, api_key: str | None) -> int | None:
"""Resolve latest version ID from model ID."""
console.print(f"[cyan]Looking up model {model_id}...[/cyan]")
model_data = fetch_civitai_model(model_id, api_key, console)
if not model_data:
console.print(f"[red]Error: Model {model_id} not found.[/red]")
return None
versions = model_data.get("modelVersions", [])
if not versions:
console.print("[red]Error: Model has no versions.[/red]")
return None
latest = versions[0]
latest_vid: int | None = latest.get("id")
if latest_vid:
console.print(f"[green]Found latest:[/green] {latest.get('name', 'N/A')} (ID: {latest_vid})")
return latest_vid
def _resolve_version_id( def _resolve_version_id(
version_id: int | None, version_id: int | None,
hash_val: str | None, hash_val: str | None,
model_id: int | None, model_id: int | None,
api_key: str | None, api_key: str | None,
) -> int | None: ) -> int | None:
"""Resolve version ID from hash or model ID.""" """Resolve version ID from direct ID, hash, or model ID."""
if version_id: if version_id:
return version_id return version_id
if hash_val: if hash_val:
console.print(f"[cyan]Looking up model by hash: {hash_val[:16]}...[/cyan]") return _resolve_by_hash(hash_val, api_key)
civitai_data = fetch_civitai_by_hash(hash_val.upper(), api_key, console)
if not civitai_data:
console.print("[red]Error: Model not found on CivitAI for this hash.[/red]")
return None
vid: int | None = civitai_data.get("id")
if vid:
console.print(f"[green]Found:[/green] {civitai_data.get('name', 'N/A')}")
return vid
if model_id: if model_id:
console.print(f"[cyan]Looking up model {model_id}...[/cyan]") return _resolve_by_model_id(model_id, api_key)
model_data = fetch_civitai_model(model_id, api_key, console)
if not model_data:
console.print(f"[red]Error: Model {model_id} not found.[/red]")
return None
versions = model_data.get("modelVersions", [])
if not versions:
console.print("[red]Error: Model has no versions.[/red]")
return None
latest = versions[0]
latest_vid: int | None = latest.get("id")
if latest_vid:
name = latest.get("name", "N/A")
console.print(f"[green]Found latest:[/green] {name} (ID: {latest_vid})")
return latest_vid
return None return None
@@ -359,7 +368,7 @@ def config(
key = load_api_key() key = load_api_key()
if key: if key:
masked = key[:4] + "..." + key[-4:] if len(key) > 8 else "***" masked = key[:4] + "..." + key[-4:] if len(key) > MIN_KEY_LENGTH_FOR_MASKING else "***"
console.print(f"[bold]API key:[/bold] {masked}") console.print(f"[bold]API key:[/bold] {masked}")
else: else:
console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]") console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]")
+16 -9
View File
@@ -12,23 +12,30 @@ from rich.table import Table
if TYPE_CHECKING: if TYPE_CHECKING:
from rich.console import Console from rich.console import Console
# Size formatting constants
KB = 1024
MB_IN_KB = KB * KB
THOUSAND = 1000
MILLION = 1_000_000
MAX_TAGS_DISPLAY = 10
def _format_size(size_kb: float) -> str: def _format_size(size_kb: float) -> str:
"""Format size in KB to human-readable string.""" """Format size in KB to human-readable string."""
if size_kb < 1024: if size_kb < KB:
return f"{size_kb:.0f} KB" return f"{size_kb:.0f} KB"
if size_kb < 1024 * 1024: if size_kb < MB_IN_KB:
return f"{size_kb / 1024:.1f} MB" return f"{size_kb / KB:.1f} MB"
return f"{size_kb / 1024 / 1024:.2f} GB" return f"{size_kb / KB / KB:.2f} GB"
def _format_count(count: int) -> str: def _format_count(count: int) -> str:
"""Format large numbers with K/M suffix.""" """Format large numbers with K/M suffix."""
if count < 1000: if count < THOUSAND:
return str(count) return str(count)
if count < 1_000_000: if count < MILLION:
return f"{count / 1000:.1f}K" return f"{count / THOUSAND:.1f}K"
return f"{count / 1_000_000:.1f}M" return f"{count / MILLION:.1f}M"
def display_file_info(file_path: Path, local_metadata: dict[str, Any], sha256_hash: str, console: Console) -> None: def display_file_info(file_path: Path, local_metadata: dict[str, Any], sha256_hash: str, console: Console) -> None:
@@ -174,7 +181,7 @@ def _add_model_basic_info(table: Table, model_data: dict[str, Any]) -> None:
tags: list[str] = model_data.get("tags", []) tags: list[str] = model_data.get("tags", [])
if tags: if tags:
table.add_row("Tags", ", ".join(tags[:10]) + ("..." if len(tags) > 10 else "")) table.add_row("Tags", ", ".join(tags[:MAX_TAGS_DISPLAY]) + ("..." if len(tags) > MAX_TAGS_DISPLAY else ""))
stats: dict[str, Any] = model_data.get("stats", {}) stats: dict[str, Any] = model_data.get("stats", {})
if stats: if stats:
+7 -4
View File
@@ -24,18 +24,21 @@ from rich.progress import (
if TYPE_CHECKING: if TYPE_CHECKING:
from rich.console import Console from rich.console import Console
# Safetensor format constants
HEADER_SIZE_BYTES = 8 # u64 little-endian
MAX_HEADER_SIZE = 100_000_000 # 100MB sanity check
def read_safetensor_metadata(file_path: Path) -> dict[str, Any]: def read_safetensor_metadata(file_path: Path) -> dict[str, Any]:
"""Read metadata from a safetensor file header.""" """Read metadata from a safetensor file header."""
with file_path.open("rb") as f: with file_path.open("rb") as f:
# First 8 bytes are the header size (little-endian u64) header_size_bytes = f.read(HEADER_SIZE_BYTES)
header_size_bytes = f.read(8) if len(header_size_bytes) < HEADER_SIZE_BYTES:
if len(header_size_bytes) < 8:
raise ValueError("Invalid safetensor file: too short") raise ValueError("Invalid safetensor file: too short")
header_size = struct.unpack("<Q", header_size_bytes)[0] header_size = struct.unpack("<Q", header_size_bytes)[0]
if header_size > 100_000_000: # 100MB sanity check if header_size > MAX_HEADER_SIZE:
raise ValueError(f"Invalid header size: {header_size}") raise ValueError(f"Invalid header size: {header_size}")
header_bytes = f.read(header_size) header_bytes = f.read(header_size)