diff --git a/tensors/cli.py b/tensors/cli.py index 13db473..546de11 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -604,56 +604,29 @@ def download( def _add_downloaded_file_to_db(dest_path: Path, version_info: dict[str, Any]) -> None: - """Add a downloaded file to the database and link to CivitAI. + """Add a downloaded file to the database, link to CivitAI, and cache full model data. Args: dest_path: Path to the downloaded file version_info: CivitAI version info response """ - try: - console.print("[dim]Adding to database...[/dim]") + console.print("[dim]Adding to database...[/dim]") + api_key = load_api_key() + with Database() as db: + db.init_schema() + result = db.register_downloaded_file(dest_path, version_info, api_key=api_key, console=console) - # Compute SHA256 hash - sha256 = compute_sha256(dest_path, console) + if result["error"]: + console.print(f"[yellow]Warning: Could not add to database: {result['error']}[/yellow]") + return - # Read safetensor metadata - metadata = read_safetensor_metadata(dest_path) - - # Extract CivitAI IDs + console.print(f"[green]Added to database (id={result['file_id']})[/green]") + if result["linked"]: civitai_version_id = version_info.get("id") civitai_model_id = version_info.get("modelId") or version_info.get("model", {}).get("id") - - with Database() as db: - db.init_schema() - with db.session() as session: - # Add local file record - local_file = db._upsert_local_file( - session, - file_path=str(dest_path.resolve()), - sha256=sha256, - header_size=metadata.get("header_size"), - tensor_count=metadata.get("tensor_count"), - ) - - # Store safetensor metadata - db._store_safetensor_metadata(session, local_file.id, metadata.get("metadata", {})) - - # Link to CivitAI if we have the IDs - if civitai_model_id and civitai_version_id: - local_file.civitai_model_id = civitai_model_id - local_file.civitai_version_id = civitai_version_id - session.add(local_file) - - session.commit() - file_id = local_file.id - - # Report success - console.print(f"[green]Added to database (id={file_id})[/green]") - if civitai_model_id and civitai_version_id: - console.print(f"[green]Linked to CivitAI model={civitai_model_id} version={civitai_version_id}[/green]") - - except Exception as e: - console.print(f"[yellow]Warning: Could not add to database: {e}[/yellow]") + console.print(f"[green]Linked to CivitAI model={civitai_model_id} version={civitai_version_id}[/green]") + if result["cached"]: + console.print("[green]Cached model metadata[/green]") def _display_download_info( diff --git a/tensors/db.py b/tensors/db.py index 7df0ba3..5e97540 100644 --- a/tensors/db.py +++ b/tensors/db.py @@ -257,6 +257,71 @@ class Database: session.add(f) session.commit() + def register_downloaded_file( + self, + dest_path: Path, + version_info: dict[str, Any], + api_key: str | None = None, + console: Console | None = None, + ) -> dict[str, Any]: + """Register a freshly-downloaded file: hash, store metadata, link, and cache full model. + + Idempotent and shared by the CLI download flow and the FastAPI background worker so + both paths produce identical DB state (local_file row + cached models/versions/tags + so ``db list`` can resolve names, triggers, and base_model). + + Args: + dest_path: Path to the downloaded safetensor file. + version_info: CivitAI ``model-versions/{id}`` response (already fetched). + api_key: Optional CivitAI API key for the model fetch. + console: Optional Rich console for hash progress output. + + Returns: + ``{"file_id": int, "sha256": str, "linked": bool, "cached": bool, "error": str | None}`` + """ + # Lazy import to avoid pulling httpx into modules that only need DB ops + from tensors.api import fetch_civitai_model # noqa: PLC0415 + + result: dict[str, Any] = {"file_id": None, "sha256": None, "linked": False, "cached": False, "error": None} + try: + sha256 = compute_sha256(dest_path, console) + metadata = read_safetensor_metadata(dest_path) + + civitai_version_id = version_info.get("id") + civitai_model_id = version_info.get("modelId") or version_info.get("model", {}).get("id") + + with self.session() as session: + local_file = self._upsert_local_file( + session, + file_path=str(dest_path.resolve()), + sha256=sha256, + header_size=metadata.get("header_size"), + tensor_count=metadata.get("tensor_count"), + ) + self._store_safetensor_metadata(session, local_file.id, metadata.get("metadata", {})) + + if civitai_model_id and civitai_version_id: + local_file.civitai_model_id = civitai_model_id + local_file.civitai_version_id = civitai_version_id + session.add(local_file) + result["linked"] = True + + session.commit() + result["file_id"] = local_file.id + result["sha256"] = sha256 + + # Cache full model metadata so db list can resolve names/triggers/base_model. + # The version endpoint payload is too sparse for cache_model() (no creator, tags, + # or full modelVersions list), so we fetch the model endpoint here. + if civitai_model_id: + model_data = fetch_civitai_model(civitai_model_id, api_key, console) + if model_data: + self.cache_model(model_data) + result["cached"] = True + except Exception as e: # surface any failure to caller without crashing the download + result["error"] = str(e) + return result + # ========================================================================= # CivitAI Cache Operations # ========================================================================= diff --git a/tensors/server/download_routes.py b/tensors/server/download_routes.py index bea0f3a..a1978f0 100644 --- a/tensors/server/download_routes.py +++ b/tensors/server/download_routes.py @@ -106,6 +106,7 @@ def _do_download( dest_path: Path, api_key: str | None, download_id: str, + version_info: dict[str, Any], ) -> None: """Background task to perform the download.""" try: @@ -133,8 +134,18 @@ def _do_download( _active_downloads[download_id]["progress"] = 100 _active_downloads[download_id]["path"] = str(dest_path) - # Auto-scan and link the downloaded file - _auto_link_file(dest_path, api_key) + # Register the file in DB: hash, link to CivitAI IDs from version_info, + # and cache full model metadata so /api/db/* endpoints return resolved data. + with Database() as db: + db.init_schema() + db_result = db.register_downloaded_file(dest_path, version_info, api_key=api_key) + + _active_downloads[download_id]["db_file_id"] = db_result["file_id"] + _active_downloads[download_id]["db_linked"] = db_result["linked"] + _active_downloads[download_id]["db_cached"] = db_result["cached"] + if db_result["error"]: + _active_downloads[download_id]["db_error"] = db_result["error"] + logger.error("DB register failed for %s: %s", dest_path, db_result["error"]) else: _active_downloads[download_id]["status"] = "failed" _active_downloads[download_id]["error"] = "Download failed" @@ -145,28 +156,6 @@ def _do_download( _active_downloads[download_id]["error"] = str(e) -def _auto_link_file(file_path: Path, api_key: str | None) -> None: - """Auto-scan and link the downloaded file to CivitAI.""" - try: - with Database() as db: - db.init_schema() - # Scan the single file - results = db.scan_directory(file_path.parent) - - # Find and link the new file - for result in results: - if result["file_path"] == str(file_path): - sha256 = result["sha256"] - civitai_data = fetch_civitai_by_hash(sha256, api_key) - if civitai_data: - version_id = civitai_data.get("id", 0) - model_id = civitai_data.get("modelId", 0) - if version_id and model_id: - db.link_file_to_civitai(result["id"], model_id, version_id) - except Exception: - logger.exception("Auto-link failed") - - # ============================================================================= # Download Endpoints # ============================================================================= @@ -215,7 +204,7 @@ def start_download(req: DownloadRequest, background_tasks: BackgroundTasks) -> d } # Start background download - background_tasks.add_task(_do_download, version_id, dest_path, api_key, download_id) + background_tasks.add_task(_do_download, version_id, dest_path, api_key, download_id, version_info) return { "download_id": download_id, diff --git a/tests/test_server.py b/tests/test_server.py index b3930a1..339f7ed 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1216,6 +1216,30 @@ class TestGalleryRoutesExtended: class TestDownloadBackgroundTasks: """Tests for download background task functions.""" + @staticmethod + def _patch_db_noop(monkeypatch, download_routes_module) -> dict: + """Replace Database with a no-op stub; return a dict capturing register calls.""" + captured: dict = {} + + class StubDB: + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + def init_schema(self): + pass + + def register_downloaded_file(self, dest_path, version_info, api_key=None, console=None): + captured["dest_path"] = dest_path + captured["version_info"] = version_info + captured["api_key"] = api_key + return {"file_id": 42, "sha256": "deadbeef", "linked": True, "cached": True, "error": None} + + monkeypatch.setattr(download_routes_module, "Database", lambda: StubDB()) + return captured + def test_do_download_success(self, monkeypatch, tmp_path) -> None: """Test successful download task execution.""" from tensors.server import download_routes @@ -1233,13 +1257,18 @@ class TestDownloadBackgroundTasks: return True monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download) - monkeypatch.setattr(download_routes, "_auto_link_file", lambda *args: None) + captured = self._patch_db_noop(monkeypatch, download_routes) dest_path = tmp_path / "model.safetensors" - _do_download(12345, dest_path, None, download_id) + version_info = {"id": 999, "modelId": 888, "name": "v1"} + _do_download(12345, dest_path, None, download_id, version_info) assert download_routes._active_downloads[download_id]["status"] == "completed" assert download_routes._active_downloads[download_id]["progress"] == 100 + assert download_routes._active_downloads[download_id]["db_file_id"] == 42 + assert download_routes._active_downloads[download_id]["db_linked"] is True + assert download_routes._active_downloads[download_id]["db_cached"] is True + assert captured["version_info"] == version_info # Cleanup del download_routes._active_downloads[download_id] @@ -1256,7 +1285,7 @@ class TestDownloadBackgroundTasks: monkeypatch.setattr(download_routes, "download_model_with_progress", lambda *args, **kwargs: False) dest_path = tmp_path / "model.safetensors" - _do_download(12345, dest_path, None, download_id) + _do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1}) assert download_routes._active_downloads[download_id]["status"] == "failed" assert "error" in download_routes._active_downloads[download_id] @@ -1278,7 +1307,7 @@ class TestDownloadBackgroundTasks: monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download) dest_path = tmp_path / "model.safetensors" - _do_download(12345, dest_path, None, download_id) + _do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1}) assert download_routes._active_downloads[download_id]["status"] == "failed" assert "Network error" in download_routes._active_downloads[download_id]["error"] @@ -1309,10 +1338,10 @@ class TestDownloadBackgroundTasks: return True monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download) - monkeypatch.setattr(download_routes, "_auto_link_file", lambda *args: None) + self._patch_db_noop(monkeypatch, download_routes) dest_path = tmp_path / "model.safetensors" - _do_download(12345, dest_path, None, download_id) + _do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1}) # Check progress formatting was called assert len(progress_calls) == 3 @@ -1320,7 +1349,6 @@ class TestDownloadBackgroundTasks: assert progress_calls[0]["total_str"] == "1.0 KB" assert progress_calls[1]["downloaded_str"] == "500.0 KB" assert progress_calls[2]["downloaded_str"] == "500.0 MB" - assert progress_calls[2]["total_str"] == "1.0 GB" del download_routes._active_downloads[download_id] @@ -1337,74 +1365,49 @@ class TestDownloadBackgroundTasks: return True monkeypatch.setattr(download_routes, "download_model_with_progress", mock_download) - monkeypatch.setattr(download_routes, "_auto_link_file", lambda *args: None) + self._patch_db_noop(monkeypatch, download_routes) dest_path = tmp_path / "model.safetensors" - _do_download(12345, dest_path, None, download_id) + _do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1}) assert download_routes._active_downloads[download_id]["total_str"] == "Unknown" del download_routes._active_downloads[download_id] - -class TestAutoLinkFile: - """Tests for _auto_link_file function.""" - - def test_auto_link_success(self, monkeypatch, tmp_path, temp_db) -> None: - """Test auto-linking a downloaded file.""" + def test_do_download_db_error_surfaced(self, monkeypatch, tmp_path) -> None: + """DB errors during register must be surfaced into _active_downloads.""" from tensors.server import download_routes - from tensors.server.download_routes import _auto_link_file + from tensors.server.download_routes import _do_download - # Create a fake safetensor file - file_path = tmp_path / "test.safetensors" - file_path.write_bytes(b"fake safetensor data") + download_id = "test_db_err" + download_routes._active_downloads[download_id] = {"id": download_id, "status": "queued"} - from tensors import db as db_module + monkeypatch.setattr(download_routes, "download_model_with_progress", lambda *a, **kw: True) - monkeypatch.setattr(db_module, "DB_PATH", temp_db.db_path) - - # Mock scan results - scanned_files = [] - - def mock_scan(directory): - return [{"id": 1, "file_path": str(file_path), "sha256": "abc123"}] - - monkeypatch.setattr(temp_db, "scan_directory", mock_scan) - - # Mock CivitAI lookup - def mock_fetch_by_hash(sha256, api_key): - return {"id": 999, "modelId": 888} - - monkeypatch.setattr(download_routes, "fetch_civitai_by_hash", mock_fetch_by_hash) - - # Mock Database context manager to use temp_db - class MockDB: + class FailingDB: def __enter__(self): - return temp_db + return self def __exit__(self, *args): + return False + + def init_schema(self): pass - monkeypatch.setattr(download_routes, "Database", MockDB) + def register_downloaded_file(self, *args, **kwargs): + return {"file_id": None, "sha256": None, "linked": False, "cached": False, "error": "boom"} - # This should not raise - _auto_link_file(file_path, None) + monkeypatch.setattr(download_routes, "Database", lambda: FailingDB()) - def test_auto_link_exception_handled(self, monkeypatch, tmp_path) -> None: - """Test auto-link handles exceptions gracefully.""" - from tensors.server.download_routes import _auto_link_file + dest_path = tmp_path / "model.safetensors" + _do_download(12345, dest_path, None, download_id, {"id": 1, "modelId": 1}) - # Mock Database to raise exception - def mock_db(*args, **kwargs): - raise RuntimeError("DB error") + # Download itself succeeded; DB layer is reported separately + assert download_routes._active_downloads[download_id]["status"] == "completed" + assert download_routes._active_downloads[download_id]["db_error"] == "boom" + assert download_routes._active_downloads[download_id]["db_linked"] is False - from tensors.server import download_routes - - monkeypatch.setattr(download_routes, "Database", mock_db) - - file_path = tmp_path / "test.safetensors" - # Should not raise - _auto_link_file(file_path, None) + del download_routes._active_downloads[download_id] # =============================================================================