fix: populate model cache tables after download so db list resolves names
The CLI download flow only set civitai_model_id/version_id on local_files without caching the full model payload, so 'tsr db list' joined against empty models/versions/creators tables and showed every linked file as 'unlinked'. The server's _auto_link_file path had additional bugs: resolved-vs-unresolved path comparison after rescan, redundant CivitAI hash lookup, and silent failure swallowed by 'completed' status. - New Database.register_downloaded_file() consolidates hashing, metadata storage, FK linking, and cache_model() into a single idempotent call shared by both CLI and server paths. - Server _do_download now passes version_info straight through and surfaces db_file_id/db_linked/db_cached/db_error onto _active_downloads. - Drops the broken _auto_link_file rescan helper.
This commit is contained in:
+58
-55
@@ -1216,6 +1216,30 @@ class TestGalleryRoutesExtended:
|
||||
class TestDownloadBackgroundTasks:
|
||||
"""Tests for download background task functions."""
|
||||
|
||||
@staticmethod
|
||||
def _patch_db_noop(monkeypatch, download_routes_module) -> dict:
|
||||
"""Replace Database with a no-op stub; return a dict capturing register calls."""
|
||||
captured: dict = {}
|
||||
|
||||
class StubDB:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
return False
|
||||
|
||||
def init_schema(self):
|
||||
pass
|
||||
|
||||
def register_downloaded_file(self, dest_path, version_info, api_key=None, console=None):
|
||||
captured["dest_path"] = dest_path
|
||||
captured["version_info"] = version_info
|
||||
captured["api_key"] = api_key
|
||||
return {"file_id": 42, "sha256": "deadbeef", "linked": True, "cached": True, "error": None}
|
||||
|
||||
monkeypatch.setattr(download_routes_module, "Database", lambda: StubDB())
|
||||
return captured
|
||||
|
||||
def test_do_download_success(self, monkeypatch, tmp_path) -> None:
|
||||
"""Test successful download task execution."""
|
||||
from tensors.server import download_routes
|
||||
@@ -1233,13 +1257,18 @@ class TestDownloadBackgroundTasks:
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download)
|
||||
monkeypatch.setattr(download_routes, "_auto_link_file", lambda *args: None)
|
||||
captured = self._patch_db_noop(monkeypatch, download_routes)
|
||||
|
||||
dest_path = tmp_path / "model.safetensors"
|
||||
_do_download(12345, dest_path, None, download_id)
|
||||
version_info = {"id": 999, "modelId": 888, "name": "v1"}
|
||||
_do_download(12345, dest_path, None, download_id, version_info)
|
||||
|
||||
assert download_routes._active_downloads[download_id]["status"] == "completed"
|
||||
assert download_routes._active_downloads[download_id]["progress"] == 100
|
||||
assert download_routes._active_downloads[download_id]["db_file_id"] == 42
|
||||
assert download_routes._active_downloads[download_id]["db_linked"] is True
|
||||
assert download_routes._active_downloads[download_id]["db_cached"] is True
|
||||
assert captured["version_info"] == version_info
|
||||
|
||||
# Cleanup
|
||||
del download_routes._active_downloads[download_id]
|
||||
@@ -1256,7 +1285,7 @@ class TestDownloadBackgroundTasks:
|
||||
monkeypatch.setattr(download_routes, "download_model_with_progress", lambda *args, **kwargs: False)
|
||||
|
||||
dest_path = tmp_path / "model.safetensors"
|
||||
_do_download(12345, dest_path, None, download_id)
|
||||
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
|
||||
|
||||
assert download_routes._active_downloads[download_id]["status"] == "failed"
|
||||
assert "error" in download_routes._active_downloads[download_id]
|
||||
@@ -1278,7 +1307,7 @@ class TestDownloadBackgroundTasks:
|
||||
monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download)
|
||||
|
||||
dest_path = tmp_path / "model.safetensors"
|
||||
_do_download(12345, dest_path, None, download_id)
|
||||
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
|
||||
|
||||
assert download_routes._active_downloads[download_id]["status"] == "failed"
|
||||
assert "Network error" in download_routes._active_downloads[download_id]["error"]
|
||||
@@ -1309,10 +1338,10 @@ class TestDownloadBackgroundTasks:
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download)
|
||||
monkeypatch.setattr(download_routes, "_auto_link_file", lambda *args: None)
|
||||
self._patch_db_noop(monkeypatch, download_routes)
|
||||
|
||||
dest_path = tmp_path / "model.safetensors"
|
||||
_do_download(12345, dest_path, None, download_id)
|
||||
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
|
||||
|
||||
# Check progress formatting was called
|
||||
assert len(progress_calls) == 3
|
||||
@@ -1320,7 +1349,6 @@ class TestDownloadBackgroundTasks:
|
||||
assert progress_calls[0]["total_str"] == "1.0 KB"
|
||||
assert progress_calls[1]["downloaded_str"] == "500.0 KB"
|
||||
assert progress_calls[2]["downloaded_str"] == "500.0 MB"
|
||||
assert progress_calls[2]["total_str"] == "1.0 GB"
|
||||
|
||||
del download_routes._active_downloads[download_id]
|
||||
|
||||
@@ -1337,74 +1365,49 @@ class TestDownloadBackgroundTasks:
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download)
|
||||
monkeypatch.setattr(download_routes, "_auto_link_file", lambda *args: None)
|
||||
self._patch_db_noop(monkeypatch, download_routes)
|
||||
|
||||
dest_path = tmp_path / "model.safetensors"
|
||||
_do_download(12345, dest_path, None, download_id)
|
||||
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
|
||||
|
||||
assert download_routes._active_downloads[download_id]["total_str"] == "Unknown"
|
||||
|
||||
del download_routes._active_downloads[download_id]
|
||||
|
||||
|
||||
class TestAutoLinkFile:
|
||||
"""Tests for _auto_link_file function."""
|
||||
|
||||
def test_auto_link_success(self, monkeypatch, tmp_path, temp_db) -> None:
|
||||
"""Test auto-linking a downloaded file."""
|
||||
def test_do_download_db_error_surfaced(self, monkeypatch, tmp_path) -> None:
|
||||
"""DB errors during register must be surfaced into _active_downloads."""
|
||||
from tensors.server import download_routes
|
||||
from tensors.server.download_routes import _auto_link_file
|
||||
from tensors.server.download_routes import _do_download
|
||||
|
||||
# Create a fake safetensor file
|
||||
file_path = tmp_path / "test.safetensors"
|
||||
file_path.write_bytes(b"fake safetensor data")
|
||||
download_id = "test_db_err"
|
||||
download_routes._active_downloads[download_id] = {"id": download_id, "status": "queued"}
|
||||
|
||||
from tensors import db as db_module
|
||||
monkeypatch.setattr(download_routes, "download_model_with_progress", lambda *a, **kw: True)
|
||||
|
||||
monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path)
|
||||
|
||||
# Mock scan results
|
||||
scanned_files = []
|
||||
|
||||
def mock_scan(directory):
|
||||
return [{"id": 1, "file_path": str(file_path), "sha256": "abc123"}]
|
||||
|
||||
monkeypatch.setattr(temp_db, "scan_directory", mock_scan)
|
||||
|
||||
# Mock CivitAI lookup
|
||||
def mock_fetch_by_hash(sha256, api_key):
|
||||
return {"id": 999, "modelId": 888}
|
||||
|
||||
monkeypatch.setattr(download_routes, "fetch_civitai_by_hash", mock_fetch_by_hash)
|
||||
|
||||
# Mock Database context manager to use temp_db
|
||||
class MockDB:
|
||||
class FailingDB:
|
||||
def __enter__(self):
|
||||
return temp_db
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
return False
|
||||
|
||||
def init_schema(self):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(download_routes, "Database", MockDB)
|
||||
def register_downloaded_file(self, *args, **kwargs):
|
||||
return {"file_id": None, "sha256": None, "linked": False, "cached": False, "error": "boom"}
|
||||
|
||||
# This should not raise
|
||||
_auto_link_file(file_path, None)
|
||||
monkeypatch.setattr(download_routes, "Database", lambda: FailingDB())
|
||||
|
||||
def test_auto_link_exception_handled(self, monkeypatch, tmp_path) -> None:
|
||||
"""Test auto-link handles exceptions gracefully."""
|
||||
from tensors.server.download_routes import _auto_link_file
|
||||
dest_path = tmp_path / "model.safetensors"
|
||||
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
|
||||
|
||||
# Mock Database to raise exception
|
||||
def mock_db(*args, **kwargs):
|
||||
raise RuntimeError("DB error")
|
||||
# Download itself succeeded; DB layer is reported separately
|
||||
assert download_routes._active_downloads[download_id]["status"] == "completed"
|
||||
assert download_routes._active_downloads[download_id]["db_error"] == "boom"
|
||||
assert download_routes._active_downloads[download_id]["db_linked"] is False
|
||||
|
||||
from tensors.server import download_routes
|
||||
|
||||
monkeypatch.setattr(download_routes, "Database", mock_db)
|
||||
|
||||
file_path = tmp_path / "test.safetensors"
|
||||
# Should not raise
|
||||
_auto_link_file(file_path, None)
|
||||
del download_routes._active_downloads[download_id]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
Reference in New Issue
Block a user