Phase 3: Enhanced Server API
Add comprehensive server API for remote operations:
Step 3.1 - Gallery endpoints:
- GET /api/images - List images (paginated, newest first)
- GET /api/images/{id} - Get image file
- GET /api/images/{id}/meta - Get generation metadata
- POST /api/images/{id}/edit - Update metadata (tags, notes)
- DELETE /api/images/{id} - Delete image + sidecar
- Gallery module with image management and sidecar JSON support
Step 3.2 - Model management:
- GET /api/models - List available checkpoints
- GET /api/models/active - Current loaded model info
- POST /api/models/switch - Switch model (hot reload)
- GET /api/models/loras - List available LoRAs
- GET /api/models/scan - Scan all model directories
Step 3.3 - Download proxy:
- POST /api/download - Start background download from CivitAI
- GET /api/download/status/{id} - Check download progress
- GET /api/download/active - List active downloads
- Auto-scan and link files after download
Step 3.4 - Enhanced generation:
- POST /api/generate - Generate with gallery integration
- Saves images to gallery with metadata sidecar
- Supports all sd-server params
- GET /api/samplers, /api/schedulers - List options
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -11,10 +11,10 @@
|
|||||||
- [x] Step 2.3: Add `/api/db/*` endpoints (files, models, triggers, scan, link, cache, stats)
|
- [x] Step 2.3: Add `/api/db/*` endpoints (files, models, triggers, scan, link, cache, stats)
|
||||||
|
|
||||||
## Phase 3: Enhanced Server API
|
## Phase 3: Enhanced Server API
|
||||||
- [ ] Step 3.1: Add `/api/images` gallery endpoints (list, get, delete, edit)
|
- [x] Step 3.1: Add `/api/images` gallery endpoints (list, get, delete, edit)
|
||||||
- [ ] Step 3.2: Add `/api/models` endpoints (list, active, switch, loras)
|
- [x] Step 3.2: Add `/api/models` endpoints (list, active, switch, loras)
|
||||||
- [ ] Step 3.3: Add `/api/download` endpoint (CivitAI proxy download)
|
- [x] Step 3.3: Add `/api/download` endpoint (CivitAI proxy download)
|
||||||
- [ ] Step 3.4: Enhance `/api/generate` (gallery integration, full params)
|
- [x] Step 3.4: Enhance `/api/generate` (gallery integration, full params)
|
||||||
|
|
||||||
## Phase 4: Client Mode for tsr CLI
|
## Phase 4: Client Mode for tsr CLI
|
||||||
- [ ] Step 4.1: Create `tensors/client.py` (TsrClient HTTP wrapper)
|
- [ ] Step 4.1: Create `tensors/client.py` (TsrClient HTTP wrapper)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ CONFIG_FILE = CONFIG_DIR / "config.toml"
|
|||||||
DATA_DIR = Path(os.environ.get("XDG_DATA_HOME", Path.home() / ".local" / "share")) / "tensors"
|
DATA_DIR = Path(os.environ.get("XDG_DATA_HOME", Path.home() / ".local" / "share")) / "tensors"
|
||||||
MODELS_DIR = DATA_DIR / "models"
|
MODELS_DIR = DATA_DIR / "models"
|
||||||
METADATA_DIR = DATA_DIR / "metadata"
|
METADATA_DIR = DATA_DIR / "metadata"
|
||||||
|
GALLERY_DIR = DATA_DIR / "gallery"
|
||||||
|
|
||||||
# Legacy config for migration
|
# Legacy config for migration
|
||||||
LEGACY_RC_FILE = Path.home() / ".sftrc"
|
LEGACY_RC_FILE = Path.home() / ".sftrc"
|
||||||
|
|||||||
@@ -10,7 +10,11 @@ import httpx
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
from tensors.server.db_routes import create_db_router
|
from tensors.server.db_routes import create_db_router
|
||||||
|
from tensors.server.download_routes import create_download_router
|
||||||
|
from tensors.server.gallery_routes import create_gallery_router
|
||||||
|
from tensors.server.generate_routes import create_generate_router
|
||||||
from tensors.server.models import ServerConfig
|
from tensors.server.models import ServerConfig
|
||||||
|
from tensors.server.models_routes import create_models_router
|
||||||
from tensors.server.process import ProcessManager
|
from tensors.server.process import ProcessManager
|
||||||
from tensors.server.routes import create_router
|
from tensors.server.routes import create_router
|
||||||
|
|
||||||
@@ -42,7 +46,11 @@ def create_app(config: ServerConfig | None = None) -> FastAPI:
|
|||||||
pm.stop()
|
pm.stop()
|
||||||
|
|
||||||
app = FastAPI(title="sd-server wrapper", lifespan=lifespan)
|
app = FastAPI(title="sd-server wrapper", lifespan=lifespan)
|
||||||
app.include_router(create_db_router()) # Must be first to avoid catch-all conflict
|
app.include_router(create_db_router()) # Must be before catch-all proxy
|
||||||
|
app.include_router(create_gallery_router()) # Must be before catch-all proxy
|
||||||
|
app.include_router(create_models_router(pm)) # Must be before catch-all proxy
|
||||||
|
app.include_router(create_download_router()) # Must be before catch-all proxy
|
||||||
|
app.include_router(create_generate_router(pm)) # Must be before catch-all proxy
|
||||||
app.include_router(create_router(pm))
|
app.include_router(create_router(pm))
|
||||||
app.state.pm = pm
|
app.state.pm = pm
|
||||||
return app
|
return app
|
||||||
|
|||||||
@@ -0,0 +1,230 @@
|
|||||||
|
"""FastAPI route handlers for CivitAI download proxy."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||||
|
from pydantic import BaseModel as PydanticBaseModel
|
||||||
|
|
||||||
|
from tensors.api import download_model, fetch_civitai_by_hash, fetch_civitai_model, fetch_civitai_model_version
|
||||||
|
from tensors.config import MODELS_DIR, load_api_key
|
||||||
|
from tensors.db import Database
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/download", tags=["download"])
|
||||||
|
|
||||||
|
# Track active downloads
|
||||||
|
_active_downloads: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Request/Response Models
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadRequest(PydanticBaseModel):
|
||||||
|
"""Request body for downloading a model."""
|
||||||
|
|
||||||
|
version_id: int | None = None
|
||||||
|
model_id: int | None = None
|
||||||
|
hash: str | None = None
|
||||||
|
output_dir: str | None = None # Override default path
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Helper Functions
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_version_id(
|
||||||
|
version_id: int | None,
|
||||||
|
model_id: int | None,
|
||||||
|
hash_val: str | None,
|
||||||
|
api_key: str | None,
|
||||||
|
) -> tuple[int | None, dict[str, Any] | None]:
|
||||||
|
"""Resolve version ID and get version info."""
|
||||||
|
if version_id:
|
||||||
|
info = fetch_civitai_model_version(version_id, api_key)
|
||||||
|
return version_id, info
|
||||||
|
|
||||||
|
if hash_val:
|
||||||
|
info = fetch_civitai_by_hash(hash_val.upper(), api_key)
|
||||||
|
if info:
|
||||||
|
return info.get("id"), info
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if model_id:
|
||||||
|
model_data = fetch_civitai_model(model_id, api_key)
|
||||||
|
if model_data:
|
||||||
|
versions = model_data.get("modelVersions", [])
|
||||||
|
if versions:
|
||||||
|
latest = versions[0]
|
||||||
|
return latest.get("id"), latest
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_output_dir(version_info: dict[str, Any], override: str | None) -> Path:
|
||||||
|
"""Determine output directory based on model type."""
|
||||||
|
if override:
|
||||||
|
return Path(override)
|
||||||
|
|
||||||
|
model_type = version_info.get("model", {}).get("type", "Checkpoint")
|
||||||
|
|
||||||
|
# Map type to directory
|
||||||
|
type_dirs = {
|
||||||
|
"Checkpoint": MODELS_DIR / "checkpoints",
|
||||||
|
"LORA": MODELS_DIR / "loras",
|
||||||
|
"LoCon": MODELS_DIR / "loras",
|
||||||
|
"TextualInversion": MODELS_DIR / "embeddings",
|
||||||
|
"VAE": MODELS_DIR / "vae",
|
||||||
|
"Controlnet": MODELS_DIR / "controlnet",
|
||||||
|
}
|
||||||
|
|
||||||
|
return type_dirs.get(model_type, MODELS_DIR / "other")
|
||||||
|
|
||||||
|
|
||||||
|
def _do_download(
|
||||||
|
version_id: int,
|
||||||
|
dest_path: Path,
|
||||||
|
api_key: str | None,
|
||||||
|
download_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Background task to perform the download."""
|
||||||
|
try:
|
||||||
|
_active_downloads[download_id]["status"] = "downloading"
|
||||||
|
|
||||||
|
# Create a mock console for download progress
|
||||||
|
from io import StringIO # noqa: PLC0415
|
||||||
|
|
||||||
|
from rich.console import Console # noqa: PLC0415
|
||||||
|
|
||||||
|
output = StringIO()
|
||||||
|
console = Console(file=output, force_terminal=False)
|
||||||
|
|
||||||
|
success = download_model(version_id, dest_path, api_key, console, resume=True)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
_active_downloads[download_id]["status"] = "completed"
|
||||||
|
_active_downloads[download_id]["path"] = str(dest_path)
|
||||||
|
|
||||||
|
# Auto-scan and link the downloaded file
|
||||||
|
_auto_link_file(dest_path, api_key)
|
||||||
|
else:
|
||||||
|
_active_downloads[download_id]["status"] = "failed"
|
||||||
|
_active_downloads[download_id]["error"] = "Download failed"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Download failed")
|
||||||
|
_active_downloads[download_id]["status"] = "failed"
|
||||||
|
_active_downloads[download_id]["error"] = str(e)
|
||||||
|
|
||||||
|
|
||||||
|
def _auto_link_file(file_path: Path, api_key: str | None) -> None:
|
||||||
|
"""Auto-scan and link the downloaded file to CivitAI."""
|
||||||
|
try:
|
||||||
|
with Database() as db:
|
||||||
|
db.init_schema()
|
||||||
|
# Scan the single file
|
||||||
|
results = db.scan_directory(file_path.parent)
|
||||||
|
|
||||||
|
# Find and link the new file
|
||||||
|
for result in results:
|
||||||
|
if result["file_path"] == str(file_path):
|
||||||
|
sha256 = result["sha256"]
|
||||||
|
civitai_data = fetch_civitai_by_hash(sha256, api_key)
|
||||||
|
if civitai_data:
|
||||||
|
version_id = civitai_data.get("id", 0)
|
||||||
|
model_id = civitai_data.get("modelId", 0)
|
||||||
|
if version_id and model_id:
|
||||||
|
db.link_file_to_civitai(result["id"], model_id, version_id)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Auto-link failed")
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Download Endpoints
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("")
|
||||||
|
def start_download(req: DownloadRequest, background_tasks: BackgroundTasks) -> dict[str, Any]:
|
||||||
|
"""Start a model download (async with progress tracking)."""
|
||||||
|
api_key = load_api_key()
|
||||||
|
|
||||||
|
# Resolve version ID
|
||||||
|
version_id, version_info = _resolve_version_id(
|
||||||
|
req.version_id,
|
||||||
|
req.model_id,
|
||||||
|
req.hash,
|
||||||
|
api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not version_id or not version_info:
|
||||||
|
raise HTTPException(status_code=404, detail="Model/version not found on CivitAI")
|
||||||
|
|
||||||
|
# Get output directory
|
||||||
|
output_dir = _get_output_dir(version_info, req.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Get filename from version info
|
||||||
|
files = version_info.get("files", [])
|
||||||
|
primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
|
||||||
|
|
||||||
|
if not primary_file:
|
||||||
|
raise HTTPException(status_code=400, detail="No files found for this version")
|
||||||
|
|
||||||
|
filename = primary_file.get("name", f"model-{version_id}.safetensors")
|
||||||
|
dest_path = output_dir / filename
|
||||||
|
|
||||||
|
# Create download tracking entry
|
||||||
|
download_id = f"{version_id}_{int(__import__('time').time())}"
|
||||||
|
_active_downloads[download_id] = {
|
||||||
|
"id": download_id,
|
||||||
|
"version_id": version_id,
|
||||||
|
"status": "queued",
|
||||||
|
"path": str(dest_path),
|
||||||
|
"filename": filename,
|
||||||
|
"model_name": version_info.get("model", {}).get("name", "Unknown"),
|
||||||
|
"version_name": version_info.get("name", "Unknown"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Start background download
|
||||||
|
background_tasks.add_task(_do_download, version_id, dest_path, api_key, download_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"download_id": download_id,
|
||||||
|
"status": "queued",
|
||||||
|
"version_id": version_id,
|
||||||
|
"destination": str(dest_path),
|
||||||
|
"model_name": version_info.get("model", {}).get("name"),
|
||||||
|
"version_name": version_info.get("name"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/status/{download_id}")
|
||||||
|
def get_download_status(download_id: str) -> dict[str, Any]:
|
||||||
|
"""Get status of a download."""
|
||||||
|
if download_id not in _active_downloads:
|
||||||
|
raise HTTPException(status_code=404, detail="Download not found")
|
||||||
|
|
||||||
|
return _active_downloads[download_id]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/active")
|
||||||
|
def list_active_downloads() -> dict[str, Any]:
|
||||||
|
"""List all active/recent downloads."""
|
||||||
|
return {
|
||||||
|
"downloads": list(_active_downloads.values()),
|
||||||
|
"total": len(_active_downloads),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_download_router() -> APIRouter:
|
||||||
|
"""Return the download API router."""
|
||||||
|
return router
|
||||||
@@ -0,0 +1,183 @@
|
|||||||
|
"""Image gallery management for generated images."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from tensors.config import GALLERY_DIR
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GalleryImage:
|
||||||
|
"""Represents an image in the gallery."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
path: Path
|
||||||
|
created_at: float
|
||||||
|
width: int | None = None
|
||||||
|
height: int | None = None
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def meta_path(self) -> Path:
|
||||||
|
"""Path to the sidecar metadata JSON file."""
|
||||||
|
return self.path.with_suffix(".json")
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary for API response."""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"path": str(self.path),
|
||||||
|
"filename": self.path.name,
|
||||||
|
"created_at": self.created_at,
|
||||||
|
"width": self.width,
|
||||||
|
"height": self.height,
|
||||||
|
"has_metadata": self.meta_path.exists(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Gallery:
|
||||||
|
"""Manages the image gallery directory."""
|
||||||
|
|
||||||
|
def __init__(self, gallery_dir: Path | None = None) -> None:
|
||||||
|
"""Initialize gallery with directory path."""
|
||||||
|
self.gallery_dir = gallery_dir or GALLERY_DIR
|
||||||
|
self.gallery_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def list_images(
|
||||||
|
self,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
newest_first: bool = True,
|
||||||
|
) -> list[GalleryImage]:
|
||||||
|
"""List images in the gallery, paginated."""
|
||||||
|
images: list[GalleryImage] = []
|
||||||
|
|
||||||
|
for path in self.gallery_dir.glob("*.png"):
|
||||||
|
img = self._load_image(path)
|
||||||
|
if img:
|
||||||
|
images.append(img)
|
||||||
|
|
||||||
|
# Sort by creation time
|
||||||
|
images.sort(key=lambda x: x.created_at, reverse=newest_first)
|
||||||
|
|
||||||
|
# Apply pagination
|
||||||
|
return images[offset : offset + limit]
|
||||||
|
|
||||||
|
def get_image(self, image_id: str) -> GalleryImage | None:
|
||||||
|
"""Get an image by ID."""
|
||||||
|
# ID is the filename stem
|
||||||
|
path = self.gallery_dir / f"{image_id}.png"
|
||||||
|
if not path.exists():
|
||||||
|
return None
|
||||||
|
return self._load_image(path)
|
||||||
|
|
||||||
|
def get_metadata(self, image_id: str) -> dict[str, Any] | None:
|
||||||
|
"""Get metadata for an image."""
|
||||||
|
meta_path = self.gallery_dir / f"{image_id}.json"
|
||||||
|
if not meta_path.exists():
|
||||||
|
return None
|
||||||
|
result: dict[str, Any] = json.loads(meta_path.read_text())
|
||||||
|
return result
|
||||||
|
|
||||||
|
def update_metadata(self, image_id: str, updates: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
"""Update metadata for an image (merge with existing)."""
|
||||||
|
meta_path = self.gallery_dir / f"{image_id}.json"
|
||||||
|
img_path = self.gallery_dir / f"{image_id}.png"
|
||||||
|
|
||||||
|
if not img_path.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Load existing or create new
|
||||||
|
metadata = json.loads(meta_path.read_text()) if meta_path.exists() else {}
|
||||||
|
|
||||||
|
# Merge updates
|
||||||
|
metadata.update(updates)
|
||||||
|
metadata["updated_at"] = time.time()
|
||||||
|
|
||||||
|
# Save
|
||||||
|
meta_path.write_text(json.dumps(metadata, indent=2))
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def delete_image(self, image_id: str) -> bool:
|
||||||
|
"""Delete an image and its metadata."""
|
||||||
|
img_path = self.gallery_dir / f"{image_id}.png"
|
||||||
|
meta_path = self.gallery_dir / f"{image_id}.json"
|
||||||
|
|
||||||
|
if not img_path.exists():
|
||||||
|
return False
|
||||||
|
|
||||||
|
img_path.unlink()
|
||||||
|
if meta_path.exists():
|
||||||
|
meta_path.unlink()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def save_image(
|
||||||
|
self,
|
||||||
|
image_data: bytes,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
seed: int | None = None,
|
||||||
|
) -> GalleryImage:
|
||||||
|
"""Save an image to the gallery with optional metadata."""
|
||||||
|
timestamp = int(time.time() * 1000) # milliseconds
|
||||||
|
seed_str = str(seed) if seed is not None else "0"
|
||||||
|
image_id = f"{timestamp}_{seed_str}"
|
||||||
|
|
||||||
|
img_path = self.gallery_dir / f"{image_id}.png"
|
||||||
|
img_path.write_bytes(image_data)
|
||||||
|
|
||||||
|
# Save metadata if provided
|
||||||
|
if metadata:
|
||||||
|
meta = metadata.copy()
|
||||||
|
meta["created_at"] = time.time()
|
||||||
|
meta["seed"] = seed
|
||||||
|
meta_path = img_path.with_suffix(".json")
|
||||||
|
meta_path.write_text(json.dumps(meta, indent=2))
|
||||||
|
|
||||||
|
return self._load_image(img_path) or GalleryImage(
|
||||||
|
id=image_id,
|
||||||
|
path=img_path,
|
||||||
|
created_at=time.time(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_image(self, path: Path) -> GalleryImage | None:
|
||||||
|
"""Load image info from path."""
|
||||||
|
if not path.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
image_id = path.stem
|
||||||
|
stat = path.stat()
|
||||||
|
|
||||||
|
# Try to get dimensions from metadata or PIL
|
||||||
|
width: int | None = None
|
||||||
|
height: int | None = None
|
||||||
|
metadata: dict[str, Any] = {}
|
||||||
|
|
||||||
|
meta_path = path.with_suffix(".json")
|
||||||
|
if meta_path.exists():
|
||||||
|
try:
|
||||||
|
metadata = json.loads(meta_path.read_text())
|
||||||
|
width = metadata.get("width")
|
||||||
|
height = metadata.get("height")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return GalleryImage(
|
||||||
|
id=image_id,
|
||||||
|
path=path,
|
||||||
|
created_at=stat.st_mtime,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
def count(self) -> int:
|
||||||
|
"""Count total images in gallery."""
|
||||||
|
return len(list(self.gallery_dir.glob("*.png")))
|
||||||
@@ -0,0 +1,149 @@
|
|||||||
|
"""FastAPI route handlers for image gallery endpoints."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
from pydantic import BaseModel as PydanticBaseModel
|
||||||
|
|
||||||
|
from tensors.server.gallery import Gallery
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/images", tags=["gallery"])
|
||||||
|
|
||||||
|
# Shared gallery instance
|
||||||
|
_gallery: Gallery | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_gallery() -> Gallery:
|
||||||
|
"""Get or create the gallery instance."""
|
||||||
|
global _gallery # noqa: PLW0603
|
||||||
|
if _gallery is None:
|
||||||
|
_gallery = Gallery()
|
||||||
|
return _gallery
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Request/Response Models
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataUpdate(PydanticBaseModel):
|
||||||
|
"""Request body for updating image metadata."""
|
||||||
|
|
||||||
|
tags: list[str] | None = None
|
||||||
|
notes: str | None = None
|
||||||
|
rating: int | None = None
|
||||||
|
favorite: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Gallery Endpoints
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
def list_images(
|
||||||
|
limit: int = Query(default=50, le=200, description="Max images to return"),
|
||||||
|
offset: int = Query(default=0, ge=0, description="Offset for pagination"),
|
||||||
|
newest_first: bool = Query(default=True, description="Sort newest first"),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""List images in the gallery, paginated."""
|
||||||
|
gallery = get_gallery()
|
||||||
|
images = gallery.list_images(limit=limit, offset=offset, newest_first=newest_first)
|
||||||
|
total = gallery.count()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"images": [img.to_dict() for img in images],
|
||||||
|
"total": total,
|
||||||
|
"limit": limit,
|
||||||
|
"offset": offset,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{image_id}")
|
||||||
|
def get_image(image_id: str) -> FileResponse:
|
||||||
|
"""Get an image file by ID."""
|
||||||
|
gallery = get_gallery()
|
||||||
|
image = gallery.get_image(image_id)
|
||||||
|
|
||||||
|
if not image:
|
||||||
|
raise HTTPException(status_code=404, detail="Image not found")
|
||||||
|
|
||||||
|
return FileResponse(
|
||||||
|
path=image.path,
|
||||||
|
media_type="image/png",
|
||||||
|
filename=image.path.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{image_id}/meta")
|
||||||
|
def get_image_metadata(image_id: str) -> dict[str, Any]:
|
||||||
|
"""Get metadata for an image."""
|
||||||
|
gallery = get_gallery()
|
||||||
|
image = gallery.get_image(image_id)
|
||||||
|
|
||||||
|
if not image:
|
||||||
|
raise HTTPException(status_code=404, detail="Image not found")
|
||||||
|
|
||||||
|
metadata = gallery.get_metadata(image_id) or {}
|
||||||
|
return {
|
||||||
|
"id": image_id,
|
||||||
|
"path": str(image.path),
|
||||||
|
"created_at": image.created_at,
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{image_id}/edit")
|
||||||
|
def edit_image_metadata(image_id: str, updates: MetadataUpdate) -> dict[str, Any]:
|
||||||
|
"""Update metadata for an image."""
|
||||||
|
gallery = get_gallery()
|
||||||
|
|
||||||
|
# Build update dict from non-None values
|
||||||
|
update_dict: dict[str, Any] = {}
|
||||||
|
if updates.tags is not None:
|
||||||
|
update_dict["tags"] = updates.tags
|
||||||
|
if updates.notes is not None:
|
||||||
|
update_dict["notes"] = updates.notes
|
||||||
|
if updates.rating is not None:
|
||||||
|
update_dict["rating"] = updates.rating
|
||||||
|
if updates.favorite is not None:
|
||||||
|
update_dict["favorite"] = updates.favorite
|
||||||
|
|
||||||
|
result = gallery.update_metadata(image_id, update_dict)
|
||||||
|
if result is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Image not found")
|
||||||
|
|
||||||
|
return {"id": image_id, "metadata": result}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{image_id}")
|
||||||
|
def delete_image(image_id: str) -> dict[str, Any]:
|
||||||
|
"""Delete an image and its metadata."""
|
||||||
|
gallery = get_gallery()
|
||||||
|
deleted = gallery.delete_image(image_id)
|
||||||
|
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(status_code=404, detail="Image not found")
|
||||||
|
|
||||||
|
return {"deleted": True, "id": image_id}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/stats/summary")
|
||||||
|
def gallery_stats() -> dict[str, Any]:
|
||||||
|
"""Get gallery statistics."""
|
||||||
|
gallery = get_gallery()
|
||||||
|
return {
|
||||||
|
"total_images": gallery.count(),
|
||||||
|
"gallery_dir": str(gallery.gallery_dir),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_gallery_router() -> APIRouter:
|
||||||
|
"""Return the gallery API router."""
|
||||||
|
return router
|
||||||
@@ -0,0 +1,199 @@
|
|||||||
|
"""FastAPI route handlers for image generation with gallery integration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from pydantic import BaseModel as PydanticBaseModel
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from tensors.server.gallery import Gallery
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from tensors.server.process import ProcessManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Request/Response Models
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateRequest(PydanticBaseModel):
|
||||||
|
"""Request body for image generation."""
|
||||||
|
|
||||||
|
prompt: str
|
||||||
|
negative_prompt: str = ""
|
||||||
|
width: int = Field(default=512, ge=64, le=2048)
|
||||||
|
height: int = Field(default=512, ge=64, le=2048)
|
||||||
|
steps: int = Field(default=20, ge=1, le=150)
|
||||||
|
cfg_scale: float = Field(default=7.0, ge=0, le=30)
|
||||||
|
seed: int = -1
|
||||||
|
sampler_name: str = ""
|
||||||
|
scheduler: str = ""
|
||||||
|
batch_size: int = Field(default=1, ge=1, le=16)
|
||||||
|
save_to_gallery: bool = True
|
||||||
|
return_base64: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Helper Functions
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def _build_sd_request(req: GenerateRequest) -> dict[str, Any]:
|
||||||
|
"""Build request body for sd-server."""
|
||||||
|
body: dict[str, Any] = {
|
||||||
|
"prompt": req.prompt,
|
||||||
|
"negative_prompt": req.negative_prompt,
|
||||||
|
"width": req.width,
|
||||||
|
"height": req.height,
|
||||||
|
"steps": req.steps,
|
||||||
|
"cfg_scale": req.cfg_scale,
|
||||||
|
"seed": req.seed,
|
||||||
|
"batch_size": req.batch_size,
|
||||||
|
}
|
||||||
|
if req.sampler_name:
|
||||||
|
body["sampler_name"] = req.sampler_name
|
||||||
|
if req.scheduler:
|
||||||
|
body["scheduler"] = req.scheduler
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_info(info: Any) -> dict[str, Any]:
|
||||||
|
"""Parse info from sd-server response."""
|
||||||
|
if isinstance(info, str):
|
||||||
|
import json # noqa: PLC0415
|
||||||
|
|
||||||
|
try:
|
||||||
|
return dict(json.loads(info))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return {"raw": info}
|
||||||
|
return info if isinstance(info, dict) else {}
|
||||||
|
|
||||||
|
|
||||||
|
def _process_image(
|
||||||
|
img_b64: str,
|
||||||
|
index: int,
|
||||||
|
seed: int,
|
||||||
|
req: GenerateRequest,
|
||||||
|
gallery: Gallery,
|
||||||
|
model: str | None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Process a single generated image."""
|
||||||
|
image_bytes = base64.b64decode(img_b64)
|
||||||
|
image_info: dict[str, Any] = {"index": index, "seed": seed}
|
||||||
|
|
||||||
|
if req.save_to_gallery:
|
||||||
|
metadata = {
|
||||||
|
"prompt": req.prompt,
|
||||||
|
"negative_prompt": req.negative_prompt,
|
||||||
|
"width": req.width,
|
||||||
|
"height": req.height,
|
||||||
|
"steps": req.steps,
|
||||||
|
"cfg_scale": req.cfg_scale,
|
||||||
|
"sampler": req.sampler_name,
|
||||||
|
"scheduler": req.scheduler,
|
||||||
|
"model": model,
|
||||||
|
"generated_at": time.time(),
|
||||||
|
}
|
||||||
|
gallery_img = gallery.save_image(image_bytes, metadata=metadata, seed=seed)
|
||||||
|
image_info["id"] = gallery_img.id
|
||||||
|
image_info["path"] = str(gallery_img.path)
|
||||||
|
|
||||||
|
if req.return_base64:
|
||||||
|
image_info["base64"] = img_b64
|
||||||
|
|
||||||
|
return image_info
|
||||||
|
|
||||||
|
|
||||||
|
def _check_server_running(pm: ProcessManager) -> None:
|
||||||
|
"""Check if sd-server is running, raise HTTPException if not."""
|
||||||
|
if pm.proc is None or pm.proc.poll() is not None:
|
||||||
|
raise HTTPException(status_code=503, detail="sd-server is not running")
|
||||||
|
if pm.config is None:
|
||||||
|
raise HTTPException(status_code=503, detail="sd-server not configured")
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Router Factory
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def create_generate_router(pm: ProcessManager) -> APIRouter:
|
||||||
|
"""Build a router with /api/generate endpoint."""
|
||||||
|
router = APIRouter(prefix="/api", tags=["generate"])
|
||||||
|
gallery = Gallery()
|
||||||
|
|
||||||
|
@router.post("/generate")
|
||||||
|
async def generate(req: GenerateRequest) -> dict[str, Any]:
|
||||||
|
"""Generate images with gallery integration."""
|
||||||
|
_check_server_running(pm)
|
||||||
|
assert pm.config is not None # Verified by _check_server_running
|
||||||
|
|
||||||
|
body = _build_sd_request(req)
|
||||||
|
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/txt2img"
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=300) as client:
|
||||||
|
response = await client.post(url, json=body)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
logger.exception("Generation failed")
|
||||||
|
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
||||||
|
|
||||||
|
images_data = result.get("images", [])
|
||||||
|
info = _parse_info(result.get("info", {}))
|
||||||
|
all_seeds = info.get("all_seeds", [req.seed] * len(images_data))
|
||||||
|
|
||||||
|
output_images = [
|
||||||
|
_process_image(img_b64, i, all_seeds[i] if i < len(all_seeds) else req.seed + i, req, gallery, pm.config.model)
|
||||||
|
for i, img_b64 in enumerate(images_data)
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"images": output_images,
|
||||||
|
"parameters": result.get("parameters", body),
|
||||||
|
"info": info,
|
||||||
|
"saved_to_gallery": req.save_to_gallery,
|
||||||
|
"total": len(output_images),
|
||||||
|
}
|
||||||
|
|
||||||
|
@router.get("/samplers")
|
||||||
|
async def list_samplers() -> dict[str, Any]:
|
||||||
|
"""List available samplers from sd-server."""
|
||||||
|
_check_server_running(pm)
|
||||||
|
assert pm.config is not None
|
||||||
|
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/samplers"
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=30) as client:
|
||||||
|
response = await client.get(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
return {"samplers": response.json()}
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
||||||
|
|
||||||
|
@router.get("/schedulers")
|
||||||
|
async def list_schedulers() -> dict[str, Any]:
|
||||||
|
"""List available schedulers from sd-server."""
|
||||||
|
_check_server_running(pm)
|
||||||
|
assert pm.config is not None
|
||||||
|
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/schedulers"
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=30) as client:
|
||||||
|
response = await client.get(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
return {"schedulers": response.json()}
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
||||||
|
|
||||||
|
return router
|
||||||
@@ -0,0 +1,171 @@
|
|||||||
|
"""FastAPI route handlers for model management endpoints."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from pydantic import BaseModel as PydanticBaseModel
|
||||||
|
|
||||||
|
from tensors.config import MODELS_DIR
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from tensors.server.process import ProcessManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Request/Response Models
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class SwitchModelRequest(PydanticBaseModel):
|
||||||
|
"""Request body for switching models."""
|
||||||
|
|
||||||
|
model: str # Path to model file
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Helper Functions
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def scan_models(directory: Path, extensions: tuple[str, ...] = (".safetensors", ".gguf")) -> list[dict[str, Any]]:
|
||||||
|
"""Scan directory for model files."""
|
||||||
|
models: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
if not directory.exists():
|
||||||
|
return models
|
||||||
|
|
||||||
|
for ext in extensions:
|
||||||
|
for path in directory.rglob(f"*{ext}"):
|
||||||
|
stat = path.stat()
|
||||||
|
models.append(
|
||||||
|
{
|
||||||
|
"name": path.stem,
|
||||||
|
"path": str(path),
|
||||||
|
"filename": path.name,
|
||||||
|
"size_mb": round(stat.st_size / (1024 * 1024), 2),
|
||||||
|
"modified": stat.st_mtime,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sort by name
|
||||||
|
models.sort(key=lambda x: x["name"].lower())
|
||||||
|
return models
|
||||||
|
|
||||||
|
|
||||||
|
def scan_loras(directory: Path | None = None) -> list[dict[str, Any]]:
|
||||||
|
"""Scan for LoRA files."""
|
||||||
|
lora_dir = directory or MODELS_DIR / "loras"
|
||||||
|
return scan_models(lora_dir, extensions=(".safetensors",))
|
||||||
|
|
||||||
|
|
||||||
|
def scan_checkpoints(directory: Path | None = None) -> list[dict[str, Any]]:
|
||||||
|
"""Scan for checkpoint files."""
|
||||||
|
checkpoint_dir = directory or MODELS_DIR / "checkpoints"
|
||||||
|
return scan_models(checkpoint_dir, extensions=(".safetensors", ".gguf"))
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Router Factory
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def create_models_router(pm: ProcessManager) -> APIRouter:
|
||||||
|
"""Build a router with /api/models/* endpoints."""
|
||||||
|
router = APIRouter(prefix="/api/models", tags=["models"])
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
def list_models() -> dict[str, Any]:
|
||||||
|
"""List available checkpoint models."""
|
||||||
|
checkpoints = scan_checkpoints()
|
||||||
|
return {
|
||||||
|
"models": checkpoints,
|
||||||
|
"total": len(checkpoints),
|
||||||
|
}
|
||||||
|
|
||||||
|
@router.get("/active")
|
||||||
|
def get_active_model() -> dict[str, Any]:
|
||||||
|
"""Get information about the currently loaded model."""
|
||||||
|
status = pm.status()
|
||||||
|
config = pm.config
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
return {
|
||||||
|
"loaded": False,
|
||||||
|
"model": None,
|
||||||
|
"status": status.get("status"),
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"loaded": status.get("status") == "running",
|
||||||
|
"model": config.model,
|
||||||
|
"pid": status.get("pid"),
|
||||||
|
"port": config.port,
|
||||||
|
"status": status.get("status"),
|
||||||
|
}
|
||||||
|
|
||||||
|
@router.post("/switch")
|
||||||
|
async def switch_model(req: SwitchModelRequest) -> JSONResponse:
|
||||||
|
"""Switch to a different model (hot reload)."""
|
||||||
|
model_path = Path(req.model)
|
||||||
|
|
||||||
|
# Validate model exists
|
||||||
|
if not model_path.exists():
|
||||||
|
raise HTTPException(status_code=400, detail=f"Model not found: {req.model}")
|
||||||
|
|
||||||
|
# Use existing reload logic
|
||||||
|
from tensors.server.models import ServerConfig # noqa: PLC0415
|
||||||
|
|
||||||
|
new_config = ServerConfig(
|
||||||
|
model=req.model,
|
||||||
|
port=pm.config.port if pm.config else 1234,
|
||||||
|
args=pm.config.args if pm.config else [],
|
||||||
|
)
|
||||||
|
|
||||||
|
pm.stop()
|
||||||
|
pm.start(new_config)
|
||||||
|
ready = await pm.wait_ready()
|
||||||
|
|
||||||
|
if not ready:
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "sd-server failed to become ready", "model": req.model},
|
||||||
|
status_code=503,
|
||||||
|
)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
{
|
||||||
|
"ok": True,
|
||||||
|
"model": req.model,
|
||||||
|
"pid": pm.proc.pid if pm.proc else None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/loras")
|
||||||
|
def list_loras() -> dict[str, Any]:
|
||||||
|
"""List available LoRA files."""
|
||||||
|
loras = scan_loras()
|
||||||
|
return {
|
||||||
|
"loras": loras,
|
||||||
|
"total": len(loras),
|
||||||
|
}
|
||||||
|
|
||||||
|
@router.get("/scan")
|
||||||
|
def scan_all_models() -> dict[str, Any]:
|
||||||
|
"""Scan all model directories."""
|
||||||
|
checkpoints = scan_checkpoints()
|
||||||
|
loras = scan_loras()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"checkpoints": checkpoints,
|
||||||
|
"loras": loras,
|
||||||
|
"total_checkpoints": len(checkpoints),
|
||||||
|
"total_loras": len(loras),
|
||||||
|
}
|
||||||
|
|
||||||
|
return router
|
||||||
Reference in New Issue
Block a user