Phase 3: Enhanced Server API
Add comprehensive server API for remote operations:
Step 3.1 - Gallery endpoints:
- GET /api/images - List images (paginated, newest first)
- GET /api/images/{id} - Get image file
- GET /api/images/{id}/meta - Get generation metadata
- POST /api/images/{id}/edit - Update metadata (tags, notes)
- DELETE /api/images/{id} - Delete image + sidecar
- Gallery module with image management and sidecar JSON support
Step 3.2 - Model management:
- GET /api/models - List available checkpoints
- GET /api/models/active - Current loaded model info
- POST /api/models/switch - Switch model (hot reload)
- GET /api/models/loras - List available LoRAs
- GET /api/models/scan - Scan all model directories
Step 3.3 - Download proxy:
- POST /api/download - Start background download from CivitAI
- GET /api/download/status/{id} - Check download progress
- GET /api/download/active - List active downloads
- Auto-scan and link files after download
Step 3.4 - Enhanced generation:
- POST /api/generate - Generate with gallery integration
- Saves images to gallery with metadata sidecar
- Supports all sd-server params
- GET /api/samplers, /api/schedulers - List options
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -11,10 +11,10 @@
|
||||
- [x] Step 2.3: Add `/api/db/*` endpoints (files, models, triggers, scan, link, cache, stats)
|
||||
|
||||
## Phase 3: Enhanced Server API
|
||||
- [ ] Step 3.1: Add `/api/images` gallery endpoints (list, get, delete, edit)
|
||||
- [ ] Step 3.2: Add `/api/models` endpoints (list, active, switch, loras)
|
||||
- [ ] Step 3.3: Add `/api/download` endpoint (CivitAI proxy download)
|
||||
- [ ] Step 3.4: Enhance `/api/generate` (gallery integration, full params)
|
||||
- [x] Step 3.1: Add `/api/images` gallery endpoints (list, get, delete, edit)
|
||||
- [x] Step 3.2: Add `/api/models` endpoints (list, active, switch, loras)
|
||||
- [x] Step 3.3: Add `/api/download` endpoint (CivitAI proxy download)
|
||||
- [x] Step 3.4: Enhance `/api/generate` (gallery integration, full params)
|
||||
|
||||
## Phase 4: Client Mode for tsr CLI
|
||||
- [ ] Step 4.1: Create `tensors/client.py` (TsrClient HTTP wrapper)
|
||||
|
||||
@@ -20,6 +20,7 @@ CONFIG_FILE = CONFIG_DIR / "config.toml"
|
||||
DATA_DIR = Path(os.environ.get("XDG_DATA_HOME", Path.home() / ".local" / "share")) / "tensors"
|
||||
MODELS_DIR = DATA_DIR / "models"
|
||||
METADATA_DIR = DATA_DIR / "metadata"
|
||||
GALLERY_DIR = DATA_DIR / "gallery"
|
||||
|
||||
# Legacy config for migration
|
||||
LEGACY_RC_FILE = Path.home() / ".sftrc"
|
||||
|
||||
@@ -10,7 +10,11 @@ import httpx
|
||||
from fastapi import FastAPI
|
||||
|
||||
from tensors.server.db_routes import create_db_router
|
||||
from tensors.server.download_routes import create_download_router
|
||||
from tensors.server.gallery_routes import create_gallery_router
|
||||
from tensors.server.generate_routes import create_generate_router
|
||||
from tensors.server.models import ServerConfig
|
||||
from tensors.server.models_routes import create_models_router
|
||||
from tensors.server.process import ProcessManager
|
||||
from tensors.server.routes import create_router
|
||||
|
||||
@@ -42,7 +46,11 @@ def create_app(config: ServerConfig | None = None) -> FastAPI:
|
||||
pm.stop()
|
||||
|
||||
app = FastAPI(title="sd-server wrapper", lifespan=lifespan)
|
||||
app.include_router(create_db_router()) # Must be first to avoid catch-all conflict
|
||||
app.include_router(create_db_router()) # Must be before catch-all proxy
|
||||
app.include_router(create_gallery_router()) # Must be before catch-all proxy
|
||||
app.include_router(create_models_router(pm)) # Must be before catch-all proxy
|
||||
app.include_router(create_download_router()) # Must be before catch-all proxy
|
||||
app.include_router(create_generate_router(pm)) # Must be before catch-all proxy
|
||||
app.include_router(create_router(pm))
|
||||
app.state.pm = pm
|
||||
return app
|
||||
|
||||
@@ -0,0 +1,230 @@
|
||||
"""FastAPI route handlers for CivitAI download proxy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
|
||||
from tensors.api import download_model, fetch_civitai_by_hash, fetch_civitai_model, fetch_civitai_model_version
|
||||
from tensors.config import MODELS_DIR, load_api_key
|
||||
from tensors.db import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/download", tags=["download"])
|
||||
|
||||
# Track active downloads
|
||||
_active_downloads: dict[str, dict[str, Any]] = {}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Request/Response Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class DownloadRequest(PydanticBaseModel):
|
||||
"""Request body for downloading a model."""
|
||||
|
||||
version_id: int | None = None
|
||||
model_id: int | None = None
|
||||
hash: str | None = None
|
||||
output_dir: str | None = None # Override default path
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _resolve_version_id(
|
||||
version_id: int | None,
|
||||
model_id: int | None,
|
||||
hash_val: str | None,
|
||||
api_key: str | None,
|
||||
) -> tuple[int | None, dict[str, Any] | None]:
|
||||
"""Resolve version ID and get version info."""
|
||||
if version_id:
|
||||
info = fetch_civitai_model_version(version_id, api_key)
|
||||
return version_id, info
|
||||
|
||||
if hash_val:
|
||||
info = fetch_civitai_by_hash(hash_val.upper(), api_key)
|
||||
if info:
|
||||
return info.get("id"), info
|
||||
return None, None
|
||||
|
||||
if model_id:
|
||||
model_data = fetch_civitai_model(model_id, api_key)
|
||||
if model_data:
|
||||
versions = model_data.get("modelVersions", [])
|
||||
if versions:
|
||||
latest = versions[0]
|
||||
return latest.get("id"), latest
|
||||
return None, None
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def _get_output_dir(version_info: dict[str, Any], override: str | None) -> Path:
|
||||
"""Determine output directory based on model type."""
|
||||
if override:
|
||||
return Path(override)
|
||||
|
||||
model_type = version_info.get("model", {}).get("type", "Checkpoint")
|
||||
|
||||
# Map type to directory
|
||||
type_dirs = {
|
||||
"Checkpoint": MODELS_DIR / "checkpoints",
|
||||
"LORA": MODELS_DIR / "loras",
|
||||
"LoCon": MODELS_DIR / "loras",
|
||||
"TextualInversion": MODELS_DIR / "embeddings",
|
||||
"VAE": MODELS_DIR / "vae",
|
||||
"Controlnet": MODELS_DIR / "controlnet",
|
||||
}
|
||||
|
||||
return type_dirs.get(model_type, MODELS_DIR / "other")
|
||||
|
||||
|
||||
def _do_download(
|
||||
version_id: int,
|
||||
dest_path: Path,
|
||||
api_key: str | None,
|
||||
download_id: str,
|
||||
) -> None:
|
||||
"""Background task to perform the download."""
|
||||
try:
|
||||
_active_downloads[download_id]["status"] = "downloading"
|
||||
|
||||
# Create a mock console for download progress
|
||||
from io import StringIO # noqa: PLC0415
|
||||
|
||||
from rich.console import Console # noqa: PLC0415
|
||||
|
||||
output = StringIO()
|
||||
console = Console(file=output, force_terminal=False)
|
||||
|
||||
success = download_model(version_id, dest_path, api_key, console, resume=True)
|
||||
|
||||
if success:
|
||||
_active_downloads[download_id]["status"] = "completed"
|
||||
_active_downloads[download_id]["path"] = str(dest_path)
|
||||
|
||||
# Auto-scan and link the downloaded file
|
||||
_auto_link_file(dest_path, api_key)
|
||||
else:
|
||||
_active_downloads[download_id]["status"] = "failed"
|
||||
_active_downloads[download_id]["error"] = "Download failed"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Download failed")
|
||||
_active_downloads[download_id]["status"] = "failed"
|
||||
_active_downloads[download_id]["error"] = str(e)
|
||||
|
||||
|
||||
def _auto_link_file(file_path: Path, api_key: str | None) -> None:
|
||||
"""Auto-scan and link the downloaded file to CivitAI."""
|
||||
try:
|
||||
with Database() as db:
|
||||
db.init_schema()
|
||||
# Scan the single file
|
||||
results = db.scan_directory(file_path.parent)
|
||||
|
||||
# Find and link the new file
|
||||
for result in results:
|
||||
if result["file_path"] == str(file_path):
|
||||
sha256 = result["sha256"]
|
||||
civitai_data = fetch_civitai_by_hash(sha256, api_key)
|
||||
if civitai_data:
|
||||
version_id = civitai_data.get("id", 0)
|
||||
model_id = civitai_data.get("modelId", 0)
|
||||
if version_id and model_id:
|
||||
db.link_file_to_civitai(result["id"], model_id, version_id)
|
||||
except Exception:
|
||||
logger.exception("Auto-link failed")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Download Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("")
|
||||
def start_download(req: DownloadRequest, background_tasks: BackgroundTasks) -> dict[str, Any]:
|
||||
"""Start a model download (async with progress tracking)."""
|
||||
api_key = load_api_key()
|
||||
|
||||
# Resolve version ID
|
||||
version_id, version_info = _resolve_version_id(
|
||||
req.version_id,
|
||||
req.model_id,
|
||||
req.hash,
|
||||
api_key,
|
||||
)
|
||||
|
||||
if not version_id or not version_info:
|
||||
raise HTTPException(status_code=404, detail="Model/version not found on CivitAI")
|
||||
|
||||
# Get output directory
|
||||
output_dir = _get_output_dir(version_info, req.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get filename from version info
|
||||
files = version_info.get("files", [])
|
||||
primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
|
||||
|
||||
if not primary_file:
|
||||
raise HTTPException(status_code=400, detail="No files found for this version")
|
||||
|
||||
filename = primary_file.get("name", f"model-{version_id}.safetensors")
|
||||
dest_path = output_dir / filename
|
||||
|
||||
# Create download tracking entry
|
||||
download_id = f"{version_id}_{int(__import__('time').time())}"
|
||||
_active_downloads[download_id] = {
|
||||
"id": download_id,
|
||||
"version_id": version_id,
|
||||
"status": "queued",
|
||||
"path": str(dest_path),
|
||||
"filename": filename,
|
||||
"model_name": version_info.get("model", {}).get("name", "Unknown"),
|
||||
"version_name": version_info.get("name", "Unknown"),
|
||||
}
|
||||
|
||||
# Start background download
|
||||
background_tasks.add_task(_do_download, version_id, dest_path, api_key, download_id)
|
||||
|
||||
return {
|
||||
"download_id": download_id,
|
||||
"status": "queued",
|
||||
"version_id": version_id,
|
||||
"destination": str(dest_path),
|
||||
"model_name": version_info.get("model", {}).get("name"),
|
||||
"version_name": version_info.get("name"),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/status/{download_id}")
|
||||
def get_download_status(download_id: str) -> dict[str, Any]:
|
||||
"""Get status of a download."""
|
||||
if download_id not in _active_downloads:
|
||||
raise HTTPException(status_code=404, detail="Download not found")
|
||||
|
||||
return _active_downloads[download_id]
|
||||
|
||||
|
||||
@router.get("/active")
|
||||
def list_active_downloads() -> dict[str, Any]:
|
||||
"""List all active/recent downloads."""
|
||||
return {
|
||||
"downloads": list(_active_downloads.values()),
|
||||
"total": len(_active_downloads),
|
||||
}
|
||||
|
||||
|
||||
def create_download_router() -> APIRouter:
|
||||
"""Return the download API router."""
|
||||
return router
|
||||
@@ -0,0 +1,183 @@
|
||||
"""Image gallery management for generated images."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from tensors.config import GALLERY_DIR
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class GalleryImage:
|
||||
"""Represents an image in the gallery."""
|
||||
|
||||
id: str
|
||||
path: Path
|
||||
created_at: float
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def meta_path(self) -> Path:
|
||||
"""Path to the sidecar metadata JSON file."""
|
||||
return self.path.with_suffix(".json")
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for API response."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"path": str(self.path),
|
||||
"filename": self.path.name,
|
||||
"created_at": self.created_at,
|
||||
"width": self.width,
|
||||
"height": self.height,
|
||||
"has_metadata": self.meta_path.exists(),
|
||||
}
|
||||
|
||||
|
||||
class Gallery:
|
||||
"""Manages the image gallery directory."""
|
||||
|
||||
def __init__(self, gallery_dir: Path | None = None) -> None:
|
||||
"""Initialize gallery with directory path."""
|
||||
self.gallery_dir = gallery_dir or GALLERY_DIR
|
||||
self.gallery_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def list_images(
|
||||
self,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
newest_first: bool = True,
|
||||
) -> list[GalleryImage]:
|
||||
"""List images in the gallery, paginated."""
|
||||
images: list[GalleryImage] = []
|
||||
|
||||
for path in self.gallery_dir.glob("*.png"):
|
||||
img = self._load_image(path)
|
||||
if img:
|
||||
images.append(img)
|
||||
|
||||
# Sort by creation time
|
||||
images.sort(key=lambda x: x.created_at, reverse=newest_first)
|
||||
|
||||
# Apply pagination
|
||||
return images[offset : offset + limit]
|
||||
|
||||
def get_image(self, image_id: str) -> GalleryImage | None:
|
||||
"""Get an image by ID."""
|
||||
# ID is the filename stem
|
||||
path = self.gallery_dir / f"{image_id}.png"
|
||||
if not path.exists():
|
||||
return None
|
||||
return self._load_image(path)
|
||||
|
||||
def get_metadata(self, image_id: str) -> dict[str, Any] | None:
|
||||
"""Get metadata for an image."""
|
||||
meta_path = self.gallery_dir / f"{image_id}.json"
|
||||
if not meta_path.exists():
|
||||
return None
|
||||
result: dict[str, Any] = json.loads(meta_path.read_text())
|
||||
return result
|
||||
|
||||
def update_metadata(self, image_id: str, updates: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Update metadata for an image (merge with existing)."""
|
||||
meta_path = self.gallery_dir / f"{image_id}.json"
|
||||
img_path = self.gallery_dir / f"{image_id}.png"
|
||||
|
||||
if not img_path.exists():
|
||||
return None
|
||||
|
||||
# Load existing or create new
|
||||
metadata = json.loads(meta_path.read_text()) if meta_path.exists() else {}
|
||||
|
||||
# Merge updates
|
||||
metadata.update(updates)
|
||||
metadata["updated_at"] = time.time()
|
||||
|
||||
# Save
|
||||
meta_path.write_text(json.dumps(metadata, indent=2))
|
||||
return metadata
|
||||
|
||||
def delete_image(self, image_id: str) -> bool:
|
||||
"""Delete an image and its metadata."""
|
||||
img_path = self.gallery_dir / f"{image_id}.png"
|
||||
meta_path = self.gallery_dir / f"{image_id}.json"
|
||||
|
||||
if not img_path.exists():
|
||||
return False
|
||||
|
||||
img_path.unlink()
|
||||
if meta_path.exists():
|
||||
meta_path.unlink()
|
||||
|
||||
return True
|
||||
|
||||
def save_image(
|
||||
self,
|
||||
image_data: bytes,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
seed: int | None = None,
|
||||
) -> GalleryImage:
|
||||
"""Save an image to the gallery with optional metadata."""
|
||||
timestamp = int(time.time() * 1000) # milliseconds
|
||||
seed_str = str(seed) if seed is not None else "0"
|
||||
image_id = f"{timestamp}_{seed_str}"
|
||||
|
||||
img_path = self.gallery_dir / f"{image_id}.png"
|
||||
img_path.write_bytes(image_data)
|
||||
|
||||
# Save metadata if provided
|
||||
if metadata:
|
||||
meta = metadata.copy()
|
||||
meta["created_at"] = time.time()
|
||||
meta["seed"] = seed
|
||||
meta_path = img_path.with_suffix(".json")
|
||||
meta_path.write_text(json.dumps(meta, indent=2))
|
||||
|
||||
return self._load_image(img_path) or GalleryImage(
|
||||
id=image_id,
|
||||
path=img_path,
|
||||
created_at=time.time(),
|
||||
)
|
||||
|
||||
def _load_image(self, path: Path) -> GalleryImage | None:
|
||||
"""Load image info from path."""
|
||||
if not path.exists():
|
||||
return None
|
||||
|
||||
image_id = path.stem
|
||||
stat = path.stat()
|
||||
|
||||
# Try to get dimensions from metadata or PIL
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
meta_path = path.with_suffix(".json")
|
||||
if meta_path.exists():
|
||||
try:
|
||||
metadata = json.loads(meta_path.read_text())
|
||||
width = metadata.get("width")
|
||||
height = metadata.get("height")
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return GalleryImage(
|
||||
id=image_id,
|
||||
path=path,
|
||||
created_at=stat.st_mtime,
|
||||
width=width,
|
||||
height=height,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def count(self) -> int:
|
||||
"""Count total images in gallery."""
|
||||
return len(list(self.gallery_dir.glob("*.png")))
|
||||
@@ -0,0 +1,149 @@
|
||||
"""FastAPI route handlers for image gallery endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
|
||||
from tensors.server.gallery import Gallery
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/images", tags=["gallery"])
|
||||
|
||||
# Shared gallery instance
|
||||
_gallery: Gallery | None = None
|
||||
|
||||
|
||||
def get_gallery() -> Gallery:
|
||||
"""Get or create the gallery instance."""
|
||||
global _gallery # noqa: PLW0603
|
||||
if _gallery is None:
|
||||
_gallery = Gallery()
|
||||
return _gallery
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Request/Response Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class MetadataUpdate(PydanticBaseModel):
|
||||
"""Request body for updating image metadata."""
|
||||
|
||||
tags: list[str] | None = None
|
||||
notes: str | None = None
|
||||
rating: int | None = None
|
||||
favorite: bool | None = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Gallery Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list_images(
|
||||
limit: int = Query(default=50, le=200, description="Max images to return"),
|
||||
offset: int = Query(default=0, ge=0, description="Offset for pagination"),
|
||||
newest_first: bool = Query(default=True, description="Sort newest first"),
|
||||
) -> dict[str, Any]:
|
||||
"""List images in the gallery, paginated."""
|
||||
gallery = get_gallery()
|
||||
images = gallery.list_images(limit=limit, offset=offset, newest_first=newest_first)
|
||||
total = gallery.count()
|
||||
|
||||
return {
|
||||
"images": [img.to_dict() for img in images],
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{image_id}")
|
||||
def get_image(image_id: str) -> FileResponse:
|
||||
"""Get an image file by ID."""
|
||||
gallery = get_gallery()
|
||||
image = gallery.get_image(image_id)
|
||||
|
||||
if not image:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
return FileResponse(
|
||||
path=image.path,
|
||||
media_type="image/png",
|
||||
filename=image.path.name,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{image_id}/meta")
|
||||
def get_image_metadata(image_id: str) -> dict[str, Any]:
|
||||
"""Get metadata for an image."""
|
||||
gallery = get_gallery()
|
||||
image = gallery.get_image(image_id)
|
||||
|
||||
if not image:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
metadata = gallery.get_metadata(image_id) or {}
|
||||
return {
|
||||
"id": image_id,
|
||||
"path": str(image.path),
|
||||
"created_at": image.created_at,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{image_id}/edit")
|
||||
def edit_image_metadata(image_id: str, updates: MetadataUpdate) -> dict[str, Any]:
|
||||
"""Update metadata for an image."""
|
||||
gallery = get_gallery()
|
||||
|
||||
# Build update dict from non-None values
|
||||
update_dict: dict[str, Any] = {}
|
||||
if updates.tags is not None:
|
||||
update_dict["tags"] = updates.tags
|
||||
if updates.notes is not None:
|
||||
update_dict["notes"] = updates.notes
|
||||
if updates.rating is not None:
|
||||
update_dict["rating"] = updates.rating
|
||||
if updates.favorite is not None:
|
||||
update_dict["favorite"] = updates.favorite
|
||||
|
||||
result = gallery.update_metadata(image_id, update_dict)
|
||||
if result is None:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
return {"id": image_id, "metadata": result}
|
||||
|
||||
|
||||
@router.delete("/{image_id}")
|
||||
def delete_image(image_id: str) -> dict[str, Any]:
|
||||
"""Delete an image and its metadata."""
|
||||
gallery = get_gallery()
|
||||
deleted = gallery.delete_image(image_id)
|
||||
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
return {"deleted": True, "id": image_id}
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
def gallery_stats() -> dict[str, Any]:
|
||||
"""Get gallery statistics."""
|
||||
gallery = get_gallery()
|
||||
return {
|
||||
"total_images": gallery.count(),
|
||||
"gallery_dir": str(gallery.gallery_dir),
|
||||
}
|
||||
|
||||
|
||||
def create_gallery_router() -> APIRouter:
|
||||
"""Return the gallery API router."""
|
||||
return router
|
||||
@@ -0,0 +1,199 @@
|
||||
"""FastAPI route handlers for image generation with gallery integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from tensors.server.gallery import Gallery
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensors.server.process import ProcessManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Request/Response Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class GenerateRequest(PydanticBaseModel):
|
||||
"""Request body for image generation."""
|
||||
|
||||
prompt: str
|
||||
negative_prompt: str = ""
|
||||
width: int = Field(default=512, ge=64, le=2048)
|
||||
height: int = Field(default=512, ge=64, le=2048)
|
||||
steps: int = Field(default=20, ge=1, le=150)
|
||||
cfg_scale: float = Field(default=7.0, ge=0, le=30)
|
||||
seed: int = -1
|
||||
sampler_name: str = ""
|
||||
scheduler: str = ""
|
||||
batch_size: int = Field(default=1, ge=1, le=16)
|
||||
save_to_gallery: bool = True
|
||||
return_base64: bool = False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _build_sd_request(req: GenerateRequest) -> dict[str, Any]:
|
||||
"""Build request body for sd-server."""
|
||||
body: dict[str, Any] = {
|
||||
"prompt": req.prompt,
|
||||
"negative_prompt": req.negative_prompt,
|
||||
"width": req.width,
|
||||
"height": req.height,
|
||||
"steps": req.steps,
|
||||
"cfg_scale": req.cfg_scale,
|
||||
"seed": req.seed,
|
||||
"batch_size": req.batch_size,
|
||||
}
|
||||
if req.sampler_name:
|
||||
body["sampler_name"] = req.sampler_name
|
||||
if req.scheduler:
|
||||
body["scheduler"] = req.scheduler
|
||||
return body
|
||||
|
||||
|
||||
def _parse_info(info: Any) -> dict[str, Any]:
|
||||
"""Parse info from sd-server response."""
|
||||
if isinstance(info, str):
|
||||
import json # noqa: PLC0415
|
||||
|
||||
try:
|
||||
return dict(json.loads(info))
|
||||
except json.JSONDecodeError:
|
||||
return {"raw": info}
|
||||
return info if isinstance(info, dict) else {}
|
||||
|
||||
|
||||
def _process_image(
|
||||
img_b64: str,
|
||||
index: int,
|
||||
seed: int,
|
||||
req: GenerateRequest,
|
||||
gallery: Gallery,
|
||||
model: str | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Process a single generated image."""
|
||||
image_bytes = base64.b64decode(img_b64)
|
||||
image_info: dict[str, Any] = {"index": index, "seed": seed}
|
||||
|
||||
if req.save_to_gallery:
|
||||
metadata = {
|
||||
"prompt": req.prompt,
|
||||
"negative_prompt": req.negative_prompt,
|
||||
"width": req.width,
|
||||
"height": req.height,
|
||||
"steps": req.steps,
|
||||
"cfg_scale": req.cfg_scale,
|
||||
"sampler": req.sampler_name,
|
||||
"scheduler": req.scheduler,
|
||||
"model": model,
|
||||
"generated_at": time.time(),
|
||||
}
|
||||
gallery_img = gallery.save_image(image_bytes, metadata=metadata, seed=seed)
|
||||
image_info["id"] = gallery_img.id
|
||||
image_info["path"] = str(gallery_img.path)
|
||||
|
||||
if req.return_base64:
|
||||
image_info["base64"] = img_b64
|
||||
|
||||
return image_info
|
||||
|
||||
|
||||
def _check_server_running(pm: ProcessManager) -> None:
|
||||
"""Check if sd-server is running, raise HTTPException if not."""
|
||||
if pm.proc is None or pm.proc.poll() is not None:
|
||||
raise HTTPException(status_code=503, detail="sd-server is not running")
|
||||
if pm.config is None:
|
||||
raise HTTPException(status_code=503, detail="sd-server not configured")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Router Factory
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_generate_router(pm: ProcessManager) -> APIRouter:
|
||||
"""Build a router with /api/generate endpoint."""
|
||||
router = APIRouter(prefix="/api", tags=["generate"])
|
||||
gallery = Gallery()
|
||||
|
||||
@router.post("/generate")
|
||||
async def generate(req: GenerateRequest) -> dict[str, Any]:
|
||||
"""Generate images with gallery integration."""
|
||||
_check_server_running(pm)
|
||||
assert pm.config is not None # Verified by _check_server_running
|
||||
|
||||
body = _build_sd_request(req)
|
||||
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/txt2img"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
response = await client.post(url, json=body)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
except httpx.HTTPError as e:
|
||||
logger.exception("Generation failed")
|
||||
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
||||
|
||||
images_data = result.get("images", [])
|
||||
info = _parse_info(result.get("info", {}))
|
||||
all_seeds = info.get("all_seeds", [req.seed] * len(images_data))
|
||||
|
||||
output_images = [
|
||||
_process_image(img_b64, i, all_seeds[i] if i < len(all_seeds) else req.seed + i, req, gallery, pm.config.model)
|
||||
for i, img_b64 in enumerate(images_data)
|
||||
]
|
||||
|
||||
return {
|
||||
"images": output_images,
|
||||
"parameters": result.get("parameters", body),
|
||||
"info": info,
|
||||
"saved_to_gallery": req.save_to_gallery,
|
||||
"total": len(output_images),
|
||||
}
|
||||
|
||||
@router.get("/samplers")
|
||||
async def list_samplers() -> dict[str, Any]:
|
||||
"""List available samplers from sd-server."""
|
||||
_check_server_running(pm)
|
||||
assert pm.config is not None
|
||||
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/samplers"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
return {"samplers": response.json()}
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
||||
|
||||
@router.get("/schedulers")
|
||||
async def list_schedulers() -> dict[str, Any]:
|
||||
"""List available schedulers from sd-server."""
|
||||
_check_server_running(pm)
|
||||
assert pm.config is not None
|
||||
url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/schedulers"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
return {"schedulers": response.json()}
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
||||
|
||||
return router
|
||||
@@ -0,0 +1,171 @@
|
||||
"""FastAPI route handlers for model management endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
|
||||
from tensors.config import MODELS_DIR
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensors.server.process import ProcessManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Request/Response Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class SwitchModelRequest(PydanticBaseModel):
|
||||
"""Request body for switching models."""
|
||||
|
||||
model: str # Path to model file
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def scan_models(directory: Path, extensions: tuple[str, ...] = (".safetensors", ".gguf")) -> list[dict[str, Any]]:
|
||||
"""Scan directory for model files."""
|
||||
models: list[dict[str, Any]] = []
|
||||
|
||||
if not directory.exists():
|
||||
return models
|
||||
|
||||
for ext in extensions:
|
||||
for path in directory.rglob(f"*{ext}"):
|
||||
stat = path.stat()
|
||||
models.append(
|
||||
{
|
||||
"name": path.stem,
|
||||
"path": str(path),
|
||||
"filename": path.name,
|
||||
"size_mb": round(stat.st_size / (1024 * 1024), 2),
|
||||
"modified": stat.st_mtime,
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by name
|
||||
models.sort(key=lambda x: x["name"].lower())
|
||||
return models
|
||||
|
||||
|
||||
def scan_loras(directory: Path | None = None) -> list[dict[str, Any]]:
|
||||
"""Scan for LoRA files."""
|
||||
lora_dir = directory or MODELS_DIR / "loras"
|
||||
return scan_models(lora_dir, extensions=(".safetensors",))
|
||||
|
||||
|
||||
def scan_checkpoints(directory: Path | None = None) -> list[dict[str, Any]]:
|
||||
"""Scan for checkpoint files."""
|
||||
checkpoint_dir = directory or MODELS_DIR / "checkpoints"
|
||||
return scan_models(checkpoint_dir, extensions=(".safetensors", ".gguf"))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Router Factory
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_models_router(pm: ProcessManager) -> APIRouter:
|
||||
"""Build a router with /api/models/* endpoints."""
|
||||
router = APIRouter(prefix="/api/models", tags=["models"])
|
||||
|
||||
@router.get("")
|
||||
def list_models() -> dict[str, Any]:
|
||||
"""List available checkpoint models."""
|
||||
checkpoints = scan_checkpoints()
|
||||
return {
|
||||
"models": checkpoints,
|
||||
"total": len(checkpoints),
|
||||
}
|
||||
|
||||
@router.get("/active")
|
||||
def get_active_model() -> dict[str, Any]:
|
||||
"""Get information about the currently loaded model."""
|
||||
status = pm.status()
|
||||
config = pm.config
|
||||
|
||||
if config is None:
|
||||
return {
|
||||
"loaded": False,
|
||||
"model": None,
|
||||
"status": status.get("status"),
|
||||
}
|
||||
|
||||
return {
|
||||
"loaded": status.get("status") == "running",
|
||||
"model": config.model,
|
||||
"pid": status.get("pid"),
|
||||
"port": config.port,
|
||||
"status": status.get("status"),
|
||||
}
|
||||
|
||||
@router.post("/switch")
|
||||
async def switch_model(req: SwitchModelRequest) -> JSONResponse:
|
||||
"""Switch to a different model (hot reload)."""
|
||||
model_path = Path(req.model)
|
||||
|
||||
# Validate model exists
|
||||
if not model_path.exists():
|
||||
raise HTTPException(status_code=400, detail=f"Model not found: {req.model}")
|
||||
|
||||
# Use existing reload logic
|
||||
from tensors.server.models import ServerConfig # noqa: PLC0415
|
||||
|
||||
new_config = ServerConfig(
|
||||
model=req.model,
|
||||
port=pm.config.port if pm.config else 1234,
|
||||
args=pm.config.args if pm.config else [],
|
||||
)
|
||||
|
||||
pm.stop()
|
||||
pm.start(new_config)
|
||||
ready = await pm.wait_ready()
|
||||
|
||||
if not ready:
|
||||
return JSONResponse(
|
||||
{"error": "sd-server failed to become ready", "model": req.model},
|
||||
status_code=503,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"ok": True,
|
||||
"model": req.model,
|
||||
"pid": pm.proc.pid if pm.proc else None,
|
||||
}
|
||||
)
|
||||
|
||||
@router.get("/loras")
|
||||
def list_loras() -> dict[str, Any]:
|
||||
"""List available LoRA files."""
|
||||
loras = scan_loras()
|
||||
return {
|
||||
"loras": loras,
|
||||
"total": len(loras),
|
||||
}
|
||||
|
||||
@router.get("/scan")
|
||||
def scan_all_models() -> dict[str, Any]:
|
||||
"""Scan all model directories."""
|
||||
checkpoints = scan_checkpoints()
|
||||
loras = scan_loras()
|
||||
|
||||
return {
|
||||
"checkpoints": checkpoints,
|
||||
"loras": loras,
|
||||
"total_checkpoints": len(checkpoints),
|
||||
"total_loras": len(loras),
|
||||
}
|
||||
|
||||
return router
|
||||
Reference in New Issue
Block a user