Update
This commit is contained in:
+102
@@ -473,6 +473,62 @@ def download(
|
||||
if not success:
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Add downloaded file to database and link to CivitAI
|
||||
_add_downloaded_file_to_db(dest_path, version_info)
|
||||
|
||||
|
||||
def _add_downloaded_file_to_db(dest_path: Path, version_info: dict[str, Any]) -> None:
|
||||
"""Add a downloaded file to the database and link to CivitAI.
|
||||
|
||||
Args:
|
||||
dest_path: Path to the downloaded file
|
||||
version_info: CivitAI version info response
|
||||
"""
|
||||
try:
|
||||
console.print("[dim]Adding to database...[/dim]")
|
||||
|
||||
# Compute SHA256 hash
|
||||
sha256 = compute_sha256(dest_path, console)
|
||||
|
||||
# Read safetensor metadata
|
||||
metadata = read_safetensor_metadata(dest_path)
|
||||
|
||||
# Extract CivitAI IDs
|
||||
civitai_version_id = version_info.get("id")
|
||||
civitai_model_id = version_info.get("modelId") or version_info.get("model", {}).get("id")
|
||||
|
||||
with Database() as db:
|
||||
db.init_schema()
|
||||
with db.session() as session:
|
||||
# Add local file record
|
||||
local_file = db._upsert_local_file(
|
||||
session,
|
||||
file_path=str(dest_path.resolve()),
|
||||
sha256=sha256,
|
||||
header_size=metadata.get("header_size"),
|
||||
tensor_count=metadata.get("tensor_count"),
|
||||
)
|
||||
|
||||
# Store safetensor metadata
|
||||
db._store_safetensor_metadata(session, local_file.id, metadata.get("metadata", {}))
|
||||
|
||||
# Link to CivitAI if we have the IDs
|
||||
if civitai_model_id and civitai_version_id:
|
||||
local_file.civitai_model_id = civitai_model_id
|
||||
local_file.civitai_version_id = civitai_version_id
|
||||
session.add(local_file)
|
||||
|
||||
session.commit()
|
||||
file_id = local_file.id
|
||||
|
||||
# Report success
|
||||
console.print(f"[green]Added to database (id={file_id})[/green]")
|
||||
if civitai_model_id and civitai_version_id:
|
||||
console.print(f"[green]Linked to CivitAI model={civitai_model_id} version={civitai_version_id}[/green]")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[yellow]Warning: Could not add to database: {e}[/yellow]")
|
||||
|
||||
|
||||
def _display_download_info(
|
||||
version_info: dict[str, Any],
|
||||
@@ -670,6 +726,52 @@ def db_link(
|
||||
else:
|
||||
console.print(f"[green]Linked {len(linked)} file(s)[/green]")
|
||||
|
||||
# Cache model data for newly linked files
|
||||
if linked:
|
||||
_cache_linked_models(db, key, linked, json_output)
|
||||
|
||||
|
||||
def _cache_linked_models(
|
||||
db: Database,
|
||||
api_key: str | None,
|
||||
linked: list[dict[str, Any]],
|
||||
json_output: bool,
|
||||
) -> None:
|
||||
"""Fetch and cache full model data for linked files.
|
||||
|
||||
Args:
|
||||
db: Database instance (already initialized)
|
||||
api_key: CivitAI API key
|
||||
linked: List of linked file info dicts with model_id
|
||||
json_output: Whether to suppress console output
|
||||
"""
|
||||
# Collect unique model IDs
|
||||
model_ids: set[int] = {item["model_id"] for item in linked if item.get("model_id")}
|
||||
|
||||
# Find which models are not yet cached
|
||||
uncached_ids: list[int] = []
|
||||
for model_id in model_ids:
|
||||
if db.get_model(model_id) is None:
|
||||
uncached_ids.append(model_id)
|
||||
|
||||
if not uncached_ids:
|
||||
return
|
||||
|
||||
if not json_output:
|
||||
console.print(f"[cyan]Caching {len(uncached_ids)} model(s)...[/cyan]")
|
||||
|
||||
cached: list[dict[str, Any]] = []
|
||||
for model_id in uncached_ids:
|
||||
model_data = fetch_civitai_model(model_id, api_key, console if not json_output else None)
|
||||
if model_data:
|
||||
db.cache_model(model_data)
|
||||
cached.append({"model_id": model_id, "name": model_data.get("name", "")})
|
||||
if not json_output:
|
||||
console.print(f" [green]✓[/green] Cached: {model_data.get('name', 'N/A')}")
|
||||
|
||||
if not json_output and cached:
|
||||
console.print(f"[green]Cached {len(cached)} model(s)[/green]")
|
||||
|
||||
|
||||
@db_app.command("cache")
|
||||
def db_cache(
|
||||
|
||||
@@ -882,7 +882,6 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "tensors"
|
||||
version = "0.1.18"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "httpx" },
|
||||
@@ -1027,16 +1026,16 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "virtualenv"
|
||||
version = "20.37.0"
|
||||
version = "20.38.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "distlib" },
|
||||
{ name = "filelock" },
|
||||
{ name = "platformdirs" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c1/ef/d9d4ce633df789bf3430bd81fb0d8b9d9465dfc1d1f0deb3fb62cd80f5c2/virtualenv-20.37.0.tar.gz", hash = "sha256:6f7e2064ed470aa7418874e70b6369d53b66bcd9e9fd5389763e96b6c94ccb7c", size = 5864710, upload-time = "2026-02-16T16:17:59.42Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d2/03/a94d404ca09a89a7301a7008467aed525d4cdeb9186d262154dd23208709/virtualenv-20.38.0.tar.gz", hash = "sha256:94f39b1abaea5185bf7ea5a46702b56f1d0c9aa2f41a6c2b8b0af4ddc74c10a7", size = 5864558, upload-time = "2026-02-19T07:48:02.385Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/42/4b/6cf85b485be7ec29db837ec2a1d8cd68bc1147b1abf23d8636c5bd65b3cc/virtualenv-20.37.0-py3-none-any.whl", hash = "sha256:5d3951c32d57232ae3569d4de4cc256c439e045135ebf43518131175d9be435d", size = 5837480, upload-time = "2026-02-16T16:17:57.341Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/42/d7/394801755d4c8684b655d35c665aea7836ec68320304f62ab3c94395b442/virtualenv-20.38.0-py3-none-any.whl", hash = "sha256:d6e78e5889de3a4742df2d3d44e779366325a90cf356f15621fddace82431794", size = 5837778, upload-time = "2026-02-19T07:47:59.778Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user