diff --git a/.coverage b/.coverage index 2639881..edbdd6f 100644 Binary files a/.coverage and b/.coverage differ diff --git a/tensors/cli.py b/tensors/cli.py index 9ed37a7..6cea831 100644 --- a/tensors/cli.py +++ b/tensors/cli.py @@ -473,6 +473,62 @@ def download( if not success: raise typer.Exit(1) + # Add downloaded file to database and link to CivitAI + _add_downloaded_file_to_db(dest_path, version_info) + + +def _add_downloaded_file_to_db(dest_path: Path, version_info: dict[str, Any]) -> None: + """Add a downloaded file to the database and link to CivitAI. + + Args: + dest_path: Path to the downloaded file + version_info: CivitAI version info response + """ + try: + console.print("[dim]Adding to database...[/dim]") + + # Compute SHA256 hash + sha256 = compute_sha256(dest_path, console) + + # Read safetensor metadata + metadata = read_safetensor_metadata(dest_path) + + # Extract CivitAI IDs + civitai_version_id = version_info.get("id") + civitai_model_id = version_info.get("modelId") or version_info.get("model", {}).get("id") + + with Database() as db: + db.init_schema() + with db.session() as session: + # Add local file record + local_file = db._upsert_local_file( + session, + file_path=str(dest_path.resolve()), + sha256=sha256, + header_size=metadata.get("header_size"), + tensor_count=metadata.get("tensor_count"), + ) + + # Store safetensor metadata + db._store_safetensor_metadata(session, local_file.id, metadata.get("metadata", {})) + + # Link to CivitAI if we have the IDs + if civitai_model_id and civitai_version_id: + local_file.civitai_model_id = civitai_model_id + local_file.civitai_version_id = civitai_version_id + session.add(local_file) + + session.commit() + file_id = local_file.id + + # Report success + console.print(f"[green]Added to database (id={file_id})[/green]") + if civitai_model_id and civitai_version_id: + console.print(f"[green]Linked to CivitAI model={civitai_model_id} version={civitai_version_id}[/green]") + + except Exception as e: + console.print(f"[yellow]Warning: Could not add to database: {e}[/yellow]") + def _display_download_info( version_info: dict[str, Any], @@ -670,6 +726,52 @@ def db_link( else: console.print(f"[green]Linked {len(linked)} file(s)[/green]") + # Cache model data for newly linked files + if linked: + _cache_linked_models(db, key, linked, json_output) + + +def _cache_linked_models( + db: Database, + api_key: str | None, + linked: list[dict[str, Any]], + json_output: bool, +) -> None: + """Fetch and cache full model data for linked files. + + Args: + db: Database instance (already initialized) + api_key: CivitAI API key + linked: List of linked file info dicts with model_id + json_output: Whether to suppress console output + """ + # Collect unique model IDs + model_ids: set[int] = {item["model_id"] for item in linked if item.get("model_id")} + + # Find which models are not yet cached + uncached_ids: list[int] = [] + for model_id in model_ids: + if db.get_model(model_id) is None: + uncached_ids.append(model_id) + + if not uncached_ids: + return + + if not json_output: + console.print(f"[cyan]Caching {len(uncached_ids)} model(s)...[/cyan]") + + cached: list[dict[str, Any]] = [] + for model_id in uncached_ids: + model_data = fetch_civitai_model(model_id, api_key, console if not json_output else None) + if model_data: + db.cache_model(model_data) + cached.append({"model_id": model_id, "name": model_data.get("name", "")}) + if not json_output: + console.print(f" [green]✓[/green] Cached: {model_data.get('name', 'N/A')}") + + if not json_output and cached: + console.print(f"[green]Cached {len(cached)} model(s)[/green]") + @db_app.command("cache") def db_cache( diff --git a/uv.lock b/uv.lock index 93cbaee..8de96ea 100644 --- a/uv.lock +++ b/uv.lock @@ -882,7 +882,6 @@ wheels = [ [[package]] name = "tensors" -version = "0.1.18" source = { editable = "." } dependencies = [ { name = "httpx" }, @@ -1027,16 +1026,16 @@ wheels = [ [[package]] name = "virtualenv" -version = "20.37.0" +version = "20.38.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "distlib" }, { name = "filelock" }, { name = "platformdirs" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c1/ef/d9d4ce633df789bf3430bd81fb0d8b9d9465dfc1d1f0deb3fb62cd80f5c2/virtualenv-20.37.0.tar.gz", hash = "sha256:6f7e2064ed470aa7418874e70b6369d53b66bcd9e9fd5389763e96b6c94ccb7c", size = 5864710, upload-time = "2026-02-16T16:17:59.42Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d2/03/a94d404ca09a89a7301a7008467aed525d4cdeb9186d262154dd23208709/virtualenv-20.38.0.tar.gz", hash = "sha256:94f39b1abaea5185bf7ea5a46702b56f1d0c9aa2f41a6c2b8b0af4ddc74c10a7", size = 5864558, upload-time = "2026-02-19T07:48:02.385Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/42/4b/6cf85b485be7ec29db837ec2a1d8cd68bc1147b1abf23d8636c5bd65b3cc/virtualenv-20.37.0-py3-none-any.whl", hash = "sha256:5d3951c32d57232ae3569d4de4cc256c439e045135ebf43518131175d9be435d", size = 5837480, upload-time = "2026-02-16T16:17:57.341Z" }, + { url = "https://files.pythonhosted.org/packages/42/d7/394801755d4c8684b655d35c665aea7836ec68320304f62ab3c94395b442/virtualenv-20.38.0-py3-none-any.whl", hash = "sha256:d6e78e5889de3a4742df2d3d44e779366325a90cf356f15621fddace82431794", size = 5837778, upload-time = "2026-02-19T07:47:59.778Z" }, ] [[package]]