Add HuggingFace model caching to database

- New tables: hf_models, hf_model_tags, hf_safetensor_files
- Cache HF search results automatically
- Add search_hf_models(), get_hf_model() methods
- Include hf_models in stats

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Adam Ladachowski
2026-02-15 21:18:24 +01:00
parent d79861df53
commit 08e612ffa0
3 changed files with 235 additions and 0 deletions
+162
View File
@@ -582,6 +582,166 @@ class Database:
)
return [row["word"] for row in cur.fetchall()]
# =========================================================================
# HuggingFace Cache Operations
# =========================================================================
def cache_hf_model(self, data: dict[str, Any]) -> int:
"""Cache HuggingFace model data.
Args:
data: Model data dict with keys like repo_id, author, downloads, etc.
Returns the internal model ID.
"""
cur = self.conn.cursor()
repo_id = data.get("repo_id") or data.get("id") or data.get("modelId")
if not repo_id:
raise ValueError("repo_id is required")
# Parse author from repo_id if not provided
author = data.get("author")
model_name = repo_id
if "/" in repo_id:
parts = repo_id.split("/", 1)
author = author or parts[0]
model_name = parts[1]
# Check if model exists
cur.execute("SELECT id FROM hf_models WHERE repo_id = ?", (repo_id,))
existing = cur.fetchone()
if existing:
model_id = int(existing["id"])
cur.execute(
"""
UPDATE hf_models SET
author = ?, model_name = ?, pipeline_tag = ?, library_name = ?,
downloads = ?, likes = ?, trending_score = ?,
is_private = ?, is_gated = ?, last_modified = ?,
updated_at = datetime('now')
WHERE id = ?
""",
(
author,
model_name,
data.get("pipeline_tag"),
data.get("library_name"),
data.get("downloads", 0),
data.get("likes", 0),
data.get("trending_score"),
1 if data.get("private") else 0,
1 if data.get("gated") else 0,
data.get("last_modified") or data.get("lastModified"),
model_id,
),
)
else:
cur.execute(
"""
INSERT INTO hf_models (
repo_id, author, model_name, pipeline_tag, library_name,
downloads, likes, trending_score, is_private, is_gated,
last_modified, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
repo_id,
author,
model_name,
data.get("pipeline_tag"),
data.get("library_name"),
data.get("downloads", 0),
data.get("likes", 0),
data.get("trending_score"),
1 if data.get("private") else 0,
1 if data.get("gated") else 0,
data.get("last_modified") or data.get("lastModified"),
data.get("created_at") or data.get("createdAt"),
),
)
model_id = cur.lastrowid or 0
# Cache tags
for tag in data.get("tags", []):
cur.execute(
"INSERT OR IGNORE INTO hf_model_tags (hf_model_id, tag) VALUES (?, ?)",
(model_id, tag),
)
# Cache safetensor files
for file_info in data.get("safetensor_files", []):
if isinstance(file_info, str):
cur.execute(
"INSERT OR IGNORE INTO hf_safetensor_files (hf_model_id, filename) VALUES (?, ?)",
(model_id, file_info),
)
elif isinstance(file_info, dict):
cur.execute(
"""
INSERT OR IGNORE INTO hf_safetensor_files (hf_model_id, filename, size_bytes)
VALUES (?, ?, ?)
""",
(model_id, file_info.get("filename"), file_info.get("size")),
)
self.conn.commit()
return model_id
def search_hf_models(
self,
query: str | None = None,
author: str | None = None,
pipeline_tag: str | None = None,
limit: int = 20,
) -> list[dict[str, Any]]:
"""Search cached HuggingFace models."""
cur = self.conn.cursor()
sql = "SELECT * FROM v_hf_models WHERE 1=1"
params: list[Any] = []
if query:
sql += " AND (repo_id LIKE ? OR model_name LIKE ?)"
params.extend([f"%{query}%", f"%{query}%"])
if author:
sql += " AND author = ?"
params.append(author)
if pipeline_tag:
sql += " AND pipeline_tag = ?"
params.append(pipeline_tag)
sql += " ORDER BY downloads DESC LIMIT ?"
params.append(limit)
cur.execute(sql, params)
return [dict(row) for row in cur.fetchall()]
def get_hf_model(self, repo_id: str) -> dict[str, Any] | None:
"""Get cached HF model by repo_id."""
cur = self.conn.cursor()
cur.execute("SELECT * FROM v_hf_models WHERE repo_id = ?", (repo_id,))
row = cur.fetchone()
return dict(row) if row else None
def get_hf_safetensor_files(self, repo_id: str) -> list[dict[str, Any]]:
"""Get safetensor files for an HF model."""
cur = self.conn.cursor()
cur.execute(
"""
SELECT hsf.filename, hsf.size_bytes
FROM hf_safetensor_files hsf
JOIN hf_models hm ON hsf.hf_model_id = hm.id
WHERE hm.repo_id = ?
ORDER BY hsf.filename
""",
(repo_id,),
)
return [dict(row) for row in cur.fetchall()]
# =========================================================================
# Statistics
# =========================================================================
@@ -598,6 +758,8 @@ class Database:
"trained_words",
"creators",
"tags",
"hf_models",
"hf_safetensor_files",
]:
cur.execute(f"SELECT COUNT(*) FROM {table}")
stats[table] = cur.fetchone()[0]
+64
View File
@@ -217,6 +217,52 @@ CREATE TABLE IF NOT EXISTS image_resources (
CREATE INDEX IF NOT EXISTS idx_image_resources_image ON image_resources(image_id);
-- ============================================================================
-- HuggingFace Cache Tables
-- ============================================================================
CREATE TABLE IF NOT EXISTS hf_models (
id INTEGER PRIMARY KEY,
repo_id TEXT NOT NULL UNIQUE,
author TEXT,
model_name TEXT NOT NULL,
pipeline_tag TEXT,
library_name TEXT,
downloads INTEGER DEFAULT 0,
likes INTEGER DEFAULT 0,
trending_score REAL,
is_private INTEGER DEFAULT 0,
is_gated INTEGER DEFAULT 0,
last_modified TEXT,
created_at TEXT,
cached_at TEXT DEFAULT (datetime('now')),
updated_at TEXT DEFAULT (datetime('now'))
);
CREATE INDEX IF NOT EXISTS idx_hf_models_repo ON hf_models(repo_id);
CREATE INDEX IF NOT EXISTS idx_hf_models_author ON hf_models(author);
CREATE INDEX IF NOT EXISTS idx_hf_models_downloads ON hf_models(downloads);
CREATE TABLE IF NOT EXISTS hf_model_tags (
hf_model_id INTEGER NOT NULL,
tag TEXT NOT NULL,
PRIMARY KEY (hf_model_id, tag),
FOREIGN KEY (hf_model_id) REFERENCES hf_models(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_hf_model_tags_model ON hf_model_tags(hf_model_id);
CREATE TABLE IF NOT EXISTS hf_safetensor_files (
id INTEGER PRIMARY KEY,
hf_model_id INTEGER NOT NULL,
filename TEXT NOT NULL,
size_bytes INTEGER,
FOREIGN KEY (hf_model_id) REFERENCES hf_models(id) ON DELETE CASCADE,
UNIQUE(hf_model_id, filename)
);
CREATE INDEX IF NOT EXISTS idx_hf_files_model ON hf_safetensor_files(hf_model_id);
-- ============================================================================
-- Views
-- ============================================================================
@@ -237,6 +283,24 @@ FROM models m
LEFT JOIN creators c ON m.creator_id = c.id
LEFT JOIN model_versions mv ON mv.model_id = m.id AND mv.version_index = 0;
CREATE VIEW IF NOT EXISTS v_hf_models AS
SELECT
hm.id,
hm.repo_id,
hm.author,
hm.model_name,
hm.pipeline_tag,
hm.downloads,
hm.likes,
hm.is_gated,
hm.last_modified,
GROUP_CONCAT(DISTINCT hmt.tag) as tags,
COUNT(DISTINCT hsf.id) as safetensor_count
FROM hf_models hm
LEFT JOIN hf_model_tags hmt ON hm.id = hmt.hf_model_id
LEFT JOIN hf_safetensor_files hsf ON hm.id = hsf.hf_model_id
GROUP BY hm.id;
CREATE VIEW IF NOT EXISTS v_local_files_full AS
SELECT
lf.id,
+9
View File
@@ -29,6 +29,7 @@ from tensors.config import (
from tensors.config import (
SortOrder as SortOrderEnum,
)
from tensors.db import Database
from tensors.hf import search_hf_models
logger = logging.getLogger(__name__)
@@ -151,6 +152,14 @@ async def search_models(
if hf_results:
results["huggingface"] = hf_results
# Cache HF models
try:
with Database() as db:
db.init_schema()
for model_data in hf_results:
db.cache_hf_model(model_data)
except Exception as e:
logger.warning("Failed to cache HF search results: %s", e)
return results