fix: populate model cache tables after download so db list resolves names
The CLI download flow only set civitai_model_id/version_id on local_files without caching the full model payload, so 'tsr db list' joined against empty models/versions/creators tables and showed every linked file as 'unlinked'. The server's _auto_link_file path had additional bugs: resolved-vs-unresolved path comparison after rescan, redundant CivitAI hash lookup, and silent failure swallowed by 'completed' status. - New Database.register_downloaded_file() consolidates hashing, metadata storage, FK linking, and cache_model() into a single idempotent call shared by both CLI and server paths. - Server _do_download now passes version_info straight through and surfaces db_file_id/db_linked/db_cached/db_error onto _active_downloads. - Drops the broken _auto_link_file rescan helper.
This commit is contained in:
+12
-39
@@ -604,56 +604,29 @@ def download(
|
|||||||
|
|
||||||
|
|
||||||
def _add_downloaded_file_to_db(dest_path: Path, version_info: dict[str, Any]) -> None:
|
def _add_downloaded_file_to_db(dest_path: Path, version_info: dict[str, Any]) -> None:
|
||||||
"""Add a downloaded file to the database and link to CivitAI.
|
"""Add a downloaded file to the database, link to CivitAI, and cache full model data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dest_path: Path to the downloaded file
|
dest_path: Path to the downloaded file
|
||||||
version_info: CivitAI version info response
|
version_info: CivitAI version info response
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
console.print("[dim]Adding to database...[/dim]")
|
console.print("[dim]Adding to database...[/dim]")
|
||||||
|
api_key = load_api_key()
|
||||||
# Compute SHA256 hash
|
|
||||||
sha256 = compute_sha256(dest_path, console)
|
|
||||||
|
|
||||||
# Read safetensor metadata
|
|
||||||
metadata = read_safetensor_metadata(dest_path)
|
|
||||||
|
|
||||||
# Extract CivitAI IDs
|
|
||||||
civitai_version_id = version_info.get("id")
|
|
||||||
civitai_model_id = version_info.get("modelId") or version_info.get("model", {}).get("id")
|
|
||||||
|
|
||||||
with Database() as db:
|
with Database() as db:
|
||||||
db.init_schema()
|
db.init_schema()
|
||||||
with db.session() as session:
|
result = db.register_downloaded_file(dest_path, version_info, api_key=api_key, console=console)
|
||||||
# Add local file record
|
|
||||||
local_file = db._upsert_local_file(
|
|
||||||
session,
|
|
||||||
file_path=str(dest_path.resolve()),
|
|
||||||
sha256=sha256,
|
|
||||||
header_size=metadata.get("header_size"),
|
|
||||||
tensor_count=metadata.get("tensor_count"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store safetensor metadata
|
if result["error"]:
|
||||||
db._store_safetensor_metadata(session, local_file.id, metadata.get("metadata", {}))
|
console.print(f"[yellow]Warning: Could not add to database: {result['error']}[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
# Link to CivitAI if we have the IDs
|
console.print(f"[green]Added to database (id={result['file_id']})[/green]")
|
||||||
if civitai_model_id and civitai_version_id:
|
if result["linked"]:
|
||||||
local_file.civitai_model_id = civitai_model_id
|
civitai_version_id = version_info.get("id")
|
||||||
local_file.civitai_version_id = civitai_version_id
|
civitai_model_id = version_info.get("modelId") or version_info.get("model", {}).get("id")
|
||||||
session.add(local_file)
|
|
||||||
|
|
||||||
session.commit()
|
|
||||||
file_id = local_file.id
|
|
||||||
|
|
||||||
# Report success
|
|
||||||
console.print(f"[green]Added to database (id={file_id})[/green]")
|
|
||||||
if civitai_model_id and civitai_version_id:
|
|
||||||
console.print(f"[green]Linked to CivitAI model={civitai_model_id} version={civitai_version_id}[/green]")
|
console.print(f"[green]Linked to CivitAI model={civitai_model_id} version={civitai_version_id}[/green]")
|
||||||
|
if result["cached"]:
|
||||||
except Exception as e:
|
console.print("[green]Cached model metadata[/green]")
|
||||||
console.print(f"[yellow]Warning: Could not add to database: {e}[/yellow]")
|
|
||||||
|
|
||||||
|
|
||||||
def _display_download_info(
|
def _display_download_info(
|
||||||
|
|||||||
@@ -257,6 +257,71 @@ class Database:
|
|||||||
session.add(f)
|
session.add(f)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
def register_downloaded_file(
|
||||||
|
self,
|
||||||
|
dest_path: Path,
|
||||||
|
version_info: dict[str, Any],
|
||||||
|
api_key: str | None = None,
|
||||||
|
console: Console | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Register a freshly-downloaded file: hash, store metadata, link, and cache full model.
|
||||||
|
|
||||||
|
Idempotent and shared by the CLI download flow and the FastAPI background worker so
|
||||||
|
both paths produce identical DB state (local_file row + cached models/versions/tags
|
||||||
|
so ``db list`` can resolve names, triggers, and base_model).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dest_path: Path to the downloaded safetensor file.
|
||||||
|
version_info: CivitAI ``model-versions/{id}`` response (already fetched).
|
||||||
|
api_key: Optional CivitAI API key for the model fetch.
|
||||||
|
console: Optional Rich console for hash progress output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``{"file_id": int, "sha256": str, "linked": bool, "cached": bool, "error": str | None}``
|
||||||
|
"""
|
||||||
|
# Lazy import to avoid pulling httpx into modules that only need DB ops
|
||||||
|
from tensors.api import fetch_civitai_model # noqa: PLC0415
|
||||||
|
|
||||||
|
result: dict[str, Any] = {"file_id": None, "sha256": None, "linked": False, "cached": False, "error": None}
|
||||||
|
try:
|
||||||
|
sha256 = compute_sha256(dest_path, console)
|
||||||
|
metadata = read_safetensor_metadata(dest_path)
|
||||||
|
|
||||||
|
civitai_version_id = version_info.get("id")
|
||||||
|
civitai_model_id = version_info.get("modelId") or version_info.get("model", {}).get("id")
|
||||||
|
|
||||||
|
with self.session() as session:
|
||||||
|
local_file = self._upsert_local_file(
|
||||||
|
session,
|
||||||
|
file_path=str(dest_path.resolve()),
|
||||||
|
sha256=sha256,
|
||||||
|
header_size=metadata.get("header_size"),
|
||||||
|
tensor_count=metadata.get("tensor_count"),
|
||||||
|
)
|
||||||
|
self._store_safetensor_metadata(session, local_file.id, metadata.get("metadata", {}))
|
||||||
|
|
||||||
|
if civitai_model_id and civitai_version_id:
|
||||||
|
local_file.civitai_model_id = civitai_model_id
|
||||||
|
local_file.civitai_version_id = civitai_version_id
|
||||||
|
session.add(local_file)
|
||||||
|
result["linked"] = True
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
result["file_id"] = local_file.id
|
||||||
|
result["sha256"] = sha256
|
||||||
|
|
||||||
|
# Cache full model metadata so db list can resolve names/triggers/base_model.
|
||||||
|
# The version endpoint payload is too sparse for cache_model() (no creator, tags,
|
||||||
|
# or full modelVersions list), so we fetch the model endpoint here.
|
||||||
|
if civitai_model_id:
|
||||||
|
model_data = fetch_civitai_model(civitai_model_id, api_key, console)
|
||||||
|
if model_data:
|
||||||
|
self.cache_model(model_data)
|
||||||
|
result["cached"] = True
|
||||||
|
except Exception as e: # surface any failure to caller without crashing the download
|
||||||
|
result["error"] = str(e)
|
||||||
|
return result
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# CivitAI Cache Operations
|
# CivitAI Cache Operations
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
|
|||||||
@@ -106,6 +106,7 @@ def _do_download(
|
|||||||
dest_path: Path,
|
dest_path: Path,
|
||||||
api_key: str | None,
|
api_key: str | None,
|
||||||
download_id: str,
|
download_id: str,
|
||||||
|
version_info: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Background task to perform the download."""
|
"""Background task to perform the download."""
|
||||||
try:
|
try:
|
||||||
@@ -133,8 +134,18 @@ def _do_download(
|
|||||||
_active_downloads[download_id]["progress"] = 100
|
_active_downloads[download_id]["progress"] = 100
|
||||||
_active_downloads[download_id]["path"] = str(dest_path)
|
_active_downloads[download_id]["path"] = str(dest_path)
|
||||||
|
|
||||||
# Auto-scan and link the downloaded file
|
# Register the file in DB: hash, link to CivitAI IDs from version_info,
|
||||||
_auto_link_file(dest_path, api_key)
|
# and cache full model metadata so /api/db/* endpoints return resolved data.
|
||||||
|
with Database() as db:
|
||||||
|
db.init_schema()
|
||||||
|
db_result = db.register_downloaded_file(dest_path, version_info, api_key=api_key)
|
||||||
|
|
||||||
|
_active_downloads[download_id]["db_file_id"] = db_result["file_id"]
|
||||||
|
_active_downloads[download_id]["db_linked"] = db_result["linked"]
|
||||||
|
_active_downloads[download_id]["db_cached"] = db_result["cached"]
|
||||||
|
if db_result["error"]:
|
||||||
|
_active_downloads[download_id]["db_error"] = db_result["error"]
|
||||||
|
logger.error("DB register failed for %s: %s", dest_path, db_result["error"])
|
||||||
else:
|
else:
|
||||||
_active_downloads[download_id]["status"] = "failed"
|
_active_downloads[download_id]["status"] = "failed"
|
||||||
_active_downloads[download_id]["error"] = "Download failed"
|
_active_downloads[download_id]["error"] = "Download failed"
|
||||||
@@ -145,28 +156,6 @@ def _do_download(
|
|||||||
_active_downloads[download_id]["error"] = str(e)
|
_active_downloads[download_id]["error"] = str(e)
|
||||||
|
|
||||||
|
|
||||||
def _auto_link_file(file_path: Path, api_key: str | None) -> None:
|
|
||||||
"""Auto-scan and link the downloaded file to CivitAI."""
|
|
||||||
try:
|
|
||||||
with Database() as db:
|
|
||||||
db.init_schema()
|
|
||||||
# Scan the single file
|
|
||||||
results = db.scan_directory(file_path.parent)
|
|
||||||
|
|
||||||
# Find and link the new file
|
|
||||||
for result in results:
|
|
||||||
if result["file_path"] == str(file_path):
|
|
||||||
sha256 = result["sha256"]
|
|
||||||
civitai_data = fetch_civitai_by_hash(sha256, api_key)
|
|
||||||
if civitai_data:
|
|
||||||
version_id = civitai_data.get("id", 0)
|
|
||||||
model_id = civitai_data.get("modelId", 0)
|
|
||||||
if version_id and model_id:
|
|
||||||
db.link_file_to_civitai(result["id"], model_id, version_id)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Auto-link failed")
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Download Endpoints
|
# Download Endpoints
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -215,7 +204,7 @@ def start_download(req: DownloadRequest, background_tasks: BackgroundTasks) -> d
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Start background download
|
# Start background download
|
||||||
background_tasks.add_task(_do_download, version_id, dest_path, api_key, download_id)
|
background_tasks.add_task(_do_download, version_id, dest_path, api_key, download_id, version_info)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"download_id": download_id,
|
"download_id": download_id,
|
||||||
|
|||||||
+58
-55
@@ -1216,6 +1216,30 @@ class TestGalleryRoutesExtended:
|
|||||||
class TestDownloadBackgroundTasks:
|
class TestDownloadBackgroundTasks:
|
||||||
"""Tests for download background task functions."""
|
"""Tests for download background task functions."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _patch_db_noop(monkeypatch, download_routes_module) -> dict:
|
||||||
|
"""Replace Database with a no-op stub; return a dict capturing register calls."""
|
||||||
|
captured: dict = {}
|
||||||
|
|
||||||
|
class StubDB:
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def init_schema(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def register_downloaded_file(self, dest_path, version_info, api_key=None, console=None):
|
||||||
|
captured["dest_path"] = dest_path
|
||||||
|
captured["version_info"] = version_info
|
||||||
|
captured["api_key"] = api_key
|
||||||
|
return {"file_id": 42, "sha256": "deadbeef", "linked": True, "cached": True, "error": None}
|
||||||
|
|
||||||
|
monkeypatch.setattr(download_routes_module, "Database", lambda: StubDB())
|
||||||
|
return captured
|
||||||
|
|
||||||
def test_do_download_success(self, monkeypatch, tmp_path) -> None:
|
def test_do_download_success(self, monkeypatch, tmp_path) -> None:
|
||||||
"""Test successful download task execution."""
|
"""Test successful download task execution."""
|
||||||
from tensors.server import download_routes
|
from tensors.server import download_routes
|
||||||
@@ -1233,13 +1257,18 @@ class TestDownloadBackgroundTasks:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download)
|
monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download)
|
||||||
monkeypatch.setattr(download_routes, "_auto_link_file", lambda *args: None)
|
captured = self._patch_db_noop(monkeypatch, download_routes)
|
||||||
|
|
||||||
dest_path = tmp_path / "model.safetensors"
|
dest_path = tmp_path / "model.safetensors"
|
||||||
_do_download(12345, dest_path, None, download_id)
|
version_info = {"id": 999, "modelId": 888, "name": "v1"}
|
||||||
|
_do_download(12345, dest_path, None, download_id, version_info)
|
||||||
|
|
||||||
assert download_routes._active_downloads[download_id]["status"] == "completed"
|
assert download_routes._active_downloads[download_id]["status"] == "completed"
|
||||||
assert download_routes._active_downloads[download_id]["progress"] == 100
|
assert download_routes._active_downloads[download_id]["progress"] == 100
|
||||||
|
assert download_routes._active_downloads[download_id]["db_file_id"] == 42
|
||||||
|
assert download_routes._active_downloads[download_id]["db_linked"] is True
|
||||||
|
assert download_routes._active_downloads[download_id]["db_cached"] is True
|
||||||
|
assert captured["version_info"] == version_info
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
del download_routes._active_downloads[download_id]
|
del download_routes._active_downloads[download_id]
|
||||||
@@ -1256,7 +1285,7 @@ class TestDownloadBackgroundTasks:
|
|||||||
monkeypatch.setattr(download_routes, "download_model_with_progress", lambda *args, **kwargs: False)
|
monkeypatch.setattr(download_routes, "download_model_with_progress", lambda *args, **kwargs: False)
|
||||||
|
|
||||||
dest_path = tmp_path / "model.safetensors"
|
dest_path = tmp_path / "model.safetensors"
|
||||||
_do_download(12345, dest_path, None, download_id)
|
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
|
||||||
|
|
||||||
assert download_routes._active_downloads[download_id]["status"] == "failed"
|
assert download_routes._active_downloads[download_id]["status"] == "failed"
|
||||||
assert "error" in download_routes._active_downloads[download_id]
|
assert "error" in download_routes._active_downloads[download_id]
|
||||||
@@ -1278,7 +1307,7 @@ class TestDownloadBackgroundTasks:
|
|||||||
monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download)
|
monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download)
|
||||||
|
|
||||||
dest_path = tmp_path / "model.safetensors"
|
dest_path = tmp_path / "model.safetensors"
|
||||||
_do_download(12345, dest_path, None, download_id)
|
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
|
||||||
|
|
||||||
assert download_routes._active_downloads[download_id]["status"] == "failed"
|
assert download_routes._active_downloads[download_id]["status"] == "failed"
|
||||||
assert "Network error" in download_routes._active_downloads[download_id]["error"]
|
assert "Network error" in download_routes._active_downloads[download_id]["error"]
|
||||||
@@ -1309,10 +1338,10 @@ class TestDownloadBackgroundTasks:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download)
|
monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download)
|
||||||
monkeypatch.setattr(download_routes, "_auto_link_file", lambda *args: None)
|
self._patch_db_noop(monkeypatch, download_routes)
|
||||||
|
|
||||||
dest_path = tmp_path / "model.safetensors"
|
dest_path = tmp_path / "model.safetensors"
|
||||||
_do_download(12345, dest_path, None, download_id)
|
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
|
||||||
|
|
||||||
# Check progress formatting was called
|
# Check progress formatting was called
|
||||||
assert len(progress_calls) == 3
|
assert len(progress_calls) == 3
|
||||||
@@ -1320,7 +1349,6 @@ class TestDownloadBackgroundTasks:
|
|||||||
assert progress_calls[0]["total_str"] == "1.0 KB"
|
assert progress_calls[0]["total_str"] == "1.0 KB"
|
||||||
assert progress_calls[1]["downloaded_str"] == "500.0 KB"
|
assert progress_calls[1]["downloaded_str"] == "500.0 KB"
|
||||||
assert progress_calls[2]["downloaded_str"] == "500.0 MB"
|
assert progress_calls[2]["downloaded_str"] == "500.0 MB"
|
||||||
assert progress_calls[2]["total_str"] == "1.0 GB"
|
|
||||||
|
|
||||||
del download_routes._active_downloads[download_id]
|
del download_routes._active_downloads[download_id]
|
||||||
|
|
||||||
@@ -1337,74 +1365,49 @@ class TestDownloadBackgroundTasks:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download)
|
monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download)
|
||||||
monkeypatch.setattr(download_routes, "_auto_link_file", lambda *args: None)
|
self._patch_db_noop(monkeypatch, download_routes)
|
||||||
|
|
||||||
dest_path = tmp_path / "model.safetensors"
|
dest_path = tmp_path / "model.safetensors"
|
||||||
_do_download(12345, dest_path, None, download_id)
|
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
|
||||||
|
|
||||||
assert download_routes._active_downloads[download_id]["total_str"] == "Unknown"
|
assert download_routes._active_downloads[download_id]["total_str"] == "Unknown"
|
||||||
|
|
||||||
del download_routes._active_downloads[download_id]
|
del download_routes._active_downloads[download_id]
|
||||||
|
|
||||||
|
def test_do_download_db_error_surfaced(self, monkeypatch, tmp_path) -> None:
|
||||||
class TestAutoLinkFile:
|
"""DB errors during register must be surfaced into _active_downloads."""
|
||||||
"""Tests for _auto_link_file function."""
|
|
||||||
|
|
||||||
def test_auto_link_success(self, monkeypatch, tmp_path, temp_db) -> None:
|
|
||||||
"""Test auto-linking a downloaded file."""
|
|
||||||
from tensors.server import download_routes
|
from tensors.server import download_routes
|
||||||
from tensors.server.download_routes import _auto_link_file
|
from tensors.server.download_routes import _do_download
|
||||||
|
|
||||||
# Create a fake safetensor file
|
download_id = "test_db_err"
|
||||||
file_path = tmp_path / "test.safetensors"
|
download_routes._active_downloads[download_id] = {"id": download_id, "status": "queued"}
|
||||||
file_path.write_bytes(b"fake safetensor data")
|
|
||||||
|
|
||||||
from tensors import db as db_module
|
monkeypatch.setattr(download_routes, "download_model_with_progress", lambda *a, **kw: True)
|
||||||
|
|
||||||
monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path)
|
class FailingDB:
|
||||||
|
|
||||||
# Mock scan results
|
|
||||||
scanned_files = []
|
|
||||||
|
|
||||||
def mock_scan(directory):
|
|
||||||
return [{"id": 1, "file_path": str(file_path), "sha256": "abc123"}]
|
|
||||||
|
|
||||||
monkeypatch.setattr(temp_db, "scan_directory", mock_scan)
|
|
||||||
|
|
||||||
# Mock CivitAI lookup
|
|
||||||
def mock_fetch_by_hash(sha256, api_key):
|
|
||||||
return {"id": 999, "modelId": 888}
|
|
||||||
|
|
||||||
monkeypatch.setattr(download_routes, "fetch_civitai_by_hash", mock_fetch_by_hash)
|
|
||||||
|
|
||||||
# Mock Database context manager to use temp_db
|
|
||||||
class MockDB:
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return temp_db
|
return self
|
||||||
|
|
||||||
def __exit__(self, *args):
|
def __exit__(self, *args):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def init_schema(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
monkeypatch.setattr(download_routes, "Database", MockDB)
|
def register_downloaded_file(self, *args, **kwargs):
|
||||||
|
return {"file_id": None, "sha256": None, "linked": False, "cached": False, "error": "boom"}
|
||||||
|
|
||||||
# This should not raise
|
monkeypatch.setattr(download_routes, "Database", lambda: FailingDB())
|
||||||
_auto_link_file(file_path, None)
|
|
||||||
|
|
||||||
def test_auto_link_exception_handled(self, monkeypatch, tmp_path) -> None:
|
dest_path = tmp_path / "model.safetensors"
|
||||||
"""Test auto-link handles exceptions gracefully."""
|
_do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1})
|
||||||
from tensors.server.download_routes import _auto_link_file
|
|
||||||
|
|
||||||
# Mock Database to raise exception
|
# Download itself succeeded; DB layer is reported separately
|
||||||
def mock_db(*args, **kwargs):
|
assert download_routes._active_downloads[download_id]["status"] == "completed"
|
||||||
raise RuntimeError("DB error")
|
assert download_routes._active_downloads[download_id]["db_error"] == "boom"
|
||||||
|
assert download_routes._active_downloads[download_id]["db_linked"] is False
|
||||||
|
|
||||||
from tensors.server import download_routes
|
del download_routes._active_downloads[download_id]
|
||||||
|
|
||||||
monkeypatch.setattr(download_routes, "Database", mock_db)
|
|
||||||
|
|
||||||
file_path = tmp_path / "test.safetensors"
|
|
||||||
# Should not raise
|
|
||||||
_auto_link_file(file_path, None)
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|||||||
Reference in New Issue
Block a user