Files
tensors/tensors/safetensor.py
T
Adam Ladachowski c7c5a4a995 Phase 2.1: Add SQLite database module for models metadata
Create database infrastructure for local model file tracking and CivitAI cache:
- schema.sql: Full schema with local_files, CivitAI cache tables, and views
- db.py: Database class with CRUD operations for file scanning, CivitAI linking,
  model caching, search, and trigger word retrieval
- Update compute_sha256 to support optional console for silent batch operations

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-14 01:31:32 +01:00

105 lines
3.1 KiB
Python

"""Safetensor file reading functions."""
from __future__ import annotations
import hashlib
import json
import struct
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from pathlib import Path
from rich.progress import (
BarColumn,
DownloadColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeRemainingColumn,
TransferSpeedColumn,
)
if TYPE_CHECKING:
from rich.console import Console
# Safetensor format constants
HEADER_SIZE_BYTES = 8 # u64 little-endian
MAX_HEADER_SIZE = 100_000_000 # 100MB sanity check
def read_safetensor_metadata(file_path: Path) -> dict[str, Any]:
"""Read metadata from a safetensor file header."""
with file_path.open("rb") as f:
header_size_bytes = f.read(HEADER_SIZE_BYTES)
if len(header_size_bytes) < HEADER_SIZE_BYTES:
raise ValueError("Invalid safetensor file: too short")
header_size = struct.unpack("<Q", header_size_bytes)[0]
if header_size > MAX_HEADER_SIZE:
raise ValueError(f"Invalid header size: {header_size}")
header_bytes = f.read(header_size)
if len(header_bytes) < header_size:
raise ValueError("Invalid safetensor file: header truncated")
header: dict[str, Any] = json.loads(header_bytes.decode("utf-8"))
# Extract __metadata__ if present
metadata: dict[str, Any] = header.get("__metadata__", {})
# Count tensors (keys that aren't __metadata__)
tensor_count = sum(1 for k in header if k != "__metadata__")
return {
"metadata": metadata,
"tensor_count": tensor_count,
"header_size": header_size,
}
def compute_sha256(file_path: Path, console: Console | None = None) -> str:
"""Compute SHA256 hash of a file with optional progress display.
If console is provided, shows a progress bar. Otherwise computes silently.
"""
file_size = file_path.stat().st_size
sha256 = hashlib.sha256()
chunk_size = 1024 * 1024 * 8 # 8MB chunks
if console is not None:
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
DownloadColumn(),
TransferSpeedColumn(),
TimeRemainingColumn(),
console=console,
) as progress:
task = progress.add_task(f"[cyan]Hashing {file_path.name}...", total=file_size)
with file_path.open("rb") as f:
while chunk := f.read(chunk_size):
sha256.update(chunk)
progress.update(task, advance=len(chunk))
else:
# Silent mode - no progress display
with file_path.open("rb") as f:
while chunk := f.read(chunk_size):
sha256.update(chunk)
return sha256.hexdigest().upper()
def get_base_name(file_path: Path) -> str:
"""Get base filename without .safetensors extension."""
name = file_path.name
for ext in (".safetensors", ".sft"):
if name.lower().endswith(ext):
return name[: -len(ext)]
return file_path.stem