Fix PLR2004/PLR0911: replace magic values with constants, refactor _resolve_version_id
This commit is contained in:
@@ -52,9 +52,7 @@ select = [
|
|||||||
"RUF", # ruff-specific
|
"RUF", # ruff-specific
|
||||||
]
|
]
|
||||||
ignore = [
|
ignore = [
|
||||||
"PLR2004", # Magic values - acceptable in CLI apps
|
|
||||||
"PLR0913", # Too many arguments - CLI commands need many options
|
"PLR0913", # Too many arguments - CLI commands need many options
|
||||||
"PLR0911", # Too many return statements
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff.lint.isort]
|
[tool.ruff.lint.isort]
|
||||||
|
|||||||
+6
-5
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
from http import HTTPStatus
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -40,7 +41,7 @@ def fetch_civitai_model_version(version_id: int, api_key: str | None, console: C
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
||||||
if response.status_code == 404:
|
if response.status_code == HTTPStatus.NOT_FOUND:
|
||||||
return None
|
return None
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result: dict[str, Any] = response.json()
|
result: dict[str, Any] = response.json()
|
||||||
@@ -67,7 +68,7 @@ def fetch_civitai_model(model_id: int, api_key: str | None, console: Console) ->
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
||||||
if response.status_code == 404:
|
if response.status_code == HTTPStatus.NOT_FOUND:
|
||||||
return None
|
return None
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result: dict[str, Any] = response.json()
|
result: dict[str, Any] = response.json()
|
||||||
@@ -94,7 +95,7 @@ def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None, console: Consol
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
||||||
if response.status_code == 404:
|
if response.status_code == HTTPStatus.NOT_FOUND:
|
||||||
return None
|
return None
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result: dict[str, Any] = response.json()
|
result: dict[str, Any] = response.json()
|
||||||
@@ -269,7 +270,7 @@ def download_model(
|
|||||||
follow_redirects=True,
|
follow_redirects=True,
|
||||||
timeout=httpx.Timeout(30.0, read=None),
|
timeout=httpx.Timeout(30.0, read=None),
|
||||||
) as response:
|
) as response:
|
||||||
if response.status_code == 416:
|
if response.status_code == HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE:
|
||||||
console.print("[green]File already fully downloaded.[/green]")
|
console.print("[green]File already fully downloaded.[/green]")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -279,7 +280,7 @@ def download_model(
|
|||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
console.print(f"[red]Download error: HTTP {e.response.status_code}[/red]")
|
console.print(f"[red]Download error: HTTP {e.response.status_code}[/red]")
|
||||||
if e.response.status_code == 401:
|
if e.response.status_code == HTTPStatus.UNAUTHORIZED:
|
||||||
console.print("[yellow]Hint: This model may require an API key.[/yellow]")
|
console.print("[yellow]Hint: This model may require an API key.[/yellow]")
|
||||||
return False
|
return False
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
|
|||||||
+38
-29
@@ -38,6 +38,9 @@ from tensors.display import (
|
|||||||
)
|
)
|
||||||
from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_metadata
|
from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_metadata
|
||||||
|
|
||||||
|
# Key masking threshold
|
||||||
|
MIN_KEY_LENGTH_FOR_MASKING = 8
|
||||||
|
|
||||||
app = typer.Typer(
|
app = typer.Typer(
|
||||||
name="tsr",
|
name="tsr",
|
||||||
help="Read safetensor metadata, search and download CivitAI models.",
|
help="Read safetensor metadata, search and download CivitAI models.",
|
||||||
@@ -211,44 +214,50 @@ def get(
|
|||||||
display_model_info(model_data, console)
|
display_model_info(model_data, console)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_by_hash(hash_val: str, api_key: str | None) -> int | None:
|
||||||
|
"""Resolve version ID from SHA256 hash."""
|
||||||
|
console.print(f"[cyan]Looking up model by hash: {hash_val[:16]}...[/cyan]")
|
||||||
|
civitai_data = fetch_civitai_by_hash(hash_val.upper(), api_key, console)
|
||||||
|
if not civitai_data:
|
||||||
|
console.print("[red]Error: Model not found on CivitAI for this hash.[/red]")
|
||||||
|
return None
|
||||||
|
vid: int | None = civitai_data.get("id")
|
||||||
|
if vid:
|
||||||
|
console.print(f"[green]Found:[/green] {civitai_data.get('name', 'N/A')}")
|
||||||
|
return vid
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_by_model_id(model_id: int, api_key: str | None) -> int | None:
|
||||||
|
"""Resolve latest version ID from model ID."""
|
||||||
|
console.print(f"[cyan]Looking up model {model_id}...[/cyan]")
|
||||||
|
model_data = fetch_civitai_model(model_id, api_key, console)
|
||||||
|
if not model_data:
|
||||||
|
console.print(f"[red]Error: Model {model_id} not found.[/red]")
|
||||||
|
return None
|
||||||
|
versions = model_data.get("modelVersions", [])
|
||||||
|
if not versions:
|
||||||
|
console.print("[red]Error: Model has no versions.[/red]")
|
||||||
|
return None
|
||||||
|
latest = versions[0]
|
||||||
|
latest_vid: int | None = latest.get("id")
|
||||||
|
if latest_vid:
|
||||||
|
console.print(f"[green]Found latest:[/green] {latest.get('name', 'N/A')} (ID: {latest_vid})")
|
||||||
|
return latest_vid
|
||||||
|
|
||||||
|
|
||||||
def _resolve_version_id(
|
def _resolve_version_id(
|
||||||
version_id: int | None,
|
version_id: int | None,
|
||||||
hash_val: str | None,
|
hash_val: str | None,
|
||||||
model_id: int | None,
|
model_id: int | None,
|
||||||
api_key: str | None,
|
api_key: str | None,
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
"""Resolve version ID from hash or model ID."""
|
"""Resolve version ID from direct ID, hash, or model ID."""
|
||||||
if version_id:
|
if version_id:
|
||||||
return version_id
|
return version_id
|
||||||
|
|
||||||
if hash_val:
|
if hash_val:
|
||||||
console.print(f"[cyan]Looking up model by hash: {hash_val[:16]}...[/cyan]")
|
return _resolve_by_hash(hash_val, api_key)
|
||||||
civitai_data = fetch_civitai_by_hash(hash_val.upper(), api_key, console)
|
|
||||||
if not civitai_data:
|
|
||||||
console.print("[red]Error: Model not found on CivitAI for this hash.[/red]")
|
|
||||||
return None
|
|
||||||
vid: int | None = civitai_data.get("id")
|
|
||||||
if vid:
|
|
||||||
console.print(f"[green]Found:[/green] {civitai_data.get('name', 'N/A')}")
|
|
||||||
return vid
|
|
||||||
|
|
||||||
if model_id:
|
if model_id:
|
||||||
console.print(f"[cyan]Looking up model {model_id}...[/cyan]")
|
return _resolve_by_model_id(model_id, api_key)
|
||||||
model_data = fetch_civitai_model(model_id, api_key, console)
|
|
||||||
if not model_data:
|
|
||||||
console.print(f"[red]Error: Model {model_id} not found.[/red]")
|
|
||||||
return None
|
|
||||||
versions = model_data.get("modelVersions", [])
|
|
||||||
if not versions:
|
|
||||||
console.print("[red]Error: Model has no versions.[/red]")
|
|
||||||
return None
|
|
||||||
latest = versions[0]
|
|
||||||
latest_vid: int | None = latest.get("id")
|
|
||||||
if latest_vid:
|
|
||||||
name = latest.get("name", "N/A")
|
|
||||||
console.print(f"[green]Found latest:[/green] {name} (ID: {latest_vid})")
|
|
||||||
return latest_vid
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -359,7 +368,7 @@ def config(
|
|||||||
|
|
||||||
key = load_api_key()
|
key = load_api_key()
|
||||||
if key:
|
if key:
|
||||||
masked = key[:4] + "..." + key[-4:] if len(key) > 8 else "***"
|
masked = key[:4] + "..." + key[-4:] if len(key) > MIN_KEY_LENGTH_FOR_MASKING else "***"
|
||||||
console.print(f"[bold]API key:[/bold] {masked}")
|
console.print(f"[bold]API key:[/bold] {masked}")
|
||||||
else:
|
else:
|
||||||
console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]")
|
console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]")
|
||||||
|
|||||||
+16
-9
@@ -12,23 +12,30 @@ from rich.table import Table
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
|
# Size formatting constants
|
||||||
|
KB = 1024
|
||||||
|
MB_IN_KB = KB * KB
|
||||||
|
THOUSAND = 1000
|
||||||
|
MILLION = 1_000_000
|
||||||
|
MAX_TAGS_DISPLAY = 10
|
||||||
|
|
||||||
|
|
||||||
def _format_size(size_kb: float) -> str:
|
def _format_size(size_kb: float) -> str:
|
||||||
"""Format size in KB to human-readable string."""
|
"""Format size in KB to human-readable string."""
|
||||||
if size_kb < 1024:
|
if size_kb < KB:
|
||||||
return f"{size_kb:.0f} KB"
|
return f"{size_kb:.0f} KB"
|
||||||
if size_kb < 1024 * 1024:
|
if size_kb < MB_IN_KB:
|
||||||
return f"{size_kb / 1024:.1f} MB"
|
return f"{size_kb / KB:.1f} MB"
|
||||||
return f"{size_kb / 1024 / 1024:.2f} GB"
|
return f"{size_kb / KB / KB:.2f} GB"
|
||||||
|
|
||||||
|
|
||||||
def _format_count(count: int) -> str:
|
def _format_count(count: int) -> str:
|
||||||
"""Format large numbers with K/M suffix."""
|
"""Format large numbers with K/M suffix."""
|
||||||
if count < 1000:
|
if count < THOUSAND:
|
||||||
return str(count)
|
return str(count)
|
||||||
if count < 1_000_000:
|
if count < MILLION:
|
||||||
return f"{count / 1000:.1f}K"
|
return f"{count / THOUSAND:.1f}K"
|
||||||
return f"{count / 1_000_000:.1f}M"
|
return f"{count / MILLION:.1f}M"
|
||||||
|
|
||||||
|
|
||||||
def display_file_info(file_path: Path, local_metadata: dict[str, Any], sha256_hash: str, console: Console) -> None:
|
def display_file_info(file_path: Path, local_metadata: dict[str, Any], sha256_hash: str, console: Console) -> None:
|
||||||
@@ -174,7 +181,7 @@ def _add_model_basic_info(table: Table, model_data: dict[str, Any]) -> None:
|
|||||||
|
|
||||||
tags: list[str] = model_data.get("tags", [])
|
tags: list[str] = model_data.get("tags", [])
|
||||||
if tags:
|
if tags:
|
||||||
table.add_row("Tags", ", ".join(tags[:10]) + ("..." if len(tags) > 10 else ""))
|
table.add_row("Tags", ", ".join(tags[:MAX_TAGS_DISPLAY]) + ("..." if len(tags) > MAX_TAGS_DISPLAY else ""))
|
||||||
|
|
||||||
stats: dict[str, Any] = model_data.get("stats", {})
|
stats: dict[str, Any] = model_data.get("stats", {})
|
||||||
if stats:
|
if stats:
|
||||||
|
|||||||
@@ -24,18 +24,21 @@ from rich.progress import (
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
|
# Safetensor format constants
|
||||||
|
HEADER_SIZE_BYTES = 8 # u64 little-endian
|
||||||
|
MAX_HEADER_SIZE = 100_000_000 # 100MB sanity check
|
||||||
|
|
||||||
|
|
||||||
def read_safetensor_metadata(file_path: Path) -> dict[str, Any]:
|
def read_safetensor_metadata(file_path: Path) -> dict[str, Any]:
|
||||||
"""Read metadata from a safetensor file header."""
|
"""Read metadata from a safetensor file header."""
|
||||||
with file_path.open("rb") as f:
|
with file_path.open("rb") as f:
|
||||||
# First 8 bytes are the header size (little-endian u64)
|
header_size_bytes = f.read(HEADER_SIZE_BYTES)
|
||||||
header_size_bytes = f.read(8)
|
if len(header_size_bytes) < HEADER_SIZE_BYTES:
|
||||||
if len(header_size_bytes) < 8:
|
|
||||||
raise ValueError("Invalid safetensor file: too short")
|
raise ValueError("Invalid safetensor file: too short")
|
||||||
|
|
||||||
header_size = struct.unpack("<Q", header_size_bytes)[0]
|
header_size = struct.unpack("<Q", header_size_bytes)[0]
|
||||||
|
|
||||||
if header_size > 100_000_000: # 100MB sanity check
|
if header_size > MAX_HEADER_SIZE:
|
||||||
raise ValueError(f"Invalid header size: {header_size}")
|
raise ValueError(f"Invalid header size: {header_size}")
|
||||||
|
|
||||||
header_bytes = f.read(header_size)
|
header_bytes = f.read(header_size)
|
||||||
|
|||||||
Reference in New Issue
Block a user