fix: populate model cache tables after download so db list resolves names

The CLI download flow only set civitai_model_id/version_id on local_files
without caching the full model payload, so 'tsr db list' joined against
empty models/versions/creators tables and showed every linked file as
'unlinked'. The server's _auto_link_file path had additional bugs:
resolved-vs-unresolved path comparison after rescan, redundant CivitAI
hash lookup, and silent failure swallowed by 'completed' status.

- New Database.register_downloaded_file() consolidates hashing, metadata
  storage, FK linking, and cache_model() into a single idempotent call
  shared by both CLI and server paths.
- Server _do_download now passes version_info straight through and
  surfaces db_file_id/db_linked/db_cached/db_error onto _active_downloads.
- Drops the broken _auto_link_file rescan helper.
This commit is contained in:
2026-05-16 00:44:12 +02:00
parent ec080803fc
commit b731a88beb
4 changed files with 151 additions and 121 deletions
+14 -41
View File
@@ -604,56 +604,29 @@ def download(
def _add_downloaded_file_to_db(dest_path: Path, version_info: dict[str, Any]) -> None:
"""Add a downloaded file to the database and link to CivitAI.
"""Add a downloaded file to the database, link to CivitAI, and cache full model data.
Args:
dest_path: Path to the downloaded file
version_info: CivitAI version info response
"""
try:
console.print("[dim]Adding to database...[/dim]")
console.print("[dim]Adding to database...[/dim]")
api_key = load_api_key()
with Database() as db:
db.init_schema()
result = db.register_downloaded_file(dest_path, version_info, api_key=api_key, console=console)
# Compute SHA256 hash
sha256 = compute_sha256(dest_path, console)
if result["error"]:
console.print(f"[yellow]Warning: Could not add to database: {result['error']}[/yellow]")
return
# Read safetensor metadata
metadata = read_safetensor_metadata(dest_path)
# Extract CivitAI IDs
console.print(f"[green]Added to database (id={result['file_id']})[/green]")
if result["linked"]:
civitai_version_id = version_info.get("id")
civitai_model_id = version_info.get("modelId") or version_info.get("model", {}).get("id")
with Database() as db:
db.init_schema()
with db.session() as session:
# Add local file record
local_file = db._upsert_local_file(
session,
file_path=str(dest_path.resolve()),
sha256=sha256,
header_size=metadata.get("header_size"),
tensor_count=metadata.get("tensor_count"),
)
# Store safetensor metadata
db._store_safetensor_metadata(session, local_file.id, metadata.get("metadata", {}))
# Link to CivitAI if we have the IDs
if civitai_model_id and civitai_version_id:
local_file.civitai_model_id = civitai_model_id
local_file.civitai_version_id = civitai_version_id
session.add(local_file)
session.commit()
file_id = local_file.id
# Report success
console.print(f"[green]Added to database (id={file_id})[/green]")
if civitai_model_id and civitai_version_id:
console.print(f"[green]Linked to CivitAI model={civitai_model_id} version={civitai_version_id}[/green]")
except Exception as e:
console.print(f"[yellow]Warning: Could not add to database: {e}[/yellow]")
console.print(f"[green]Linked to CivitAI model={civitai_model_id} version={civitai_version_id}[/green]")
if result["cached"]:
console.print("[green]Cached model metadata[/green]")
def _display_download_info(
+65
View File
@@ -257,6 +257,71 @@ class Database:
session.add(f)
session.commit()
def register_downloaded_file(
self,
dest_path: Path,
version_info: dict[str, Any],
api_key: str | None = None,
console: Console | None = None,
) -> dict[str, Any]:
"""Register a freshly-downloaded file: hash, store metadata, link, and cache full model.
Idempotent and shared by the CLI download flow and the FastAPI background worker so
both paths produce identical DB state (local_file row + cached models/versions/tags
so ``db list`` can resolve names, triggers, and base_model).
Args:
dest_path: Path to the downloaded safetensor file.
version_info: CivitAI ``model-versions/{id}`` response (already fetched).
api_key: Optional CivitAI API key for the model fetch.
console: Optional Rich console for hash progress output.
Returns:
``{"file_id": int, "sha256": str, "linked": bool, "cached": bool, "error": str | None}``
"""
# Lazy import to avoid pulling httpx into modules that only need DB ops
from tensors.api import fetch_civitai_model # noqa: PLC0415
result: dict[str, Any] = {"file_id": None, "sha256": None, "linked": False, "cached": False, "error": None}
try:
sha256 = compute_sha256(dest_path, console)
metadata = read_safetensor_metadata(dest_path)
civitai_version_id = version_info.get("id")
civitai_model_id = version_info.get("modelId") or version_info.get("model", {}).get("id")
with self.session() as session:
local_file = self._upsert_local_file(
session,
file_path=str(dest_path.resolve()),
sha256=sha256,
header_size=metadata.get("header_size"),
tensor_count=metadata.get("tensor_count"),
)
self._store_safetensor_metadata(session, local_file.id, metadata.get("metadata", {}))
if civitai_model_id and civitai_version_id:
local_file.civitai_model_id = civitai_model_id
local_file.civitai_version_id = civitai_version_id
session.add(local_file)
result["linked"] = True
session.commit()
result["file_id"] = local_file.id
result["sha256"] = sha256
# Cache full model metadata so db list can resolve names/triggers/base_model.
# The version endpoint payload is too sparse for cache_model() (no creator, tags,
# or full modelVersions list), so we fetch the model endpoint here.
if civitai_model_id:
model_data = fetch_civitai_model(civitai_model_id, api_key, console)
if model_data:
self.cache_model(model_data)
result["cached"] = True
except Exception as e: # surface any failure to caller without crashing the download
result["error"] = str(e)
return result
# =========================================================================
# CivitAI Cache Operations
# =========================================================================
+14 -25
View File
@@ -106,6 +106,7 @@ def _do_download(
dest_path: Path,
api_key: str | None,
download_id: str,
version_info: dict[str, Any],
) -> None:
"""Background task to perform the download."""
try:
@@ -133,8 +134,18 @@ def _do_download(
_active_downloads[download_id]["progress"] = 100
_active_downloads[download_id]["path"] = str(dest_path)
# Auto-scan and link the downloaded file
_auto_link_file(dest_path, api_key)
# Register the file in DB: hash, link to CivitAI IDs from version_info,
# and cache full model metadata so /api/db/* endpoints return resolved data.
with Database() as db:
db.init_schema()
db_result = db.register_downloaded_file(dest_path, version_info, api_key=api_key)
_active_downloads[download_id]["db_file_id"] = db_result["file_id"]
_active_downloads[download_id]["db_linked"] = db_result["linked"]
_active_downloads[download_id]["db_cached"] = db_result["cached"]
if db_result["error"]:
_active_downloads[download_id]["db_error"] = db_result["error"]
logger.error("DB register failed for %s: %s", dest_path, db_result["error"])
else:
_active_downloads[download_id]["status"] = "failed"
_active_downloads[download_id]["error"] = "Download failed"
@@ -145,28 +156,6 @@ def _do_download(
_active_downloads[download_id]["error"] = str(e)
def _auto_link_file(file_path: Path, api_key: str | None) -> None:
"""Auto-scan and link the downloaded file to CivitAI."""
try:
with Database() as db:
db.init_schema()
# Scan the single file
results = db.scan_directory(file_path.parent)
# Find and link the new file
for result in results:
if result["file_path"] == str(file_path):
sha256 = result["sha256"]
civitai_data = fetch_civitai_by_hash(sha256, api_key)
if civitai_data:
version_id = civitai_data.get("id", 0)
model_id = civitai_data.get("modelId", 0)
if version_id and model_id:
db.link_file_to_civitai(result["id"], model_id, version_id)
except Exception:
logger.exception("Auto-link failed")
# =============================================================================
# Download Endpoints
# =============================================================================
@@ -215,7 +204,7 @@ def start_download(req: DownloadRequest, background_tasks: BackgroundTasks) -> d
}
# Start background download
background_tasks.add_task(_do_download, version_id, dest_path, api_key, download_id)
background_tasks.add_task(_do_download, version_id, dest_path, api_key, download_id, version_info)
return {
"download_id": download_id,
+58 -55
View File
@@ -1216,6 +1216,30 @@ class TestGalleryRoutesExtended:
class TestDownloadBackgroundTasks:
"""Tests for download background task functions."""
@staticmethod
def _patch_db_noop(monkeypatch, download_routes_module) -> dict:
"""Replace Database with a no-op stub; return a dict capturing register calls."""
captured: dict = {}
class StubDB:
def __enter__(self):
return self
def __exit__(self, *args):
return False
def init_schema(self):
pass
def register_downloaded_file(self, dest_path, version_info, api_key=None, console=None):
captured["dest_path"] = dest_path
captured["version_info"] = version_info
captured["api_key"] = api_key
return {"file_id": 42, "sha256": "deadbeef", "linked": True, "cached": True, "error": None}
monkeypatch.setattr(download_routes_module, "Database", lambda: StubDB())
return captured
def test_do_download_success(self, monkeypatch, tmp_path) -> None:
"""Test successful download task execution."""
from tensors.server import download_routes
@@ -1233,13 +1257,18 @@ class TestDownloadBackgroundTasks:
return True
monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download)
monkeypatch.setattr(download_routes, "_auto_link_file", lambda *args: None)
captured = self._patch_db_noop(monkeypatch, download_routes)
dest_path = tmp_path / "model.safetensors"
_do_download(12345, dest_path, None, download_id)
version_info = {"id": 999, "modelId": 888, "name": "v1"}
_do_download(12345, dest_path, None, download_id, version_info)
assert download_routes._active_downloads[download_id]["status"] == "completed"
assert download_routes._active_downloads[download_id]["progress"] == 100
assert download_routes._active_downloads[download_id]["db_file_id"] == 42
assert download_routes._active_downloads[download_id]["db_linked"] is True
assert download_routes._active_downloads[download_id]["db_cached"] is True
assert captured["version_info"] == version_info
# Cleanup
del download_routes._active_downloads[download_id]
@@ -1256,7 +1285,7 @@ class TestDownloadBackgroundTasks:
monkeypatch.setattr(download_routes, "download_model_with_progress", lambda *args, **kwargs: False)
dest_path = tmp_path / "model.safetensors"
_do_download(12345, dest_path, None, download_id)
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
assert download_routes._active_downloads[download_id]["status"] == "failed"
assert "error" in download_routes._active_downloads[download_id]
@@ -1278,7 +1307,7 @@ class TestDownloadBackgroundTasks:
monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download)
dest_path = tmp_path / "model.safetensors"
_do_download(12345, dest_path, None, download_id)
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
assert download_routes._active_downloads[download_id]["status"] == "failed"
assert "Network error" in download_routes._active_downloads[download_id]["error"]
@@ -1309,10 +1338,10 @@ class TestDownloadBackgroundTasks:
return True
monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download)
monkeypatch.setattr(download_routes, "_auto_link_file", lambda *args: None)
self._patch_db_noop(monkeypatch, download_routes)
dest_path = tmp_path / "model.safetensors"
_do_download(12345, dest_path, None, download_id)
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
# Check progress formatting was called
assert len(progress_calls) == 3
@@ -1320,7 +1349,6 @@ class TestDownloadBackgroundTasks:
assert progress_calls[0]["total_str"] == "1.0 KB"
assert progress_calls[1]["downloaded_str"] == "500.0 KB"
assert progress_calls[2]["downloaded_str"] == "500.0 MB"
assert progress_calls[2]["total_str"] == "1.0 GB"
del download_routes._active_downloads[download_id]
@@ -1337,74 +1365,49 @@ class TestDownloadBackgroundTasks:
return True
monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download)
monkeypatch.setattr(download_routes, "_auto_link_file", lambda *args: None)
self._patch_db_noop(monkeypatch, download_routes)
dest_path = tmp_path / "model.safetensors"
_do_download(12345, dest_path, None, download_id)
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
assert download_routes._active_downloads[download_id]["total_str"] == "Unknown"
del download_routes._active_downloads[download_id]
class TestAutoLinkFile:
"""Tests for _auto_link_file function."""
def test_auto_link_success(self, monkeypatch, tmp_path, temp_db) -> None:
"""Test auto-linking a downloaded file."""
def test_do_download_db_error_surfaced(self, monkeypatch, tmp_path) -> None:
"""DB errors during register must be surfaced into _active_downloads."""
from tensors.server import download_routes
from tensors.server.download_routes import _auto_link_file
from tensors.server.download_routes import _do_download
# Create a fake safetensor file
file_path = tmp_path / "test.safetensors"
file_path.write_bytes(b"fake safetensor data")
download_id = "test_db_err"
download_routes._active_downloads[download_id] = {"id": download_id, "status": "queued"}
from tensors import db as db_module
monkeypatch.setattr(download_routes, "download_model_with_progress", lambda *a, **kw: True)
monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path)
# Mock scan results
scanned_files = []
def mock_scan(directory):
return [{"id": 1, "file_path": str(file_path), "sha256": "abc123"}]
monkeypatch.setattr(temp_db, "scan_directory", mock_scan)
# Mock CivitAI lookup
def mock_fetch_by_hash(sha256, api_key):
return {"id": 999, "modelId": 888}
monkeypatch.setattr(download_routes, "fetch_civitai_by_hash", mock_fetch_by_hash)
# Mock Database context manager to use temp_db
class MockDB:
class FailingDB:
def __enter__(self):
return temp_db
return self
def __exit__(self, *args):
return False
def init_schema(self):
pass
monkeypatch.setattr(download_routes, "Database", MockDB)
def register_downloaded_file(self, *args, **kwargs):
return {"file_id": None, "sha256": None, "linked": False, "cached": False, "error": "boom"}
# This should not raise
_auto_link_file(file_path, None)
monkeypatch.setattr(download_routes, "Database", lambda: FailingDB())
def test_auto_link_exception_handled(self, monkeypatch, tmp_path) -> None:
"""Test auto-link handles exceptions gracefully."""
from tensors.server.download_routes import _auto_link_file
dest_path = tmp_path / "model.safetensors"
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
# Mock Database to raise exception
def mock_db(*args, **kwargs):
raise RuntimeError("DB error")
# Download itself succeeded; DB layer is reported separately
assert download_routes._active_downloads[download_id]["status"] == "completed"
assert download_routes._active_downloads[download_id]["db_error"] == "boom"
assert download_routes._active_downloads[download_id]["db_linked"] is False
from tensors.server import download_routes
monkeypatch.setattr(download_routes, "Database", mock_db)
file_path = tmp_path / "test.safetensors"
# Should not raise
_auto_link_file(file_path, None)
del download_routes._active_downloads[download_id]
# =============================================================================