Fix PLR2004/PLR0911: replace magic values with constants, refactor _resolve_version_id
This commit is contained in:
@@ -52,9 +52,7 @@ select = [
|
||||
"RUF", # ruff-specific
|
||||
]
|
||||
ignore = [
|
||||
"PLR2004", # Magic values - acceptable in CLI apps
|
||||
"PLR0913", # Too many arguments - CLI commands need many options
|
||||
"PLR0911", # Too many return statements
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
|
||||
+6
-5
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -40,7 +41,7 @@ def fetch_civitai_model_version(version_id: int, api_key: str | None, console: C
|
||||
|
||||
try:
|
||||
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
||||
if response.status_code == 404:
|
||||
if response.status_code == HTTPStatus.NOT_FOUND:
|
||||
return None
|
||||
response.raise_for_status()
|
||||
result: dict[str, Any] = response.json()
|
||||
@@ -67,7 +68,7 @@ def fetch_civitai_model(model_id: int, api_key: str | None, console: Console) ->
|
||||
|
||||
try:
|
||||
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
||||
if response.status_code == 404:
|
||||
if response.status_code == HTTPStatus.NOT_FOUND:
|
||||
return None
|
||||
response.raise_for_status()
|
||||
result: dict[str, Any] = response.json()
|
||||
@@ -94,7 +95,7 @@ def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None, console: Consol
|
||||
|
||||
try:
|
||||
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
||||
if response.status_code == 404:
|
||||
if response.status_code == HTTPStatus.NOT_FOUND:
|
||||
return None
|
||||
response.raise_for_status()
|
||||
result: dict[str, Any] = response.json()
|
||||
@@ -269,7 +270,7 @@ def download_model(
|
||||
follow_redirects=True,
|
||||
timeout=httpx.Timeout(30.0, read=None),
|
||||
) as response:
|
||||
if response.status_code == 416:
|
||||
if response.status_code == HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE:
|
||||
console.print("[green]File already fully downloaded.[/green]")
|
||||
return True
|
||||
|
||||
@@ -279,7 +280,7 @@ def download_model(
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
console.print(f"[red]Download error: HTTP {e.response.status_code}[/red]")
|
||||
if e.response.status_code == 401:
|
||||
if e.response.status_code == HTTPStatus.UNAUTHORIZED:
|
||||
console.print("[yellow]Hint: This model may require an API key.[/yellow]")
|
||||
return False
|
||||
except httpx.RequestError as e:
|
||||
|
||||
+24
-15
@@ -38,6 +38,9 @@ from tensors.display import (
|
||||
)
|
||||
from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_metadata
|
||||
|
||||
# Key masking threshold
|
||||
MIN_KEY_LENGTH_FOR_MASKING = 8
|
||||
|
||||
app = typer.Typer(
|
||||
name="tsr",
|
||||
help="Read safetensor metadata, search and download CivitAI models.",
|
||||
@@ -211,17 +214,8 @@ def get(
|
||||
display_model_info(model_data, console)
|
||||
|
||||
|
||||
def _resolve_version_id(
|
||||
version_id: int | None,
|
||||
hash_val: str | None,
|
||||
model_id: int | None,
|
||||
api_key: str | None,
|
||||
) -> int | None:
|
||||
"""Resolve version ID from hash or model ID."""
|
||||
if version_id:
|
||||
return version_id
|
||||
|
||||
if hash_val:
|
||||
def _resolve_by_hash(hash_val: str, api_key: str | None) -> int | None:
|
||||
"""Resolve version ID from SHA256 hash."""
|
||||
console.print(f"[cyan]Looking up model by hash: {hash_val[:16]}...[/cyan]")
|
||||
civitai_data = fetch_civitai_by_hash(hash_val.upper(), api_key, console)
|
||||
if not civitai_data:
|
||||
@@ -232,7 +226,9 @@ def _resolve_version_id(
|
||||
console.print(f"[green]Found:[/green] {civitai_data.get('name', 'N/A')}")
|
||||
return vid
|
||||
|
||||
if model_id:
|
||||
|
||||
def _resolve_by_model_id(model_id: int, api_key: str | None) -> int | None:
|
||||
"""Resolve latest version ID from model ID."""
|
||||
console.print(f"[cyan]Looking up model {model_id}...[/cyan]")
|
||||
model_data = fetch_civitai_model(model_id, api_key, console)
|
||||
if not model_data:
|
||||
@@ -245,10 +241,23 @@ def _resolve_version_id(
|
||||
latest = versions[0]
|
||||
latest_vid: int | None = latest.get("id")
|
||||
if latest_vid:
|
||||
name = latest.get("name", "N/A")
|
||||
console.print(f"[green]Found latest:[/green] {name} (ID: {latest_vid})")
|
||||
console.print(f"[green]Found latest:[/green] {latest.get('name', 'N/A')} (ID: {latest_vid})")
|
||||
return latest_vid
|
||||
|
||||
|
||||
def _resolve_version_id(
|
||||
version_id: int | None,
|
||||
hash_val: str | None,
|
||||
model_id: int | None,
|
||||
api_key: str | None,
|
||||
) -> int | None:
|
||||
"""Resolve version ID from direct ID, hash, or model ID."""
|
||||
if version_id:
|
||||
return version_id
|
||||
if hash_val:
|
||||
return _resolve_by_hash(hash_val, api_key)
|
||||
if model_id:
|
||||
return _resolve_by_model_id(model_id, api_key)
|
||||
return None
|
||||
|
||||
|
||||
@@ -359,7 +368,7 @@ def config(
|
||||
|
||||
key = load_api_key()
|
||||
if key:
|
||||
masked = key[:4] + "..." + key[-4:] if len(key) > 8 else "***"
|
||||
masked = key[:4] + "..." + key[-4:] if len(key) > MIN_KEY_LENGTH_FOR_MASKING else "***"
|
||||
console.print(f"[bold]API key:[/bold] {masked}")
|
||||
else:
|
||||
console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]")
|
||||
|
||||
+16
-9
@@ -12,23 +12,30 @@ from rich.table import Table
|
||||
if TYPE_CHECKING:
|
||||
from rich.console import Console
|
||||
|
||||
# Size formatting constants
|
||||
KB = 1024
|
||||
MB_IN_KB = KB * KB
|
||||
THOUSAND = 1000
|
||||
MILLION = 1_000_000
|
||||
MAX_TAGS_DISPLAY = 10
|
||||
|
||||
|
||||
def _format_size(size_kb: float) -> str:
|
||||
"""Format size in KB to human-readable string."""
|
||||
if size_kb < 1024:
|
||||
if size_kb < KB:
|
||||
return f"{size_kb:.0f} KB"
|
||||
if size_kb < 1024 * 1024:
|
||||
return f"{size_kb / 1024:.1f} MB"
|
||||
return f"{size_kb / 1024 / 1024:.2f} GB"
|
||||
if size_kb < MB_IN_KB:
|
||||
return f"{size_kb / KB:.1f} MB"
|
||||
return f"{size_kb / KB / KB:.2f} GB"
|
||||
|
||||
|
||||
def _format_count(count: int) -> str:
|
||||
"""Format large numbers with K/M suffix."""
|
||||
if count < 1000:
|
||||
if count < THOUSAND:
|
||||
return str(count)
|
||||
if count < 1_000_000:
|
||||
return f"{count / 1000:.1f}K"
|
||||
return f"{count / 1_000_000:.1f}M"
|
||||
if count < MILLION:
|
||||
return f"{count / THOUSAND:.1f}K"
|
||||
return f"{count / MILLION:.1f}M"
|
||||
|
||||
|
||||
def display_file_info(file_path: Path, local_metadata: dict[str, Any], sha256_hash: str, console: Console) -> None:
|
||||
@@ -174,7 +181,7 @@ def _add_model_basic_info(table: Table, model_data: dict[str, Any]) -> None:
|
||||
|
||||
tags: list[str] = model_data.get("tags", [])
|
||||
if tags:
|
||||
table.add_row("Tags", ", ".join(tags[:10]) + ("..." if len(tags) > 10 else ""))
|
||||
table.add_row("Tags", ", ".join(tags[:MAX_TAGS_DISPLAY]) + ("..." if len(tags) > MAX_TAGS_DISPLAY else ""))
|
||||
|
||||
stats: dict[str, Any] = model_data.get("stats", {})
|
||||
if stats:
|
||||
|
||||
@@ -24,18 +24,21 @@ from rich.progress import (
|
||||
if TYPE_CHECKING:
|
||||
from rich.console import Console
|
||||
|
||||
# Safetensor format constants
|
||||
HEADER_SIZE_BYTES = 8 # u64 little-endian
|
||||
MAX_HEADER_SIZE = 100_000_000 # 100MB sanity check
|
||||
|
||||
|
||||
def read_safetensor_metadata(file_path: Path) -> dict[str, Any]:
|
||||
"""Read metadata from a safetensor file header."""
|
||||
with file_path.open("rb") as f:
|
||||
# First 8 bytes are the header size (little-endian u64)
|
||||
header_size_bytes = f.read(8)
|
||||
if len(header_size_bytes) < 8:
|
||||
header_size_bytes = f.read(HEADER_SIZE_BYTES)
|
||||
if len(header_size_bytes) < HEADER_SIZE_BYTES:
|
||||
raise ValueError("Invalid safetensor file: too short")
|
||||
|
||||
header_size = struct.unpack("<Q", header_size_bytes)[0]
|
||||
|
||||
if header_size > 100_000_000: # 100MB sanity check
|
||||
if header_size > MAX_HEADER_SIZE:
|
||||
raise ValueError(f"Invalid header size: {header_size}")
|
||||
|
||||
header_bytes = f.read(header_size)
|
||||
|
||||
Reference in New Issue
Block a user