fix: populate model cache tables after download so db list resolves names
The CLI download flow only set civitai_model_id/version_id on local_files without caching the full model payload, so 'tsr db list' joined against empty models/versions/creators tables and showed every linked file as 'unlinked'. The server's _auto_link_file path had additional bugs: resolved-vs-unresolved path comparison after rescan, redundant CivitAI hash lookup, and silent failure swallowed by 'completed' status. - New Database.register_downloaded_file() consolidates hashing, metadata storage, FK linking, and cache_model() into a single idempotent call shared by both CLI and server paths. - Server _do_download now passes version_info straight through and surfaces db_file_id/db_linked/db_cached/db_error onto _active_downloads. - Drops the broken _auto_link_file rescan helper.
This commit is contained in:
+12
-39
@@ -604,56 +604,29 @@ def download(
|
||||
|
||||
|
||||
def _add_downloaded_file_to_db(dest_path: Path, version_info: dict[str, Any]) -> None:
|
||||
"""Add a downloaded file to the database and link to CivitAI.
|
||||
"""Add a downloaded file to the database, link to CivitAI, and cache full model data.
|
||||
|
||||
Args:
|
||||
dest_path: Path to the downloaded file
|
||||
version_info: CivitAI version info response
|
||||
"""
|
||||
try:
|
||||
console.print("[dim]Adding to database...[/dim]")
|
||||
|
||||
# Compute SHA256 hash
|
||||
sha256 = compute_sha256(dest_path, console)
|
||||
|
||||
# Read safetensor metadata
|
||||
metadata = read_safetensor_metadata(dest_path)
|
||||
|
||||
# Extract CivitAI IDs
|
||||
civitai_version_id = version_info.get("id")
|
||||
civitai_model_id = version_info.get("modelId") or version_info.get("model", {}).get("id")
|
||||
|
||||
api_key = load_api_key()
|
||||
with Database() as db:
|
||||
db.init_schema()
|
||||
with db.session() as session:
|
||||
# Add local file record
|
||||
local_file = db._upsert_local_file(
|
||||
session,
|
||||
file_path=str(dest_path.resolve()),
|
||||
sha256=sha256,
|
||||
header_size=metadata.get("header_size"),
|
||||
tensor_count=metadata.get("tensor_count"),
|
||||
)
|
||||
result = db.register_downloaded_file(dest_path, version_info, api_key=api_key, console=console)
|
||||
|
||||
# Store safetensor metadata
|
||||
db._store_safetensor_metadata(session, local_file.id, metadata.get("metadata", {}))
|
||||
if result["error"]:
|
||||
console.print(f"[yellow]Warning: Could not add to database: {result['error']}[/yellow]")
|
||||
return
|
||||
|
||||
# Link to CivitAI if we have the IDs
|
||||
if civitai_model_id and civitai_version_id:
|
||||
local_file.civitai_model_id = civitai_model_id
|
||||
local_file.civitai_version_id = civitai_version_id
|
||||
session.add(local_file)
|
||||
|
||||
session.commit()
|
||||
file_id = local_file.id
|
||||
|
||||
# Report success
|
||||
console.print(f"[green]Added to database (id={file_id})[/green]")
|
||||
if civitai_model_id and civitai_version_id:
|
||||
console.print(f"[green]Added to database (id={result['file_id']})[/green]")
|
||||
if result["linked"]:
|
||||
civitai_version_id = version_info.get("id")
|
||||
civitai_model_id = version_info.get("modelId") or version_info.get("model", {}).get("id")
|
||||
console.print(f"[green]Linked to CivitAI model={civitai_model_id} version={civitai_version_id}[/green]")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[yellow]Warning: Could not add to database: {e}[/yellow]")
|
||||
if result["cached"]:
|
||||
console.print("[green]Cached model metadata[/green]")
|
||||
|
||||
|
||||
def _display_download_info(
|
||||
|
||||
@@ -257,6 +257,71 @@ class Database:
|
||||
session.add(f)
|
||||
session.commit()
|
||||
|
||||
def register_downloaded_file(
|
||||
self,
|
||||
dest_path: Path,
|
||||
version_info: dict[str, Any],
|
||||
api_key: str | None = None,
|
||||
console: Console | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Register a freshly-downloaded file: hash, store metadata, link, and cache full model.
|
||||
|
||||
Idempotent and shared by the CLI download flow and the FastAPI background worker so
|
||||
both paths produce identical DB state (local_file row + cached models/versions/tags
|
||||
so ``db list`` can resolve names, triggers, and base_model).
|
||||
|
||||
Args:
|
||||
dest_path: Path to the downloaded safetensor file.
|
||||
version_info: CivitAI ``model-versions/{id}`` response (already fetched).
|
||||
api_key: Optional CivitAI API key for the model fetch.
|
||||
console: Optional Rich console for hash progress output.
|
||||
|
||||
Returns:
|
||||
``{"file_id": int, "sha256": str, "linked": bool, "cached": bool, "error": str | None}``
|
||||
"""
|
||||
# Lazy import to avoid pulling httpx into modules that only need DB ops
|
||||
from tensors.api import fetch_civitai_model # noqa: PLC0415
|
||||
|
||||
result: dict[str, Any] = {"file_id": None, "sha256": None, "linked": False, "cached": False, "error": None}
|
||||
try:
|
||||
sha256 = compute_sha256(dest_path, console)
|
||||
metadata = read_safetensor_metadata(dest_path)
|
||||
|
||||
civitai_version_id = version_info.get("id")
|
||||
civitai_model_id = version_info.get("modelId") or version_info.get("model", {}).get("id")
|
||||
|
||||
with self.session() as session:
|
||||
local_file = self._upsert_local_file(
|
||||
session,
|
||||
file_path=str(dest_path.resolve()),
|
||||
sha256=sha256,
|
||||
header_size=metadata.get("header_size"),
|
||||
tensor_count=metadata.get("tensor_count"),
|
||||
)
|
||||
self._store_safetensor_metadata(session, local_file.id, metadata.get("metadata", {}))
|
||||
|
||||
if civitai_model_id and civitai_version_id:
|
||||
local_file.civitai_model_id = civitai_model_id
|
||||
local_file.civitai_version_id = civitai_version_id
|
||||
session.add(local_file)
|
||||
result["linked"] = True
|
||||
|
||||
session.commit()
|
||||
result["file_id"] = local_file.id
|
||||
result["sha256"] = sha256
|
||||
|
||||
# Cache full model metadata so db list can resolve names/triggers/base_model.
|
||||
# The version endpoint payload is too sparse for cache_model() (no creator, tags,
|
||||
# or full modelVersions list), so we fetch the model endpoint here.
|
||||
if civitai_model_id:
|
||||
model_data = fetch_civitai_model(civitai_model_id, api_key, console)
|
||||
if model_data:
|
||||
self.cache_model(model_data)
|
||||
result["cached"] = True
|
||||
except Exception as e: # surface any failure to caller without crashing the download
|
||||
result["error"] = str(e)
|
||||
return result
|
||||
|
||||
# =========================================================================
|
||||
# CivitAI Cache Operations
|
||||
# =========================================================================
|
||||
|
||||
@@ -106,6 +106,7 @@ def _do_download(
|
||||
dest_path: Path,
|
||||
api_key: str | None,
|
||||
download_id: str,
|
||||
version_info: dict[str, Any],
|
||||
) -> None:
|
||||
"""Background task to perform the download."""
|
||||
try:
|
||||
@@ -133,8 +134,18 @@ def _do_download(
|
||||
_active_downloads[download_id]["progress"] = 100
|
||||
_active_downloads[download_id]["path"] = str(dest_path)
|
||||
|
||||
# Auto-scan and link the downloaded file
|
||||
_auto_link_file(dest_path, api_key)
|
||||
# Register the file in DB: hash, link to CivitAI IDs from version_info,
|
||||
# and cache full model metadata so /api/db/* endpoints return resolved data.
|
||||
with Database() as db:
|
||||
db.init_schema()
|
||||
db_result = db.register_downloaded_file(dest_path, version_info, api_key=api_key)
|
||||
|
||||
_active_downloads[download_id]["db_file_id"] = db_result["file_id"]
|
||||
_active_downloads[download_id]["db_linked"] = db_result["linked"]
|
||||
_active_downloads[download_id]["db_cached"] = db_result["cached"]
|
||||
if db_result["error"]:
|
||||
_active_downloads[download_id]["db_error"] = db_result["error"]
|
||||
logger.error("DB register failed for %s: %s", dest_path, db_result["error"])
|
||||
else:
|
||||
_active_downloads[download_id]["status"] = "failed"
|
||||
_active_downloads[download_id]["error"] = "Download failed"
|
||||
@@ -145,28 +156,6 @@ def _do_download(
|
||||
_active_downloads[download_id]["error"] = str(e)
|
||||
|
||||
|
||||
def _auto_link_file(file_path: Path, api_key: str | None) -> None:
|
||||
"""Auto-scan and link the downloaded file to CivitAI."""
|
||||
try:
|
||||
with Database() as db:
|
||||
db.init_schema()
|
||||
# Scan the single file
|
||||
results = db.scan_directory(file_path.parent)
|
||||
|
||||
# Find and link the new file
|
||||
for result in results:
|
||||
if result["file_path"] == str(file_path):
|
||||
sha256 = result["sha256"]
|
||||
civitai_data = fetch_civitai_by_hash(sha256, api_key)
|
||||
if civitai_data:
|
||||
version_id = civitai_data.get("id", 0)
|
||||
model_id = civitai_data.get("modelId", 0)
|
||||
if version_id and model_id:
|
||||
db.link_file_to_civitai(result["id"], model_id, version_id)
|
||||
except Exception:
|
||||
logger.exception("Auto-link failed")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Download Endpoints
|
||||
# =============================================================================
|
||||
@@ -215,7 +204,7 @@ def start_download(req: DownloadRequest, background_tasks: BackgroundTasks) -> d
|
||||
}
|
||||
|
||||
# Start background download
|
||||
background_tasks.add_task(_do_download, version_id, dest_path, api_key, download_id)
|
||||
background_tasks.add_task(_do_download, version_id, dest_path, api_key, download_id, version_info)
|
||||
|
||||
return {
|
||||
"download_id": download_id,
|
||||
|
||||
+58
-55
@@ -1216,6 +1216,30 @@ class TestGalleryRoutesExtended:
|
||||
class TestDownloadBackgroundTasks:
|
||||
"""Tests for download background task functions."""
|
||||
|
||||
@staticmethod
|
||||
def _patch_db_noop(monkeypatch, download_routes_module) -> dict:
|
||||
"""Replace Database with a no-op stub; return a dict capturing register calls."""
|
||||
captured: dict = {}
|
||||
|
||||
class StubDB:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
return False
|
||||
|
||||
def init_schema(self):
|
||||
pass
|
||||
|
||||
def register_downloaded_file(self, dest_path, version_info, api_key=None, console=None):
|
||||
captured["dest_path"] = dest_path
|
||||
captured["version_info"] = version_info
|
||||
captured["api_key"] = api_key
|
||||
return {"file_id": 42, "sha256": "deadbeef", "linked": True, "cached": True, "error": None}
|
||||
|
||||
monkeypatch.setattr(download_routes_module, "Database", lambda: StubDB())
|
||||
return captured
|
||||
|
||||
def test_do_download_success(self, monkeypatch, tmp_path) -> None:
|
||||
"""Test successful download task execution."""
|
||||
from tensors.server import download_routes
|
||||
@@ -1233,13 +1257,18 @@ class TestDownloadBackgroundTasks:
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download)
|
||||
monkeypatch.setattr(download_routes, "_auto_link_file", lambda *args: None)
|
||||
captured = self._patch_db_noop(monkeypatch, download_routes)
|
||||
|
||||
dest_path = tmp_path / "model.safetensors"
|
||||
_do_download(12345, dest_path, None, download_id)
|
||||
version_info = {"id": 999, "modelId": 888, "name": "v1"}
|
||||
_do_download(12345, dest_path, None, download_id, version_info)
|
||||
|
||||
assert download_routes._active_downloads[download_id]["status"] == "completed"
|
||||
assert download_routes._active_downloads[download_id]["progress"] == 100
|
||||
assert download_routes._active_downloads[download_id]["db_file_id"] == 42
|
||||
assert download_routes._active_downloads[download_id]["db_linked"] is True
|
||||
assert download_routes._active_downloads[download_id]["db_cached"] is True
|
||||
assert captured["version_info"] == version_info
|
||||
|
||||
# Cleanup
|
||||
del download_routes._active_downloads[download_id]
|
||||
@@ -1256,7 +1285,7 @@ class TestDownloadBackgroundTasks:
|
||||
monkeypatch.setattr(download_routes, "download_model_with_progress", lambda *args, **kwargs: False)
|
||||
|
||||
dest_path = tmp_path / "model.safetensors"
|
||||
_do_download(12345, dest_path, None, download_id)
|
||||
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
|
||||
|
||||
assert download_routes._active_downloads[download_id]["status"] == "failed"
|
||||
assert "error" in download_routes._active_downloads[download_id]
|
||||
@@ -1278,7 +1307,7 @@ class TestDownloadBackgroundTasks:
|
||||
monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download)
|
||||
|
||||
dest_path = tmp_path / "model.safetensors"
|
||||
_do_download(12345, dest_path, None, download_id)
|
||||
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
|
||||
|
||||
assert download_routes._active_downloads[download_id]["status"] == "failed"
|
||||
assert "Network error" in download_routes._active_downloads[download_id]["error"]
|
||||
@@ -1309,10 +1338,10 @@ class TestDownloadBackgroundTasks:
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download)
|
||||
monkeypatch.setattr(download_routes, "_auto_link_file", lambda *args: None)
|
||||
self._patch_db_noop(monkeypatch, download_routes)
|
||||
|
||||
dest_path = tmp_path / "model.safetensors"
|
||||
_do_download(12345, dest_path, None, download_id)
|
||||
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
|
||||
|
||||
# Check progress formatting was called
|
||||
assert len(progress_calls) == 3
|
||||
@@ -1320,7 +1349,6 @@ class TestDownloadBackgroundTasks:
|
||||
assert progress_calls[0]["total_str"] == "1.0 KB"
|
||||
assert progress_calls[1]["downloaded_str"] == "500.0 KB"
|
||||
assert progress_calls[2]["downloaded_str"] == "500.0 MB"
|
||||
assert progress_calls[2]["total_str"] == "1.0 GB"
|
||||
|
||||
del download_routes._active_downloads[download_id]
|
||||
|
||||
@@ -1337,74 +1365,49 @@ class TestDownloadBackgroundTasks:
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download)
|
||||
monkeypatch.setattr(download_routes, "_auto_link_file", lambda *args: None)
|
||||
self._patch_db_noop(monkeypatch, download_routes)
|
||||
|
||||
dest_path = tmp_path / "model.safetensors"
|
||||
_do_download(12345, dest_path, None, download_id)
|
||||
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
|
||||
|
||||
assert download_routes._active_downloads[download_id]["total_str"] == "Unknown"
|
||||
|
||||
del download_routes._active_downloads[download_id]
|
||||
|
||||
|
||||
class TestAutoLinkFile:
|
||||
"""Tests for _auto_link_file function."""
|
||||
|
||||
def test_auto_link_success(self, monkeypatch, tmp_path, temp_db) -> None:
|
||||
"""Test auto-linking a downloaded file."""
|
||||
def test_do_download_db_error_surfaced(self, monkeypatch, tmp_path) -> None:
|
||||
"""DB errors during register must be surfaced into _active_downloads."""
|
||||
from tensors.server import download_routes
|
||||
from tensors.server.download_routes import _auto_link_file
|
||||
from tensors.server.download_routes import _do_download
|
||||
|
||||
# Create a fake safetensor file
|
||||
file_path = tmp_path / "test.safetensors"
|
||||
file_path.write_bytes(b"fake safetensor data")
|
||||
download_id = "test_db_err"
|
||||
download_routes._active_downloads[download_id] = {"id": download_id, "status": "queued"}
|
||||
|
||||
from tensors import db as db_module
|
||||
monkeypatch.setattr(download_routes, "download_model_with_progress", lambda *a, **kw: True)
|
||||
|
||||
monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path)
|
||||
|
||||
# Mock scan results
|
||||
scanned_files = []
|
||||
|
||||
def mock_scan(directory):
|
||||
return [{"id": 1, "file_path": str(file_path), "sha256": "abc123"}]
|
||||
|
||||
monkeypatch.setattr(temp_db, "scan_directory", mock_scan)
|
||||
|
||||
# Mock CivitAI lookup
|
||||
def mock_fetch_by_hash(sha256, api_key):
|
||||
return {"id": 999, "modelId": 888}
|
||||
|
||||
monkeypatch.setattr(download_routes, "fetch_civitai_by_hash", mock_fetch_by_hash)
|
||||
|
||||
# Mock Database context manager to use temp_db
|
||||
class MockDB:
|
||||
class FailingDB:
|
||||
def __enter__(self):
|
||||
return temp_db
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
return False
|
||||
|
||||
def init_schema(self):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(download_routes, "Database", MockDB)
|
||||
def register_downloaded_file(self, *args, **kwargs):
|
||||
return {"file_id": None, "sha256": None, "linked": False, "cached": False, "error": "boom"}
|
||||
|
||||
# This should not raise
|
||||
_auto_link_file(file_path, None)
|
||||
monkeypatch.setattr(download_routes, "Database", lambda: FailingDB())
|
||||
|
||||
def test_auto_link_exception_handled(self, monkeypatch, tmp_path) -> None:
|
||||
"""Test auto-link handles exceptions gracefully."""
|
||||
from tensors.server.download_routes import _auto_link_file
|
||||
dest_path = tmp_path / "model.safetensors"
|
||||
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
|
||||
|
||||
# Mock Database to raise exception
|
||||
def mock_db(*args, **kwargs):
|
||||
raise RuntimeError("DB error")
|
||||
# Download itself succeeded; DB layer is reported separately
|
||||
assert download_routes._active_downloads[download_id]["status"] == "completed"
|
||||
assert download_routes._active_downloads[download_id]["db_error"] == "boom"
|
||||
assert download_routes._active_downloads[download_id]["db_linked"] is False
|
||||
|
||||
from tensors.server import download_routes
|
||||
|
||||
monkeypatch.setattr(download_routes, "Database", mock_db)
|
||||
|
||||
file_path = tmp_path / "test.safetensors"
|
||||
# Should not raise
|
||||
_auto_link_file(file_path, None)
|
||||
del download_routes._active_downloads[download_id]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
Reference in New Issue
Block a user