diff --git a/TODO.md b/TODO.md index 9f2dab9..d7f4e34 100644 --- a/TODO.md +++ b/TODO.md @@ -1,12 +1,12 @@ # TODO: tsr Server/Client Architecture -## Phase 1: Model-Specific Docker Images -- [ ] Step 1.1: Create `rocm-docker/model-defaults.toml` (optimal params per model family) -- [ ] Step 1.2: Parameterize `Dockerfile.sd-server` with `MODEL_FAMILY` build arg -- [ ] Step 1.3: Create `rocm-docker/build-all.sh` (build all model variants) +## Phase 1: Model-Specific Docker Images (SKIPPED) +- [x] Step 1.1: ~~Create `rocm-docker/model-defaults.toml`~~ (skipped) +- [x] Step 1.2: ~~Parameterize `Dockerfile.sd-server`~~ (skipped) +- [x] Step 1.3: ~~Create `rocm-docker/build-all.sh`~~ (skipped) ## Phase 2: Models Database in tensors -- [ ] Step 2.1: Create `tensors/db.py` + `tensors/schema.sql` (SQLite wrapper, schema, CRUD) +- [x] Step 2.1: Create `tensors/db.py` + `tensors/schema.sql` (SQLite wrapper, schema, CRUD) - [ ] Step 2.2: Add `tsr db` CLI commands (scan, link, cache, list, search, triggers) - [ ] Step 2.3: Add `/api/db/*` endpoints (files, models, triggers, scan, link) @@ -21,10 +21,10 @@ - [ ] Step 4.2: Add `[remotes]` config section + `--remote` flag support - [ ] Step 4.3: Update CLI commands with `--remote` support (generate, images, models, dl, db) -## Phase 5: Docker Deployment Automation -- [ ] Step 5.1: Create `rocm-docker/docker-compose.yml` (multi-model setup) -- [ ] Step 5.2: Create `rocm-docker/deploy.sh` (one-command deploy) -- [ ] Step 5.3: Create `rocm-docker/tsr-server.service` (systemd unit) +## Phase 5: Docker Deployment Automation (SKIPPED) +- [x] Step 5.1: ~~Create `rocm-docker/docker-compose.yml`~~ (skipped) +- [x] Step 5.2: ~~Create `rocm-docker/deploy.sh`~~ (skipped) +- [x] Step 5.3: ~~Create `rocm-docker/tsr-server.service`~~ (skipped) ## Phase 6: Tests - [ ] Step 6.1: `tests/test_db.py` (database module tests) diff --git a/tensors/db.py b/tensors/db.py new file mode 100644 index 0000000..63ec993 --- /dev/null +++ b/tensors/db.py @@ -0,0 +1,604 @@ +"""SQLite database for local model metadata and CivitAI cache.""" + +from __future__ import annotations + +import json +import sqlite3 +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from tensors.config import DATA_DIR +from tensors.safetensor import compute_sha256, read_safetensor_metadata + +if TYPE_CHECKING: + from rich.console import Console + +# Database location +DB_PATH = DATA_DIR / "models.db" + +# Load schema from file +_SCHEMA_PATH = Path(__file__).parent / "schema.sql" + + +class Database: + """SQLite database wrapper for models metadata.""" + + def __init__(self, db_path: Path | None = None) -> None: + """Initialize database connection.""" + self.db_path = db_path or DB_PATH + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._conn: sqlite3.Connection | None = None + + @property + def conn(self) -> sqlite3.Connection: + """Get or create database connection.""" + if self._conn is None: + self._conn = sqlite3.connect(self.db_path) + self._conn.row_factory = sqlite3.Row + self._conn.execute("PRAGMA foreign_keys = ON") + return self._conn + + def close(self) -> None: + """Close database connection.""" + if self._conn is not None: + self._conn.close() + self._conn = None + + def __enter__(self) -> Database: + return self + + def __exit__(self, *exc: object) -> None: + self.close() + + def init_schema(self) -> None: + """Initialize database schema from schema.sql.""" + schema = _SCHEMA_PATH.read_text() + self.conn.executescript(schema) + self.conn.commit() + + # ========================================================================= + # Local Files Operations + # ========================================================================= + + def scan_directory( + self, + directory: Path, + console: Console | None = None, + ) -> list[dict[str, Any]]: + """Scan directory for safetensor files and add to database. + + Returns list of scanned file info dicts. + """ + results: list[dict[str, Any]] = [] + safetensor_files = list(directory.rglob("*.safetensors")) + + for path in safetensor_files: + if console: + console.print(f"[dim]Scanning {path.name}...[/dim]") + + try: + sha256 = compute_sha256(path) + metadata = read_safetensor_metadata(path) + + file_info = self._upsert_local_file( + file_path=str(path.resolve()), + sha256=sha256, + header_size=metadata.get("header_size"), + tensor_count=metadata.get("tensor_count"), + ) + + # Store safetensor metadata + self._store_safetensor_metadata(file_info["id"], metadata.get("metadata", {})) + + results.append(file_info) + self.conn.commit() + + except Exception as e: + if console: + console.print(f"[red]Error scanning {path.name}: {e}[/red]") + + return results + + def _upsert_local_file( + self, + file_path: str, + sha256: str, + header_size: int | None = None, + tensor_count: int | None = None, + ) -> dict[str, Any]: + """Insert or update a local file record.""" + cur = self.conn.cursor() + + cur.execute("SELECT id FROM local_files WHERE file_path = ?", (file_path,)) + existing = cur.fetchone() + + if existing: + cur.execute( + """ + UPDATE local_files SET sha256 = ?, header_size = ?, tensor_count = ?, + updated_at = datetime('now') WHERE id = ? + """, + (sha256, header_size, tensor_count, existing["id"]), + ) + file_id = existing["id"] + else: + cur.execute( + """ + INSERT INTO local_files (file_path, sha256, header_size, tensor_count) + VALUES (?, ?, ?, ?) + """, + (file_path, sha256, header_size, tensor_count), + ) + file_id = cur.lastrowid or 0 # lastrowid is always set after INSERT + + return {"id": file_id, "file_path": file_path, "sha256": sha256} + + def _store_safetensor_metadata(self, local_file_id: int, metadata: dict[str, Any]) -> None: + """Store safetensor header metadata.""" + cur = self.conn.cursor() + for key, value in metadata.items(): + str_value = json.dumps(value) if not isinstance(value, str) else value + cur.execute( + """ + INSERT INTO safetensor_metadata (local_file_id, key, value) + VALUES (?, ?, ?) + ON CONFLICT(local_file_id, key) DO UPDATE SET value = excluded.value + """, + (local_file_id, key, str_value), + ) + + def list_local_files(self) -> list[dict[str, Any]]: + """List all local files with CivitAI info.""" + cur = self.conn.cursor() + cur.execute("SELECT * FROM v_local_files_full ORDER BY file_path") + return [dict(row) for row in cur.fetchall()] + + def get_local_file_by_path(self, file_path: str) -> dict[str, Any] | None: + """Get local file by path.""" + cur = self.conn.cursor() + cur.execute("SELECT * FROM v_local_files_full WHERE file_path = ?", (file_path,)) + row = cur.fetchone() + return dict(row) if row else None + + def get_local_file_by_hash(self, sha256: str) -> dict[str, Any] | None: + """Get local file by SHA256 hash.""" + cur = self.conn.cursor() + cur.execute("SELECT * FROM v_local_files_full WHERE sha256 = ?", (sha256.upper(),)) + row = cur.fetchone() + return dict(row) if row else None + + def get_unlinked_files(self) -> list[dict[str, Any]]: + """Get local files not linked to CivitAI.""" + cur = self.conn.cursor() + cur.execute( + """ + SELECT id, file_path, sha256 FROM local_files + WHERE civitai_model_id IS NULL + """ + ) + return [dict(row) for row in cur.fetchall()] + + def link_file_to_civitai( + self, + file_id: int, + model_id: int, + version_id: int, + ) -> None: + """Link a local file to CivitAI model/version.""" + cur = self.conn.cursor() + cur.execute( + """ + UPDATE local_files + SET civitai_model_id = ?, civitai_version_id = ?, updated_at = datetime('now') + WHERE id = ? + """, + (model_id, version_id, file_id), + ) + self.conn.commit() + + # ========================================================================= + # CivitAI Cache Operations + # ========================================================================= + + def get_version_by_hash(self, sha256: str) -> dict[str, Any] | None: + """Find cached version by file hash.""" + cur = self.conn.cursor() + cur.execute( + """ + SELECT mv.civitai_id as version_id, m.civitai_id as model_id, + m.name as model_name, mv.name as version_name + FROM file_hashes fh + JOIN version_files vf ON fh.file_id = vf.id + JOIN model_versions mv ON vf.version_id = mv.id + JOIN models m ON mv.model_id = m.id + WHERE UPPER(fh.hash_value) = UPPER(?) + """, + (sha256,), + ) + row = cur.fetchone() + return dict(row) if row else None + + def cache_model(self, data: dict[str, Any]) -> int: + """Cache full model data from CivitAI API response. + + Returns the internal model ID. + """ + cur = self.conn.cursor() + + # Get or create creator + creator_id = self._get_or_create_creator(data.get("creator")) + + # Check if model exists + civitai_id = data.get("id") + cur.execute("SELECT id FROM models WHERE civitai_id = ?", (civitai_id,)) + existing = cur.fetchone() + + stats = data.get("stats", {}) + + if existing: + model_id = int(existing["id"]) + cur.execute( + """ + UPDATE models SET + name = ?, description = ?, type = ?, nsfw = ?, + download_count = ?, thumbs_up_count = ?, + updated_at = datetime('now') + WHERE id = ? + """, + ( + data.get("name"), + data.get("description"), + data.get("type"), + 1 if data.get("nsfw") else 0, + stats.get("downloadCount", 0), + stats.get("thumbsUpCount", 0), + model_id, + ), + ) + else: + cur.execute( + """ + INSERT INTO models ( + civitai_id, name, description, type, nsfw, poi, minor, + sfw_only, nsfw_level, availability, allow_no_credit, + allow_commercial_use, allow_derivatives, allow_different_license, + supports_generation, creator_id, download_count, thumbs_up_count, + thumbs_down_count, comment_count, tipped_amount_count, + created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now')) + """, + ( + civitai_id, + data.get("name"), + data.get("description"), + data.get("type"), + 1 if data.get("nsfw") else 0, + 1 if data.get("poi") else 0, + 1 if data.get("minor") else 0, + 1 if data.get("sfwOnly") else 0, + data.get("nsfwLevel"), + data.get("availability"), + 1 if data.get("allowNoCredit") else 0, + str(data.get("allowCommercialUse", "")), + 1 if data.get("allowDerivatives") else 0, + 1 if data.get("allowDifferentLicense") else 0, + 1 if data.get("supportsGeneration") else 0, + creator_id, + stats.get("downloadCount", 0), + stats.get("thumbsUpCount", 0), + stats.get("thumbsDownCount", 0), + stats.get("commentCount", 0), + stats.get("tippedAmountCount", 0), + data.get("createdAt"), + ), + ) + model_id = cur.lastrowid or 0 # lastrowid is always set after INSERT + + # Cache tags + for tag_name in data.get("tags", []): + tag_id = self._get_or_create_tag(tag_name) + cur.execute("INSERT OR IGNORE INTO model_tags (model_id, tag_id) VALUES (?, ?)", (model_id, tag_id)) + + # Cache versions + for idx, version in enumerate(data.get("modelVersions", [])): + self._cache_version(model_id, version, idx) + + self.conn.commit() + return model_id + + def _get_or_create_creator(self, creator_data: dict[str, Any] | None) -> int | None: + """Get or create a creator record.""" + if not creator_data: + return None + username = creator_data.get("username") + if not username: + return None + + cur = self.conn.cursor() + cur.execute("SELECT id FROM creators WHERE username = ?", (username,)) + row = cur.fetchone() + if row: + return int(row["id"]) + + cur.execute( + "INSERT INTO creators (username, image_url) VALUES (?, ?)", + (username, creator_data.get("image")), + ) + return cur.lastrowid or 0 + + def _get_or_create_tag(self, tag_name: str) -> int: + """Get or create a tag record.""" + cur = self.conn.cursor() + cur.execute("SELECT id FROM tags WHERE name = ?", (tag_name,)) + row = cur.fetchone() + if row: + return int(row["id"]) + + cur.execute("INSERT INTO tags (name) VALUES (?)", (tag_name,)) + return cur.lastrowid or 0 # lastrowid is always set after INSERT + + def _cache_version(self, model_id: int, version: dict[str, Any], index: int) -> int: + """Cache a model version.""" + cur = self.conn.cursor() + civitai_id = version.get("id") + + cur.execute("SELECT id FROM model_versions WHERE civitai_id = ?", (civitai_id,)) + existing = cur.fetchone() + + stats = version.get("stats", {}) + + if existing: + version_id = int(existing["id"]) + else: + cur.execute( + """ + INSERT INTO model_versions ( + civitai_id, model_id, name, description, base_model, + base_model_type, nsfw_level, status, availability, + download_count, thumbs_up_count, thumbs_down_count, + supports_generation, download_url, created_at, published_at, + updated_at, version_index + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + civitai_id, + model_id, + version.get("name"), + version.get("description"), + version.get("baseModel"), + version.get("baseModelType"), + version.get("nsfwLevel"), + version.get("status"), + version.get("availability"), + stats.get("downloadCount", 0), + stats.get("thumbsUpCount", 0), + stats.get("thumbsDownCount", 0), + 1 if version.get("supportsGeneration") else 0, + version.get("downloadUrl"), + version.get("createdAt"), + version.get("publishedAt"), + version.get("updatedAt"), + index, + ), + ) + version_id = cur.lastrowid or 0 # lastrowid is always set after INSERT + + # Cache trained words + for pos, word in enumerate(version.get("trainedWords", [])): + cur.execute( + "INSERT OR IGNORE INTO trained_words (version_id, word, position) VALUES (?, ?, ?)", + (version_id, word, pos), + ) + + # Cache files and hashes + for file_data in version.get("files", []): + self._cache_file(version_id, file_data) + + # Cache images + for image_data in version.get("images", []): + self._cache_image(version_id, image_data) + + return version_id + + def _cache_file(self, version_id: int, file_data: dict[str, Any]) -> int | None: + """Cache a version file.""" + cur = self.conn.cursor() + civitai_id = file_data.get("id") + if not civitai_id: + return None + + cur.execute("SELECT id FROM version_files WHERE civitai_id = ?", (civitai_id,)) + existing = cur.fetchone() + + if existing: + return int(existing["id"]) + + meta = file_data.get("metadata", {}) + cur.execute( + """ + INSERT INTO version_files ( + civitai_id, version_id, name, type, size_kb, format, + size_type, fp, is_primary, pickle_scan_result, + virus_scan_result, scanned_at, download_url + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + civitai_id, + version_id, + file_data.get("name"), + file_data.get("type"), + file_data.get("sizeKB"), + meta.get("format"), + meta.get("size"), + meta.get("fp"), + 1 if file_data.get("primary") else 0, + file_data.get("pickleScanResult"), + file_data.get("virusScanResult"), + file_data.get("scannedAt"), + file_data.get("downloadUrl"), + ), + ) + file_id = cur.lastrowid or 0 # lastrowid is always set after INSERT + + # Cache hashes + for hash_type, hash_value in file_data.get("hashes", {}).items(): + cur.execute( + "INSERT OR IGNORE INTO file_hashes (file_id, hash_type, hash_value) VALUES (?, ?, ?)", + (file_id, hash_type, hash_value), + ) + + return file_id + + def _cache_image(self, version_id: int, image_data: dict[str, Any]) -> int | None: + """Cache a version image.""" + cur = self.conn.cursor() + url = image_data.get("url") + if not url: + return None + + cur.execute("SELECT id FROM version_images WHERE url = ?", (url,)) + existing = cur.fetchone() + + if existing: + return int(existing["id"]) + + cur.execute( + """ + INSERT INTO version_images ( + civitai_id, version_id, url, type, nsfw_level, width, + height, hash, has_meta, has_positive_prompt, on_site, + minor, poi, availability + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + image_data.get("id"), + version_id, + url, + image_data.get("type"), + image_data.get("nsfwLevel"), + image_data.get("width"), + image_data.get("height"), + image_data.get("hash"), + 1 if image_data.get("hasMeta") else 0, + 1 if image_data.get("hasPositivePrompt") else 0, + 1 if image_data.get("onSite") else 0, + 1 if image_data.get("minor") else 0, + 1 if image_data.get("poi") else 0, + image_data.get("availability"), + ), + ) + image_id = cur.lastrowid or 0 # lastrowid is always set after INSERT + + # Cache generation params + meta = image_data.get("meta", {}) + for key, value in meta.items(): + if key == "resources": + continue + str_value = str(value) if value is not None else None + cur.execute( + "INSERT OR IGNORE INTO image_generation_params (image_id, key, value) VALUES (?, ?, ?)", + (image_id, key, str_value), + ) + + # Cache resources + for res in meta.get("resources", []): + cur.execute( + "INSERT INTO image_resources (image_id, name, type, hash, weight) VALUES (?, ?, ?, ?, ?)", + (image_id, res.get("name"), res.get("type"), res.get("hash"), res.get("weight")), + ) + + return image_id + + # ========================================================================= + # Query Operations + # ========================================================================= + + def search_models( + self, + query: str | None = None, + model_type: str | None = None, + base_model: str | None = None, + limit: int = 20, + ) -> list[dict[str, Any]]: + """Search cached models.""" + cur = self.conn.cursor() + + sql = "SELECT * FROM v_models_with_latest WHERE 1=1" + params: list[Any] = [] + + if query: + sql += " AND name LIKE ?" + params.append(f"%{query}%") + + if model_type: + sql += " AND type = ?" + params.append(model_type) + + if base_model: + sql += " AND base_model LIKE ?" + params.append(f"%{base_model}%") + + sql += " ORDER BY download_count DESC LIMIT ?" + params.append(limit) + + cur.execute(sql, params) + return [dict(row) for row in cur.fetchall()] + + def get_model(self, civitai_id: int) -> dict[str, Any] | None: + """Get cached model by CivitAI ID.""" + cur = self.conn.cursor() + cur.execute("SELECT * FROM v_models_with_latest WHERE civitai_id = ?", (civitai_id,)) + row = cur.fetchone() + return dict(row) if row else None + + def get_triggers(self, file_path: str) -> list[str]: + """Get trigger words for a local file.""" + cur = self.conn.cursor() + cur.execute( + """ + SELECT tw.word + FROM trained_words tw + JOIN model_versions mv ON tw.version_id = mv.id + JOIN local_files lf ON lf.civitai_version_id = mv.civitai_id + WHERE lf.file_path = ? + ORDER BY tw.position + """, + (file_path,), + ) + return [row["word"] for row in cur.fetchall()] + + def get_triggers_by_version(self, version_id: int) -> list[str]: + """Get trigger words for a version by CivitAI version ID.""" + cur = self.conn.cursor() + cur.execute( + """ + SELECT tw.word + FROM trained_words tw + JOIN model_versions mv ON tw.version_id = mv.id + WHERE mv.civitai_id = ? + ORDER BY tw.position + """, + (version_id,), + ) + return [row["word"] for row in cur.fetchall()] + + # ========================================================================= + # Statistics + # ========================================================================= + + def get_stats(self) -> dict[str, int]: + """Get database statistics.""" + cur = self.conn.cursor() + stats = {} + for table in [ + "local_files", + "models", + "model_versions", + "version_files", + "trained_words", + "creators", + "tags", + ]: + cur.execute(f"SELECT COUNT(*) FROM {table}") + stats[table] = cur.fetchone()[0] + return stats diff --git a/tensors/safetensor.py b/tensors/safetensor.py index 710d57b..f47ae2e 100644 --- a/tensors/safetensor.py +++ b/tensors/safetensor.py @@ -60,28 +60,37 @@ def read_safetensor_metadata(file_path: Path) -> dict[str, Any]: } -def compute_sha256(file_path: Path, console: Console) -> str: - """Compute SHA256 hash of a file with progress display.""" +def compute_sha256(file_path: Path, console: Console | None = None) -> str: + """Compute SHA256 hash of a file with optional progress display. + + If console is provided, shows a progress bar. Otherwise computes silently. + """ file_size = file_path.stat().st_size sha256 = hashlib.sha256() chunk_size = 1024 * 1024 * 8 # 8MB chunks - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - DownloadColumn(), - TransferSpeedColumn(), - TimeRemainingColumn(), - console=console, - ) as progress: - task = progress.add_task(f"[cyan]Hashing {file_path.name}...", total=file_size) + if console is not None: + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + DownloadColumn(), + TransferSpeedColumn(), + TimeRemainingColumn(), + console=console, + ) as progress: + task = progress.add_task(f"[cyan]Hashing {file_path.name}...", total=file_size) + with file_path.open("rb") as f: + while chunk := f.read(chunk_size): + sha256.update(chunk) + progress.update(task, advance=len(chunk)) + else: + # Silent mode - no progress display with file_path.open("rb") as f: while chunk := f.read(chunk_size): sha256.update(chunk) - progress.update(task, advance=len(chunk)) return sha256.hexdigest().upper() diff --git a/tensors/schema.sql b/tensors/schema.sql new file mode 100644 index 0000000..80412a6 --- /dev/null +++ b/tensors/schema.sql @@ -0,0 +1,259 @@ +-- Models Database Schema +-- SQLite database for local model metadata storage and CivitAI model information cache. + +-- ============================================================================ +-- Core Tables: Local Files +-- ============================================================================ + +CREATE TABLE IF NOT EXISTS local_files ( + id INTEGER PRIMARY KEY, + file_path TEXT NOT NULL UNIQUE, + sha256 TEXT NOT NULL, + header_size INTEGER, + tensor_count INTEGER, + civitai_model_id INTEGER, + civitai_version_id INTEGER, + created_at TEXT DEFAULT (datetime('now')), + updated_at TEXT DEFAULT (datetime('now')) +); + +CREATE INDEX IF NOT EXISTS idx_local_files_sha256 ON local_files(sha256); +CREATE INDEX IF NOT EXISTS idx_local_files_civitai_model ON local_files(civitai_model_id); + +CREATE TABLE IF NOT EXISTS safetensor_metadata ( + id INTEGER PRIMARY KEY, + local_file_id INTEGER NOT NULL, + key TEXT NOT NULL, + value TEXT, + FOREIGN KEY (local_file_id) REFERENCES local_files(id) ON DELETE CASCADE, + UNIQUE(local_file_id, key) +); + +CREATE INDEX IF NOT EXISTS idx_safetensor_metadata_file ON safetensor_metadata(local_file_id); + +-- ============================================================================ +-- CivitAI Cache Tables +-- ============================================================================ + +CREATE TABLE IF NOT EXISTS creators ( + id INTEGER PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + image_url TEXT +); + +CREATE TABLE IF NOT EXISTS models ( + id INTEGER PRIMARY KEY, + civitai_id INTEGER UNIQUE NOT NULL, + name TEXT NOT NULL, + description TEXT, + type TEXT NOT NULL, + nsfw INTEGER DEFAULT 0, + poi INTEGER DEFAULT 0, + minor INTEGER DEFAULT 0, + sfw_only INTEGER DEFAULT 0, + nsfw_level INTEGER, + availability TEXT, + allow_no_credit INTEGER, + allow_commercial_use TEXT, + allow_derivatives INTEGER, + allow_different_license INTEGER, + supports_generation INTEGER DEFAULT 0, + creator_id INTEGER, + download_count INTEGER DEFAULT 0, + thumbs_up_count INTEGER DEFAULT 0, + thumbs_down_count INTEGER DEFAULT 0, + comment_count INTEGER DEFAULT 0, + tipped_amount_count INTEGER DEFAULT 0, + created_at TEXT, + updated_at TEXT, + FOREIGN KEY (creator_id) REFERENCES creators(id) +); + +CREATE INDEX IF NOT EXISTS idx_models_civitai ON models(civitai_id); +CREATE INDEX IF NOT EXISTS idx_models_type ON models(type); +CREATE INDEX IF NOT EXISTS idx_models_name ON models(name); + +CREATE TABLE IF NOT EXISTS tags ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL UNIQUE +); + +CREATE TABLE IF NOT EXISTS model_tags ( + model_id INTEGER NOT NULL, + tag_id INTEGER NOT NULL, + PRIMARY KEY (model_id, tag_id), + FOREIGN KEY (model_id) REFERENCES models(id) ON DELETE CASCADE, + FOREIGN KEY (tag_id) REFERENCES tags(id) ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS model_versions ( + id INTEGER PRIMARY KEY, + civitai_id INTEGER UNIQUE NOT NULL, + model_id INTEGER NOT NULL, + name TEXT NOT NULL, + description TEXT, + base_model TEXT, + base_model_type TEXT, + nsfw_level INTEGER, + status TEXT, + availability TEXT, + upload_type TEXT, + usage_control TEXT, + air TEXT, + training_status TEXT, + training_details TEXT, + early_access_ends_at TEXT, + download_count INTEGER DEFAULT 0, + thumbs_up_count INTEGER DEFAULT 0, + thumbs_down_count INTEGER DEFAULT 0, + supports_generation INTEGER DEFAULT 0, + download_url TEXT, + created_at TEXT, + published_at TEXT, + updated_at TEXT, + version_index INTEGER, + FOREIGN KEY (model_id) REFERENCES models(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_model_versions_civitai ON model_versions(civitai_id); +CREATE INDEX IF NOT EXISTS idx_model_versions_model ON model_versions(model_id); +CREATE INDEX IF NOT EXISTS idx_model_versions_base ON model_versions(base_model); + +CREATE TABLE IF NOT EXISTS trained_words ( + id INTEGER PRIMARY KEY, + version_id INTEGER NOT NULL, + word TEXT NOT NULL, + position INTEGER, + FOREIGN KEY (version_id) REFERENCES model_versions(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_trained_words_version ON trained_words(version_id); + +CREATE TABLE IF NOT EXISTS version_files ( + id INTEGER PRIMARY KEY, + civitai_id INTEGER UNIQUE NOT NULL, + version_id INTEGER NOT NULL, + name TEXT NOT NULL, + type TEXT, + size_kb REAL, + format TEXT, + size_type TEXT, + fp TEXT, + is_primary INTEGER DEFAULT 0, + pickle_scan_result TEXT, + pickle_scan_message TEXT, + virus_scan_result TEXT, + virus_scan_message TEXT, + scanned_at TEXT, + download_url TEXT, + FOREIGN KEY (version_id) REFERENCES model_versions(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_version_files_version ON version_files(version_id); + +CREATE TABLE IF NOT EXISTS file_hashes ( + id INTEGER PRIMARY KEY, + file_id INTEGER NOT NULL, + hash_type TEXT NOT NULL, + hash_value TEXT NOT NULL, + FOREIGN KEY (file_id) REFERENCES version_files(id) ON DELETE CASCADE, + UNIQUE(file_id, hash_type) +); + +CREATE INDEX IF NOT EXISTS idx_file_hashes_file ON file_hashes(file_id); +CREATE INDEX IF NOT EXISTS idx_file_hashes_value ON file_hashes(hash_value); + +CREATE TABLE IF NOT EXISTS version_images ( + id INTEGER PRIMARY KEY, + civitai_id INTEGER, + version_id INTEGER NOT NULL, + url TEXT NOT NULL, + type TEXT, + nsfw_level INTEGER, + width INTEGER, + height INTEGER, + hash TEXT, + has_meta INTEGER DEFAULT 0, + has_positive_prompt INTEGER DEFAULT 0, + on_site INTEGER DEFAULT 0, + minor INTEGER DEFAULT 0, + poi INTEGER DEFAULT 0, + availability TEXT, + remix_of_id INTEGER, + FOREIGN KEY (version_id) REFERENCES model_versions(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_version_images_version ON version_images(version_id); + +CREATE TABLE IF NOT EXISTS image_video_metadata ( + id INTEGER PRIMARY KEY, + image_id INTEGER NOT NULL UNIQUE, + duration REAL, + has_audio INTEGER DEFAULT 0, + size_bytes INTEGER, + FOREIGN KEY (image_id) REFERENCES version_images(id) ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS image_generation_params ( + id INTEGER PRIMARY KEY, + image_id INTEGER NOT NULL, + key TEXT NOT NULL, + value TEXT, + FOREIGN KEY (image_id) REFERENCES version_images(id) ON DELETE CASCADE, + UNIQUE(image_id, key) +); + +CREATE INDEX IF NOT EXISTS idx_image_params_image ON image_generation_params(image_id); + +CREATE TABLE IF NOT EXISTS image_resources ( + id INTEGER PRIMARY KEY, + image_id INTEGER NOT NULL, + name TEXT NOT NULL, + type TEXT, + hash TEXT, + weight REAL, + FOREIGN KEY (image_id) REFERENCES version_images(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_image_resources_image ON image_resources(image_id); + +-- ============================================================================ +-- Views +-- ============================================================================ + +CREATE VIEW IF NOT EXISTS v_models_with_latest AS +SELECT + m.id, + m.civitai_id, + m.name, + m.type, + m.nsfw, + c.username as creator, + mv.name as latest_version, + mv.base_model, + m.download_count, + m.thumbs_up_count +FROM models m +LEFT JOIN creators c ON m.creator_id = c.id +LEFT JOIN model_versions mv ON mv.model_id = m.id AND mv.version_index = 0; + +CREATE VIEW IF NOT EXISTS v_local_files_full AS +SELECT + lf.id, + lf.file_path, + lf.sha256, + lf.header_size, + lf.tensor_count, + lf.civitai_model_id, + lf.civitai_version_id, + m.name as model_name, + m.type as model_type, + mv.name as version_name, + mv.base_model, + c.username as creator, + lf.created_at, + lf.updated_at +FROM local_files lf +LEFT JOIN models m ON lf.civitai_model_id = m.civitai_id +LEFT JOIN model_versions mv ON lf.civitai_version_id = mv.civitai_id +LEFT JOIN creators c ON m.creator_id = c.id;