Add HuggingFace model caching to database
- New tables: hf_models, hf_model_tags, hf_safetensor_files - Cache HF search results automatically - Add search_hf_models(), get_hf_model() methods - Include hf_models in stats Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
+162
@@ -582,6 +582,166 @@ class Database:
|
||||
)
|
||||
return [row["word"] for row in cur.fetchall()]
|
||||
|
||||
# =========================================================================
|
||||
# HuggingFace Cache Operations
|
||||
# =========================================================================
|
||||
|
||||
def cache_hf_model(self, data: dict[str, Any]) -> int:
|
||||
"""Cache HuggingFace model data.
|
||||
|
||||
Args:
|
||||
data: Model data dict with keys like repo_id, author, downloads, etc.
|
||||
|
||||
Returns the internal model ID.
|
||||
"""
|
||||
cur = self.conn.cursor()
|
||||
|
||||
repo_id = data.get("repo_id") or data.get("id") or data.get("modelId")
|
||||
if not repo_id:
|
||||
raise ValueError("repo_id is required")
|
||||
|
||||
# Parse author from repo_id if not provided
|
||||
author = data.get("author")
|
||||
model_name = repo_id
|
||||
if "/" in repo_id:
|
||||
parts = repo_id.split("/", 1)
|
||||
author = author or parts[0]
|
||||
model_name = parts[1]
|
||||
|
||||
# Check if model exists
|
||||
cur.execute("SELECT id FROM hf_models WHERE repo_id = ?", (repo_id,))
|
||||
existing = cur.fetchone()
|
||||
|
||||
if existing:
|
||||
model_id = int(existing["id"])
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE hf_models SET
|
||||
author = ?, model_name = ?, pipeline_tag = ?, library_name = ?,
|
||||
downloads = ?, likes = ?, trending_score = ?,
|
||||
is_private = ?, is_gated = ?, last_modified = ?,
|
||||
updated_at = datetime('now')
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
author,
|
||||
model_name,
|
||||
data.get("pipeline_tag"),
|
||||
data.get("library_name"),
|
||||
data.get("downloads", 0),
|
||||
data.get("likes", 0),
|
||||
data.get("trending_score"),
|
||||
1 if data.get("private") else 0,
|
||||
1 if data.get("gated") else 0,
|
||||
data.get("last_modified") or data.get("lastModified"),
|
||||
model_id,
|
||||
),
|
||||
)
|
||||
else:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO hf_models (
|
||||
repo_id, author, model_name, pipeline_tag, library_name,
|
||||
downloads, likes, trending_score, is_private, is_gated,
|
||||
last_modified, created_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
repo_id,
|
||||
author,
|
||||
model_name,
|
||||
data.get("pipeline_tag"),
|
||||
data.get("library_name"),
|
||||
data.get("downloads", 0),
|
||||
data.get("likes", 0),
|
||||
data.get("trending_score"),
|
||||
1 if data.get("private") else 0,
|
||||
1 if data.get("gated") else 0,
|
||||
data.get("last_modified") or data.get("lastModified"),
|
||||
data.get("created_at") or data.get("createdAt"),
|
||||
),
|
||||
)
|
||||
model_id = cur.lastrowid or 0
|
||||
|
||||
# Cache tags
|
||||
for tag in data.get("tags", []):
|
||||
cur.execute(
|
||||
"INSERT OR IGNORE INTO hf_model_tags (hf_model_id, tag) VALUES (?, ?)",
|
||||
(model_id, tag),
|
||||
)
|
||||
|
||||
# Cache safetensor files
|
||||
for file_info in data.get("safetensor_files", []):
|
||||
if isinstance(file_info, str):
|
||||
cur.execute(
|
||||
"INSERT OR IGNORE INTO hf_safetensor_files (hf_model_id, filename) VALUES (?, ?)",
|
||||
(model_id, file_info),
|
||||
)
|
||||
elif isinstance(file_info, dict):
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO hf_safetensor_files (hf_model_id, filename, size_bytes)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(model_id, file_info.get("filename"), file_info.get("size")),
|
||||
)
|
||||
|
||||
self.conn.commit()
|
||||
return model_id
|
||||
|
||||
def search_hf_models(
|
||||
self,
|
||||
query: str | None = None,
|
||||
author: str | None = None,
|
||||
pipeline_tag: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search cached HuggingFace models."""
|
||||
cur = self.conn.cursor()
|
||||
|
||||
sql = "SELECT * FROM v_hf_models WHERE 1=1"
|
||||
params: list[Any] = []
|
||||
|
||||
if query:
|
||||
sql += " AND (repo_id LIKE ? OR model_name LIKE ?)"
|
||||
params.extend([f"%{query}%", f"%{query}%"])
|
||||
|
||||
if author:
|
||||
sql += " AND author = ?"
|
||||
params.append(author)
|
||||
|
||||
if pipeline_tag:
|
||||
sql += " AND pipeline_tag = ?"
|
||||
params.append(pipeline_tag)
|
||||
|
||||
sql += " ORDER BY downloads DESC LIMIT ?"
|
||||
params.append(limit)
|
||||
|
||||
cur.execute(sql, params)
|
||||
return [dict(row) for row in cur.fetchall()]
|
||||
|
||||
def get_hf_model(self, repo_id: str) -> dict[str, Any] | None:
|
||||
"""Get cached HF model by repo_id."""
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("SELECT * FROM v_hf_models WHERE repo_id = ?", (repo_id,))
|
||||
row = cur.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def get_hf_safetensor_files(self, repo_id: str) -> list[dict[str, Any]]:
|
||||
"""Get safetensor files for an HF model."""
|
||||
cur = self.conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT hsf.filename, hsf.size_bytes
|
||||
FROM hf_safetensor_files hsf
|
||||
JOIN hf_models hm ON hsf.hf_model_id = hm.id
|
||||
WHERE hm.repo_id = ?
|
||||
ORDER BY hsf.filename
|
||||
""",
|
||||
(repo_id,),
|
||||
)
|
||||
return [dict(row) for row in cur.fetchall()]
|
||||
|
||||
# =========================================================================
|
||||
# Statistics
|
||||
# =========================================================================
|
||||
@@ -598,6 +758,8 @@ class Database:
|
||||
"trained_words",
|
||||
"creators",
|
||||
"tags",
|
||||
"hf_models",
|
||||
"hf_safetensor_files",
|
||||
]:
|
||||
cur.execute(f"SELECT COUNT(*) FROM {table}")
|
||||
stats[table] = cur.fetchone()[0]
|
||||
|
||||
@@ -217,6 +217,52 @@ CREATE TABLE IF NOT EXISTS image_resources (
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_image_resources_image ON image_resources(image_id);
|
||||
|
||||
-- ============================================================================
|
||||
-- HuggingFace Cache Tables
|
||||
-- ============================================================================
|
||||
|
||||
CREATE TABLE IF NOT EXISTS hf_models (
|
||||
id INTEGER PRIMARY KEY,
|
||||
repo_id TEXT NOT NULL UNIQUE,
|
||||
author TEXT,
|
||||
model_name TEXT NOT NULL,
|
||||
pipeline_tag TEXT,
|
||||
library_name TEXT,
|
||||
downloads INTEGER DEFAULT 0,
|
||||
likes INTEGER DEFAULT 0,
|
||||
trending_score REAL,
|
||||
is_private INTEGER DEFAULT 0,
|
||||
is_gated INTEGER DEFAULT 0,
|
||||
last_modified TEXT,
|
||||
created_at TEXT,
|
||||
cached_at TEXT DEFAULT (datetime('now')),
|
||||
updated_at TEXT DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_hf_models_repo ON hf_models(repo_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_hf_models_author ON hf_models(author);
|
||||
CREATE INDEX IF NOT EXISTS idx_hf_models_downloads ON hf_models(downloads);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS hf_model_tags (
|
||||
hf_model_id INTEGER NOT NULL,
|
||||
tag TEXT NOT NULL,
|
||||
PRIMARY KEY (hf_model_id, tag),
|
||||
FOREIGN KEY (hf_model_id) REFERENCES hf_models(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_hf_model_tags_model ON hf_model_tags(hf_model_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS hf_safetensor_files (
|
||||
id INTEGER PRIMARY KEY,
|
||||
hf_model_id INTEGER NOT NULL,
|
||||
filename TEXT NOT NULL,
|
||||
size_bytes INTEGER,
|
||||
FOREIGN KEY (hf_model_id) REFERENCES hf_models(id) ON DELETE CASCADE,
|
||||
UNIQUE(hf_model_id, filename)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_hf_files_model ON hf_safetensor_files(hf_model_id);
|
||||
|
||||
-- ============================================================================
|
||||
-- Views
|
||||
-- ============================================================================
|
||||
@@ -237,6 +283,24 @@ FROM models m
|
||||
LEFT JOIN creators c ON m.creator_id = c.id
|
||||
LEFT JOIN model_versions mv ON mv.model_id = m.id AND mv.version_index = 0;
|
||||
|
||||
CREATE VIEW IF NOT EXISTS v_hf_models AS
|
||||
SELECT
|
||||
hm.id,
|
||||
hm.repo_id,
|
||||
hm.author,
|
||||
hm.model_name,
|
||||
hm.pipeline_tag,
|
||||
hm.downloads,
|
||||
hm.likes,
|
||||
hm.is_gated,
|
||||
hm.last_modified,
|
||||
GROUP_CONCAT(DISTINCT hmt.tag) as tags,
|
||||
COUNT(DISTINCT hsf.id) as safetensor_count
|
||||
FROM hf_models hm
|
||||
LEFT JOIN hf_model_tags hmt ON hm.id = hmt.hf_model_id
|
||||
LEFT JOIN hf_safetensor_files hsf ON hm.id = hsf.hf_model_id
|
||||
GROUP BY hm.id;
|
||||
|
||||
CREATE VIEW IF NOT EXISTS v_local_files_full AS
|
||||
SELECT
|
||||
lf.id,
|
||||
|
||||
@@ -29,6 +29,7 @@ from tensors.config import (
|
||||
from tensors.config import (
|
||||
SortOrder as SortOrderEnum,
|
||||
)
|
||||
from tensors.db import Database
|
||||
from tensors.hf import search_hf_models
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -151,6 +152,14 @@ async def search_models(
|
||||
|
||||
if hf_results:
|
||||
results["huggingface"] = hf_results
|
||||
# Cache HF models
|
||||
try:
|
||||
with Database() as db:
|
||||
db.init_schema()
|
||||
for model_data in hf_results:
|
||||
db.cache_hf_model(model_data)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cache HF search results: %s", e)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
Reference in New Issue
Block a user