diff --git a/.coverage b/.coverage index 8bd6919..0c1aa4d 100644 Binary files a/.coverage and b/.coverage differ diff --git a/pyproject.toml b/pyproject.toml index ddf77af..31c1ffe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "typer>=0.15.0", "websocket-client>=1.9.0", "huggingface_hub>=0.25.0", + "sqlmodel>=0.0.33", ] [project.optional-dependencies] diff --git a/tensors/db.py b/tensors/db.py index 7089fa2..d2e5e56 100644 --- a/tensors/db.py +++ b/tensors/db.py @@ -1,48 +1,67 @@ -"""SQLite database for local model metadata and CivitAI cache.""" +"""SQLModel database for local model metadata and CivitAI/HuggingFace cache.""" from __future__ import annotations import json -import sqlite3 -from pathlib import Path +from datetime import datetime from typing import TYPE_CHECKING, Any +from sqlmodel import Session, col, func, select + from tensors.config import DATA_DIR +from tensors.models import ( + Creator, + FileHash, + HFModel, + HFModelTag, + HFSafetensorFile, + ImageGenerationParam, + ImageResource, + LocalFile, + Model, + ModelTag, + ModelVersion, + SafetensorMetadata, + Tag, + TrainedWord, + VersionFile, + VersionImage, + create_tables, + get_engine, +) from tensors.safetensor import compute_sha256, read_safetensor_metadata if TYPE_CHECKING: + from pathlib import Path + from rich.console import Console + from sqlalchemy import Engine # Database location DB_PATH = DATA_DIR / "models.db" -# Load schema from file -_SCHEMA_PATH = Path(__file__).parent / "schema.sql" - class Database: - """SQLite database wrapper for models metadata.""" + """SQLModel database wrapper for models metadata.""" def __init__(self, db_path: Path | None = None) -> None: """Initialize database connection.""" self.db_path = db_path or DB_PATH self.db_path.parent.mkdir(parents=True, exist_ok=True) - self._conn: sqlite3.Connection | None = None + self._engine: Engine | None = None @property - def conn(self) -> sqlite3.Connection: - """Get or create database connection.""" - if self._conn is None: - self._conn = sqlite3.connect(self.db_path) - self._conn.row_factory = sqlite3.Row - self._conn.execute("PRAGMA foreign_keys = ON") - return self._conn + def engine(self) -> Engine: + """Get or create database engine.""" + if self._engine is None: + self._engine = get_engine(str(self.db_path)) + return self._engine def close(self) -> None: """Close database connection.""" - if self._conn is not None: - self._conn.close() - self._conn = None + if self._engine is not None: + self._engine.dispose() + self._engine = None def __enter__(self) -> Database: return self @@ -51,10 +70,12 @@ class Database: self.close() def init_schema(self) -> None: - """Initialize database schema from schema.sql.""" - schema = _SCHEMA_PATH.read_text() - self.conn.executescript(schema) - self.conn.commit() + """Initialize database schema.""" + create_tables(self.engine) + + def session(self) -> Session: + """Create a new session.""" + return Session(self.engine) # ========================================================================= # Local Files Operations @@ -65,10 +86,7 @@ class Database: directory: Path, console: Console | None = None, ) -> list[dict[str, Any]]: - """Scan directory for safetensor files and add to database. - - Returns list of scanned file info dicts. - """ + """Scan directory for safetensor files and add to database.""" results: list[dict[str, Any]] = [] safetensor_files = list(directory.rglob("*.safetensors")) @@ -80,18 +98,20 @@ class Database: sha256 = compute_sha256(path) metadata = read_safetensor_metadata(path) - file_info = self._upsert_local_file( - file_path=str(path.resolve()), - sha256=sha256, - header_size=metadata.get("header_size"), - tensor_count=metadata.get("tensor_count"), - ) + with self.session() as session: + file_info = self._upsert_local_file( + session, + file_path=str(path.resolve()), + sha256=sha256, + header_size=metadata.get("header_size"), + tensor_count=metadata.get("tensor_count"), + ) + self._store_safetensor_metadata(session, file_info.id, metadata.get("metadata", {})) + session.commit() + # Extract values before session closes + result = {"id": file_info.id, "file_path": file_info.file_path, "sha256": file_info.sha256} - # Store safetensor metadata - self._store_safetensor_metadata(file_info["id"], metadata.get("metadata", {})) - - results.append(file_info) - self.conn.commit() + results.append(result) except Exception as e: if console: @@ -101,100 +121,133 @@ class Database: def _upsert_local_file( self, + session: Session, file_path: str, sha256: str, header_size: int | None = None, tensor_count: int | None = None, - ) -> dict[str, Any]: + ) -> LocalFile: """Insert or update a local file record.""" - cur = self.conn.cursor() - - cur.execute("SELECT id FROM local_files WHERE file_path = ?", (file_path,)) - existing = cur.fetchone() + existing = session.exec(select(LocalFile).where(LocalFile.file_path == file_path)).first() if existing: - cur.execute( - """ - UPDATE local_files SET sha256 = ?, header_size = ?, tensor_count = ?, - updated_at = datetime('now') WHERE id = ? - """, - (sha256, header_size, tensor_count, existing["id"]), - ) - file_id = existing["id"] - else: - cur.execute( - """ - INSERT INTO local_files (file_path, sha256, header_size, tensor_count) - VALUES (?, ?, ?, ?) - """, - (file_path, sha256, header_size, tensor_count), - ) - file_id = cur.lastrowid or 0 # lastrowid is always set after INSERT + existing.sha256 = sha256 + existing.header_size = header_size + existing.tensor_count = tensor_count + existing.updated_at = datetime.utcnow() + session.add(existing) + return existing - return {"id": file_id, "file_path": file_path, "sha256": sha256} + local_file = LocalFile( + file_path=file_path, + sha256=sha256, + header_size=header_size, + tensor_count=tensor_count, + ) + session.add(local_file) + session.flush() + return local_file - def _store_safetensor_metadata(self, local_file_id: int, metadata: dict[str, Any]) -> None: + def _store_safetensor_metadata(self, session: Session, local_file_id: int | None, metadata: dict[str, Any]) -> None: """Store safetensor header metadata.""" - cur = self.conn.cursor() + if not local_file_id: + return for key, value in metadata.items(): str_value = json.dumps(value) if not isinstance(value, str) else value - cur.execute( - """ - INSERT INTO safetensor_metadata (local_file_id, key, value) - VALUES (?, ?, ?) - ON CONFLICT(local_file_id, key) DO UPDATE SET value = excluded.value - """, - (local_file_id, key, str_value), - ) + existing = session.exec( + select(SafetensorMetadata).where(SafetensorMetadata.local_file_id == local_file_id, SafetensorMetadata.key == key) + ).first() + if existing: + existing.value = str_value + session.add(existing) + else: + session.add(SafetensorMetadata(local_file_id=local_file_id, key=key, value=str_value)) def list_local_files(self) -> list[dict[str, Any]]: """List all local files with CivitAI info.""" - cur = self.conn.cursor() - cur.execute("SELECT * FROM v_local_files_full ORDER BY file_path") - return [dict(row) for row in cur.fetchall()] + with self.session() as session: + files = session.exec(select(LocalFile)).all() + results = [] + for f in files: + model = None + if f.civitai_model_id: + model = session.exec(select(Model).where(Model.civitai_id == f.civitai_model_id)).first() + version = None + if f.civitai_version_id: + version = session.exec(select(ModelVersion).where(ModelVersion.civitai_id == f.civitai_version_id)).first() + creator = None + if model and model.creator_id: + creator = session.exec(select(Creator).where(Creator.id == model.creator_id)).first() + results.append( + { + "id": f.id, + "file_path": f.file_path, + "sha256": f.sha256, + "header_size": f.header_size, + "tensor_count": f.tensor_count, + "civitai_model_id": f.civitai_model_id, + "civitai_version_id": f.civitai_version_id, + "model_name": model.name if model else None, + "model_type": model.type if model else None, + "version_name": version.name if version else None, + "base_model": version.base_model if version else None, + "creator": creator.username if creator else None, + } + ) + return results def get_local_file_by_path(self, file_path: str) -> dict[str, Any] | None: """Get local file by path.""" - cur = self.conn.cursor() - cur.execute("SELECT * FROM v_local_files_full WHERE file_path = ?", (file_path,)) - row = cur.fetchone() - return dict(row) if row else None + with self.session() as session: + f = session.exec(select(LocalFile).where(LocalFile.file_path == file_path)).first() + if not f: + return None + model = None + if f.civitai_model_id: + model = session.exec(select(Model).where(Model.civitai_id == f.civitai_model_id)).first() + version = None + if f.civitai_version_id: + version = session.exec(select(ModelVersion).where(ModelVersion.civitai_id == f.civitai_version_id)).first() + creator = None + if model and model.creator_id: + creator = session.exec(select(Creator).where(Creator.id == model.creator_id)).first() + return { + "id": f.id, + "file_path": f.file_path, + "sha256": f.sha256, + "civitai_model_id": f.civitai_model_id, + "civitai_version_id": f.civitai_version_id, + "model_name": model.name if model else None, + "model_type": model.type if model else None, + "version_name": version.name if version else None, + "base_model": version.base_model if version else None, + "creator": creator.username if creator else None, + } def get_local_file_by_hash(self, sha256: str) -> dict[str, Any] | None: """Get local file by SHA256 hash.""" - cur = self.conn.cursor() - cur.execute("SELECT * FROM v_local_files_full WHERE sha256 = ?", (sha256.upper(),)) - row = cur.fetchone() - return dict(row) if row else None + with self.session() as session: + f = session.exec(select(LocalFile).where(LocalFile.sha256 == sha256.upper())).first() + if not f: + return None + return {"id": f.id, "file_path": f.file_path, "sha256": f.sha256} def get_unlinked_files(self) -> list[dict[str, Any]]: """Get local files not linked to CivitAI.""" - cur = self.conn.cursor() - cur.execute( - """ - SELECT id, file_path, sha256 FROM local_files - WHERE civitai_model_id IS NULL - """ - ) - return [dict(row) for row in cur.fetchall()] + with self.session() as session: + files = session.exec(select(LocalFile).where(LocalFile.civitai_model_id == None)).all() # noqa: E711 + return [{"id": f.id, "file_path": f.file_path, "sha256": f.sha256} for f in files] - def link_file_to_civitai( - self, - file_id: int, - model_id: int, - version_id: int, - ) -> None: + def link_file_to_civitai(self, file_id: int, model_id: int, version_id: int) -> None: """Link a local file to CivitAI model/version.""" - cur = self.conn.cursor() - cur.execute( - """ - UPDATE local_files - SET civitai_model_id = ?, civitai_version_id = ?, updated_at = datetime('now') - WHERE id = ? - """, - (model_id, version_id, file_id), - ) - self.conn.commit() + with self.session() as session: + f = session.get(LocalFile, file_id) + if f: + f.civitai_model_id = model_id + f.civitai_version_id = version_id + f.updated_at = datetime.utcnow() + session.add(f) + session.commit() # ========================================================================= # CivitAI Cache Operations @@ -202,111 +255,88 @@ class Database: def get_version_by_hash(self, sha256: str) -> dict[str, Any] | None: """Find cached version by file hash.""" - cur = self.conn.cursor() - cur.execute( - """ - SELECT mv.civitai_id as version_id, m.civitai_id as model_id, - m.name as model_name, mv.name as version_name - FROM file_hashes fh - JOIN version_files vf ON fh.file_id = vf.id - JOIN model_versions mv ON vf.version_id = mv.id - JOIN models m ON mv.model_id = m.id - WHERE UPPER(fh.hash_value) = UPPER(?) - """, - (sha256,), - ) - row = cur.fetchone() - return dict(row) if row else None + with self.session() as session: + fh = session.exec(select(FileHash).where(FileHash.hash_value == sha256.upper())).first() + if not fh: + return None + vf = session.get(VersionFile, fh.file_id) + if not vf: + return None + mv = session.get(ModelVersion, vf.version_id) + if not mv: + return None + m = session.get(Model, mv.model_id) + return { + "version_id": mv.civitai_id, + "model_id": m.civitai_id if m else None, + "model_name": m.name if m else None, + "version_name": mv.name, + } def cache_model(self, data: dict[str, Any]) -> int: - """Cache full model data from CivitAI API response. + """Cache full model data from CivitAI API response.""" + with self.session() as session: + creator_id = self._get_or_create_creator(session, data.get("creator")) + civitai_id = data.get("id") + existing = session.exec(select(Model).where(Model.civitai_id == civitai_id)).first() + stats = data.get("stats", {}) - Returns the internal model ID. - """ - cur = self.conn.cursor() + if existing: + existing.name = data.get("name", existing.name) + existing.description = data.get("description") + existing.type = data.get("type", existing.type) + existing.nsfw = bool(data.get("nsfw")) + existing.download_count = stats.get("downloadCount", 0) + existing.thumbs_up_count = stats.get("thumbsUpCount", 0) + existing.updated_at = datetime.utcnow() + session.add(existing) + model_id = existing.id + else: + model = Model( + civitai_id=civitai_id, + name=data.get("name", ""), + description=data.get("description"), + type=data.get("type", ""), + nsfw=bool(data.get("nsfw")), + poi=bool(data.get("poi")), + minor=bool(data.get("minor")), + sfw_only=bool(data.get("sfwOnly")), + nsfw_level=data.get("nsfwLevel"), + availability=data.get("availability"), + allow_no_credit=bool(data.get("allowNoCredit")), + allow_commercial_use=str(data.get("allowCommercialUse", "")), + allow_derivatives=bool(data.get("allowDerivatives")), + allow_different_license=bool(data.get("allowDifferentLicense")), + supports_generation=bool(data.get("supportsGeneration")), + creator_id=creator_id, + download_count=stats.get("downloadCount", 0), + thumbs_up_count=stats.get("thumbsUpCount", 0), + thumbs_down_count=stats.get("thumbsDownCount", 0), + comment_count=stats.get("commentCount", 0), + tipped_amount_count=stats.get("tippedAmountCount", 0), + ) + session.add(model) + session.flush() + model_id = model.id - # Get or create creator - creator_id = self._get_or_create_creator(data.get("creator")) + # Cache tags + for tag_name in data.get("tags", []): + tag_id = self._get_or_create_tag(session, tag_name) + if model_id and tag_id: + existing_mt = session.exec( + select(ModelTag).where(ModelTag.model_id == model_id, ModelTag.tag_id == tag_id) + ).first() + if not existing_mt: + session.add(ModelTag(model_id=model_id, tag_id=tag_id)) - # Check if model exists - civitai_id = data.get("id") - cur.execute("SELECT id FROM models WHERE civitai_id = ?", (civitai_id,)) - existing = cur.fetchone() + # Cache versions + for idx, version in enumerate(data.get("modelVersions", [])): + self._cache_version(session, model_id, version, idx) - stats = data.get("stats", {}) + session.commit() + return model_id or 0 - if existing: - model_id = int(existing["id"]) - cur.execute( - """ - UPDATE models SET - name = ?, description = ?, type = ?, nsfw = ?, - download_count = ?, thumbs_up_count = ?, - updated_at = datetime('now') - WHERE id = ? - """, - ( - data.get("name"), - data.get("description"), - data.get("type"), - 1 if data.get("nsfw") else 0, - stats.get("downloadCount", 0), - stats.get("thumbsUpCount", 0), - model_id, - ), - ) - else: - cur.execute( - """ - INSERT INTO models ( - civitai_id, name, description, type, nsfw, poi, minor, - sfw_only, nsfw_level, availability, allow_no_credit, - allow_commercial_use, allow_derivatives, allow_different_license, - supports_generation, creator_id, download_count, thumbs_up_count, - thumbs_down_count, comment_count, tipped_amount_count, - created_at, updated_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now')) - """, - ( - civitai_id, - data.get("name"), - data.get("description"), - data.get("type"), - 1 if data.get("nsfw") else 0, - 1 if data.get("poi") else 0, - 1 if data.get("minor") else 0, - 1 if data.get("sfwOnly") else 0, - data.get("nsfwLevel"), - data.get("availability"), - 1 if data.get("allowNoCredit") else 0, - str(data.get("allowCommercialUse", "")), - 1 if data.get("allowDerivatives") else 0, - 1 if data.get("allowDifferentLicense") else 0, - 1 if data.get("supportsGeneration") else 0, - creator_id, - stats.get("downloadCount", 0), - stats.get("thumbsUpCount", 0), - stats.get("thumbsDownCount", 0), - stats.get("commentCount", 0), - stats.get("tippedAmountCount", 0), - data.get("createdAt"), - ), - ) - model_id = cur.lastrowid or 0 # lastrowid is always set after INSERT - - # Cache tags - for tag_name in data.get("tags", []): - tag_id = self._get_or_create_tag(tag_name) - cur.execute("INSERT OR IGNORE INTO model_tags (model_id, tag_id) VALUES (?, ?)", (model_id, tag_id)) - - # Cache versions - for idx, version in enumerate(data.get("modelVersions", [])): - self._cache_version(model_id, version, idx) - - self.conn.commit() - return model_id - - def _get_or_create_creator(self, creator_data: dict[str, Any] | None) -> int | None: + def _get_or_create_creator(self, session: Session, creator_data: dict[str, Any] | None) -> int | None: """Get or create a creator record.""" if not creator_data: return None @@ -314,180 +344,148 @@ class Database: if not username: return None - cur = self.conn.cursor() - cur.execute("SELECT id FROM creators WHERE username = ?", (username,)) - row = cur.fetchone() - if row: - return int(row["id"]) + existing = session.exec(select(Creator).where(Creator.username == username)).first() + if existing: + return existing.id - cur.execute( - "INSERT INTO creators (username, image_url) VALUES (?, ?)", - (username, creator_data.get("image")), - ) - return cur.lastrowid or 0 + creator = Creator(username=username, image_url=creator_data.get("image")) + session.add(creator) + session.flush() + return creator.id - def _get_or_create_tag(self, tag_name: str) -> int: + def _get_or_create_tag(self, session: Session, tag_name: str) -> int | None: """Get or create a tag record.""" - cur = self.conn.cursor() - cur.execute("SELECT id FROM tags WHERE name = ?", (tag_name,)) - row = cur.fetchone() - if row: - return int(row["id"]) + existing = session.exec(select(Tag).where(Tag.name == tag_name)).first() + if existing: + return existing.id - cur.execute("INSERT INTO tags (name) VALUES (?)", (tag_name,)) - return cur.lastrowid or 0 # lastrowid is always set after INSERT + tag = Tag(name=tag_name) + session.add(tag) + session.flush() + return tag.id - def _cache_version(self, model_id: int, version: dict[str, Any], index: int) -> int: + def _cache_version(self, session: Session, model_id: int | None, version: dict[str, Any], index: int) -> int | None: """Cache a model version.""" - cur = self.conn.cursor() + if not model_id: + return None civitai_id = version.get("id") - - cur.execute("SELECT id FROM model_versions WHERE civitai_id = ?", (civitai_id,)) - existing = cur.fetchone() - + existing = session.exec(select(ModelVersion).where(ModelVersion.civitai_id == civitai_id)).first() stats = version.get("stats", {}) if existing: - version_id = int(existing["id"]) + version_id = existing.id else: - cur.execute( - """ - INSERT INTO model_versions ( - civitai_id, model_id, name, description, base_model, - base_model_type, nsfw_level, status, availability, - download_count, thumbs_up_count, thumbs_down_count, - supports_generation, download_url, created_at, published_at, - updated_at, version_index - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - civitai_id, - model_id, - version.get("name"), - version.get("description"), - version.get("baseModel"), - version.get("baseModelType"), - version.get("nsfwLevel"), - version.get("status"), - version.get("availability"), - stats.get("downloadCount", 0), - stats.get("thumbsUpCount", 0), - stats.get("thumbsDownCount", 0), - 1 if version.get("supportsGeneration") else 0, - version.get("downloadUrl"), - version.get("createdAt"), - version.get("publishedAt"), - version.get("updatedAt"), - index, - ), + mv = ModelVersion( + civitai_id=civitai_id, + model_id=model_id, + name=version.get("name", ""), + description=version.get("description"), + base_model=version.get("baseModel"), + base_model_type=version.get("baseModelType"), + nsfw_level=version.get("nsfwLevel"), + status=version.get("status"), + availability=version.get("availability"), + download_count=stats.get("downloadCount", 0), + thumbs_up_count=stats.get("thumbsUpCount", 0), + thumbs_down_count=stats.get("thumbsDownCount", 0), + supports_generation=bool(version.get("supportsGeneration")), + download_url=version.get("downloadUrl"), + version_index=index, ) - version_id = cur.lastrowid or 0 # lastrowid is always set after INSERT + session.add(mv) + session.flush() + version_id = mv.id # Cache trained words for pos, word in enumerate(version.get("trainedWords", [])): - cur.execute( - "INSERT OR IGNORE INTO trained_words (version_id, word, position) VALUES (?, ?, ?)", - (version_id, word, pos), - ) + existing_tw = session.exec( + select(TrainedWord).where(TrainedWord.version_id == version_id, TrainedWord.word == word) + ).first() + if not existing_tw: + session.add(TrainedWord(version_id=version_id, word=word, position=pos)) - # Cache files and hashes + # Cache files for file_data in version.get("files", []): - self._cache_file(version_id, file_data) + self._cache_file(session, version_id, file_data) # Cache images for image_data in version.get("images", []): - self._cache_image(version_id, image_data) + self._cache_image(session, version_id, image_data) return version_id - def _cache_file(self, version_id: int, file_data: dict[str, Any]) -> int | None: + def _cache_file(self, session: Session, version_id: int | None, file_data: dict[str, Any]) -> int | None: """Cache a version file.""" - cur = self.conn.cursor() + if not version_id: + return None civitai_id = file_data.get("id") if not civitai_id: return None - cur.execute("SELECT id FROM version_files WHERE civitai_id = ?", (civitai_id,)) - existing = cur.fetchone() - + existing = session.exec(select(VersionFile).where(VersionFile.civitai_id == civitai_id)).first() if existing: - return int(existing["id"]) + return existing.id meta = file_data.get("metadata", {}) - cur.execute( - """ - INSERT INTO version_files ( - civitai_id, version_id, name, type, size_kb, format, - size_type, fp, is_primary, pickle_scan_result, - virus_scan_result, scanned_at, download_url - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - civitai_id, - version_id, - file_data.get("name"), - file_data.get("type"), - file_data.get("sizeKB"), - meta.get("format"), - meta.get("size"), - meta.get("fp"), - 1 if file_data.get("primary") else 0, - file_data.get("pickleScanResult"), - file_data.get("virusScanResult"), - file_data.get("scannedAt"), - file_data.get("downloadUrl"), - ), + vf = VersionFile( + civitai_id=civitai_id, + version_id=version_id, + name=file_data.get("name", ""), + type=file_data.get("type"), + size_kb=file_data.get("sizeKB"), + format=meta.get("format"), + size_type=meta.get("size"), + fp=meta.get("fp"), + is_primary=bool(file_data.get("primary")), + pickle_scan_result=file_data.get("pickleScanResult"), + virus_scan_result=file_data.get("virusScanResult"), + download_url=file_data.get("downloadUrl"), ) - file_id = cur.lastrowid or 0 # lastrowid is always set after INSERT + session.add(vf) + session.flush() + file_id = vf.id # Cache hashes for hash_type, hash_value in file_data.get("hashes", {}).items(): - cur.execute( - "INSERT OR IGNORE INTO file_hashes (file_id, hash_type, hash_value) VALUES (?, ?, ?)", - (file_id, hash_type, hash_value), - ) + existing_fh = session.exec( + select(FileHash).where(FileHash.file_id == file_id, FileHash.hash_type == hash_type) + ).first() + if not existing_fh: + session.add(FileHash(file_id=file_id, hash_type=hash_type, hash_value=hash_value)) return file_id - def _cache_image(self, version_id: int, image_data: dict[str, Any]) -> int | None: + def _cache_image(self, session: Session, version_id: int | None, image_data: dict[str, Any]) -> int | None: """Cache a version image.""" - cur = self.conn.cursor() + if not version_id: + return None url = image_data.get("url") if not url: return None - cur.execute("SELECT id FROM version_images WHERE url = ?", (url,)) - existing = cur.fetchone() - + existing = session.exec(select(VersionImage).where(VersionImage.url == url)).first() if existing: - return int(existing["id"]) + return existing.id - cur.execute( - """ - INSERT INTO version_images ( - civitai_id, version_id, url, type, nsfw_level, width, - height, hash, has_meta, has_positive_prompt, on_site, - minor, poi, availability - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - image_data.get("id"), - version_id, - url, - image_data.get("type"), - image_data.get("nsfwLevel"), - image_data.get("width"), - image_data.get("height"), - image_data.get("hash"), - 1 if image_data.get("hasMeta") else 0, - 1 if image_data.get("hasPositivePrompt") else 0, - 1 if image_data.get("onSite") else 0, - 1 if image_data.get("minor") else 0, - 1 if image_data.get("poi") else 0, - image_data.get("availability"), - ), + vi = VersionImage( + civitai_id=image_data.get("id"), + version_id=version_id, + url=url, + type=image_data.get("type"), + nsfw_level=image_data.get("nsfwLevel"), + width=image_data.get("width"), + height=image_data.get("height"), + hash=image_data.get("hash"), + has_meta=bool(image_data.get("hasMeta")), + has_positive_prompt=bool(image_data.get("hasPositivePrompt")), + on_site=bool(image_data.get("onSite")), + minor=bool(image_data.get("minor")), + poi=bool(image_data.get("poi")), + availability=image_data.get("availability"), ) - image_id = cur.lastrowid or 0 # lastrowid is always set after INSERT + session.add(vi) + session.flush() + image_id = vi.id # Cache generation params meta = image_data.get("meta", {}) @@ -495,20 +493,170 @@ class Database: if key == "resources": continue str_value = str(value) if value is not None else None - cur.execute( - "INSERT OR IGNORE INTO image_generation_params (image_id, key, value) VALUES (?, ?, ?)", - (image_id, key, str_value), - ) + session.add(ImageGenerationParam(image_id=image_id, key=key, value=str_value)) # Cache resources for res in meta.get("resources", []): - cur.execute( - "INSERT INTO image_resources (image_id, name, type, hash, weight) VALUES (?, ?, ?, ?, ?)", - (image_id, res.get("name"), res.get("type"), res.get("hash"), res.get("weight")), + session.add( + ImageResource( + image_id=image_id, + name=res.get("name", ""), + type=res.get("type"), + hash=res.get("hash"), + weight=res.get("weight"), + ) ) return image_id + # ========================================================================= + # HuggingFace Cache Operations + # ========================================================================= + + def cache_hf_model(self, data: dict[str, Any]) -> int: + """Cache HuggingFace model data.""" + repo_id = data.get("repo_id") or data.get("id") or data.get("modelId") + if not repo_id: + raise ValueError("repo_id is required") + + author = data.get("author") + model_name = repo_id + if "/" in repo_id: + parts = repo_id.split("/", 1) + author = author or parts[0] + model_name = parts[1] + + with self.session() as session: + existing = session.exec(select(HFModel).where(HFModel.repo_id == repo_id)).first() + + if existing: + existing.author = author + existing.model_name = model_name + existing.pipeline_tag = data.get("pipeline_tag") + existing.library_name = data.get("library_name") + existing.downloads = data.get("downloads", 0) + existing.likes = data.get("likes", 0) + existing.trending_score = data.get("trending_score") + existing.is_private = bool(data.get("private")) + existing.is_gated = bool(data.get("gated")) + existing.last_modified = data.get("last_modified") or data.get("lastModified") + existing.updated_at = datetime.utcnow() + session.add(existing) + model_id = existing.id + else: + hf_model = HFModel( + repo_id=repo_id, + author=author, + model_name=model_name, + pipeline_tag=data.get("pipeline_tag"), + library_name=data.get("library_name"), + downloads=data.get("downloads", 0), + likes=data.get("likes", 0), + trending_score=data.get("trending_score"), + is_private=bool(data.get("private")), + is_gated=bool(data.get("gated")), + last_modified=data.get("last_modified") or data.get("lastModified"), + created_at=data.get("created_at") or data.get("createdAt"), + ) + session.add(hf_model) + session.flush() + model_id = hf_model.id + + # Cache tags + for tag in data.get("tags", []): + existing_tag = session.exec( + select(HFModelTag).where(HFModelTag.hf_model_id == model_id, HFModelTag.tag == tag) + ).first() + if not existing_tag: + session.add(HFModelTag(hf_model_id=model_id, tag=tag)) + + # Cache safetensor files + for file_info in data.get("safetensor_files", []): + if isinstance(file_info, str): + existing_sf = session.exec( + select(HFSafetensorFile).where( + HFSafetensorFile.hf_model_id == model_id, HFSafetensorFile.filename == file_info + ) + ).first() + if not existing_sf: + session.add(HFSafetensorFile(hf_model_id=model_id, filename=file_info)) + elif isinstance(file_info, dict): + filename = file_info.get("filename") + if filename: + existing_sf = session.exec( + select(HFSafetensorFile).where( + HFSafetensorFile.hf_model_id == model_id, HFSafetensorFile.filename == filename + ) + ).first() + if not existing_sf: + session.add( + HFSafetensorFile(hf_model_id=model_id, filename=filename, size_bytes=file_info.get("size")) + ) + + session.commit() + return model_id or 0 + + def search_hf_models( + self, + query: str | None = None, + author: str | None = None, + pipeline_tag: str | None = None, + limit: int = 20, + ) -> list[dict[str, Any]]: + """Search cached HuggingFace models.""" + with self.session() as session: + stmt = select(HFModel) + + if query: + stmt = stmt.where(col(HFModel.repo_id).contains(query) | col(HFModel.model_name).contains(query)) + if author: + stmt = stmt.where(HFModel.author == author) + if pipeline_tag: + stmt = stmt.where(HFModel.pipeline_tag == pipeline_tag) + + stmt = stmt.order_by(col(HFModel.downloads).desc()).limit(limit) + models = session.exec(stmt).all() + + return [ + { + "id": m.id, + "repo_id": m.repo_id, + "author": m.author, + "model_name": m.model_name, + "pipeline_tag": m.pipeline_tag, + "downloads": m.downloads, + "likes": m.likes, + "is_gated": m.is_gated, + } + for m in models + ] + + def get_hf_model(self, repo_id: str) -> dict[str, Any] | None: + """Get cached HF model by repo_id.""" + with self.session() as session: + m = session.exec(select(HFModel).where(HFModel.repo_id == repo_id)).first() + if not m: + return None + return { + "id": m.id, + "repo_id": m.repo_id, + "author": m.author, + "model_name": m.model_name, + "pipeline_tag": m.pipeline_tag, + "downloads": m.downloads, + "likes": m.likes, + "is_gated": m.is_gated, + } + + def get_hf_safetensor_files(self, repo_id: str) -> list[dict[str, Any]]: + """Get safetensor files for an HF model.""" + with self.session() as session: + m = session.exec(select(HFModel).where(HFModel.repo_id == repo_id)).first() + if not m: + return [] + files = session.exec(select(HFSafetensorFile).where(HFSafetensorFile.hf_model_id == m.id)).all() + return [{"filename": f.filename, "size_bytes": f.size_bytes} for f in files] + # ========================================================================= # Query Operations # ========================================================================= @@ -520,227 +668,92 @@ class Database: base_model: str | None = None, limit: int = 20, ) -> list[dict[str, Any]]: - """Search cached models.""" - cur = self.conn.cursor() + """Search cached CivitAI models.""" + with self.session() as session: + stmt = select(Model) - sql = "SELECT * FROM v_models_with_latest WHERE 1=1" - params: list[Any] = [] + if query: + stmt = stmt.where(col(Model.name).contains(query)) + if model_type: + stmt = stmt.where(Model.type == model_type) - if query: - sql += " AND name LIKE ?" - params.append(f"%{query}%") + stmt = stmt.order_by(col(Model.download_count).desc()).limit(limit) + models = session.exec(stmt).all() - if model_type: - sql += " AND type = ?" - params.append(model_type) + results = [] + for m in models: + # Get latest version + latest = session.exec( + select(ModelVersion).where(ModelVersion.model_id == m.id, ModelVersion.version_index == 0) + ).first() + creator = session.get(Creator, m.creator_id) if m.creator_id else None - if base_model: - sql += " AND base_model LIKE ?" - params.append(f"%{base_model}%") + # Filter by base_model if specified + if base_model and latest and latest.base_model and base_model.lower() not in latest.base_model.lower(): + continue - sql += " ORDER BY download_count DESC LIMIT ?" - params.append(limit) + results.append( + { + "id": m.id, + "civitai_id": m.civitai_id, + "name": m.name, + "type": m.type, + "nsfw": m.nsfw, + "creator": creator.username if creator else None, + "latest_version": latest.name if latest else None, + "base_model": latest.base_model if latest else None, + "download_count": m.download_count, + "thumbs_up_count": m.thumbs_up_count, + } + ) - cur.execute(sql, params) - return [dict(row) for row in cur.fetchall()] + return results[:limit] def get_model(self, civitai_id: int) -> dict[str, Any] | None: """Get cached model by CivitAI ID.""" - cur = self.conn.cursor() - cur.execute("SELECT * FROM v_models_with_latest WHERE civitai_id = ?", (civitai_id,)) - row = cur.fetchone() - return dict(row) if row else None + with self.session() as session: + m = session.exec(select(Model).where(Model.civitai_id == civitai_id)).first() + if not m: + return None + latest = session.exec( + select(ModelVersion).where(ModelVersion.model_id == m.id, ModelVersion.version_index == 0) + ).first() + creator = session.get(Creator, m.creator_id) if m.creator_id else None + return { + "id": m.id, + "civitai_id": m.civitai_id, + "name": m.name, + "type": m.type, + "creator": creator.username if creator else None, + "latest_version": latest.name if latest else None, + "base_model": latest.base_model if latest else None, + "download_count": m.download_count, + } def get_triggers(self, file_path: str) -> list[str]: """Get trigger words for a local file.""" - cur = self.conn.cursor() - cur.execute( - """ - SELECT tw.word - FROM trained_words tw - JOIN model_versions mv ON tw.version_id = mv.id - JOIN local_files lf ON lf.civitai_version_id = mv.civitai_id - WHERE lf.file_path = ? - ORDER BY tw.position - """, - (file_path,), - ) - return [row["word"] for row in cur.fetchall()] + with self.session() as session: + lf = session.exec(select(LocalFile).where(LocalFile.file_path == file_path)).first() + if not lf or not lf.civitai_version_id: + return [] + mv = session.exec(select(ModelVersion).where(ModelVersion.civitai_id == lf.civitai_version_id)).first() + if not mv: + return [] + words = session.exec( + select(TrainedWord).where(TrainedWord.version_id == mv.id).order_by(col(TrainedWord.position)) + ).all() + return [w.word for w in words] def get_triggers_by_version(self, version_id: int) -> list[str]: """Get trigger words for a version by CivitAI version ID.""" - cur = self.conn.cursor() - cur.execute( - """ - SELECT tw.word - FROM trained_words tw - JOIN model_versions mv ON tw.version_id = mv.id - WHERE mv.civitai_id = ? - ORDER BY tw.position - """, - (version_id,), - ) - return [row["word"] for row in cur.fetchall()] - - # ========================================================================= - # HuggingFace Cache Operations - # ========================================================================= - - def cache_hf_model(self, data: dict[str, Any]) -> int: - """Cache HuggingFace model data. - - Args: - data: Model data dict with keys like repo_id, author, downloads, etc. - - Returns the internal model ID. - """ - cur = self.conn.cursor() - - repo_id = data.get("repo_id") or data.get("id") or data.get("modelId") - if not repo_id: - raise ValueError("repo_id is required") - - # Parse author from repo_id if not provided - author = data.get("author") - model_name = repo_id - if "/" in repo_id: - parts = repo_id.split("/", 1) - author = author or parts[0] - model_name = parts[1] - - # Check if model exists - cur.execute("SELECT id FROM hf_models WHERE repo_id = ?", (repo_id,)) - existing = cur.fetchone() - - if existing: - model_id = int(existing["id"]) - cur.execute( - """ - UPDATE hf_models SET - author = ?, model_name = ?, pipeline_tag = ?, library_name = ?, - downloads = ?, likes = ?, trending_score = ?, - is_private = ?, is_gated = ?, last_modified = ?, - updated_at = datetime('now') - WHERE id = ? - """, - ( - author, - model_name, - data.get("pipeline_tag"), - data.get("library_name"), - data.get("downloads", 0), - data.get("likes", 0), - data.get("trending_score"), - 1 if data.get("private") else 0, - 1 if data.get("gated") else 0, - data.get("last_modified") or data.get("lastModified"), - model_id, - ), - ) - else: - cur.execute( - """ - INSERT INTO hf_models ( - repo_id, author, model_name, pipeline_tag, library_name, - downloads, likes, trending_score, is_private, is_gated, - last_modified, created_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - repo_id, - author, - model_name, - data.get("pipeline_tag"), - data.get("library_name"), - data.get("downloads", 0), - data.get("likes", 0), - data.get("trending_score"), - 1 if data.get("private") else 0, - 1 if data.get("gated") else 0, - data.get("last_modified") or data.get("lastModified"), - data.get("created_at") or data.get("createdAt"), - ), - ) - model_id = cur.lastrowid or 0 - - # Cache tags - for tag in data.get("tags", []): - cur.execute( - "INSERT OR IGNORE INTO hf_model_tags (hf_model_id, tag) VALUES (?, ?)", - (model_id, tag), - ) - - # Cache safetensor files - for file_info in data.get("safetensor_files", []): - if isinstance(file_info, str): - cur.execute( - "INSERT OR IGNORE INTO hf_safetensor_files (hf_model_id, filename) VALUES (?, ?)", - (model_id, file_info), - ) - elif isinstance(file_info, dict): - cur.execute( - """ - INSERT OR IGNORE INTO hf_safetensor_files (hf_model_id, filename, size_bytes) - VALUES (?, ?, ?) - """, - (model_id, file_info.get("filename"), file_info.get("size")), - ) - - self.conn.commit() - return model_id - - def search_hf_models( - self, - query: str | None = None, - author: str | None = None, - pipeline_tag: str | None = None, - limit: int = 20, - ) -> list[dict[str, Any]]: - """Search cached HuggingFace models.""" - cur = self.conn.cursor() - - sql = "SELECT * FROM v_hf_models WHERE 1=1" - params: list[Any] = [] - - if query: - sql += " AND (repo_id LIKE ? OR model_name LIKE ?)" - params.extend([f"%{query}%", f"%{query}%"]) - - if author: - sql += " AND author = ?" - params.append(author) - - if pipeline_tag: - sql += " AND pipeline_tag = ?" - params.append(pipeline_tag) - - sql += " ORDER BY downloads DESC LIMIT ?" - params.append(limit) - - cur.execute(sql, params) - return [dict(row) for row in cur.fetchall()] - - def get_hf_model(self, repo_id: str) -> dict[str, Any] | None: - """Get cached HF model by repo_id.""" - cur = self.conn.cursor() - cur.execute("SELECT * FROM v_hf_models WHERE repo_id = ?", (repo_id,)) - row = cur.fetchone() - return dict(row) if row else None - - def get_hf_safetensor_files(self, repo_id: str) -> list[dict[str, Any]]: - """Get safetensor files for an HF model.""" - cur = self.conn.cursor() - cur.execute( - """ - SELECT hsf.filename, hsf.size_bytes - FROM hf_safetensor_files hsf - JOIN hf_models hm ON hsf.hf_model_id = hm.id - WHERE hm.repo_id = ? - ORDER BY hsf.filename - """, - (repo_id,), - ) - return [dict(row) for row in cur.fetchall()] + with self.session() as session: + mv = session.exec(select(ModelVersion).where(ModelVersion.civitai_id == version_id)).first() + if not mv: + return [] + words = session.exec( + select(TrainedWord).where(TrainedWord.version_id == mv.id).order_by(col(TrainedWord.position)) + ).all() + return [w.word for w in words] # ========================================================================= # Statistics @@ -748,19 +761,16 @@ class Database: def get_stats(self) -> dict[str, int]: """Get database statistics.""" - cur = self.conn.cursor() - stats = {} - for table in [ - "local_files", - "models", - "model_versions", - "version_files", - "trained_words", - "creators", - "tags", - "hf_models", - "hf_safetensor_files", - ]: - cur.execute(f"SELECT COUNT(*) FROM {table}") - stats[table] = cur.fetchone()[0] - return stats + with self.session() as session: + stats = { + "local_files": session.exec(select(func.count(col(LocalFile.id)))).one(), + "models": session.exec(select(func.count(col(Model.id)))).one(), + "model_versions": session.exec(select(func.count(col(ModelVersion.id)))).one(), + "version_files": session.exec(select(func.count(col(VersionFile.id)))).one(), + "trained_words": session.exec(select(func.count(col(TrainedWord.id)))).one(), + "creators": session.exec(select(func.count(col(Creator.id)))).one(), + "tags": session.exec(select(func.count(col(Tag.id)))).one(), + "hf_models": session.exec(select(func.count(col(HFModel.id)))).one(), + "hf_safetensor_files": session.exec(select(func.count(col(HFSafetensorFile.id)))).one(), + } + return stats diff --git a/tensors/models.py b/tensors/models.py new file mode 100644 index 0000000..bcce596 --- /dev/null +++ b/tensors/models.py @@ -0,0 +1,342 @@ +"""SQLModel database models for tensors.""" + +from datetime import datetime +from typing import Any, Optional + +from sqlmodel import Field, Relationship, SQLModel + +# ============================================================================= +# Local Files +# ============================================================================= + + +class LocalFile(SQLModel, table=True): + """Local safetensor file.""" + + __tablename__ = "local_files" + + id: int | None = Field(default=None, primary_key=True) + file_path: str = Field(unique=True) + sha256: str = Field(index=True) + header_size: int | None = None + tensor_count: int | None = None + civitai_model_id: int | None = Field(default=None, index=True) + civitai_version_id: int | None = None + created_at: datetime | None = Field(default_factory=datetime.utcnow) + updated_at: datetime | None = Field(default_factory=datetime.utcnow) + + metadata_entries: list["SafetensorMetadata"] = Relationship(back_populates="local_file") + + +class SafetensorMetadata(SQLModel, table=True): + """Safetensor header metadata key-value pairs.""" + + __tablename__ = "safetensor_metadata" + + id: int | None = Field(default=None, primary_key=True) + local_file_id: int = Field(foreign_key="local_files.id", index=True) + key: str + value: str | None = None + + local_file: "LocalFile" = Relationship(back_populates="metadata_entries") + + +# ============================================================================= +# CivitAI Models +# ============================================================================= + + +class Creator(SQLModel, table=True): + """CivitAI model creator.""" + + __tablename__ = "creators" + + id: int | None = Field(default=None, primary_key=True) + username: str = Field(unique=True) + image_url: str | None = None + + models: list["Model"] = Relationship(back_populates="creator") + + +class Tag(SQLModel, table=True): + """Model tag.""" + + __tablename__ = "tags" + + id: int | None = Field(default=None, primary_key=True) + name: str = Field(unique=True) + + +class ModelTag(SQLModel, table=True): + """Model-tag association.""" + + __tablename__ = "model_tags" + + model_id: int = Field(foreign_key="models.id", primary_key=True) + tag_id: int = Field(foreign_key="tags.id", primary_key=True) + + +class Model(SQLModel, table=True): + """CivitAI model.""" + + __tablename__ = "models" + + id: int | None = Field(default=None, primary_key=True) + civitai_id: int = Field(unique=True, index=True) + name: str = Field(index=True) + description: str | None = None + type: str = Field(index=True) + nsfw: bool = False + poi: bool = False + minor: bool = False + sfw_only: bool = False + nsfw_level: int | None = None + availability: str | None = None + allow_no_credit: bool | None = None + allow_commercial_use: str | None = None + allow_derivatives: bool | None = None + allow_different_license: bool | None = None + supports_generation: bool = False + creator_id: int | None = Field(default=None, foreign_key="creators.id") + download_count: int = 0 + thumbs_up_count: int = 0 + thumbs_down_count: int = 0 + comment_count: int = 0 + tipped_amount_count: int = 0 + created_at: datetime | None = None + updated_at: datetime | None = Field(default_factory=datetime.utcnow) + + creator: Optional["Creator"] = Relationship(back_populates="models") + versions: list["ModelVersion"] = Relationship(back_populates="model") + + +class ModelVersion(SQLModel, table=True): + """CivitAI model version.""" + + __tablename__ = "model_versions" + + id: int | None = Field(default=None, primary_key=True) + civitai_id: int = Field(unique=True, index=True) + model_id: int = Field(foreign_key="models.id", index=True) + name: str + description: str | None = None + base_model: str | None = Field(default=None, index=True) + base_model_type: str | None = None + nsfw_level: int | None = None + status: str | None = None + availability: str | None = None + upload_type: str | None = None + usage_control: str | None = None + air: str | None = None + training_status: str | None = None + training_details: str | None = None + early_access_ends_at: datetime | None = None + download_count: int = 0 + thumbs_up_count: int = 0 + thumbs_down_count: int = 0 + supports_generation: bool = False + download_url: str | None = None + created_at: datetime | None = None + published_at: datetime | None = None + updated_at: datetime | None = None + version_index: int | None = None + + model: "Model" = Relationship(back_populates="versions") + files: list["VersionFile"] = Relationship(back_populates="version") + images: list["VersionImage"] = Relationship(back_populates="version") + trained_words: list["TrainedWord"] = Relationship(back_populates="version") + + +class TrainedWord(SQLModel, table=True): + """Trigger words for a model version.""" + + __tablename__ = "trained_words" + + id: int | None = Field(default=None, primary_key=True) + version_id: int = Field(foreign_key="model_versions.id", index=True) + word: str + position: int | None = None + + version: "ModelVersion" = Relationship(back_populates="trained_words") + + +class VersionFile(SQLModel, table=True): + """Model version file.""" + + __tablename__ = "version_files" + + id: int | None = Field(default=None, primary_key=True) + civitai_id: int = Field(unique=True) + version_id: int = Field(foreign_key="model_versions.id", index=True) + name: str + type: str | None = None + size_kb: float | None = None + format: str | None = None + size_type: str | None = None + fp: str | None = None + is_primary: bool = False + pickle_scan_result: str | None = None + pickle_scan_message: str | None = None + virus_scan_result: str | None = None + virus_scan_message: str | None = None + scanned_at: datetime | None = None + download_url: str | None = None + + version: "ModelVersion" = Relationship(back_populates="files") + hashes: list["FileHash"] = Relationship(back_populates="file") + + +class FileHash(SQLModel, table=True): + """File hash values.""" + + __tablename__ = "file_hashes" + + id: int | None = Field(default=None, primary_key=True) + file_id: int = Field(foreign_key="version_files.id", index=True) + hash_type: str + hash_value: str = Field(index=True) + + file: "VersionFile" = Relationship(back_populates="hashes") + + +class VersionImage(SQLModel, table=True): + """Model version example image.""" + + __tablename__ = "version_images" + + id: int | None = Field(default=None, primary_key=True) + civitai_id: int | None = None + version_id: int = Field(foreign_key="model_versions.id", index=True) + url: str + type: str | None = None + nsfw_level: int | None = None + width: int | None = None + height: int | None = None + hash: str | None = None + has_meta: bool = False + has_positive_prompt: bool = False + on_site: bool = False + minor: bool = False + poi: bool = False + availability: str | None = None + remix_of_id: int | None = None + + version: "ModelVersion" = Relationship(back_populates="images") + generation_params: list["ImageGenerationParam"] = Relationship(back_populates="image") + resources: list["ImageResource"] = Relationship(back_populates="image") + + +class ImageVideoMetadata(SQLModel, table=True): + """Video metadata for animated images.""" + + __tablename__ = "image_video_metadata" + + id: int | None = Field(default=None, primary_key=True) + image_id: int = Field(foreign_key="version_images.id", unique=True) + duration: float | None = None + has_audio: bool = False + size_bytes: int | None = None + + +class ImageGenerationParam(SQLModel, table=True): + """Image generation parameters.""" + + __tablename__ = "image_generation_params" + + id: int | None = Field(default=None, primary_key=True) + image_id: int = Field(foreign_key="version_images.id", index=True) + key: str + value: str | None = None + + image: "VersionImage" = Relationship(back_populates="generation_params") + + +class ImageResource(SQLModel, table=True): + """Resources used in image generation.""" + + __tablename__ = "image_resources" + + id: int | None = Field(default=None, primary_key=True) + image_id: int = Field(foreign_key="version_images.id", index=True) + name: str + type: str | None = None + hash: str | None = None + weight: float | None = None + + image: "VersionImage" = Relationship(back_populates="resources") + + +# ============================================================================= +# HuggingFace Models +# ============================================================================= + + +class HFModel(SQLModel, table=True): + """HuggingFace model.""" + + __tablename__ = "hf_models" + + id: int | None = Field(default=None, primary_key=True) + repo_id: str = Field(unique=True, index=True) + author: str | None = Field(default=None, index=True) + model_name: str + pipeline_tag: str | None = None + library_name: str | None = None + downloads: int = Field(default=0, index=True) + likes: int = 0 + trending_score: float | None = None + is_private: bool = False + is_gated: bool = False + last_modified: datetime | None = None + created_at: datetime | None = None + cached_at: datetime | None = Field(default_factory=datetime.utcnow) + updated_at: datetime | None = Field(default_factory=datetime.utcnow) + + tags: list["HFModelTag"] = Relationship(back_populates="model") + safetensor_files: list["HFSafetensorFile"] = Relationship(back_populates="model") + + +class HFModelTag(SQLModel, table=True): + """HuggingFace model tag.""" + + __tablename__ = "hf_model_tags" + + hf_model_id: int = Field(foreign_key="hf_models.id", primary_key=True, index=True) + tag: str = Field(primary_key=True) + + model: "HFModel" = Relationship(back_populates="tags") + + +class HFSafetensorFile(SQLModel, table=True): + """Safetensor file in HuggingFace model.""" + + __tablename__ = "hf_safetensor_files" + + id: int | None = Field(default=None, primary_key=True) + hf_model_id: int = Field(foreign_key="hf_models.id", index=True) + filename: str + size_bytes: int | None = None + + model: "HFModel" = Relationship(back_populates="safetensor_files") + + +# ============================================================================= +# Database Setup +# ============================================================================= + + +def get_engine(db_path: str = "") -> Any: + """Create database engine.""" + from sqlmodel import create_engine # noqa: PLC0415 + + from tensors.config import DATA_DIR # noqa: PLC0415 + + if not db_path: + db_path = str(DATA_DIR / "models.db") + + return create_engine(f"sqlite:///{db_path}", echo=False) + + +def create_tables(engine: Any) -> None: + """Create all tables.""" + SQLModel.metadata.create_all(engine) diff --git a/tensors/server/__init__.py b/tensors/server/__init__.py index 6278852..e5d3537 100644 --- a/tensors/server/__init__.py +++ b/tensors/server/__init__.py @@ -7,6 +7,7 @@ from contextlib import asynccontextmanager from typing import TYPE_CHECKING from fastapi import Depends, FastAPI +from fastapi.middleware.cors import CORSMiddleware from scalar_fastapi import get_scalar_api_reference from tensors.config import get_server_api_key @@ -47,6 +48,15 @@ def create_app() -> FastAPI: redoc_url=None, ) + # CORS - allow all origins + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + # Public endpoints (no auth) @app.get("/status") async def status() -> dict[str, str]: diff --git a/tests/test_db.py b/tests/test_db.py index 02a0b34..c48986e 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -7,8 +7,10 @@ import struct from pathlib import Path import pytest +from sqlalchemy import text from tensors.db import Database +from tensors.models import Creator, Model, ModelVersion, Tag, TrainedWord, VersionFile @pytest.fixture @@ -117,9 +119,9 @@ class TestDatabaseSchema: def test_init_schema(self, temp_db: Database) -> None: """Test schema initialization creates tables.""" - cur = temp_db.conn.cursor() - cur.execute("SELECT name FROM sqlite_master WHERE type='table'") - tables = {row[0] for row in cur.fetchall()} + with temp_db.session() as session: + result = session.exec(text("SELECT name FROM sqlite_master WHERE type='table'")) + tables = {row[0] for row in result.fetchall()} expected = { "local_files", @@ -139,13 +141,10 @@ class TestDatabaseSchema: assert expected.issubset(tables) def test_init_schema_creates_views(self, temp_db: Database) -> None: - """Test schema creates required views.""" - cur = temp_db.conn.cursor() - cur.execute("SELECT name FROM sqlite_master WHERE type='view'") - views = {row[0] for row in cur.fetchall()} - - assert "v_local_files_full" in views - assert "v_models_with_latest" in views + """Test schema creates required views - SQLModel doesn't create views, so skip.""" + # SQLModel creates tables but not views - views would need raw SQL + # This test is no longer applicable with SQLModel + pass class TestLocalFiles: @@ -227,75 +226,84 @@ class TestCivitAICache: def test_cache_model_creates_creator(self, temp_db: Database, sample_civitai_model: dict) -> None: """Test that caching model creates creator record.""" + from sqlmodel import select + temp_db.cache_model(sample_civitai_model) - cur = temp_db.conn.cursor() - cur.execute("SELECT * FROM creators WHERE username = ?", ("test_creator",)) - creator = cur.fetchone() + with temp_db.session() as session: + creator = session.exec(select(Creator).where(Creator.username == "test_creator")).first() assert creator is not None - assert creator["username"] == "test_creator" + assert creator.username == "test_creator" def test_cache_model_creates_tags(self, temp_db: Database, sample_civitai_model: dict) -> None: """Test that caching model creates tags.""" + from sqlmodel import select + temp_db.cache_model(sample_civitai_model) - cur = temp_db.conn.cursor() - cur.execute("SELECT COUNT(*) FROM tags") - count = cur.fetchone()[0] + with temp_db.session() as session: + tags = session.exec(select(Tag)).all() - assert count == 3 # test, lora, anime + assert len(tags) == 3 # test, lora, anime def test_cache_model_creates_versions(self, temp_db: Database, sample_civitai_model: dict) -> None: """Test that caching model creates versions.""" + from sqlmodel import select + temp_db.cache_model(sample_civitai_model) - cur = temp_db.conn.cursor() - cur.execute("SELECT * FROM model_versions WHERE civitai_id = ?", (789012,)) - version = cur.fetchone() + with temp_db.session() as session: + version = session.exec(select(ModelVersion).where(ModelVersion.civitai_id == 789012)).first() assert version is not None - assert version["name"] == "v1.0" - assert version["base_model"] == "SDXL 1.0" + assert version.name == "v1.0" + assert version.base_model == "SDXL 1.0" def test_cache_model_creates_trained_words(self, temp_db: Database, sample_civitai_model: dict) -> None: """Test that caching model creates trained words.""" + from sqlmodel import col, select + temp_db.cache_model(sample_civitai_model) - cur = temp_db.conn.cursor() - cur.execute("SELECT word FROM trained_words ORDER BY position") - words = [row[0] for row in cur.fetchall()] + with temp_db.session() as session: + words = session.exec(select(TrainedWord).order_by(col(TrainedWord.position))).all() - assert words == ["test_trigger", "lora_trigger"] + assert [w.word for w in words] == ["test_trigger", "lora_trigger"] def test_cache_model_creates_files_and_hashes(self, temp_db: Database, sample_civitai_model: dict) -> None: """Test that caching model creates files and hashes.""" + from sqlmodel import select + + from tensors.models import FileHash + temp_db.cache_model(sample_civitai_model) - cur = temp_db.conn.cursor() - cur.execute("SELECT * FROM version_files WHERE civitai_id = ?", (111222,)) - file_record = cur.fetchone() + with temp_db.session() as session: + file_record = session.exec(select(VersionFile).where(VersionFile.civitai_id == 111222)).first() - assert file_record is not None - assert file_record["name"] == "test_lora.safetensors" - assert file_record["is_primary"] == 1 + assert file_record is not None + assert file_record.name == "test_lora.safetensors" + assert file_record.is_primary is True - cur.execute("SELECT hash_type, hash_value FROM file_hashes WHERE file_id = ?", (file_record["id"],)) - hashes = {row[0]: row[1] for row in cur.fetchall()} + hashes = session.exec(select(FileHash).where(FileHash.file_id == file_record.id)).all() + hash_dict = {h.hash_type: h.hash_value for h in hashes} - assert hashes["SHA256"] == "ABC123DEF456" - assert hashes["BLAKE3"] == "789XYZ" + assert hash_dict["SHA256"] == "ABC123DEF456" + assert hash_dict["BLAKE3"] == "789XYZ" def test_cache_model_idempotent(self, temp_db: Database, sample_civitai_model: dict) -> None: """Test that caching same model twice is idempotent.""" + from sqlmodel import select + id1 = temp_db.cache_model(sample_civitai_model) id2 = temp_db.cache_model(sample_civitai_model) assert id1 == id2 - cur = temp_db.conn.cursor() - cur.execute("SELECT COUNT(*) FROM models") - assert cur.fetchone()[0] == 1 + with temp_db.session() as session: + models = session.exec(select(Model)).all() + assert len(models) == 1 class TestQueryOperations: @@ -436,15 +444,15 @@ class TestContextManager: stats = db.get_stats() assert stats["local_files"] == 0 - # Connection should be closed - assert db._conn is None + # Engine should be disposed + assert db._engine is None def test_connection_reuse(self, tmp_path: Path) -> None: - """Test that connection is reused within context.""" + """Test that engine is reused within context.""" db_path = tmp_path / "test.db" with Database(db_path=db_path) as db: db.init_schema() - conn1 = db.conn - conn2 = db.conn - assert conn1 is conn2 + engine1 = db.engine + engine2 = db.engine + assert engine1 is engine2 diff --git a/uv.lock b/uv.lock index 84c663c..4814ac9 100644 --- a/uv.lock +++ b/uv.lock @@ -189,6 +189,45 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e6/ab/fb21f4c939bb440104cc2b396d3be1d9b7a9fd3c6c2a53d98c45b3d7c954/fsspec-2026.2.0-py3-none-any.whl", hash = "sha256:98de475b5cb3bd66bedd5c4679e87b4fdfe1a3bf4d707b151b3c07e58c9a2437", size = 202505, upload-time = "2026-02-05T21:50:51.819Z" }, ] +[[package]] +name = "greenlet" +version = "3.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8a/99/1cd3411c56a410994669062bd73dd58270c00cc074cac15f385a1fd91f8a/greenlet-3.3.1.tar.gz", hash = "sha256:41848f3230b58c08bb43dee542e74a2a2e34d3c59dc3076cec9151aeeedcae98", size = 184690, upload-time = "2026-01-23T15:31:02.076Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/c8/9d76a66421d1ae24340dfae7e79c313957f6e3195c144d2c73333b5bfe34/greenlet-3.3.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:7e806ca53acf6d15a888405880766ec84721aa4181261cd11a457dfe9a7a4975", size = 276443, upload-time = "2026-01-23T15:30:10.066Z" }, + { url = "https://files.pythonhosted.org/packages/81/99/401ff34bb3c032d1f10477d199724f5e5f6fbfb59816ad1455c79c1eb8e7/greenlet-3.3.1-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d842c94b9155f1c9b3058036c24ffb8ff78b428414a19792b2380be9cecf4f36", size = 597359, upload-time = "2026-01-23T16:00:57.394Z" }, + { url = "https://files.pythonhosted.org/packages/2b/bc/4dcc0871ed557792d304f50be0f7487a14e017952ec689effe2180a6ff35/greenlet-3.3.1-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:20fedaadd422fa02695f82093f9a98bad3dab5fcda793c658b945fcde2ab27ba", size = 607805, upload-time = "2026-01-23T16:05:28.068Z" }, + { url = "https://files.pythonhosted.org/packages/cf/05/821587cf19e2ce1f2b24945d890b164401e5085f9d09cbd969b0c193cd20/greenlet-3.3.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:14194f5f4305800ff329cbf02c5fcc88f01886cadd29941b807668a45f0d2336", size = 609947, upload-time = "2026-01-23T15:32:51.004Z" }, + { url = "https://files.pythonhosted.org/packages/a4/52/ee8c46ed9f8babaa93a19e577f26e3d28a519feac6350ed6f25f1afee7e9/greenlet-3.3.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:7b2fe4150a0cf59f847a67db8c155ac36aed89080a6a639e9f16df5d6c6096f1", size = 1567487, upload-time = "2026-01-23T16:04:22.125Z" }, + { url = "https://files.pythonhosted.org/packages/8f/7c/456a74f07029597626f3a6db71b273a3632aecb9afafeeca452cfa633197/greenlet-3.3.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:49f4ad195d45f4a66a0eb9c1ba4832bb380570d361912fa3554746830d332149", size = 1636087, upload-time = "2026-01-23T15:33:47.486Z" }, + { url = "https://files.pythonhosted.org/packages/34/2f/5e0e41f33c69655300a5e54aeb637cf8ff57f1786a3aba374eacc0228c1d/greenlet-3.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:cc98b9c4e4870fa983436afa999d4eb16b12872fab7071423d5262fa7120d57a", size = 227156, upload-time = "2026-01-23T15:34:34.808Z" }, + { url = "https://files.pythonhosted.org/packages/c8/ab/717c58343cf02c5265b531384b248787e04d8160b8afe53d9eec053d7b44/greenlet-3.3.1-cp312-cp312-win_arm64.whl", hash = "sha256:bfb2d1763d777de5ee495c85309460f6fd8146e50ec9d0ae0183dbf6f0a829d1", size = 226403, upload-time = "2026-01-23T15:31:39.372Z" }, + { url = "https://files.pythonhosted.org/packages/ec/ab/d26750f2b7242c2b90ea2ad71de70cfcd73a948a49513188a0fc0d6fc15a/greenlet-3.3.1-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:7ab327905cabb0622adca5971e488064e35115430cec2c35a50fd36e72a315b3", size = 275205, upload-time = "2026-01-23T15:30:24.556Z" }, + { url = "https://files.pythonhosted.org/packages/10/d3/be7d19e8fad7c5a78eeefb2d896a08cd4643e1e90c605c4be3b46264998f/greenlet-3.3.1-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:65be2f026ca6a176f88fb935ee23c18333ccea97048076aef4db1ef5bc0713ac", size = 599284, upload-time = "2026-01-23T16:00:58.584Z" }, + { url = "https://files.pythonhosted.org/packages/ae/21/fe703aaa056fdb0f17e5afd4b5c80195bbdab701208918938bd15b00d39b/greenlet-3.3.1-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7a3ae05b3d225b4155bda56b072ceb09d05e974bc74be6c3fc15463cf69f33fd", size = 610274, upload-time = "2026-01-23T16:05:29.312Z" }, + { url = "https://files.pythonhosted.org/packages/cb/86/5c6ab23bb3c28c21ed6bebad006515cfe08b04613eb105ca0041fecca852/greenlet-3.3.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6423481193bbbe871313de5fd06a082f2649e7ce6e08015d2a76c1e9186ca5b3", size = 612904, upload-time = "2026-01-23T15:32:52.317Z" }, + { url = "https://files.pythonhosted.org/packages/c2/f3/7949994264e22639e40718c2daf6f6df5169bf48fb038c008a489ec53a50/greenlet-3.3.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:33a956fe78bbbda82bfc95e128d61129b32d66bcf0a20a1f0c08aa4839ffa951", size = 1567316, upload-time = "2026-01-23T16:04:23.316Z" }, + { url = "https://files.pythonhosted.org/packages/8d/6e/d73c94d13b6465e9f7cd6231c68abde838bb22408596c05d9059830b7872/greenlet-3.3.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4b065d3284be43728dd280f6f9a13990b56470b81be20375a207cdc814a983f2", size = 1636549, upload-time = "2026-01-23T15:33:48.643Z" }, + { url = "https://files.pythonhosted.org/packages/5e/b3/c9c23a6478b3bcc91f979ce4ca50879e4d0b2bd7b9a53d8ecded719b92e2/greenlet-3.3.1-cp313-cp313-win_amd64.whl", hash = "sha256:27289986f4e5b0edec7b5a91063c109f0276abb09a7e9bdab08437525977c946", size = 227042, upload-time = "2026-01-23T15:33:58.216Z" }, + { url = "https://files.pythonhosted.org/packages/90/e7/824beda656097edee36ab15809fd063447b200cc03a7f6a24c34d520bc88/greenlet-3.3.1-cp313-cp313-win_arm64.whl", hash = "sha256:2f080e028001c5273e0b42690eaf359aeef9cb1389da0f171ea51a5dc3c7608d", size = 226294, upload-time = "2026-01-23T15:30:52.73Z" }, + { url = "https://files.pythonhosted.org/packages/ae/fb/011c7c717213182caf78084a9bea51c8590b0afda98001f69d9f853a495b/greenlet-3.3.1-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:bd59acd8529b372775cd0fcbc5f420ae20681c5b045ce25bd453ed8455ab99b5", size = 275737, upload-time = "2026-01-23T15:32:16.889Z" }, + { url = "https://files.pythonhosted.org/packages/41/2e/a3a417d620363fdbb08a48b1dd582956a46a61bf8fd27ee8164f9dfe87c2/greenlet-3.3.1-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b31c05dd84ef6871dd47120386aed35323c944d86c3d91a17c4b8d23df62f15b", size = 646422, upload-time = "2026-01-23T16:01:00.354Z" }, + { url = "https://files.pythonhosted.org/packages/b4/09/c6c4a0db47defafd2d6bab8ddfe47ad19963b4e30f5bed84d75328059f8c/greenlet-3.3.1-cp314-cp314-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:02925a0bfffc41e542c70aa14c7eda3593e4d7e274bfcccca1827e6c0875902e", size = 658219, upload-time = "2026-01-23T16:05:30.956Z" }, + { url = "https://files.pythonhosted.org/packages/80/38/9d42d60dffb04b45f03dbab9430898352dba277758640751dc5cc316c521/greenlet-3.3.1-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:34a729e2e4e4ffe9ae2408d5ecaf12f944853f40ad724929b7585bca808a9d6f", size = 660237, upload-time = "2026-01-23T15:32:53.967Z" }, + { url = "https://files.pythonhosted.org/packages/96/61/373c30b7197f9e756e4c81ae90a8d55dc3598c17673f91f4d31c3c689c3f/greenlet-3.3.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:aec9ab04e82918e623415947921dea15851b152b822661cce3f8e4393c3df683", size = 1615261, upload-time = "2026-01-23T16:04:25.066Z" }, + { url = "https://files.pythonhosted.org/packages/fd/d3/ca534310343f5945316f9451e953dcd89b36fe7a19de652a1dc5a0eeef3f/greenlet-3.3.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:71c767cf281a80d02b6c1bdc41c9468e1f5a494fb11bc8688c360524e273d7b1", size = 1683719, upload-time = "2026-01-23T15:33:50.61Z" }, + { url = "https://files.pythonhosted.org/packages/52/cb/c21a3fd5d2c9c8b622e7bede6d6d00e00551a5ee474ea6d831b5f567a8b4/greenlet-3.3.1-cp314-cp314-win_amd64.whl", hash = "sha256:96aff77af063b607f2489473484e39a0bbae730f2ea90c9e5606c9b73c44174a", size = 228125, upload-time = "2026-01-23T15:32:45.265Z" }, + { url = "https://files.pythonhosted.org/packages/6a/8e/8a2db6d11491837af1de64b8aff23707c6e85241be13c60ed399a72e2ef8/greenlet-3.3.1-cp314-cp314-win_arm64.whl", hash = "sha256:b066e8b50e28b503f604fa538adc764a638b38cf8e81e025011d26e8a627fa79", size = 227519, upload-time = "2026-01-23T15:31:47.284Z" }, + { url = "https://files.pythonhosted.org/packages/28/24/cbbec49bacdcc9ec652a81d3efef7b59f326697e7edf6ed775a5e08e54c2/greenlet-3.3.1-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:3e63252943c921b90abb035ebe9de832c436401d9c45f262d80e2d06cc659242", size = 282706, upload-time = "2026-01-23T15:33:05.525Z" }, + { url = "https://files.pythonhosted.org/packages/86/2e/4f2b9323c144c4fe8842a4e0d92121465485c3c2c5b9e9b30a52e80f523f/greenlet-3.3.1-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:76e39058e68eb125de10c92524573924e827927df5d3891fbc97bd55764a8774", size = 651209, upload-time = "2026-01-23T16:01:01.517Z" }, + { url = "https://files.pythonhosted.org/packages/d9/87/50ca60e515f5bb55a2fbc5f0c9b5b156de7d2fc51a0a69abc9d23914a237/greenlet-3.3.1-cp314-cp314t-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c9f9d5e7a9310b7a2f416dd13d2e3fd8b42d803968ea580b7c0f322ccb389b97", size = 654300, upload-time = "2026-01-23T16:05:32.199Z" }, + { url = "https://files.pythonhosted.org/packages/1d/94/74310866dfa2b73dd08659a3d18762f83985ad3281901ba0ee9a815194fb/greenlet-3.3.1-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:92497c78adf3ac703b57f1e3813c2d874f27f71a178f9ea5887855da413cd6d2", size = 653842, upload-time = "2026-01-23T15:32:55.671Z" }, + { url = "https://files.pythonhosted.org/packages/97/43/8bf0ffa3d498eeee4c58c212a3905dd6146c01c8dc0b0a046481ca29b18c/greenlet-3.3.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:ed6b402bc74d6557a705e197d47f9063733091ed6357b3de33619d8a8d93ac53", size = 1614917, upload-time = "2026-01-23T16:04:26.276Z" }, + { url = "https://files.pythonhosted.org/packages/89/90/a3be7a5f378fc6e84abe4dcfb2ba32b07786861172e502388b4c90000d1b/greenlet-3.3.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:59913f1e5ada20fde795ba906916aea25d442abcc0593fba7e26c92b7ad76249", size = 1676092, upload-time = "2026-01-23T15:33:52.176Z" }, + { url = "https://files.pythonhosted.org/packages/e1/2b/98c7f93e6db9977aaee07eb1e51ca63bd5f779b900d362791d3252e60558/greenlet-3.3.1-cp314-cp314t-win_amd64.whl", hash = "sha256:301860987846c24cb8964bdec0e31a96ad4a2a801b41b4ef40963c1b44f33451", size = 233181, upload-time = "2026-01-23T15:33:00.29Z" }, +] + [[package]] name = "h11" version = "0.16.0" @@ -760,6 +799,61 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" }, ] +[[package]] +name = "sqlalchemy" +version = "2.0.46" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "greenlet", marker = "platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/aa/9ce0f3e7a9829ead5c8ce549392f33a12c4555a6c0609bb27d882e9c7ddf/sqlalchemy-2.0.46.tar.gz", hash = "sha256:cf36851ee7219c170bb0793dbc3da3e80c582e04a5437bc601bfe8c85c9216d7", size = 9865393, upload-time = "2026-01-21T18:03:45.119Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b6/35/d16bfa235c8b7caba3730bba43e20b1e376d2224f407c178fbf59559f23e/sqlalchemy-2.0.46-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3a9a72b0da8387f15d5810f1facca8f879de9b85af8c645138cba61ea147968c", size = 2153405, upload-time = "2026-01-21T19:05:54.143Z" }, + { url = "https://files.pythonhosted.org/packages/06/6c/3192e24486749862f495ddc6584ed730c0c994a67550ec395d872a2ad650/sqlalchemy-2.0.46-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2347c3f0efc4de367ba00218e0ae5c4ba2306e47216ef80d6e31761ac97cb0b9", size = 3334702, upload-time = "2026-01-21T18:46:45.384Z" }, + { url = "https://files.pythonhosted.org/packages/ea/a2/b9f33c8d68a3747d972a0bb758c6b63691f8fb8a49014bc3379ba15d4274/sqlalchemy-2.0.46-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9094c8b3197db12aa6f05c51c05daaad0a92b8c9af5388569847b03b1007fb1b", size = 3347664, upload-time = "2026-01-21T18:40:09.979Z" }, + { url = "https://files.pythonhosted.org/packages/aa/d2/3e59e2a91eaec9db7e8dc6b37b91489b5caeb054f670f32c95bcba98940f/sqlalchemy-2.0.46-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:37fee2164cf21417478b6a906adc1a91d69ae9aba8f9533e67ce882f4bb1de53", size = 3277372, upload-time = "2026-01-21T18:46:47.168Z" }, + { url = "https://files.pythonhosted.org/packages/dd/dd/67bc2e368b524e2192c3927b423798deda72c003e73a1e94c21e74b20a85/sqlalchemy-2.0.46-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b1e14b2f6965a685c7128bd315e27387205429c2e339eeec55cb75ca4ab0ea2e", size = 3312425, upload-time = "2026-01-21T18:40:11.548Z" }, + { url = "https://files.pythonhosted.org/packages/43/82/0ecd68e172bfe62247e96cb47867c2d68752566811a4e8c9d8f6e7c38a65/sqlalchemy-2.0.46-cp312-cp312-win32.whl", hash = "sha256:412f26bb4ba942d52016edc8d12fb15d91d3cd46b0047ba46e424213ad407bcb", size = 2113155, upload-time = "2026-01-21T18:42:49.748Z" }, + { url = "https://files.pythonhosted.org/packages/bc/2a/2821a45742073fc0331dc132552b30de68ba9563230853437cac54b2b53e/sqlalchemy-2.0.46-cp312-cp312-win_amd64.whl", hash = "sha256:ea3cd46b6713a10216323cda3333514944e510aa691c945334713fca6b5279ff", size = 2140078, upload-time = "2026-01-21T18:42:51.197Z" }, + { url = "https://files.pythonhosted.org/packages/b3/4b/fa7838fe20bb752810feed60e45625a9a8b0102c0c09971e2d1d95362992/sqlalchemy-2.0.46-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:93a12da97cca70cea10d4b4fc602589c4511f96c1f8f6c11817620c021d21d00", size = 2150268, upload-time = "2026-01-21T19:05:56.621Z" }, + { url = "https://files.pythonhosted.org/packages/46/c1/b34dccd712e8ea846edf396e00973dda82d598cb93762e55e43e6835eba9/sqlalchemy-2.0.46-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:af865c18752d416798dae13f83f38927c52f085c52e2f32b8ab0fef46fdd02c2", size = 3276511, upload-time = "2026-01-21T18:46:49.022Z" }, + { url = "https://files.pythonhosted.org/packages/96/48/a04d9c94753e5d5d096c628c82a98c4793b9c08ca0e7155c3eb7d7db9f24/sqlalchemy-2.0.46-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8d679b5f318423eacb61f933a9a0f75535bfca7056daeadbf6bd5bcee6183aee", size = 3292881, upload-time = "2026-01-21T18:40:13.089Z" }, + { url = "https://files.pythonhosted.org/packages/be/f4/06eda6e91476f90a7d8058f74311cb65a2fb68d988171aced81707189131/sqlalchemy-2.0.46-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:64901e08c33462acc9ec3bad27fc7a5c2b6491665f2aa57564e57a4f5d7c52ad", size = 3224559, upload-time = "2026-01-21T18:46:50.974Z" }, + { url = "https://files.pythonhosted.org/packages/ab/a2/d2af04095412ca6345ac22b33b89fe8d6f32a481e613ffcb2377d931d8d0/sqlalchemy-2.0.46-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e8ac45e8f4eaac0f9f8043ea0e224158855c6a4329fd4ee37c45c61e3beb518e", size = 3262728, upload-time = "2026-01-21T18:40:14.883Z" }, + { url = "https://files.pythonhosted.org/packages/31/48/1980c7caa5978a3b8225b4d230e69a2a6538a3562b8b31cea679b6933c83/sqlalchemy-2.0.46-cp313-cp313-win32.whl", hash = "sha256:8d3b44b3d0ab2f1319d71d9863d76eeb46766f8cf9e921ac293511804d39813f", size = 2111295, upload-time = "2026-01-21T18:42:52.366Z" }, + { url = "https://files.pythonhosted.org/packages/2d/54/f8d65bbde3d877617c4720f3c9f60e99bb7266df0d5d78b6e25e7c149f35/sqlalchemy-2.0.46-cp313-cp313-win_amd64.whl", hash = "sha256:77f8071d8fbcbb2dd11b7fd40dedd04e8ebe2eb80497916efedba844298065ef", size = 2137076, upload-time = "2026-01-21T18:42:53.924Z" }, + { url = "https://files.pythonhosted.org/packages/56/ba/9be4f97c7eb2b9d5544f2624adfc2853e796ed51d2bb8aec90bc94b7137e/sqlalchemy-2.0.46-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a1e8cc6cc01da346dc92d9509a63033b9b1bda4fed7a7a7807ed385c7dccdc10", size = 3556533, upload-time = "2026-01-21T18:33:06.636Z" }, + { url = "https://files.pythonhosted.org/packages/20/a6/b1fc6634564dbb4415b7ed6419cdfeaadefd2c39cdab1e3aa07a5f2474c2/sqlalchemy-2.0.46-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:96c7cca1a4babaaf3bfff3e4e606e38578856917e52f0384635a95b226c87764", size = 3523208, upload-time = "2026-01-21T18:45:08.436Z" }, + { url = "https://files.pythonhosted.org/packages/a1/d8/41e0bdfc0f930ff236f86fccd12962d8fa03713f17ed57332d38af6a3782/sqlalchemy-2.0.46-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b2a9f9aee38039cf4755891a1e50e1effcc42ea6ba053743f452c372c3152b1b", size = 3464292, upload-time = "2026-01-21T18:33:08.208Z" }, + { url = "https://files.pythonhosted.org/packages/f0/8b/9dcbec62d95bea85f5ecad9b8d65b78cc30fb0ffceeb3597961f3712549b/sqlalchemy-2.0.46-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:db23b1bf8cfe1f7fda19018e7207b20cdb5168f83c437ff7e95d19e39289c447", size = 3473497, upload-time = "2026-01-21T18:45:10.552Z" }, + { url = "https://files.pythonhosted.org/packages/e9/f8/5ecdfc73383ec496de038ed1614de9e740a82db9ad67e6e4514ebc0708a3/sqlalchemy-2.0.46-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:56bdd261bfd0895452006d5316cbf35739c53b9bb71a170a331fa0ea560b2ada", size = 2152079, upload-time = "2026-01-21T19:05:58.477Z" }, + { url = "https://files.pythonhosted.org/packages/e5/bf/eba3036be7663ce4d9c050bc3d63794dc29fbe01691f2bf5ccb64e048d20/sqlalchemy-2.0.46-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:33e462154edb9493f6c3ad2125931e273bbd0be8ae53f3ecd1c161ea9a1dd366", size = 3272216, upload-time = "2026-01-21T18:46:52.634Z" }, + { url = "https://files.pythonhosted.org/packages/05/45/1256fb597bb83b58a01ddb600c59fe6fdf0e5afe333f0456ed75c0f8d7bd/sqlalchemy-2.0.46-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9bcdce05f056622a632f1d44bb47dbdb677f58cad393612280406ce37530eb6d", size = 3277208, upload-time = "2026-01-21T18:40:16.38Z" }, + { url = "https://files.pythonhosted.org/packages/d9/a0/2053b39e4e63b5d7ceb3372cface0859a067c1ddbd575ea7e9985716f771/sqlalchemy-2.0.46-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:8e84b09a9b0f19accedcbeff5c2caf36e0dd537341a33aad8d680336152dc34e", size = 3221994, upload-time = "2026-01-21T18:46:54.622Z" }, + { url = "https://files.pythonhosted.org/packages/1e/87/97713497d9502553c68f105a1cb62786ba1ee91dea3852ae4067ed956a50/sqlalchemy-2.0.46-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:4f52f7291a92381e9b4de9050b0a65ce5d6a763333406861e33906b8aa4906bf", size = 3243990, upload-time = "2026-01-21T18:40:18.253Z" }, + { url = "https://files.pythonhosted.org/packages/a8/87/5d1b23548f420ff823c236f8bea36b1a997250fd2f892e44a3838ca424f4/sqlalchemy-2.0.46-cp314-cp314-win32.whl", hash = "sha256:70ed2830b169a9960193f4d4322d22be5c0925357d82cbf485b3369893350908", size = 2114215, upload-time = "2026-01-21T18:42:55.232Z" }, + { url = "https://files.pythonhosted.org/packages/3a/20/555f39cbcf0c10cf452988b6a93c2a12495035f68b3dbd1a408531049d31/sqlalchemy-2.0.46-cp314-cp314-win_amd64.whl", hash = "sha256:3c32e993bc57be6d177f7d5d31edb93f30726d798ad86ff9066d75d9bf2e0b6b", size = 2139867, upload-time = "2026-01-21T18:42:56.474Z" }, + { url = "https://files.pythonhosted.org/packages/3e/f0/f96c8057c982d9d8a7a68f45d69c674bc6f78cad401099692fe16521640a/sqlalchemy-2.0.46-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4dafb537740eef640c4d6a7c254611dca2df87eaf6d14d6a5fca9d1f4c3fc0fa", size = 3561202, upload-time = "2026-01-21T18:33:10.337Z" }, + { url = "https://files.pythonhosted.org/packages/d7/53/3b37dda0a5b137f21ef608d8dfc77b08477bab0fe2ac9d3e0a66eaeab6fc/sqlalchemy-2.0.46-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:42a1643dc5427b69aca967dae540a90b0fbf57eaf248f13a90ea5930e0966863", size = 3526296, upload-time = "2026-01-21T18:45:12.657Z" }, + { url = "https://files.pythonhosted.org/packages/33/75/f28622ba6dde79cd545055ea7bd4062dc934e0621f7b3be2891f8563f8de/sqlalchemy-2.0.46-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:ff33c6e6ad006bbc0f34f5faf941cfc62c45841c64c0a058ac38c799f15b5ede", size = 3470008, upload-time = "2026-01-21T18:33:11.725Z" }, + { url = "https://files.pythonhosted.org/packages/a9/42/4afecbbc38d5e99b18acef446453c76eec6fbd03db0a457a12a056836e22/sqlalchemy-2.0.46-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:82ec52100ec1e6ec671563bbd02d7c7c8d0b9e71a0723c72f22ecf52d1755330", size = 3476137, upload-time = "2026-01-21T18:45:15.001Z" }, + { url = "https://files.pythonhosted.org/packages/fc/a1/9c4efa03300926601c19c18582531b45aededfb961ab3c3585f1e24f120b/sqlalchemy-2.0.46-py3-none-any.whl", hash = "sha256:f9c11766e7e7c0a2767dda5acb006a118640c9fc0a4104214b96269bfb78399e", size = 1937882, upload-time = "2026-01-21T18:22:10.456Z" }, +] + +[[package]] +name = "sqlmodel" +version = "0.0.33" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "sqlalchemy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c4/62/22c287122598e61d07d005eec0b4eb97e6bde9a1b051bcd66c2bca846ea8/sqlmodel-0.0.33.tar.gz", hash = "sha256:b473544ed5fc2097894d89033049e569e1f138363dd3ec2ed4b6932cc9f29f5f", size = 95578, upload-time = "2026-02-11T15:23:39.504Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/63/39/13891bae4658133b489a4d8b6a2f193d56110e392289560f312748e796dc/sqlmodel-0.0.33-py3-none-any.whl", hash = "sha256:9045bb4d97d2ba099c5a068ee9525af2d106972dda1ff8488e187ce50556bf73", size = 27444, upload-time = "2026-02-11T15:23:38.678Z" }, +] + [[package]] name = "starlette" version = "0.52.1" @@ -782,6 +876,7 @@ dependencies = [ { name = "huggingface-hub" }, { name = "rich" }, { name = "safetensors" }, + { name = "sqlmodel" }, { name = "typer" }, { name = "websocket-client" }, ] @@ -815,6 +910,7 @@ requires-dist = [ { name = "rich", specifier = ">=13.0.0" }, { name = "safetensors", specifier = ">=0.4.0" }, { name = "scalar-fastapi", marker = "extra == 'server'", specifier = ">=1.6" }, + { name = "sqlmodel", specifier = ">=0.0.33" }, { name = "typer", specifier = ">=0.15.0" }, { name = "uvicorn", marker = "extra == 'server'", specifier = ">=0.30" }, { name = "websocket-client", specifier = ">=1.9.0" },