This commit is contained in:
Adam Ladachowski
2026-02-22 06:23:25 +00:00
parent 5ddfb07448
commit 82eb0d3b5c
3 changed files with 105 additions and 4 deletions
BIN
View File
Binary file not shown.
+102
View File
@@ -473,6 +473,62 @@ def download(
if not success:
raise typer.Exit(1)
# Add downloaded file to database and link to CivitAI
_add_downloaded_file_to_db(dest_path, version_info)
def _add_downloaded_file_to_db(dest_path: Path, version_info: dict[str, Any]) -> None:
"""Add a downloaded file to the database and link to CivitAI.
Args:
dest_path: Path to the downloaded file
version_info: CivitAI version info response
"""
try:
console.print("[dim]Adding to database...[/dim]")
# Compute SHA256 hash
sha256 = compute_sha256(dest_path, console)
# Read safetensor metadata
metadata = read_safetensor_metadata(dest_path)
# Extract CivitAI IDs
civitai_version_id = version_info.get("id")
civitai_model_id = version_info.get("modelId") or version_info.get("model", {}).get("id")
with Database() as db:
db.init_schema()
with db.session() as session:
# Add local file record
local_file = db._upsert_local_file(
session,
file_path=str(dest_path.resolve()),
sha256=sha256,
header_size=metadata.get("header_size"),
tensor_count=metadata.get("tensor_count"),
)
# Store safetensor metadata
db._store_safetensor_metadata(session, local_file.id, metadata.get("metadata", {}))
# Link to CivitAI if we have the IDs
if civitai_model_id and civitai_version_id:
local_file.civitai_model_id = civitai_model_id
local_file.civitai_version_id = civitai_version_id
session.add(local_file)
session.commit()
file_id = local_file.id
# Report success
console.print(f"[green]Added to database (id={file_id})[/green]")
if civitai_model_id and civitai_version_id:
console.print(f"[green]Linked to CivitAI model={civitai_model_id} version={civitai_version_id}[/green]")
except Exception as e:
console.print(f"[yellow]Warning: Could not add to database: {e}[/yellow]")
def _display_download_info(
version_info: dict[str, Any],
@@ -670,6 +726,52 @@ def db_link(
else:
console.print(f"[green]Linked {len(linked)} file(s)[/green]")
# Cache model data for newly linked files
if linked:
_cache_linked_models(db, key, linked, json_output)
def _cache_linked_models(
db: Database,
api_key: str | None,
linked: list[dict[str, Any]],
json_output: bool,
) -> None:
"""Fetch and cache full model data for linked files.
Args:
db: Database instance (already initialized)
api_key: CivitAI API key
linked: List of linked file info dicts with model_id
json_output: Whether to suppress console output
"""
# Collect unique model IDs
model_ids: set[int] = {item["model_id"] for item in linked if item.get("model_id")}
# Find which models are not yet cached
uncached_ids: list[int] = []
for model_id in model_ids:
if db.get_model(model_id) is None:
uncached_ids.append(model_id)
if not uncached_ids:
return
if not json_output:
console.print(f"[cyan]Caching {len(uncached_ids)} model(s)...[/cyan]")
cached: list[dict[str, Any]] = []
for model_id in uncached_ids:
model_data = fetch_civitai_model(model_id, api_key, console if not json_output else None)
if model_data:
db.cache_model(model_data)
cached.append({"model_id": model_id, "name": model_data.get("name", "")})
if not json_output:
console.print(f" [green]✓[/green] Cached: {model_data.get('name', 'N/A')}")
if not json_output and cached:
console.print(f"[green]Cached {len(cached)} model(s)[/green]")
@db_app.command("cache")
def db_cache(
Generated
+3 -4
View File
@@ -882,7 +882,6 @@ wheels = [
[[package]]
name = "tensors"
version = "0.1.18"
source = { editable = "." }
dependencies = [
{ name = "httpx" },
@@ -1027,16 +1026,16 @@ wheels = [
[[package]]
name = "virtualenv"
version = "20.37.0"
version = "20.38.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "distlib" },
{ name = "filelock" },
{ name = "platformdirs" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c1/ef/d9d4ce633df789bf3430bd81fb0d8b9d9465dfc1d1f0deb3fb62cd80f5c2/virtualenv-20.37.0.tar.gz", hash = "sha256:6f7e2064ed470aa7418874e70b6369d53b66bcd9e9fd5389763e96b6c94ccb7c", size = 5864710, upload-time = "2026-02-16T16:17:59.42Z" }
sdist = { url = "https://files.pythonhosted.org/packages/d2/03/a94d404ca09a89a7301a7008467aed525d4cdeb9186d262154dd23208709/virtualenv-20.38.0.tar.gz", hash = "sha256:94f39b1abaea5185bf7ea5a46702b56f1d0c9aa2f41a6c2b8b0af4ddc74c10a7", size = 5864558, upload-time = "2026-02-19T07:48:02.385Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/42/4b/6cf85b485be7ec29db837ec2a1d8cd68bc1147b1abf23d8636c5bd65b3cc/virtualenv-20.37.0-py3-none-any.whl", hash = "sha256:5d3951c32d57232ae3569d4de4cc256c439e045135ebf43518131175d9be435d", size = 5837480, upload-time = "2026-02-16T16:17:57.341Z" },
{ url = "https://files.pythonhosted.org/packages/42/d7/394801755d4c8684b655d35c665aea7836ec68320304f62ab3c94395b442/virtualenv-20.38.0-py3-none-any.whl", hash = "sha256:d6e78e5889de3a4742df2d3d44e779366325a90cf356f15621fddace82431794", size = 5837778, upload-time = "2026-02-19T07:47:59.778Z" },
]
[[package]]