[Update] [2026-02-03 22:57:52] 10 files
This commit is contained in:
+3
-7
@@ -19,7 +19,7 @@ requires = ["hatchling"]
|
|||||||
build-backend = "hatchling.build"
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
[tool.hatch.build.targets.wheel]
|
||||||
packages = ["tensors.py"]
|
packages = ["tensors"]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
@@ -33,7 +33,7 @@ dev = [
|
|||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py312"
|
target-version = "py312"
|
||||||
line-length = 100
|
line-length = 130
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [
|
select = [
|
||||||
@@ -51,11 +51,7 @@ select = [
|
|||||||
"PL", # pylint
|
"PL", # pylint
|
||||||
"RUF", # ruff-specific
|
"RUF", # ruff-specific
|
||||||
]
|
]
|
||||||
ignore = [
|
ignore = []
|
||||||
"PLR0911", # too many return statements
|
|
||||||
"PLR0913", # too many arguments
|
|
||||||
"PLR2004", # magic value comparison
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.lint.isort]
|
[tool.ruff.lint.isort]
|
||||||
known-first-party = ["tensors"]
|
known-first-party = ["tensors"]
|
||||||
|
|||||||
-1147
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,26 @@
|
|||||||
|
"""tsr: Read safetensor metadata, search and download CivitAI models."""
|
||||||
|
|
||||||
|
from tensors.cli import main
|
||||||
|
from tensors.config import (
|
||||||
|
CONFIG_DIR,
|
||||||
|
CONFIG_FILE,
|
||||||
|
LEGACY_RC_FILE,
|
||||||
|
get_default_output_path,
|
||||||
|
load_api_key,
|
||||||
|
load_config,
|
||||||
|
save_config,
|
||||||
|
)
|
||||||
|
from tensors.safetensor import get_base_name, read_safetensor_metadata
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CONFIG_DIR",
|
||||||
|
"CONFIG_FILE",
|
||||||
|
"LEGACY_RC_FILE",
|
||||||
|
"get_base_name",
|
||||||
|
"get_default_output_path",
|
||||||
|
"load_api_key",
|
||||||
|
"load_config",
|
||||||
|
"main",
|
||||||
|
"read_safetensor_metadata",
|
||||||
|
"save_config",
|
||||||
|
]
|
||||||
+287
@@ -0,0 +1,287 @@
|
|||||||
|
"""CivitAI API functions."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from rich.progress import (
|
||||||
|
BarColumn,
|
||||||
|
DownloadColumn,
|
||||||
|
Progress,
|
||||||
|
SpinnerColumn,
|
||||||
|
TaskProgressColumn,
|
||||||
|
TextColumn,
|
||||||
|
TimeRemainingColumn,
|
||||||
|
TransferSpeedColumn,
|
||||||
|
)
|
||||||
|
|
||||||
|
from tensors.config import CIVITAI_API_BASE, CIVITAI_DOWNLOAD_BASE, BaseModel, ModelType, SortOrder
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from rich.console import Console
|
||||||
|
|
||||||
|
|
||||||
|
def _get_headers(api_key: str | None) -> dict[str, str]:
|
||||||
|
"""Get headers for CivitAI API requests."""
|
||||||
|
headers: dict[str, str] = {}
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_civitai_model_version(version_id: int, api_key: str | None, console: Console) -> dict[str, Any] | None:
|
||||||
|
"""Fetch model version information from CivitAI by version ID."""
|
||||||
|
url = f"{CIVITAI_API_BASE}/model-versions/{version_id}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
||||||
|
if response.status_code == 404:
|
||||||
|
return None
|
||||||
|
response.raise_for_status()
|
||||||
|
result: dict[str, Any] = response.json()
|
||||||
|
return result
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
||||||
|
return None
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
console.print(f"[red]Request error: {e}[/red]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_civitai_model(model_id: int, api_key: str | None, console: Console) -> dict[str, Any] | None:
|
||||||
|
"""Fetch model information from CivitAI by model ID."""
|
||||||
|
url = f"{CIVITAI_API_BASE}/models/{model_id}"
|
||||||
|
|
||||||
|
with Progress(
|
||||||
|
SpinnerColumn(),
|
||||||
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
console=console,
|
||||||
|
transient=True,
|
||||||
|
) as progress:
|
||||||
|
progress.add_task("[cyan]Fetching model from CivitAI...", total=None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
||||||
|
if response.status_code == 404:
|
||||||
|
return None
|
||||||
|
response.raise_for_status()
|
||||||
|
result: dict[str, Any] = response.json()
|
||||||
|
return result
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
||||||
|
return None
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
console.print(f"[red]Request error: {e}[/red]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None, console: Console) -> dict[str, Any] | None:
|
||||||
|
"""Fetch model information from CivitAI by SHA256 hash."""
|
||||||
|
url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}"
|
||||||
|
|
||||||
|
with Progress(
|
||||||
|
SpinnerColumn(),
|
||||||
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
console=console,
|
||||||
|
transient=True,
|
||||||
|
) as progress:
|
||||||
|
progress.add_task("[cyan]Fetching from CivitAI...", total=None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
||||||
|
if response.status_code == 404:
|
||||||
|
return None
|
||||||
|
response.raise_for_status()
|
||||||
|
result: dict[str, Any] = response.json()
|
||||||
|
return result
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
||||||
|
return None
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
console.print(f"[red]Request error: {e}[/red]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _build_search_params(
|
||||||
|
query: str | None,
|
||||||
|
model_type: ModelType | None,
|
||||||
|
base_model: BaseModel | None,
|
||||||
|
sort: SortOrder,
|
||||||
|
limit: int,
|
||||||
|
) -> tuple[dict[str, Any], bool]:
|
||||||
|
"""Build search parameters and return (params, has_filters)."""
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"limit": min(limit, 100),
|
||||||
|
"nsfw": "true",
|
||||||
|
}
|
||||||
|
|
||||||
|
# API quirk: query + filters don't work reliably together
|
||||||
|
has_filters = model_type is not None or base_model is not None
|
||||||
|
|
||||||
|
if query and not has_filters:
|
||||||
|
params["query"] = query
|
||||||
|
|
||||||
|
if model_type:
|
||||||
|
params["types"] = model_type.to_api()
|
||||||
|
|
||||||
|
if base_model:
|
||||||
|
params["baseModels"] = base_model.to_api()
|
||||||
|
|
||||||
|
params["sort"] = sort.to_api()
|
||||||
|
|
||||||
|
# Request more if we need client-side filtering
|
||||||
|
if query and has_filters:
|
||||||
|
params["limit"] = 100
|
||||||
|
|
||||||
|
return params, has_filters
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_results(result: dict[str, Any], query: str | None, has_filters: bool, limit: int) -> dict[str, Any]:
|
||||||
|
"""Apply client-side filtering when query + filters combined."""
|
||||||
|
if query and has_filters:
|
||||||
|
q_lower = query.lower()
|
||||||
|
result["items"] = [m for m in result.get("items", []) if q_lower in m.get("name", "").lower()][:limit]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def search_civitai(
|
||||||
|
query: str | None,
|
||||||
|
model_type: ModelType | None,
|
||||||
|
base_model: BaseModel | None,
|
||||||
|
sort: SortOrder,
|
||||||
|
limit: int,
|
||||||
|
api_key: str | None,
|
||||||
|
console: Console,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Search CivitAI models."""
|
||||||
|
params, has_filters = _build_search_params(query, model_type, base_model, sort, limit)
|
||||||
|
url = f"{CIVITAI_API_BASE}/models"
|
||||||
|
|
||||||
|
with Progress(
|
||||||
|
SpinnerColumn(),
|
||||||
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
console=console,
|
||||||
|
transient=True,
|
||||||
|
) as progress:
|
||||||
|
progress.add_task("[cyan]Searching CivitAI...", total=None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = httpx.get(url, params=params, headers=_get_headers(api_key), timeout=30.0)
|
||||||
|
response.raise_for_status()
|
||||||
|
result: dict[str, Any] = response.json()
|
||||||
|
return _filter_results(result, query, has_filters, limit)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
||||||
|
return None
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
console.print(f"[red]Request error: {e}[/red]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_resume(dest_path: Path, resume: bool, console: Console) -> tuple[dict[str, str], str, int]:
|
||||||
|
"""Set up resume headers and mode for download."""
|
||||||
|
headers: dict[str, str] = {}
|
||||||
|
mode = "wb"
|
||||||
|
initial_size = 0
|
||||||
|
|
||||||
|
if resume and dest_path.exists():
|
||||||
|
initial_size = dest_path.stat().st_size
|
||||||
|
headers["Range"] = f"bytes={initial_size}-"
|
||||||
|
mode = "ab"
|
||||||
|
console.print(f"[cyan]Resuming download from {initial_size / (1024**2):.1f} MB[/cyan]")
|
||||||
|
|
||||||
|
return headers, mode, initial_size
|
||||||
|
|
||||||
|
|
||||||
|
def _get_dest_from_response(response: httpx.Response, dest_path: Path) -> Path:
|
||||||
|
"""Extract destination path from response headers if dest is directory."""
|
||||||
|
content_disp = response.headers.get("content-disposition", "")
|
||||||
|
if "filename=" in content_disp:
|
||||||
|
match = re.search(r'filename="?([^";\n]+)"?', content_disp)
|
||||||
|
if match and dest_path.is_dir():
|
||||||
|
return dest_path / match.group(1)
|
||||||
|
return dest_path
|
||||||
|
|
||||||
|
|
||||||
|
def _stream_download(
|
||||||
|
response: httpx.Response,
|
||||||
|
dest_path: Path,
|
||||||
|
mode: str,
|
||||||
|
initial_size: int,
|
||||||
|
console: Console,
|
||||||
|
) -> bool:
|
||||||
|
"""Stream download content to file with progress."""
|
||||||
|
content_length = response.headers.get("content-length")
|
||||||
|
total_size = int(content_length) + initial_size if content_length else 0
|
||||||
|
|
||||||
|
with Progress(
|
||||||
|
SpinnerColumn(),
|
||||||
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
TaskProgressColumn(),
|
||||||
|
DownloadColumn(),
|
||||||
|
TransferSpeedColumn(),
|
||||||
|
TimeRemainingColumn(),
|
||||||
|
console=console,
|
||||||
|
) as progress:
|
||||||
|
task = progress.add_task(
|
||||||
|
f"[cyan]Downloading {dest_path.name}...",
|
||||||
|
total=total_size if total_size > 0 else None,
|
||||||
|
completed=initial_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
with dest_path.open(mode) as f:
|
||||||
|
for chunk in response.iter_bytes(1024 * 1024):
|
||||||
|
f.write(chunk)
|
||||||
|
progress.update(task, advance=len(chunk))
|
||||||
|
|
||||||
|
console.print()
|
||||||
|
console.print(f'[magenta]Downloaded:[/magenta] [green]"{dest_path}"[/green]')
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def download_model(
|
||||||
|
version_id: int,
|
||||||
|
dest_path: Path,
|
||||||
|
api_key: str | None,
|
||||||
|
console: Console,
|
||||||
|
resume: bool = True,
|
||||||
|
) -> bool:
|
||||||
|
"""Download a model from CivitAI by version ID with resume support."""
|
||||||
|
url = f"{CIVITAI_DOWNLOAD_BASE}/{version_id}"
|
||||||
|
params: dict[str, str] = {}
|
||||||
|
if api_key:
|
||||||
|
params["token"] = api_key
|
||||||
|
|
||||||
|
headers, mode, initial_size = _setup_resume(dest_path, resume, console)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with httpx.stream(
|
||||||
|
"GET",
|
||||||
|
url,
|
||||||
|
params=params,
|
||||||
|
headers=headers,
|
||||||
|
follow_redirects=True,
|
||||||
|
timeout=httpx.Timeout(30.0, read=None),
|
||||||
|
) as response:
|
||||||
|
if response.status_code == 416:
|
||||||
|
console.print("[green]File already fully downloaded.[/green]")
|
||||||
|
return True
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
dest_path = _get_dest_from_response(response, dest_path)
|
||||||
|
return _stream_download(response, dest_path, mode, initial_size, console)
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
console.print(f"[red]Download error: HTTP {e.response.status_code}[/red]")
|
||||||
|
if e.response.status_code == 401:
|
||||||
|
console.print("[yellow]Hint: This model may require an API key.[/yellow]")
|
||||||
|
return False
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
console.print(f"[red]Download error: {e}[/red]")
|
||||||
|
return False
|
||||||
+386
@@ -0,0 +1,386 @@
|
|||||||
|
"""CLI application and commands for tsr."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
import typer
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
from tensors.api import (
|
||||||
|
download_model,
|
||||||
|
fetch_civitai_by_hash,
|
||||||
|
fetch_civitai_model,
|
||||||
|
fetch_civitai_model_version,
|
||||||
|
search_civitai,
|
||||||
|
)
|
||||||
|
from tensors.config import (
|
||||||
|
CONFIG_FILE,
|
||||||
|
BaseModel,
|
||||||
|
ModelType,
|
||||||
|
SortOrder,
|
||||||
|
get_default_output_path,
|
||||||
|
load_api_key,
|
||||||
|
load_config,
|
||||||
|
save_config,
|
||||||
|
)
|
||||||
|
from tensors.display import (
|
||||||
|
_format_size,
|
||||||
|
display_civitai_data,
|
||||||
|
display_file_info,
|
||||||
|
display_local_metadata,
|
||||||
|
display_model_info,
|
||||||
|
display_search_results,
|
||||||
|
)
|
||||||
|
from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_metadata
|
||||||
|
|
||||||
|
app = typer.Typer(
|
||||||
|
name="tsr",
|
||||||
|
help="Read safetensor metadata, search and download CivitAI models.",
|
||||||
|
no_args_is_help=True,
|
||||||
|
)
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def info(
|
||||||
|
file: Annotated[Path, typer.Argument(help="Path to the safetensor file")],
|
||||||
|
meta: Annotated[list[str] | None, typer.Option("--meta", "-m", help="Show specific metadata key(s) in full")] = None,
|
||||||
|
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
||||||
|
skip_civitai: Annotated[bool, typer.Option("--skip-civitai", help="Skip CivitAI API lookup")] = False,
|
||||||
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
|
save_to: Annotated[Path | None, typer.Option("--save-to", help="Save metadata to directory")] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Read safetensor metadata and fetch CivitAI info."""
|
||||||
|
file_path = file.resolve()
|
||||||
|
|
||||||
|
if not file_path.exists():
|
||||||
|
console.print(f"[red]Error: File not found: {file_path}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
if file_path.suffix.lower() not in (".safetensors", ".sft"):
|
||||||
|
console.print("[yellow]Warning: File does not have .safetensors extension[/yellow]")
|
||||||
|
|
||||||
|
try:
|
||||||
|
local_metadata = read_safetensor_metadata(file_path)
|
||||||
|
|
||||||
|
if meta:
|
||||||
|
display_local_metadata(local_metadata, console, keys_filter=meta)
|
||||||
|
return
|
||||||
|
|
||||||
|
console.print(f"[bold]Reading safetensor file:[/bold] {file_path.name}")
|
||||||
|
sha256_hash = compute_sha256(file_path, console)
|
||||||
|
|
||||||
|
civitai_data = None
|
||||||
|
if not skip_civitai:
|
||||||
|
key = api_key or load_api_key()
|
||||||
|
civitai_data = fetch_civitai_by_hash(sha256_hash, key, console)
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
_output_info_json(file_path, sha256_hash, local_metadata, civitai_data)
|
||||||
|
else:
|
||||||
|
display_file_info(file_path, local_metadata, sha256_hash, console)
|
||||||
|
display_local_metadata(local_metadata, console)
|
||||||
|
display_civitai_data(civitai_data, console)
|
||||||
|
|
||||||
|
if save_to:
|
||||||
|
_save_metadata(save_to, file_path, sha256_hash, local_metadata, civitai_data)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
console.print(f"[red]Error reading safetensor: {e}[/red]")
|
||||||
|
raise typer.Exit(1) from e
|
||||||
|
|
||||||
|
|
||||||
|
def _output_info_json(
|
||||||
|
file_path: Path,
|
||||||
|
sha256_hash: str,
|
||||||
|
local_metadata: dict[str, Any],
|
||||||
|
civitai_data: dict[str, Any] | None,
|
||||||
|
) -> None:
|
||||||
|
"""Output info command result as JSON."""
|
||||||
|
output = {
|
||||||
|
"file": str(file_path),
|
||||||
|
"sha256": sha256_hash,
|
||||||
|
"header_size": local_metadata["header_size"],
|
||||||
|
"tensor_count": local_metadata["tensor_count"],
|
||||||
|
"metadata": local_metadata["metadata"],
|
||||||
|
"civitai": civitai_data,
|
||||||
|
}
|
||||||
|
console.print_json(data=output)
|
||||||
|
|
||||||
|
|
||||||
|
def _save_metadata(
|
||||||
|
save_to: Path,
|
||||||
|
file_path: Path,
|
||||||
|
sha256_hash: str,
|
||||||
|
local_metadata: dict[str, Any],
|
||||||
|
civitai_data: dict[str, Any] | None,
|
||||||
|
) -> None:
|
||||||
|
"""Save metadata to directory."""
|
||||||
|
output_dir = save_to.resolve()
|
||||||
|
if not output_dir.exists() or not output_dir.is_dir():
|
||||||
|
console.print(f"[red]Error: Invalid directory: {output_dir}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
base_name = get_base_name(file_path)
|
||||||
|
json_path = output_dir / f"{base_name}.json"
|
||||||
|
sha_path = output_dir / f"{base_name}.sha256"
|
||||||
|
|
||||||
|
output = {
|
||||||
|
"file": str(file_path),
|
||||||
|
"sha256": sha256_hash,
|
||||||
|
"header_size": local_metadata["header_size"],
|
||||||
|
"tensor_count": local_metadata["tensor_count"],
|
||||||
|
"metadata": local_metadata["metadata"],
|
||||||
|
"civitai": civitai_data,
|
||||||
|
}
|
||||||
|
json_path.write_text(json.dumps(output, indent=2))
|
||||||
|
sha_path.write_text(f"{sha256_hash} {file_path.name}\n")
|
||||||
|
|
||||||
|
console.print()
|
||||||
|
console.print(f"[green]Saved:[/green] {json_path}")
|
||||||
|
console.print(f"[green]Saved:[/green] {sha_path}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def search(
|
||||||
|
query: Annotated[str | None, typer.Argument(help="Search query (optional)")] = None,
|
||||||
|
model_type: Annotated[ModelType | None, typer.Option("-t", "--type", help="Model type filter")] = None,
|
||||||
|
base: Annotated[BaseModel | None, typer.Option("-b", "--base", help="Base model filter")] = None,
|
||||||
|
sort: Annotated[SortOrder, typer.Option("-s", "--sort", help="Sort order")] = SortOrder.downloads,
|
||||||
|
limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 20,
|
||||||
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
|
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Search CivitAI models."""
|
||||||
|
key = api_key or load_api_key()
|
||||||
|
|
||||||
|
results = search_civitai(
|
||||||
|
query=query,
|
||||||
|
model_type=model_type,
|
||||||
|
base_model=base,
|
||||||
|
sort=sort,
|
||||||
|
limit=limit,
|
||||||
|
api_key=key,
|
||||||
|
console=console,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
console.print("[red]Search failed.[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
console.print_json(data=results)
|
||||||
|
else:
|
||||||
|
display_search_results(results, console)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def get(
|
||||||
|
id_value: Annotated[int, typer.Argument(help="CivitAI model ID or version ID")],
|
||||||
|
version: Annotated[bool, typer.Option("-v", "--version", help="Treat ID as version ID instead of model ID")] = False,
|
||||||
|
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
||||||
|
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
||||||
|
) -> None:
|
||||||
|
"""Fetch model information from CivitAI by model ID or version ID."""
|
||||||
|
key = api_key or load_api_key()
|
||||||
|
|
||||||
|
if version:
|
||||||
|
version_data = fetch_civitai_model_version(id_value, key, console)
|
||||||
|
if not version_data:
|
||||||
|
console.print(f"[red]Error: Version {id_value} not found on CivitAI.[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
console.print_json(data=version_data)
|
||||||
|
else:
|
||||||
|
display_civitai_data(version_data, console)
|
||||||
|
else:
|
||||||
|
model_data = fetch_civitai_model(id_value, key, console)
|
||||||
|
if not model_data:
|
||||||
|
console.print(f"[red]Error: Model {id_value} not found on CivitAI.[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
if json_output:
|
||||||
|
console.print_json(data=model_data)
|
||||||
|
else:
|
||||||
|
display_model_info(model_data, console)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_version_id(
|
||||||
|
version_id: int | None,
|
||||||
|
hash_val: str | None,
|
||||||
|
model_id: int | None,
|
||||||
|
api_key: str | None,
|
||||||
|
) -> int | None:
|
||||||
|
"""Resolve version ID from hash or model ID."""
|
||||||
|
if version_id:
|
||||||
|
return version_id
|
||||||
|
|
||||||
|
if hash_val:
|
||||||
|
console.print(f"[cyan]Looking up model by hash: {hash_val[:16]}...[/cyan]")
|
||||||
|
civitai_data = fetch_civitai_by_hash(hash_val.upper(), api_key, console)
|
||||||
|
if not civitai_data:
|
||||||
|
console.print("[red]Error: Model not found on CivitAI for this hash.[/red]")
|
||||||
|
return None
|
||||||
|
vid: int | None = civitai_data.get("id")
|
||||||
|
if vid:
|
||||||
|
console.print(f"[green]Found:[/green] {civitai_data.get('name', 'N/A')}")
|
||||||
|
return vid
|
||||||
|
|
||||||
|
if model_id:
|
||||||
|
console.print(f"[cyan]Looking up model {model_id}...[/cyan]")
|
||||||
|
model_data = fetch_civitai_model(model_id, api_key, console)
|
||||||
|
if not model_data:
|
||||||
|
console.print(f"[red]Error: Model {model_id} not found.[/red]")
|
||||||
|
return None
|
||||||
|
versions = model_data.get("modelVersions", [])
|
||||||
|
if not versions:
|
||||||
|
console.print("[red]Error: Model has no versions.[/red]")
|
||||||
|
return None
|
||||||
|
latest = versions[0]
|
||||||
|
latest_vid: int | None = latest.get("id")
|
||||||
|
if latest_vid:
|
||||||
|
name = latest.get("name", "N/A")
|
||||||
|
console.print(f"[green]Found latest:[/green] {name} (ID: {latest_vid})")
|
||||||
|
return latest_vid
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_download_dir(output: Path | None, model_type_str: str | None) -> Path | None:
|
||||||
|
"""Prepare output directory for download."""
|
||||||
|
if output is None:
|
||||||
|
output_dir = get_default_output_path(model_type_str)
|
||||||
|
if output_dir is None:
|
||||||
|
console.print(f"[red]Error: No default path for type '{model_type_str}'. Use --output to specify.[/red]")
|
||||||
|
return None
|
||||||
|
console.print(f"[dim]Using default path for {model_type_str}: {output_dir}[/dim]")
|
||||||
|
else:
|
||||||
|
output_dir = output.resolve()
|
||||||
|
|
||||||
|
if not output_dir.exists():
|
||||||
|
console.print(f"[cyan]Creating directory: {output_dir}[/cyan]")
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
elif not output_dir.is_dir():
|
||||||
|
console.print(f"[red]Error: Not a directory: {output_dir}[/red]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return output_dir
|
||||||
|
|
||||||
|
|
||||||
|
@app.command("dl")
|
||||||
|
def download(
|
||||||
|
version_id: Annotated[int | None, typer.Option("-v", "--version-id", help="Model version ID")] = None,
|
||||||
|
model_id: Annotated[int | None, typer.Option("-m", "--model-id", help="Model ID (downloads latest)")] = None,
|
||||||
|
hash_val: Annotated[str | None, typer.Option("-H", "--hash", help="SHA256 hash to look up")] = None,
|
||||||
|
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None,
|
||||||
|
no_resume: Annotated[bool, typer.Option("--no-resume", help="Don't resume partial downloads")] = False,
|
||||||
|
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Download a model from CivitAI."""
|
||||||
|
key = api_key or load_api_key()
|
||||||
|
|
||||||
|
resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key)
|
||||||
|
if not resolved_version_id:
|
||||||
|
if not version_id and not hash_val and not model_id:
|
||||||
|
console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
console.print(f"[cyan]Fetching version info for {resolved_version_id}...[/cyan]")
|
||||||
|
version_info = fetch_civitai_model_version(resolved_version_id, key, console)
|
||||||
|
if not version_info:
|
||||||
|
console.print("[red]Error: Could not fetch model version info.[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
model_type_str: str | None = version_info.get("model", {}).get("type")
|
||||||
|
output_dir = _prepare_download_dir(output, model_type_str)
|
||||||
|
if not output_dir:
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
files: list[dict[str, Any]] = version_info.get("files", [])
|
||||||
|
primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
|
||||||
|
if not primary_file:
|
||||||
|
console.print("[red]Error: No files found for this version.[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
filename = primary_file.get("name", f"model-{resolved_version_id}.safetensors")
|
||||||
|
dest_path = output_dir / filename
|
||||||
|
|
||||||
|
_display_download_info(version_info, filename, primary_file, dest_path)
|
||||||
|
|
||||||
|
success = download_model(resolved_version_id, dest_path, key, console, resume=not no_resume)
|
||||||
|
if not success:
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def _display_download_info(
|
||||||
|
version_info: dict[str, Any],
|
||||||
|
filename: str,
|
||||||
|
primary_file: dict[str, Any],
|
||||||
|
dest_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Display download info table."""
|
||||||
|
table = Table(title="Model Download", show_header=True, header_style="bold magenta")
|
||||||
|
table.add_column("Property", style="cyan")
|
||||||
|
table.add_column("Value", style="green")
|
||||||
|
table.add_row("Version", version_info.get("name", "N/A"))
|
||||||
|
table.add_row("Base Model", version_info.get("baseModel", "N/A"))
|
||||||
|
table.add_row("File", filename)
|
||||||
|
table.add_row("Size", _format_size(primary_file.get("sizeKB", 0)))
|
||||||
|
table.add_row("Destination", str(dest_path))
|
||||||
|
console.print()
|
||||||
|
console.print(table)
|
||||||
|
console.print()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def config(
|
||||||
|
show: Annotated[bool, typer.Option("--show", help="Show current config")] = False,
|
||||||
|
set_key: Annotated[str | None, typer.Option("--set-key", help="Set CivitAI API key")] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Manage configuration."""
|
||||||
|
if set_key:
|
||||||
|
cfg = load_config()
|
||||||
|
if "api" not in cfg:
|
||||||
|
cfg["api"] = {}
|
||||||
|
cfg["api"]["civitai_key"] = set_key
|
||||||
|
save_config(cfg)
|
||||||
|
console.print(f"[green]API key saved to {CONFIG_FILE}[/green]")
|
||||||
|
return
|
||||||
|
|
||||||
|
if show or (not set_key):
|
||||||
|
console.print(f"[bold]Config file:[/bold] {CONFIG_FILE}")
|
||||||
|
console.print(f"[bold]Config exists:[/bold] {CONFIG_FILE.exists()}")
|
||||||
|
|
||||||
|
key = load_api_key()
|
||||||
|
if key:
|
||||||
|
masked = key[:4] + "..." + key[-4:] if len(key) > 8 else "***"
|
||||||
|
console.print(f"[bold]API key:[/bold] {masked}")
|
||||||
|
else:
|
||||||
|
console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]")
|
||||||
|
|
||||||
|
console.print()
|
||||||
|
console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
"""Main entry point."""
|
||||||
|
# Handle legacy invocation: tsr <file.safetensors> -> tsr info <file>
|
||||||
|
if len(sys.argv) > 1 and not sys.argv[1].startswith("-"):
|
||||||
|
arg = sys.argv[1]
|
||||||
|
if arg not in ("info", "search", "get", "dl", "download", "config") and (
|
||||||
|
arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists()
|
||||||
|
):
|
||||||
|
sys.argv = [sys.argv[0], "info", *sys.argv[1:]]
|
||||||
|
|
||||||
|
app()
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
@@ -0,0 +1,166 @@
|
|||||||
|
"""Configuration, constants, and enums for tsr CLI."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tomllib
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# XDG Base Directory Configuration
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# Config: ~/.config/tensors/config.toml
|
||||||
|
# Data: ~/.local/share/tensors/models/, ~/.local/share/tensors/metadata/
|
||||||
|
CONFIG_DIR = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config")) / "tensors"
|
||||||
|
CONFIG_FILE = CONFIG_DIR / "config.toml"
|
||||||
|
|
||||||
|
DATA_DIR = Path(os.environ.get("XDG_DATA_HOME", Path.home() / ".local" / "share")) / "tensors"
|
||||||
|
MODELS_DIR = DATA_DIR / "models"
|
||||||
|
METADATA_DIR = DATA_DIR / "metadata"
|
||||||
|
|
||||||
|
# Legacy config for migration
|
||||||
|
LEGACY_RC_FILE = Path.home() / ".sftrc"
|
||||||
|
|
||||||
|
# Default download paths by model type
|
||||||
|
DEFAULT_PATHS: dict[str, Path] = {
|
||||||
|
"Checkpoint": MODELS_DIR / "checkpoints",
|
||||||
|
"LORA": MODELS_DIR / "loras",
|
||||||
|
"LoCon": MODELS_DIR / "loras",
|
||||||
|
}
|
||||||
|
|
||||||
|
CIVITAI_API_BASE = "https://civitai.com/api/v1"
|
||||||
|
CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Enums for CLI
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ModelType(str, Enum):
|
||||||
|
"""CivitAI model types."""
|
||||||
|
|
||||||
|
checkpoint = "checkpoint"
|
||||||
|
lora = "lora"
|
||||||
|
embedding = "embedding"
|
||||||
|
vae = "vae"
|
||||||
|
controlnet = "controlnet"
|
||||||
|
locon = "locon"
|
||||||
|
|
||||||
|
def to_api(self) -> str:
|
||||||
|
"""Convert to CivitAI API value."""
|
||||||
|
mapping = {
|
||||||
|
"checkpoint": "Checkpoint",
|
||||||
|
"lora": "LORA",
|
||||||
|
"embedding": "TextualInversion",
|
||||||
|
"vae": "VAE",
|
||||||
|
"controlnet": "Controlnet",
|
||||||
|
"locon": "LoCon",
|
||||||
|
}
|
||||||
|
return mapping[self.value]
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModel(str, Enum):
|
||||||
|
"""Common base models."""
|
||||||
|
|
||||||
|
sd15 = "sd15"
|
||||||
|
sdxl = "sdxl"
|
||||||
|
pony = "pony"
|
||||||
|
flux = "flux"
|
||||||
|
illustrious = "illustrious"
|
||||||
|
|
||||||
|
def to_api(self) -> str:
|
||||||
|
"""Convert to CivitAI API value."""
|
||||||
|
mapping = {
|
||||||
|
"sd15": "SD 1.5",
|
||||||
|
"sdxl": "SDXL 1.0",
|
||||||
|
"pony": "Pony",
|
||||||
|
"flux": "Flux.1 D",
|
||||||
|
"illustrious": "Illustrious",
|
||||||
|
}
|
||||||
|
return mapping[self.value]
|
||||||
|
|
||||||
|
|
||||||
|
class SortOrder(str, Enum):
|
||||||
|
"""Sort options for search."""
|
||||||
|
|
||||||
|
downloads = "downloads"
|
||||||
|
rating = "rating"
|
||||||
|
newest = "newest"
|
||||||
|
|
||||||
|
def to_api(self) -> str:
|
||||||
|
"""Convert to CivitAI API value."""
|
||||||
|
mapping = {
|
||||||
|
"downloads": "Most Downloaded",
|
||||||
|
"rating": "Highest Rated",
|
||||||
|
"newest": "Newest",
|
||||||
|
}
|
||||||
|
return mapping[self.value]
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Config Functions
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def load_config() -> dict[str, Any]:
|
||||||
|
"""Load configuration from TOML config file."""
|
||||||
|
if CONFIG_FILE.exists():
|
||||||
|
with CONFIG_FILE.open("rb") as f:
|
||||||
|
return tomllib.load(f)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def save_config(config: dict[str, Any]) -> None:
|
||||||
|
"""Save configuration to TOML config file."""
|
||||||
|
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
lines: list[str] = []
|
||||||
|
for key, value in config.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
lines.append(f"[{key}]")
|
||||||
|
for k, v in value.items():
|
||||||
|
if isinstance(v, str):
|
||||||
|
lines.append(f'{k} = "{v}"')
|
||||||
|
else:
|
||||||
|
lines.append(f"{k} = {v}")
|
||||||
|
lines.append("")
|
||||||
|
elif isinstance(value, str):
|
||||||
|
lines.append(f'{key} = "{value}"')
|
||||||
|
else:
|
||||||
|
lines.append(f"{key} = {value}")
|
||||||
|
|
||||||
|
CONFIG_FILE.write_text("\n".join(lines) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def load_api_key() -> str | None:
|
||||||
|
"""Load API key from config file or CIVITAI_API_KEY env var."""
|
||||||
|
# Check environment variable first
|
||||||
|
env_key = os.environ.get("CIVITAI_API_KEY")
|
||||||
|
if env_key:
|
||||||
|
return env_key
|
||||||
|
|
||||||
|
# Check TOML config file
|
||||||
|
config = load_config()
|
||||||
|
api_section = config.get("api", {})
|
||||||
|
if isinstance(api_section, dict):
|
||||||
|
key = api_section.get("civitai_key")
|
||||||
|
if key:
|
||||||
|
return str(key)
|
||||||
|
|
||||||
|
# Fall back to legacy RC file for migration
|
||||||
|
if LEGACY_RC_FILE.exists():
|
||||||
|
content = LEGACY_RC_FILE.read_text().strip()
|
||||||
|
if content:
|
||||||
|
return content
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_output_path(model_type: str | None) -> Path | None:
|
||||||
|
"""Get default output path based on model type."""
|
||||||
|
if model_type and model_type in DEFAULT_PATHS:
|
||||||
|
return DEFAULT_PATHS[model_type]
|
||||||
|
return None
|
||||||
@@ -0,0 +1,324 @@
|
|||||||
|
"""Rich table display functions for tsr CLI."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from rich.console import Console
|
||||||
|
|
||||||
|
|
||||||
|
def _format_size(size_kb: float) -> str:
|
||||||
|
"""Format size in KB to human-readable string."""
|
||||||
|
if size_kb < 1024:
|
||||||
|
return f"{size_kb:.0f} KB"
|
||||||
|
if size_kb < 1024 * 1024:
|
||||||
|
return f"{size_kb / 1024:.1f} MB"
|
||||||
|
return f"{size_kb / 1024 / 1024:.2f} GB"
|
||||||
|
|
||||||
|
|
||||||
|
def _format_count(count: int) -> str:
|
||||||
|
"""Format large numbers with K/M suffix."""
|
||||||
|
if count < 1000:
|
||||||
|
return str(count)
|
||||||
|
if count < 1_000_000:
|
||||||
|
return f"{count / 1000:.1f}K"
|
||||||
|
return f"{count / 1_000_000:.1f}M"
|
||||||
|
|
||||||
|
|
||||||
|
def display_file_info(file_path: Path, local_metadata: dict[str, Any], sha256_hash: str, console: Console) -> None:
|
||||||
|
"""Display file information table."""
|
||||||
|
prop_width = 12
|
||||||
|
|
||||||
|
file_table = Table(title="File Information", show_header=True, header_style="bold magenta", expand=True)
|
||||||
|
file_table.add_column("Property", style="cyan", width=prop_width, no_wrap=True)
|
||||||
|
file_table.add_column("Value", style="green", no_wrap=True, overflow="ellipsis")
|
||||||
|
|
||||||
|
file_table.add_row("File", str(file_path.name))
|
||||||
|
file_table.add_row("Path", str(file_path.parent))
|
||||||
|
file_table.add_row("Size", f"{file_path.stat().st_size / (1024**3):.2f} GB")
|
||||||
|
file_table.add_row("SHA256", sha256_hash)
|
||||||
|
file_table.add_row("Header Size", f"{local_metadata['header_size']:,} bytes")
|
||||||
|
file_table.add_row("Tensor Count", str(local_metadata["tensor_count"]))
|
||||||
|
|
||||||
|
console.print()
|
||||||
|
console.print(file_table)
|
||||||
|
|
||||||
|
|
||||||
|
def display_local_metadata(local_metadata: dict[str, Any], console: Console, keys_filter: list[str] | None = None) -> None:
|
||||||
|
"""Display local safetensor metadata table."""
|
||||||
|
if not local_metadata["metadata"]:
|
||||||
|
console.print()
|
||||||
|
console.print("[yellow]No embedded metadata found in safetensor file.[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
metadata = local_metadata["metadata"]
|
||||||
|
|
||||||
|
# If specific keys requested, show them in full
|
||||||
|
if keys_filter:
|
||||||
|
for key in keys_filter:
|
||||||
|
if key in metadata:
|
||||||
|
console.print(f"[cyan]{key}[/cyan]: {metadata[key]}")
|
||||||
|
else:
|
||||||
|
console.print(f"[yellow]{key}: not found[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Find the longest key to set column width
|
||||||
|
all_keys = list(metadata.keys())
|
||||||
|
key_width = max(len(k) for k in all_keys) if all_keys else 20
|
||||||
|
|
||||||
|
# Value width: terminal minus key column and table borders (7 chars)
|
||||||
|
terminal_width = console.size.width
|
||||||
|
value_width = terminal_width - key_width - 7
|
||||||
|
|
||||||
|
meta_table = Table(
|
||||||
|
title="Safetensor Metadata",
|
||||||
|
show_header=True,
|
||||||
|
header_style="bold magenta",
|
||||||
|
)
|
||||||
|
meta_table.add_column("Key", style="cyan", width=key_width, no_wrap=True)
|
||||||
|
meta_table.add_column("Value", style="green", width=value_width, no_wrap=True, overflow="ellipsis")
|
||||||
|
|
||||||
|
for key, value in sorted(metadata.items()):
|
||||||
|
meta_table.add_row(key, str(value))
|
||||||
|
|
||||||
|
console.print()
|
||||||
|
console.print(meta_table)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_civitai_table(console: Console) -> tuple[Table, int]:
|
||||||
|
"""Build CivitAI info table with proper column widths."""
|
||||||
|
prop_width = 14
|
||||||
|
terminal_width = console.size.width
|
||||||
|
overhead = 7
|
||||||
|
value_width = max(40, terminal_width - prop_width - overhead)
|
||||||
|
|
||||||
|
table = Table(title="CivitAI Model Information", show_header=True, header_style="bold magenta")
|
||||||
|
table.add_column("Property", style="cyan", width=prop_width, no_wrap=True)
|
||||||
|
table.add_column("Value", style="green", width=value_width, no_wrap=True, overflow="ellipsis")
|
||||||
|
|
||||||
|
return table, value_width
|
||||||
|
|
||||||
|
|
||||||
|
def display_civitai_data(civitai_data: dict[str, Any] | None, console: Console) -> None:
|
||||||
|
"""Display CivitAI model information table."""
|
||||||
|
if not civitai_data:
|
||||||
|
console.print()
|
||||||
|
console.print("[yellow]Model not found on CivitAI.[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
civit_table, _ = _build_civitai_table(console)
|
||||||
|
|
||||||
|
civit_table.add_row("Model ID", str(civitai_data.get("modelId", "N/A")))
|
||||||
|
civit_table.add_row("Version ID", str(civitai_data.get("id", "N/A")))
|
||||||
|
civit_table.add_row("Version Name", str(civitai_data.get("name", "N/A")))
|
||||||
|
civit_table.add_row("Base Model", str(civitai_data.get("baseModel", "N/A")))
|
||||||
|
civit_table.add_row("Created At", str(civitai_data.get("createdAt", "N/A")))
|
||||||
|
|
||||||
|
trained_words: list[str] = civitai_data.get("trainedWords", [])
|
||||||
|
if trained_words:
|
||||||
|
civit_table.add_row("Trigger Words", ", ".join(trained_words))
|
||||||
|
|
||||||
|
download_url = str(civitai_data.get("downloadUrl", "N/A"))
|
||||||
|
civit_table.add_row("Download URL", download_url)
|
||||||
|
|
||||||
|
files: list[dict[str, Any]] = civitai_data.get("files", [])
|
||||||
|
for f in files:
|
||||||
|
if f.get("primary"):
|
||||||
|
civit_table.add_row("Primary File", str(f.get("name", "N/A")))
|
||||||
|
civit_table.add_row("File Size", _format_size(f.get("sizeKB", 0)))
|
||||||
|
meta: dict[str, Any] = f.get("metadata", {})
|
||||||
|
if meta:
|
||||||
|
civit_table.add_row("Format", str(meta.get("format", "N/A")))
|
||||||
|
civit_table.add_row("Precision", str(meta.get("fp", "N/A")))
|
||||||
|
civit_table.add_row("Size Type", str(meta.get("size", "N/A")))
|
||||||
|
|
||||||
|
console.print()
|
||||||
|
console.print(civit_table)
|
||||||
|
|
||||||
|
model_id = civitai_data.get("modelId")
|
||||||
|
if model_id:
|
||||||
|
console.print()
|
||||||
|
console.print(f"[bold blue]View on CivitAI:[/bold blue] https://civitai.com/models/{model_id}")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_model_table(console: Console) -> Table:
|
||||||
|
"""Build model info table with proper column widths."""
|
||||||
|
prop_width = 10
|
||||||
|
terminal_width = console.size.width
|
||||||
|
overhead = 7
|
||||||
|
value_width = max(40, terminal_width - prop_width - overhead)
|
||||||
|
|
||||||
|
table = Table(title="Model Information", show_header=True, header_style="bold magenta")
|
||||||
|
table.add_column("Property", style="cyan", width=prop_width, no_wrap=True)
|
||||||
|
table.add_column("Value", style="green", width=value_width, no_wrap=True, overflow="ellipsis")
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
def _add_model_basic_info(table: Table, model_data: dict[str, Any]) -> None:
|
||||||
|
"""Add basic model info rows to table."""
|
||||||
|
table.add_row("ID", str(model_data.get("id", "N/A")))
|
||||||
|
table.add_row("Name", str(model_data.get("name", "N/A")))
|
||||||
|
table.add_row("Type", str(model_data.get("type", "N/A")))
|
||||||
|
table.add_row("NSFW", str(model_data.get("nsfw", False)))
|
||||||
|
|
||||||
|
creator = model_data.get("creator", {})
|
||||||
|
if creator:
|
||||||
|
table.add_row("Creator", str(creator.get("username", "N/A")))
|
||||||
|
|
||||||
|
tags: list[str] = model_data.get("tags", [])
|
||||||
|
if tags:
|
||||||
|
table.add_row("Tags", ", ".join(tags[:10]) + ("..." if len(tags) > 10 else ""))
|
||||||
|
|
||||||
|
stats: dict[str, Any] = model_data.get("stats", {})
|
||||||
|
if stats:
|
||||||
|
table.add_row("Downloads", f"{stats.get('downloadCount', 0):,}")
|
||||||
|
table.add_row("Likes", f"{stats.get('thumbsUpCount', 0):,}")
|
||||||
|
|
||||||
|
mode = model_data.get("mode")
|
||||||
|
if mode:
|
||||||
|
table.add_row("Status", str(mode))
|
||||||
|
|
||||||
|
|
||||||
|
def _build_versions_table(console: Console) -> Table:
|
||||||
|
"""Build model versions table with proper column widths."""
|
||||||
|
id_width = 7
|
||||||
|
base_width = 20
|
||||||
|
created_width = 10
|
||||||
|
size_width = 8
|
||||||
|
|
||||||
|
terminal_width = console.size.width
|
||||||
|
fixed_width = id_width + base_width + created_width + size_width
|
||||||
|
overhead = 20
|
||||||
|
remaining = max(40, terminal_width - fixed_width - overhead)
|
||||||
|
name_width = remaining // 3
|
||||||
|
file_width = remaining - name_width
|
||||||
|
|
||||||
|
table = Table(title="Model Versions", show_header=True, header_style="bold magenta")
|
||||||
|
table.add_column("ID", style="cyan", width=id_width, no_wrap=True)
|
||||||
|
table.add_column("Name", style="green", width=name_width, no_wrap=True, overflow="ellipsis")
|
||||||
|
table.add_column("Base Model", style="yellow", width=base_width, no_wrap=True, overflow="ellipsis")
|
||||||
|
table.add_column("Created", style="blue", width=created_width, no_wrap=True)
|
||||||
|
table.add_column("Filename", style="white", width=file_width, no_wrap=True, overflow="ellipsis")
|
||||||
|
table.add_column("Size", justify="right", width=size_width, no_wrap=True)
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
def _add_version_rows(table: Table, versions: list[dict[str, Any]]) -> None:
|
||||||
|
"""Add version rows to versions table."""
|
||||||
|
for ver in versions:
|
||||||
|
files: list[dict[str, Any]] = ver.get("files", [])
|
||||||
|
primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
|
||||||
|
filename = "N/A"
|
||||||
|
size = "N/A"
|
||||||
|
if primary_file:
|
||||||
|
filename = primary_file.get("name", "N/A")
|
||||||
|
size = _format_size(primary_file.get("sizeKB", 0))
|
||||||
|
|
||||||
|
created = str(ver.get("createdAt", "N/A"))[:10]
|
||||||
|
table.add_row(
|
||||||
|
str(ver.get("id", "N/A")),
|
||||||
|
str(ver.get("name", "N/A")),
|
||||||
|
str(ver.get("baseModel", "N/A")),
|
||||||
|
created,
|
||||||
|
filename,
|
||||||
|
size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def display_model_info(model_data: dict[str, Any], console: Console) -> None:
|
||||||
|
"""Display full CivitAI model information."""
|
||||||
|
model_table = _build_model_table(console)
|
||||||
|
_add_model_basic_info(model_table, model_data)
|
||||||
|
|
||||||
|
console.print()
|
||||||
|
console.print(model_table)
|
||||||
|
|
||||||
|
versions: list[dict[str, Any]] = model_data.get("modelVersions", [])
|
||||||
|
if versions:
|
||||||
|
ver_table = _build_versions_table(console)
|
||||||
|
_add_version_rows(ver_table, versions)
|
||||||
|
console.print()
|
||||||
|
console.print(ver_table)
|
||||||
|
|
||||||
|
model_id = model_data.get("id")
|
||||||
|
if model_id:
|
||||||
|
console.print()
|
||||||
|
console.print(f"[bold blue]View on CivitAI:[/bold blue] https://civitai.com/models/{model_id}")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_search_table(console: Console) -> Table:
|
||||||
|
"""Build search results table with proper column widths."""
|
||||||
|
id_width = 7
|
||||||
|
type_width = 16
|
||||||
|
base_width = 20
|
||||||
|
size_width = 8
|
||||||
|
dls_width = 6
|
||||||
|
likes_width = 6
|
||||||
|
|
||||||
|
terminal_width = console.size.width
|
||||||
|
fixed_width = id_width + type_width + base_width + size_width + dls_width + likes_width
|
||||||
|
overhead = 23
|
||||||
|
name_width = max(20, terminal_width - fixed_width - overhead)
|
||||||
|
|
||||||
|
table = Table(show_header=True, header_style="bold magenta")
|
||||||
|
table.add_column("ID", style="cyan", justify="right", width=id_width, no_wrap=True)
|
||||||
|
table.add_column("Name", style="green", width=name_width, no_wrap=True, overflow="ellipsis")
|
||||||
|
table.add_column("Type", style="yellow", width=type_width, no_wrap=True)
|
||||||
|
table.add_column("Base", style="blue", width=base_width, no_wrap=True, overflow="ellipsis")
|
||||||
|
table.add_column("Size", justify="right", width=size_width, no_wrap=True)
|
||||||
|
table.add_column("DLs", justify="right", width=dls_width, no_wrap=True)
|
||||||
|
table.add_column("Likes", justify="right", width=likes_width, no_wrap=True)
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
def _add_search_rows(table: Table, items: list[dict[str, Any]]) -> None:
|
||||||
|
"""Add search result rows to table."""
|
||||||
|
for model in items:
|
||||||
|
model_id = str(model.get("id", ""))
|
||||||
|
name = model.get("name", "N/A")
|
||||||
|
model_type = model.get("type", "N/A")
|
||||||
|
|
||||||
|
versions = model.get("modelVersions", [])
|
||||||
|
base_model = "N/A"
|
||||||
|
size = "N/A"
|
||||||
|
if versions:
|
||||||
|
latest = versions[0]
|
||||||
|
base_model = latest.get("baseModel", "N/A")
|
||||||
|
files = latest.get("files", [])
|
||||||
|
primary = next((f for f in files if f.get("primary")), files[0] if files else None)
|
||||||
|
if primary:
|
||||||
|
size = _format_size(primary.get("sizeKB", 0))
|
||||||
|
|
||||||
|
stats = model.get("stats", {})
|
||||||
|
downloads = _format_count(stats.get("downloadCount", 0))
|
||||||
|
likes = _format_count(stats.get("thumbsUpCount", 0))
|
||||||
|
|
||||||
|
table.add_row(model_id, name, model_type, base_model, size, downloads, likes)
|
||||||
|
|
||||||
|
|
||||||
|
def display_search_results(results: dict[str, Any], console: Console) -> None:
|
||||||
|
"""Display search results in a table."""
|
||||||
|
items = results.get("items", [])
|
||||||
|
if not items:
|
||||||
|
console.print("[yellow]No results found.[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
table = _build_search_table(console)
|
||||||
|
_add_search_rows(table, items)
|
||||||
|
|
||||||
|
console.print()
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
metadata = results.get("metadata", {})
|
||||||
|
total = metadata.get("totalItems", len(items))
|
||||||
|
console.print(f"\n[dim]Showing {len(items)} of {total:,} results[/dim]")
|
||||||
|
console.print("[dim]Use 'tsr get <id>' to view details or 'tsr dl -m <id>' to download[/dim]")
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
"""Safetensor file reading functions."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import struct
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from rich.progress import (
|
||||||
|
BarColumn,
|
||||||
|
DownloadColumn,
|
||||||
|
Progress,
|
||||||
|
SpinnerColumn,
|
||||||
|
TaskProgressColumn,
|
||||||
|
TextColumn,
|
||||||
|
TimeRemainingColumn,
|
||||||
|
TransferSpeedColumn,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from rich.console import Console
|
||||||
|
|
||||||
|
|
||||||
|
def read_safetensor_metadata(file_path: Path) -> dict[str, Any]:
|
||||||
|
"""Read metadata from a safetensor file header."""
|
||||||
|
with file_path.open("rb") as f:
|
||||||
|
# First 8 bytes are the header size (little-endian u64)
|
||||||
|
header_size_bytes = f.read(8)
|
||||||
|
if len(header_size_bytes) < 8:
|
||||||
|
raise ValueError("Invalid safetensor file: too short")
|
||||||
|
|
||||||
|
header_size = struct.unpack("<Q", header_size_bytes)[0]
|
||||||
|
|
||||||
|
if header_size > 100_000_000: # 100MB sanity check
|
||||||
|
raise ValueError(f"Invalid header size: {header_size}")
|
||||||
|
|
||||||
|
header_bytes = f.read(header_size)
|
||||||
|
if len(header_bytes) < header_size:
|
||||||
|
raise ValueError("Invalid safetensor file: header truncated")
|
||||||
|
|
||||||
|
header: dict[str, Any] = json.loads(header_bytes.decode("utf-8"))
|
||||||
|
|
||||||
|
# Extract __metadata__ if present
|
||||||
|
metadata: dict[str, Any] = header.get("__metadata__", {})
|
||||||
|
|
||||||
|
# Count tensors (keys that aren't __metadata__)
|
||||||
|
tensor_count = sum(1 for k in header if k != "__metadata__")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"metadata": metadata,
|
||||||
|
"tensor_count": tensor_count,
|
||||||
|
"header_size": header_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def compute_sha256(file_path: Path, console: Console) -> str:
|
||||||
|
"""Compute SHA256 hash of a file with progress display."""
|
||||||
|
file_size = file_path.stat().st_size
|
||||||
|
sha256 = hashlib.sha256()
|
||||||
|
chunk_size = 1024 * 1024 * 8 # 8MB chunks
|
||||||
|
|
||||||
|
with Progress(
|
||||||
|
SpinnerColumn(),
|
||||||
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
TaskProgressColumn(),
|
||||||
|
DownloadColumn(),
|
||||||
|
TransferSpeedColumn(),
|
||||||
|
TimeRemainingColumn(),
|
||||||
|
console=console,
|
||||||
|
) as progress:
|
||||||
|
task = progress.add_task(f"[cyan]Hashing {file_path.name}...", total=file_size)
|
||||||
|
|
||||||
|
with file_path.open("rb") as f:
|
||||||
|
while chunk := f.read(chunk_size):
|
||||||
|
sha256.update(chunk)
|
||||||
|
progress.update(task, advance=len(chunk))
|
||||||
|
|
||||||
|
return sha256.hexdigest().upper()
|
||||||
|
|
||||||
|
|
||||||
|
def get_base_name(file_path: Path) -> str:
|
||||||
|
"""Get base filename without .safetensors extension."""
|
||||||
|
name = file_path.name
|
||||||
|
for ext in (".safetensors", ".sft"):
|
||||||
|
if name.lower().endswith(ext):
|
||||||
|
return name[: -len(ext)]
|
||||||
|
return file_path.stem
|
||||||
+11
-19
@@ -7,13 +7,9 @@ from pathlib import Path
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import tensors
|
from tensors import config
|
||||||
from tensors import (
|
from tensors.config import get_default_output_path, load_api_key
|
||||||
get_base_name,
|
from tensors.safetensor import get_base_name, read_safetensor_metadata
|
||||||
get_default_output_path,
|
|
||||||
load_api_key,
|
|
||||||
read_safetensor_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestReadSafetensorMetadata:
|
class TestReadSafetensorMetadata:
|
||||||
@@ -111,28 +107,24 @@ class TestLoadApiKey:
|
|||||||
"""Test that None is returned when no key is available."""
|
"""Test that None is returned when no key is available."""
|
||||||
monkeypatch.delenv("CIVITAI_API_KEY", raising=False)
|
monkeypatch.delenv("CIVITAI_API_KEY", raising=False)
|
||||||
# Point config and legacy files to nonexistent paths
|
# Point config and legacy files to nonexistent paths
|
||||||
monkeypatch.setattr(tensors, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml")
|
monkeypatch.setattr(config, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml")
|
||||||
monkeypatch.setattr(tensors, "LEGACY_RC_FILE", tmp_path / "nonexistent")
|
monkeypatch.setattr(config, "LEGACY_RC_FILE", tmp_path / "nonexistent")
|
||||||
assert load_api_key() is None
|
assert load_api_key() is None
|
||||||
|
|
||||||
def test_returns_key_from_config_file(
|
def test_returns_key_from_config_file(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||||
self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""Test that key is loaded from TOML config file."""
|
"""Test that key is loaded from TOML config file."""
|
||||||
monkeypatch.delenv("CIVITAI_API_KEY", raising=False)
|
monkeypatch.delenv("CIVITAI_API_KEY", raising=False)
|
||||||
config_file = tmp_path / "config.toml"
|
config_file = tmp_path / "config.toml"
|
||||||
config_file.write_text('[api]\ncivitai_key = "key-from-config"\n')
|
config_file.write_text('[api]\ncivitai_key = "key-from-config"\n')
|
||||||
monkeypatch.setattr(tensors, "CONFIG_FILE", config_file)
|
monkeypatch.setattr(config, "CONFIG_FILE", config_file)
|
||||||
monkeypatch.setattr(tensors, "LEGACY_RC_FILE", tmp_path / "nonexistent")
|
monkeypatch.setattr(config, "LEGACY_RC_FILE", tmp_path / "nonexistent")
|
||||||
assert load_api_key() == "key-from-config"
|
assert load_api_key() == "key-from-config"
|
||||||
|
|
||||||
def test_returns_key_from_legacy_file(
|
def test_returns_key_from_legacy_file(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||||
self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""Test that key is loaded from legacy RC file when no config exists."""
|
"""Test that key is loaded from legacy RC file when no config exists."""
|
||||||
monkeypatch.delenv("CIVITAI_API_KEY", raising=False)
|
monkeypatch.delenv("CIVITAI_API_KEY", raising=False)
|
||||||
legacy_file = tmp_path / ".sftrc"
|
legacy_file = tmp_path / ".sftrc"
|
||||||
legacy_file.write_text("legacy-key")
|
legacy_file.write_text("legacy-key")
|
||||||
monkeypatch.setattr(tensors, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml")
|
monkeypatch.setattr(config, "CONFIG_FILE", tmp_path / "nonexistent" / "config.toml")
|
||||||
monkeypatch.setattr(tensors, "LEGACY_RC_FILE", legacy_file)
|
monkeypatch.setattr(config, "LEGACY_RC_FILE", legacy_file)
|
||||||
assert load_api_key() == "legacy-key"
|
assert load_api_key() == "legacy-key"
|
||||||
|
|||||||
Reference in New Issue
Block a user