diff --git a/tensors/db.py b/tensors/db.py index 63ec993..7089fa2 100644 --- a/tensors/db.py +++ b/tensors/db.py @@ -582,6 +582,166 @@ class Database: ) return [row["word"] for row in cur.fetchall()] + # ========================================================================= + # HuggingFace Cache Operations + # ========================================================================= + + def cache_hf_model(self, data: dict[str, Any]) -> int: + """Cache HuggingFace model data. + + Args: + data: Model data dict with keys like repo_id, author, downloads, etc. + + Returns the internal model ID. + """ + cur = self.conn.cursor() + + repo_id = data.get("repo_id") or data.get("id") or data.get("modelId") + if not repo_id: + raise ValueError("repo_id is required") + + # Parse author from repo_id if not provided + author = data.get("author") + model_name = repo_id + if "/" in repo_id: + parts = repo_id.split("/", 1) + author = author or parts[0] + model_name = parts[1] + + # Check if model exists + cur.execute("SELECT id FROM hf_models WHERE repo_id = ?", (repo_id,)) + existing = cur.fetchone() + + if existing: + model_id = int(existing["id"]) + cur.execute( + """ + UPDATE hf_models SET + author = ?, model_name = ?, pipeline_tag = ?, library_name = ?, + downloads = ?, likes = ?, trending_score = ?, + is_private = ?, is_gated = ?, last_modified = ?, + updated_at = datetime('now') + WHERE id = ? + """, + ( + author, + model_name, + data.get("pipeline_tag"), + data.get("library_name"), + data.get("downloads", 0), + data.get("likes", 0), + data.get("trending_score"), + 1 if data.get("private") else 0, + 1 if data.get("gated") else 0, + data.get("last_modified") or data.get("lastModified"), + model_id, + ), + ) + else: + cur.execute( + """ + INSERT INTO hf_models ( + repo_id, author, model_name, pipeline_tag, library_name, + downloads, likes, trending_score, is_private, is_gated, + last_modified, created_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + repo_id, + author, + model_name, + data.get("pipeline_tag"), + data.get("library_name"), + data.get("downloads", 0), + data.get("likes", 0), + data.get("trending_score"), + 1 if data.get("private") else 0, + 1 if data.get("gated") else 0, + data.get("last_modified") or data.get("lastModified"), + data.get("created_at") or data.get("createdAt"), + ), + ) + model_id = cur.lastrowid or 0 + + # Cache tags + for tag in data.get("tags", []): + cur.execute( + "INSERT OR IGNORE INTO hf_model_tags (hf_model_id, tag) VALUES (?, ?)", + (model_id, tag), + ) + + # Cache safetensor files + for file_info in data.get("safetensor_files", []): + if isinstance(file_info, str): + cur.execute( + "INSERT OR IGNORE INTO hf_safetensor_files (hf_model_id, filename) VALUES (?, ?)", + (model_id, file_info), + ) + elif isinstance(file_info, dict): + cur.execute( + """ + INSERT OR IGNORE INTO hf_safetensor_files (hf_model_id, filename, size_bytes) + VALUES (?, ?, ?) + """, + (model_id, file_info.get("filename"), file_info.get("size")), + ) + + self.conn.commit() + return model_id + + def search_hf_models( + self, + query: str | None = None, + author: str | None = None, + pipeline_tag: str | None = None, + limit: int = 20, + ) -> list[dict[str, Any]]: + """Search cached HuggingFace models.""" + cur = self.conn.cursor() + + sql = "SELECT * FROM v_hf_models WHERE 1=1" + params: list[Any] = [] + + if query: + sql += " AND (repo_id LIKE ? OR model_name LIKE ?)" + params.extend([f"%{query}%", f"%{query}%"]) + + if author: + sql += " AND author = ?" + params.append(author) + + if pipeline_tag: + sql += " AND pipeline_tag = ?" + params.append(pipeline_tag) + + sql += " ORDER BY downloads DESC LIMIT ?" + params.append(limit) + + cur.execute(sql, params) + return [dict(row) for row in cur.fetchall()] + + def get_hf_model(self, repo_id: str) -> dict[str, Any] | None: + """Get cached HF model by repo_id.""" + cur = self.conn.cursor() + cur.execute("SELECT * FROM v_hf_models WHERE repo_id = ?", (repo_id,)) + row = cur.fetchone() + return dict(row) if row else None + + def get_hf_safetensor_files(self, repo_id: str) -> list[dict[str, Any]]: + """Get safetensor files for an HF model.""" + cur = self.conn.cursor() + cur.execute( + """ + SELECT hsf.filename, hsf.size_bytes + FROM hf_safetensor_files hsf + JOIN hf_models hm ON hsf.hf_model_id = hm.id + WHERE hm.repo_id = ? + ORDER BY hsf.filename + """, + (repo_id,), + ) + return [dict(row) for row in cur.fetchall()] + # ========================================================================= # Statistics # ========================================================================= @@ -598,6 +758,8 @@ class Database: "trained_words", "creators", "tags", + "hf_models", + "hf_safetensor_files", ]: cur.execute(f"SELECT COUNT(*) FROM {table}") stats[table] = cur.fetchone()[0] diff --git a/tensors/schema.sql b/tensors/schema.sql index 80412a6..78b0fc1 100644 --- a/tensors/schema.sql +++ b/tensors/schema.sql @@ -217,6 +217,52 @@ CREATE TABLE IF NOT EXISTS image_resources ( CREATE INDEX IF NOT EXISTS idx_image_resources_image ON image_resources(image_id); +-- ============================================================================ +-- HuggingFace Cache Tables +-- ============================================================================ + +CREATE TABLE IF NOT EXISTS hf_models ( + id INTEGER PRIMARY KEY, + repo_id TEXT NOT NULL UNIQUE, + author TEXT, + model_name TEXT NOT NULL, + pipeline_tag TEXT, + library_name TEXT, + downloads INTEGER DEFAULT 0, + likes INTEGER DEFAULT 0, + trending_score REAL, + is_private INTEGER DEFAULT 0, + is_gated INTEGER DEFAULT 0, + last_modified TEXT, + created_at TEXT, + cached_at TEXT DEFAULT (datetime('now')), + updated_at TEXT DEFAULT (datetime('now')) +); + +CREATE INDEX IF NOT EXISTS idx_hf_models_repo ON hf_models(repo_id); +CREATE INDEX IF NOT EXISTS idx_hf_models_author ON hf_models(author); +CREATE INDEX IF NOT EXISTS idx_hf_models_downloads ON hf_models(downloads); + +CREATE TABLE IF NOT EXISTS hf_model_tags ( + hf_model_id INTEGER NOT NULL, + tag TEXT NOT NULL, + PRIMARY KEY (hf_model_id, tag), + FOREIGN KEY (hf_model_id) REFERENCES hf_models(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_hf_model_tags_model ON hf_model_tags(hf_model_id); + +CREATE TABLE IF NOT EXISTS hf_safetensor_files ( + id INTEGER PRIMARY KEY, + hf_model_id INTEGER NOT NULL, + filename TEXT NOT NULL, + size_bytes INTEGER, + FOREIGN KEY (hf_model_id) REFERENCES hf_models(id) ON DELETE CASCADE, + UNIQUE(hf_model_id, filename) +); + +CREATE INDEX IF NOT EXISTS idx_hf_files_model ON hf_safetensor_files(hf_model_id); + -- ============================================================================ -- Views -- ============================================================================ @@ -237,6 +283,24 @@ FROM models m LEFT JOIN creators c ON m.creator_id = c.id LEFT JOIN model_versions mv ON mv.model_id = m.id AND mv.version_index = 0; +CREATE VIEW IF NOT EXISTS v_hf_models AS +SELECT + hm.id, + hm.repo_id, + hm.author, + hm.model_name, + hm.pipeline_tag, + hm.downloads, + hm.likes, + hm.is_gated, + hm.last_modified, + GROUP_CONCAT(DISTINCT hmt.tag) as tags, + COUNT(DISTINCT hsf.id) as safetensor_count +FROM hf_models hm +LEFT JOIN hf_model_tags hmt ON hm.id = hmt.hf_model_id +LEFT JOIN hf_safetensor_files hsf ON hm.id = hsf.hf_model_id +GROUP BY hm.id; + CREATE VIEW IF NOT EXISTS v_local_files_full AS SELECT lf.id, diff --git a/tensors/server/search_routes.py b/tensors/server/search_routes.py index 9337545..839ca39 100644 --- a/tensors/server/search_routes.py +++ b/tensors/server/search_routes.py @@ -29,6 +29,7 @@ from tensors.config import ( from tensors.config import ( SortOrder as SortOrderEnum, ) +from tensors.db import Database from tensors.hf import search_hf_models logger = logging.getLogger(__name__) @@ -151,6 +152,14 @@ async def search_models( if hf_results: results["huggingface"] = hf_results + # Cache HF models + try: + with Database() as db: + db.init_schema() + for model_data in hf_results: + db.cache_hf_model(model_data) + except Exception as e: + logger.warning("Failed to cache HF search results: %s", e) return results