Phase 3: Enhanced Server API

Add comprehensive server API for remote operations:

Step 3.1 - Gallery endpoints:
- GET /api/images - List images (paginated, newest first)
- GET /api/images/{id} - Get image file
- GET /api/images/{id}/meta - Get generation metadata
- POST /api/images/{id}/edit - Update metadata (tags, notes)
- DELETE /api/images/{id} - Delete image + sidecar
- Gallery module with image management and sidecar JSON support

Step 3.2 - Model management:
- GET /api/models - List available checkpoints
- GET /api/models/active - Current loaded model info
- POST /api/models/switch - Switch model (hot reload)
- GET /api/models/loras - List available LoRAs
- GET /api/models/scan - Scan all model directories

Step 3.3 - Download proxy:
- POST /api/download - Start background download from CivitAI
- GET /api/download/status/{id} - Check download progress
- GET /api/download/active - List active downloads
- Auto-scan and link files after download

Step 3.4 - Enhanced generation:
- POST /api/generate - Generate with gallery integration
- Saves images to gallery with metadata sidecar
- Supports all sd-server params
- GET /api/samplers, /api/schedulers - List options

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Adam Ladachowski
2026-02-14 01:40:25 +01:00
parent 18b3268738
commit 11a289ebd0
8 changed files with 946 additions and 5 deletions
+4 -4
View File
@@ -11,10 +11,10 @@
- [x] Step 2.3: Add `/api/db/*` endpoints (files, models, triggers, scan, link, cache, stats) - [x] Step 2.3: Add `/api/db/*` endpoints (files, models, triggers, scan, link, cache, stats)
## Phase 3: Enhanced Server API ## Phase 3: Enhanced Server API
- [ ] Step 3.1: Add `/api/images` gallery endpoints (list, get, delete, edit) - [x] Step 3.1: Add `/api/images` gallery endpoints (list, get, delete, edit)
- [ ] Step 3.2: Add `/api/models` endpoints (list, active, switch, loras) - [x] Step 3.2: Add `/api/models` endpoints (list, active, switch, loras)
- [ ] Step 3.3: Add `/api/download` endpoint (CivitAI proxy download) - [x] Step 3.3: Add `/api/download` endpoint (CivitAI proxy download)
- [ ] Step 3.4: Enhance `/api/generate` (gallery integration, full params) - [x] Step 3.4: Enhance `/api/generate` (gallery integration, full params)
## Phase 4: Client Mode for tsr CLI ## Phase 4: Client Mode for tsr CLI
- [ ] Step 4.1: Create `tensors/client.py` (TsrClient HTTP wrapper) - [ ] Step 4.1: Create `tensors/client.py` (TsrClient HTTP wrapper)
+1
View File
@@ -20,6 +20,7 @@ CONFIG_FILE = CONFIG_DIR / "config.toml"
DATA_DIR = Path(os.environ.get("XDG_DATA_HOME", Path.home() / ".local" / "share")) / "tensors" DATA_DIR = Path(os.environ.get("XDG_DATA_HOME", Path.home() / ".local" / "share")) / "tensors"
MODELS_DIR = DATA_DIR / "models" MODELS_DIR = DATA_DIR / "models"
METADATA_DIR = DATA_DIR / "metadata" METADATA_DIR = DATA_DIR / "metadata"
GALLERY_DIR = DATA_DIR / "gallery"
# Legacy config for migration # Legacy config for migration
LEGACY_RC_FILE = Path.home() / ".sftrc" LEGACY_RC_FILE = Path.home() / ".sftrc"
+9 -1
View File
@@ -10,7 +10,11 @@ import httpx
from fastapi import FastAPI from fastapi import FastAPI
from tensors.server.db_routes import create_db_router from tensors.server.db_routes import create_db_router
from tensors.server.download_routes import create_download_router
from tensors.server.gallery_routes import create_gallery_router
from tensors.server.generate_routes import create_generate_router
from tensors.server.models import ServerConfig from tensors.server.models import ServerConfig
from tensors.server.models_routes import create_models_router
from tensors.server.process import ProcessManager from tensors.server.process import ProcessManager
from tensors.server.routes import create_router from tensors.server.routes import create_router
@@ -42,7 +46,11 @@ def create_app(config: ServerConfig | None = None) -> FastAPI:
pm.stop() pm.stop()
app = FastAPI(title="sd-server wrapper", lifespan=lifespan) app = FastAPI(title="sd-server wrapper", lifespan=lifespan)
app.include_router(create_db_router()) # Must be first to avoid catch-all conflict app.include_router(create_db_router()) # Must be before catch-all proxy
app.include_router(create_gallery_router()) # Must be before catch-all proxy
app.include_router(create_models_router(pm)) # Must be before catch-all proxy
app.include_router(create_download_router()) # Must be before catch-all proxy
app.include_router(create_generate_router(pm)) # Must be before catch-all proxy
app.include_router(create_router(pm)) app.include_router(create_router(pm))
app.state.pm = pm app.state.pm = pm
return app return app
+230
View File
@@ -0,0 +1,230 @@
"""FastAPI route handlers for CivitAI download proxy."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
from fastapi import APIRouter, BackgroundTasks, HTTPException
from pydantic import BaseModel as PydanticBaseModel
from tensors.api import download_model, fetch_civitai_by_hash, fetch_civitai_model, fetch_civitai_model_version
from tensors.config import MODELS_DIR, load_api_key
from tensors.db import Database
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/download", tags=["download"])
# Track active downloads
_active_downloads: dict[str, dict[str, Any]] = {}
# =============================================================================
# Request/Response Models
# =============================================================================
class DownloadRequest(PydanticBaseModel):
"""Request body for downloading a model."""
version_id: int | None = None
model_id: int | None = None
hash: str | None = None
output_dir: str | None = None # Override default path
# =============================================================================
# Helper Functions
# =============================================================================
def _resolve_version_id(
version_id: int | None,
model_id: int | None,
hash_val: str | None,
api_key: str | None,
) -> tuple[int | None, dict[str, Any] | None]:
"""Resolve version ID and get version info."""
if version_id:
info = fetch_civitai_model_version(version_id, api_key)
return version_id, info
if hash_val:
info = fetch_civitai_by_hash(hash_val.upper(), api_key)
if info:
return info.get("id"), info
return None, None
if model_id:
model_data = fetch_civitai_model(model_id, api_key)
if model_data:
versions = model_data.get("modelVersions", [])
if versions:
latest = versions[0]
return latest.get("id"), latest
return None, None
return None, None
def _get_output_dir(version_info: dict[str, Any], override: str | None) -> Path:
"""Determine output directory based on model type."""
if override:
return Path(override)
model_type = version_info.get("model", {}).get("type", "Checkpoint")
# Map type to directory
type_dirs = {
"Checkpoint": MODELS_DIR / "checkpoints",
"LORA": MODELS_DIR / "loras",
"LoCon": MODELS_DIR / "loras",
"TextualInversion": MODELS_DIR / "embeddings",
"VAE": MODELS_DIR / "vae",
"Controlnet": MODELS_DIR / "controlnet",
}
return type_dirs.get(model_type, MODELS_DIR / "other")
def _do_download(
version_id: int,
dest_path: Path,
api_key: str | None,
download_id: str,
) -> None:
"""Background task to perform the download."""
try:
_active_downloads[download_id]["status"] = "downloading"
# Create a mock console for download progress
from io import StringIO # noqa: PLC0415
from rich.console import Console # noqa: PLC0415
output = StringIO()
console = Console(file=output, force_terminal=False)
success = download_model(version_id, dest_path, api_key, console, resume=True)
if success:
_active_downloads[download_id]["status"] = "completed"
_active_downloads[download_id]["path"] = str(dest_path)
# Auto-scan and link the downloaded file
_auto_link_file(dest_path, api_key)
else:
_active_downloads[download_id]["status"] = "failed"
_active_downloads[download_id]["error"] = "Download failed"
except Exception as e:
logger.exception("Download failed")
_active_downloads[download_id]["status"] = "failed"
_active_downloads[download_id]["error"] = str(e)
def _auto_link_file(file_path: Path, api_key: str | None) -> None:
"""Auto-scan and link the downloaded file to CivitAI."""
try:
with Database() as db:
db.init_schema()
# Scan the single file
results = db.scan_directory(file_path.parent)
# Find and link the new file
for result in results:
if result["file_path"] == str(file_path):
sha256 = result["sha256"]
civitai_data = fetch_civitai_by_hash(sha256, api_key)
if civitai_data:
version_id = civitai_data.get("id", 0)
model_id = civitai_data.get("modelId", 0)
if version_id and model_id:
db.link_file_to_civitai(result["id"], model_id, version_id)
except Exception:
logger.exception("Auto-link failed")
# =============================================================================
# Download Endpoints
# =============================================================================
@router.post("")
def start_download(req: DownloadRequest, background_tasks: BackgroundTasks) -> dict[str, Any]:
"""Start a model download (async with progress tracking)."""
api_key = load_api_key()
# Resolve version ID
version_id, version_info = _resolve_version_id(
req.version_id,
req.model_id,
req.hash,
api_key,
)
if not version_id or not version_info:
raise HTTPException(status_code=404, detail="Model/version not found on CivitAI")
# Get output directory
output_dir = _get_output_dir(version_info, req.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Get filename from version info
files = version_info.get("files", [])
primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
if not primary_file:
raise HTTPException(status_code=400, detail="No files found for this version")
filename = primary_file.get("name", f"model-{version_id}.safetensors")
dest_path = output_dir / filename
# Create download tracking entry
download_id = f"{version_id}_{int(__import__('time').time())}"
_active_downloads[download_id] = {
"id": download_id,
"version_id": version_id,
"status": "queued",
"path": str(dest_path),
"filename": filename,
"model_name": version_info.get("model", {}).get("name", "Unknown"),
"version_name": version_info.get("name", "Unknown"),
}
# Start background download
background_tasks.add_task(_do_download, version_id, dest_path, api_key, download_id)
return {
"download_id": download_id,
"status": "queued",
"version_id": version_id,
"destination": str(dest_path),
"model_name": version_info.get("model", {}).get("name"),
"version_name": version_info.get("name"),
}
@router.get("/status/{download_id}")
def get_download_status(download_id: str) -> dict[str, Any]:
"""Get status of a download."""
if download_id not in _active_downloads:
raise HTTPException(status_code=404, detail="Download not found")
return _active_downloads[download_id]
@router.get("/active")
def list_active_downloads() -> dict[str, Any]:
"""List all active/recent downloads."""
return {
"downloads": list(_active_downloads.values()),
"total": len(_active_downloads),
}
def create_download_router() -> APIRouter:
"""Return the download API router."""
return router
+183
View File
@@ -0,0 +1,183 @@
"""Image gallery management for generated images."""
from __future__ import annotations
import json
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from tensors.config import GALLERY_DIR
if TYPE_CHECKING:
from pathlib import Path
@dataclass
class GalleryImage:
"""Represents an image in the gallery."""
id: str
path: Path
created_at: float
width: int | None = None
height: int | None = None
metadata: dict[str, Any] = field(default_factory=dict)
@property
def meta_path(self) -> Path:
"""Path to the sidecar metadata JSON file."""
return self.path.with_suffix(".json")
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for API response."""
return {
"id": self.id,
"path": str(self.path),
"filename": self.path.name,
"created_at": self.created_at,
"width": self.width,
"height": self.height,
"has_metadata": self.meta_path.exists(),
}
class Gallery:
"""Manages the image gallery directory."""
def __init__(self, gallery_dir: Path | None = None) -> None:
"""Initialize gallery with directory path."""
self.gallery_dir = gallery_dir or GALLERY_DIR
self.gallery_dir.mkdir(parents=True, exist_ok=True)
def list_images(
self,
limit: int = 50,
offset: int = 0,
newest_first: bool = True,
) -> list[GalleryImage]:
"""List images in the gallery, paginated."""
images: list[GalleryImage] = []
for path in self.gallery_dir.glob("*.png"):
img = self._load_image(path)
if img:
images.append(img)
# Sort by creation time
images.sort(key=lambda x: x.created_at, reverse=newest_first)
# Apply pagination
return images[offset : offset + limit]
def get_image(self, image_id: str) -> GalleryImage | None:
"""Get an image by ID."""
# ID is the filename stem
path = self.gallery_dir / f"{image_id}.png"
if not path.exists():
return None
return self._load_image(path)
def get_metadata(self, image_id: str) -> dict[str, Any] | None:
"""Get metadata for an image."""
meta_path = self.gallery_dir / f"{image_id}.json"
if not meta_path.exists():
return None
result: dict[str, Any] = json.loads(meta_path.read_text())
return result
def update_metadata(self, image_id: str, updates: dict[str, Any]) -> dict[str, Any] | None:
"""Update metadata for an image (merge with existing)."""
meta_path = self.gallery_dir / f"{image_id}.json"
img_path = self.gallery_dir / f"{image_id}.png"
if not img_path.exists():
return None
# Load existing or create new
metadata = json.loads(meta_path.read_text()) if meta_path.exists() else {}
# Merge updates
metadata.update(updates)
metadata["updated_at"] = time.time()
# Save
meta_path.write_text(json.dumps(metadata, indent=2))
return metadata
def delete_image(self, image_id: str) -> bool:
"""Delete an image and its metadata."""
img_path = self.gallery_dir / f"{image_id}.png"
meta_path = self.gallery_dir / f"{image_id}.json"
if not img_path.exists():
return False
img_path.unlink()
if meta_path.exists():
meta_path.unlink()
return True
def save_image(
self,
image_data: bytes,
metadata: dict[str, Any] | None = None,
seed: int | None = None,
) -> GalleryImage:
"""Save an image to the gallery with optional metadata."""
timestamp = int(time.time() * 1000) # milliseconds
seed_str = str(seed) if seed is not None else "0"
image_id = f"{timestamp}_{seed_str}"
img_path = self.gallery_dir / f"{image_id}.png"
img_path.write_bytes(image_data)
# Save metadata if provided
if metadata:
meta = metadata.copy()
meta["created_at"] = time.time()
meta["seed"] = seed
meta_path = img_path.with_suffix(".json")
meta_path.write_text(json.dumps(meta, indent=2))
return self._load_image(img_path) or GalleryImage(
id=image_id,
path=img_path,
created_at=time.time(),
)
def _load_image(self, path: Path) -> GalleryImage | None:
"""Load image info from path."""
if not path.exists():
return None
image_id = path.stem
stat = path.stat()
# Try to get dimensions from metadata or PIL
width: int | None = None
height: int | None = None
metadata: dict[str, Any] = {}
meta_path = path.with_suffix(".json")
if meta_path.exists():
try:
metadata = json.loads(meta_path.read_text())
width = metadata.get("width")
height = metadata.get("height")
except json.JSONDecodeError:
pass
return GalleryImage(
id=image_id,
path=path,
created_at=stat.st_mtime,
width=width,
height=height,
metadata=metadata,
)
def count(self) -> int:
"""Count total images in gallery."""
return len(list(self.gallery_dir.glob("*.png")))
+149
View File
@@ -0,0 +1,149 @@
"""FastAPI route handlers for image gallery endpoints."""
from __future__ import annotations
import logging
from typing import Any
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import FileResponse
from pydantic import BaseModel as PydanticBaseModel
from tensors.server.gallery import Gallery
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/images", tags=["gallery"])
# Shared gallery instance
_gallery: Gallery | None = None
def get_gallery() -> Gallery:
"""Get or create the gallery instance."""
global _gallery # noqa: PLW0603
if _gallery is None:
_gallery = Gallery()
return _gallery
# =============================================================================
# Request/Response Models
# =============================================================================
class MetadataUpdate(PydanticBaseModel):
"""Request body for updating image metadata."""
tags: list[str] | None = None
notes: str | None = None
rating: int | None = None
favorite: bool | None = None
# =============================================================================
# Gallery Endpoints
# =============================================================================
@router.get("")
def list_images(
limit: int = Query(default=50, le=200, description="Max images to return"),
offset: int = Query(default=0, ge=0, description="Offset for pagination"),
newest_first: bool = Query(default=True, description="Sort newest first"),
) -> dict[str, Any]:
"""List images in the gallery, paginated."""
gallery = get_gallery()
images = gallery.list_images(limit=limit, offset=offset, newest_first=newest_first)
total = gallery.count()
return {
"images": [img.to_dict() for img in images],
"total": total,
"limit": limit,
"offset": offset,
}
@router.get("/{image_id}")
def get_image(image_id: str) -> FileResponse:
"""Get an image file by ID."""
gallery = get_gallery()
image = gallery.get_image(image_id)
if not image:
raise HTTPException(status_code=404, detail="Image not found")
return FileResponse(
path=image.path,
media_type="image/png",
filename=image.path.name,
)
@router.get("/{image_id}/meta")
def get_image_metadata(image_id: str) -> dict[str, Any]:
"""Get metadata for an image."""
gallery = get_gallery()
image = gallery.get_image(image_id)
if not image:
raise HTTPException(status_code=404, detail="Image not found")
metadata = gallery.get_metadata(image_id) or {}
return {
"id": image_id,
"path": str(image.path),
"created_at": image.created_at,
"metadata": metadata,
}
@router.post("/{image_id}/edit")
def edit_image_metadata(image_id: str, updates: MetadataUpdate) -> dict[str, Any]:
"""Update metadata for an image."""
gallery = get_gallery()
# Build update dict from non-None values
update_dict: dict[str, Any] = {}
if updates.tags is not None:
update_dict["tags"] = updates.tags
if updates.notes is not None:
update_dict["notes"] = updates.notes
if updates.rating is not None:
update_dict["rating"] = updates.rating
if updates.favorite is not None:
update_dict["favorite"] = updates.favorite
result = gallery.update_metadata(image_id, update_dict)
if result is None:
raise HTTPException(status_code=404, detail="Image not found")
return {"id": image_id, "metadata": result}
@router.delete("/{image_id}")
def delete_image(image_id: str) -> dict[str, Any]:
"""Delete an image and its metadata."""
gallery = get_gallery()
deleted = gallery.delete_image(image_id)
if not deleted:
raise HTTPException(status_code=404, detail="Image not found")
return {"deleted": True, "id": image_id}
@router.get("/stats/summary")
def gallery_stats() -> dict[str, Any]:
"""Get gallery statistics."""
gallery = get_gallery()
return {
"total_images": gallery.count(),
"gallery_dir": str(gallery.gallery_dir),
}
def create_gallery_router() -> APIRouter:
"""Return the gallery API router."""
return router
+199
View File
@@ -0,0 +1,199 @@
"""FastAPI route handlers for image generation with gallery integration."""
from __future__ import annotations
import base64
import logging
import time
from typing import TYPE_CHECKING, Any
import httpx
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel as PydanticBaseModel
from pydantic import Field
from tensors.server.gallery import Gallery
if TYPE_CHECKING:
from tensors.server.process import ProcessManager
logger = logging.getLogger(__name__)
# =============================================================================
# Request/Response Models
# =============================================================================
class GenerateRequest(PydanticBaseModel):
"""Request body for image generation."""
prompt: str
negative_prompt: str = ""
width: int = Field(default=512, ge=64, le=2048)
height: int = Field(default=512, ge=64, le=2048)
steps: int = Field(default=20, ge=1, le=150)
cfg_scale: float = Field(default=7.0, ge=0, le=30)
seed: int = -1
sampler_name: str = ""
scheduler: str = ""
batch_size: int = Field(default=1, ge=1, le=16)
save_to_gallery: bool = True
return_base64: bool = False
# =============================================================================
# Helper Functions
# =============================================================================
def _build_sd_request(req: GenerateRequest) -> dict[str, Any]:
"""Build request body for sd-server."""
body: dict[str, Any] = {
"prompt": req.prompt,
"negative_prompt": req.negative_prompt,
"width": req.width,
"height": req.height,
"steps": req.steps,
"cfg_scale": req.cfg_scale,
"seed": req.seed,
"batch_size": req.batch_size,
}
if req.sampler_name:
body["sampler_name"] = req.sampler_name
if req.scheduler:
body["scheduler"] = req.scheduler
return body
def _parse_info(info: Any) -> dict[str, Any]:
"""Parse info from sd-server response."""
if isinstance(info, str):
import json # noqa: PLC0415
try:
return dict(json.loads(info))
except json.JSONDecodeError:
return {"raw": info}
return info if isinstance(info, dict) else {}
def _process_image(
img_b64: str,
index: int,
seed: int,
req: GenerateRequest,
gallery: Gallery,
model: str | None,
) -> dict[str, Any]:
"""Process a single generated image."""
image_bytes = base64.b64decode(img_b64)
image_info: dict[str, Any] = {"index": index, "seed": seed}
if req.save_to_gallery:
metadata = {
"prompt": req.prompt,
"negative_prompt": req.negative_prompt,
"width": req.width,
"height": req.height,
"steps": req.steps,
"cfg_scale": req.cfg_scale,
"sampler": req.sampler_name,
"scheduler": req.scheduler,
"model": model,
"generated_at": time.time(),
}
gallery_img = gallery.save_image(image_bytes, metadata=metadata, seed=seed)
image_info["id"] = gallery_img.id
image_info["path"] = str(gallery_img.path)
if req.return_base64:
image_info["base64"] = img_b64
return image_info
def _check_server_running(pm: ProcessManager) -> None:
"""Check if sd-server is running, raise HTTPException if not."""
if pm.proc is None or pm.proc.poll() is not None:
raise HTTPException(status_code=503, detail="sd-server is not running")
if pm.config is None:
raise HTTPException(status_code=503, detail="sd-server not configured")
# =============================================================================
# Router Factory
# =============================================================================
def create_generate_router(pm: ProcessManager) -> APIRouter:
"""Build a router with /api/generate endpoint."""
router = APIRouter(prefix="/api", tags=["generate"])
gallery = Gallery()
@router.post("/generate")
async def generate(req: GenerateRequest) -> dict[str, Any]:
"""Generate images with gallery integration."""
_check_server_running(pm)
assert pm.config is not None # Verified by _check_server_running
body = _build_sd_request(req)
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/txt2img"
try:
async with httpx.AsyncClient(timeout=300) as client:
response = await client.post(url, json=body)
response.raise_for_status()
result = response.json()
except httpx.HTTPError as e:
logger.exception("Generation failed")
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
images_data = result.get("images", [])
info = _parse_info(result.get("info", {}))
all_seeds = info.get("all_seeds", [req.seed] * len(images_data))
output_images = [
_process_image(img_b64, i, all_seeds[i] if i < len(all_seeds) else req.seed + i, req, gallery, pm.config.model)
for i, img_b64 in enumerate(images_data)
]
return {
"images": output_images,
"parameters": result.get("parameters", body),
"info": info,
"saved_to_gallery": req.save_to_gallery,
"total": len(output_images),
}
@router.get("/samplers")
async def list_samplers() -> dict[str, Any]:
"""List available samplers from sd-server."""
_check_server_running(pm)
assert pm.config is not None
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/samplers"
try:
async with httpx.AsyncClient(timeout=30) as client:
response = await client.get(url)
response.raise_for_status()
return {"samplers": response.json()}
except httpx.HTTPError as e:
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
@router.get("/schedulers")
async def list_schedulers() -> dict[str, Any]:
"""List available schedulers from sd-server."""
_check_server_running(pm)
assert pm.config is not None
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/schedulers"
try:
async with httpx.AsyncClient(timeout=30) as client:
response = await client.get(url)
response.raise_for_status()
return {"schedulers": response.json()}
except httpx.HTTPError as e:
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
return router
+171
View File
@@ -0,0 +1,171 @@
"""FastAPI route handlers for model management endpoints."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any
from fastapi import APIRouter, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel as PydanticBaseModel
from tensors.config import MODELS_DIR
if TYPE_CHECKING:
from tensors.server.process import ProcessManager
logger = logging.getLogger(__name__)
# =============================================================================
# Request/Response Models
# =============================================================================
class SwitchModelRequest(PydanticBaseModel):
"""Request body for switching models."""
model: str # Path to model file
# =============================================================================
# Helper Functions
# =============================================================================
def scan_models(directory: Path, extensions: tuple[str, ...] = (".safetensors", ".gguf")) -> list[dict[str, Any]]:
"""Scan directory for model files."""
models: list[dict[str, Any]] = []
if not directory.exists():
return models
for ext in extensions:
for path in directory.rglob(f"*{ext}"):
stat = path.stat()
models.append(
{
"name": path.stem,
"path": str(path),
"filename": path.name,
"size_mb": round(stat.st_size / (1024 * 1024), 2),
"modified": stat.st_mtime,
}
)
# Sort by name
models.sort(key=lambda x: x["name"].lower())
return models
def scan_loras(directory: Path | None = None) -> list[dict[str, Any]]:
"""Scan for LoRA files."""
lora_dir = directory or MODELS_DIR / "loras"
return scan_models(lora_dir, extensions=(".safetensors",))
def scan_checkpoints(directory: Path | None = None) -> list[dict[str, Any]]:
"""Scan for checkpoint files."""
checkpoint_dir = directory or MODELS_DIR / "checkpoints"
return scan_models(checkpoint_dir, extensions=(".safetensors", ".gguf"))
# =============================================================================
# Router Factory
# =============================================================================
def create_models_router(pm: ProcessManager) -> APIRouter:
"""Build a router with /api/models/* endpoints."""
router = APIRouter(prefix="/api/models", tags=["models"])
@router.get("")
def list_models() -> dict[str, Any]:
"""List available checkpoint models."""
checkpoints = scan_checkpoints()
return {
"models": checkpoints,
"total": len(checkpoints),
}
@router.get("/active")
def get_active_model() -> dict[str, Any]:
"""Get information about the currently loaded model."""
status = pm.status()
config = pm.config
if config is None:
return {
"loaded": False,
"model": None,
"status": status.get("status"),
}
return {
"loaded": status.get("status") == "running",
"model": config.model,
"pid": status.get("pid"),
"port": config.port,
"status": status.get("status"),
}
@router.post("/switch")
async def switch_model(req: SwitchModelRequest) -> JSONResponse:
"""Switch to a different model (hot reload)."""
model_path = Path(req.model)
# Validate model exists
if not model_path.exists():
raise HTTPException(status_code=400, detail=f"Model not found: {req.model}")
# Use existing reload logic
from tensors.server.models import ServerConfig # noqa: PLC0415
new_config = ServerConfig(
model=req.model,
port=pm.config.port if pm.config else 1234,
args=pm.config.args if pm.config else [],
)
pm.stop()
pm.start(new_config)
ready = await pm.wait_ready()
if not ready:
return JSONResponse(
{"error": "sd-server failed to become ready", "model": req.model},
status_code=503,
)
return JSONResponse(
{
"ok": True,
"model": req.model,
"pid": pm.proc.pid if pm.proc else None,
}
)
@router.get("/loras")
def list_loras() -> dict[str, Any]:
"""List available LoRA files."""
loras = scan_loras()
return {
"loras": loras,
"total": len(loras),
}
@router.get("/scan")
def scan_all_models() -> dict[str, Any]:
"""Scan all model directories."""
checkpoints = scan_checkpoints()
loras = scan_loras()
return {
"checkpoints": checkpoints,
"loras": loras,
"total_checkpoints": len(checkpoints),
"total_loras": len(loras),
}
return router