diff --git a/TODO.md b/TODO.md index 7337a8b..4192422 100644 --- a/TODO.md +++ b/TODO.md @@ -11,10 +11,10 @@ - [x] Step 2.3: Add `/api/db/*` endpoints (files, models, triggers, scan, link, cache, stats) ## Phase 3: Enhanced Server API -- [ ] Step 3.1: Add `/api/images` gallery endpoints (list, get, delete, edit) -- [ ] Step 3.2: Add `/api/models` endpoints (list, active, switch, loras) -- [ ] Step 3.3: Add `/api/download` endpoint (CivitAI proxy download) -- [ ] Step 3.4: Enhance `/api/generate` (gallery integration, full params) +- [x] Step 3.1: Add `/api/images` gallery endpoints (list, get, delete, edit) +- [x] Step 3.2: Add `/api/models` endpoints (list, active, switch, loras) +- [x] Step 3.3: Add `/api/download` endpoint (CivitAI proxy download) +- [x] Step 3.4: Enhance `/api/generate` (gallery integration, full params) ## Phase 4: Client Mode for tsr CLI - [ ] Step 4.1: Create `tensors/client.py` (TsrClient HTTP wrapper) diff --git a/tensors/config.py b/tensors/config.py index 9dd14bd..c4c530b 100644 --- a/tensors/config.py +++ b/tensors/config.py @@ -20,6 +20,7 @@ CONFIG_FILE = CONFIG_DIR / "config.toml" DATA_DIR = Path(os.environ.get("XDG_DATA_HOME", Path.home() / ".local" / "share")) / "tensors" MODELS_DIR = DATA_DIR / "models" METADATA_DIR = DATA_DIR / "metadata" +GALLERY_DIR = DATA_DIR / "gallery" # Legacy config for migration LEGACY_RC_FILE = Path.home() / ".sftrc" diff --git a/tensors/server/__init__.py b/tensors/server/__init__.py index 6d35cb1..7c4e8dc 100644 --- a/tensors/server/__init__.py +++ b/tensors/server/__init__.py @@ -10,7 +10,11 @@ import httpx from fastapi import FastAPI from tensors.server.db_routes import create_db_router +from tensors.server.download_routes import create_download_router +from tensors.server.gallery_routes import create_gallery_router +from tensors.server.generate_routes import create_generate_router from tensors.server.models import ServerConfig +from tensors.server.models_routes import create_models_router from tensors.server.process import ProcessManager from tensors.server.routes import create_router @@ -42,7 +46,11 @@ def create_app(config: ServerConfig | None = None) -> FastAPI: pm.stop() app = FastAPI(title="sd-server wrapper", lifespan=lifespan) - app.include_router(create_db_router()) # Must be first to avoid catch-all conflict + app.include_router(create_db_router()) # Must be before catch-all proxy + app.include_router(create_gallery_router()) # Must be before catch-all proxy + app.include_router(create_models_router(pm)) # Must be before catch-all proxy + app.include_router(create_download_router()) # Must be before catch-all proxy + app.include_router(create_generate_router(pm)) # Must be before catch-all proxy app.include_router(create_router(pm)) app.state.pm = pm return app diff --git a/tensors/server/download_routes.py b/tensors/server/download_routes.py new file mode 100644 index 0000000..63c371f --- /dev/null +++ b/tensors/server/download_routes.py @@ -0,0 +1,230 @@ +"""FastAPI route handlers for CivitAI download proxy.""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +from fastapi import APIRouter, BackgroundTasks, HTTPException +from pydantic import BaseModel as PydanticBaseModel + +from tensors.api import download_model, fetch_civitai_by_hash, fetch_civitai_model, fetch_civitai_model_version +from tensors.config import MODELS_DIR, load_api_key +from tensors.db import Database + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/download", tags=["download"]) + +# Track active downloads +_active_downloads: dict[str, dict[str, Any]] = {} + + +# ============================================================================= +# Request/Response Models +# ============================================================================= + + +class DownloadRequest(PydanticBaseModel): + """Request body for downloading a model.""" + + version_id: int | None = None + model_id: int | None = None + hash: str | None = None + output_dir: str | None = None # Override default path + + +# ============================================================================= +# Helper Functions +# ============================================================================= + + +def _resolve_version_id( + version_id: int | None, + model_id: int | None, + hash_val: str | None, + api_key: str | None, +) -> tuple[int | None, dict[str, Any] | None]: + """Resolve version ID and get version info.""" + if version_id: + info = fetch_civitai_model_version(version_id, api_key) + return version_id, info + + if hash_val: + info = fetch_civitai_by_hash(hash_val.upper(), api_key) + if info: + return info.get("id"), info + return None, None + + if model_id: + model_data = fetch_civitai_model(model_id, api_key) + if model_data: + versions = model_data.get("modelVersions", []) + if versions: + latest = versions[0] + return latest.get("id"), latest + return None, None + + return None, None + + +def _get_output_dir(version_info: dict[str, Any], override: str | None) -> Path: + """Determine output directory based on model type.""" + if override: + return Path(override) + + model_type = version_info.get("model", {}).get("type", "Checkpoint") + + # Map type to directory + type_dirs = { + "Checkpoint": MODELS_DIR / "checkpoints", + "LORA": MODELS_DIR / "loras", + "LoCon": MODELS_DIR / "loras", + "TextualInversion": MODELS_DIR / "embeddings", + "VAE": MODELS_DIR / "vae", + "Controlnet": MODELS_DIR / "controlnet", + } + + return type_dirs.get(model_type, MODELS_DIR / "other") + + +def _do_download( + version_id: int, + dest_path: Path, + api_key: str | None, + download_id: str, +) -> None: + """Background task to perform the download.""" + try: + _active_downloads[download_id]["status"] = "downloading" + + # Create a mock console for download progress + from io import StringIO # noqa: PLC0415 + + from rich.console import Console # noqa: PLC0415 + + output = StringIO() + console = Console(file=output, force_terminal=False) + + success = download_model(version_id, dest_path, api_key, console, resume=True) + + if success: + _active_downloads[download_id]["status"] = "completed" + _active_downloads[download_id]["path"] = str(dest_path) + + # Auto-scan and link the downloaded file + _auto_link_file(dest_path, api_key) + else: + _active_downloads[download_id]["status"] = "failed" + _active_downloads[download_id]["error"] = "Download failed" + + except Exception as e: + logger.exception("Download failed") + _active_downloads[download_id]["status"] = "failed" + _active_downloads[download_id]["error"] = str(e) + + +def _auto_link_file(file_path: Path, api_key: str | None) -> None: + """Auto-scan and link the downloaded file to CivitAI.""" + try: + with Database() as db: + db.init_schema() + # Scan the single file + results = db.scan_directory(file_path.parent) + + # Find and link the new file + for result in results: + if result["file_path"] == str(file_path): + sha256 = result["sha256"] + civitai_data = fetch_civitai_by_hash(sha256, api_key) + if civitai_data: + version_id = civitai_data.get("id", 0) + model_id = civitai_data.get("modelId", 0) + if version_id and model_id: + db.link_file_to_civitai(result["id"], model_id, version_id) + except Exception: + logger.exception("Auto-link failed") + + +# ============================================================================= +# Download Endpoints +# ============================================================================= + + +@router.post("") +def start_download(req: DownloadRequest, background_tasks: BackgroundTasks) -> dict[str, Any]: + """Start a model download (async with progress tracking).""" + api_key = load_api_key() + + # Resolve version ID + version_id, version_info = _resolve_version_id( + req.version_id, + req.model_id, + req.hash, + api_key, + ) + + if not version_id or not version_info: + raise HTTPException(status_code=404, detail="Model/version not found on CivitAI") + + # Get output directory + output_dir = _get_output_dir(version_info, req.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Get filename from version info + files = version_info.get("files", []) + primary_file = next((f for f in files if f.get("primary")), files[0] if files else None) + + if not primary_file: + raise HTTPException(status_code=400, detail="No files found for this version") + + filename = primary_file.get("name", f"model-{version_id}.safetensors") + dest_path = output_dir / filename + + # Create download tracking entry + download_id = f"{version_id}_{int(__import__('time').time())}" + _active_downloads[download_id] = { + "id": download_id, + "version_id": version_id, + "status": "queued", + "path": str(dest_path), + "filename": filename, + "model_name": version_info.get("model", {}).get("name", "Unknown"), + "version_name": version_info.get("name", "Unknown"), + } + + # Start background download + background_tasks.add_task(_do_download, version_id, dest_path, api_key, download_id) + + return { + "download_id": download_id, + "status": "queued", + "version_id": version_id, + "destination": str(dest_path), + "model_name": version_info.get("model", {}).get("name"), + "version_name": version_info.get("name"), + } + + +@router.get("/status/{download_id}") +def get_download_status(download_id: str) -> dict[str, Any]: + """Get status of a download.""" + if download_id not in _active_downloads: + raise HTTPException(status_code=404, detail="Download not found") + + return _active_downloads[download_id] + + +@router.get("/active") +def list_active_downloads() -> dict[str, Any]: + """List all active/recent downloads.""" + return { + "downloads": list(_active_downloads.values()), + "total": len(_active_downloads), + } + + +def create_download_router() -> APIRouter: + """Return the download API router.""" + return router diff --git a/tensors/server/gallery.py b/tensors/server/gallery.py new file mode 100644 index 0000000..b3b11f8 --- /dev/null +++ b/tensors/server/gallery.py @@ -0,0 +1,183 @@ +"""Image gallery management for generated images.""" + +from __future__ import annotations + +import json +import time +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from tensors.config import GALLERY_DIR + +if TYPE_CHECKING: + from pathlib import Path + + +@dataclass +class GalleryImage: + """Represents an image in the gallery.""" + + id: str + path: Path + created_at: float + width: int | None = None + height: int | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + @property + def meta_path(self) -> Path: + """Path to the sidecar metadata JSON file.""" + return self.path.with_suffix(".json") + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for API response.""" + return { + "id": self.id, + "path": str(self.path), + "filename": self.path.name, + "created_at": self.created_at, + "width": self.width, + "height": self.height, + "has_metadata": self.meta_path.exists(), + } + + +class Gallery: + """Manages the image gallery directory.""" + + def __init__(self, gallery_dir: Path | None = None) -> None: + """Initialize gallery with directory path.""" + self.gallery_dir = gallery_dir or GALLERY_DIR + self.gallery_dir.mkdir(parents=True, exist_ok=True) + + def list_images( + self, + limit: int = 50, + offset: int = 0, + newest_first: bool = True, + ) -> list[GalleryImage]: + """List images in the gallery, paginated.""" + images: list[GalleryImage] = [] + + for path in self.gallery_dir.glob("*.png"): + img = self._load_image(path) + if img: + images.append(img) + + # Sort by creation time + images.sort(key=lambda x: x.created_at, reverse=newest_first) + + # Apply pagination + return images[offset : offset + limit] + + def get_image(self, image_id: str) -> GalleryImage | None: + """Get an image by ID.""" + # ID is the filename stem + path = self.gallery_dir / f"{image_id}.png" + if not path.exists(): + return None + return self._load_image(path) + + def get_metadata(self, image_id: str) -> dict[str, Any] | None: + """Get metadata for an image.""" + meta_path = self.gallery_dir / f"{image_id}.json" + if not meta_path.exists(): + return None + result: dict[str, Any] = json.loads(meta_path.read_text()) + return result + + def update_metadata(self, image_id: str, updates: dict[str, Any]) -> dict[str, Any] | None: + """Update metadata for an image (merge with existing).""" + meta_path = self.gallery_dir / f"{image_id}.json" + img_path = self.gallery_dir / f"{image_id}.png" + + if not img_path.exists(): + return None + + # Load existing or create new + metadata = json.loads(meta_path.read_text()) if meta_path.exists() else {} + + # Merge updates + metadata.update(updates) + metadata["updated_at"] = time.time() + + # Save + meta_path.write_text(json.dumps(metadata, indent=2)) + return metadata + + def delete_image(self, image_id: str) -> bool: + """Delete an image and its metadata.""" + img_path = self.gallery_dir / f"{image_id}.png" + meta_path = self.gallery_dir / f"{image_id}.json" + + if not img_path.exists(): + return False + + img_path.unlink() + if meta_path.exists(): + meta_path.unlink() + + return True + + def save_image( + self, + image_data: bytes, + metadata: dict[str, Any] | None = None, + seed: int | None = None, + ) -> GalleryImage: + """Save an image to the gallery with optional metadata.""" + timestamp = int(time.time() * 1000) # milliseconds + seed_str = str(seed) if seed is not None else "0" + image_id = f"{timestamp}_{seed_str}" + + img_path = self.gallery_dir / f"{image_id}.png" + img_path.write_bytes(image_data) + + # Save metadata if provided + if metadata: + meta = metadata.copy() + meta["created_at"] = time.time() + meta["seed"] = seed + meta_path = img_path.with_suffix(".json") + meta_path.write_text(json.dumps(meta, indent=2)) + + return self._load_image(img_path) or GalleryImage( + id=image_id, + path=img_path, + created_at=time.time(), + ) + + def _load_image(self, path: Path) -> GalleryImage | None: + """Load image info from path.""" + if not path.exists(): + return None + + image_id = path.stem + stat = path.stat() + + # Try to get dimensions from metadata or PIL + width: int | None = None + height: int | None = None + metadata: dict[str, Any] = {} + + meta_path = path.with_suffix(".json") + if meta_path.exists(): + try: + metadata = json.loads(meta_path.read_text()) + width = metadata.get("width") + height = metadata.get("height") + except json.JSONDecodeError: + pass + + return GalleryImage( + id=image_id, + path=path, + created_at=stat.st_mtime, + width=width, + height=height, + metadata=metadata, + ) + + def count(self) -> int: + """Count total images in gallery.""" + return len(list(self.gallery_dir.glob("*.png"))) diff --git a/tensors/server/gallery_routes.py b/tensors/server/gallery_routes.py new file mode 100644 index 0000000..b875382 --- /dev/null +++ b/tensors/server/gallery_routes.py @@ -0,0 +1,149 @@ +"""FastAPI route handlers for image gallery endpoints.""" + +from __future__ import annotations + +import logging +from typing import Any + +from fastapi import APIRouter, HTTPException, Query +from fastapi.responses import FileResponse +from pydantic import BaseModel as PydanticBaseModel + +from tensors.server.gallery import Gallery + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/images", tags=["gallery"]) + +# Shared gallery instance +_gallery: Gallery | None = None + + +def get_gallery() -> Gallery: + """Get or create the gallery instance.""" + global _gallery # noqa: PLW0603 + if _gallery is None: + _gallery = Gallery() + return _gallery + + +# ============================================================================= +# Request/Response Models +# ============================================================================= + + +class MetadataUpdate(PydanticBaseModel): + """Request body for updating image metadata.""" + + tags: list[str] | None = None + notes: str | None = None + rating: int | None = None + favorite: bool | None = None + + +# ============================================================================= +# Gallery Endpoints +# ============================================================================= + + +@router.get("") +def list_images( + limit: int = Query(default=50, le=200, description="Max images to return"), + offset: int = Query(default=0, ge=0, description="Offset for pagination"), + newest_first: bool = Query(default=True, description="Sort newest first"), +) -> dict[str, Any]: + """List images in the gallery, paginated.""" + gallery = get_gallery() + images = gallery.list_images(limit=limit, offset=offset, newest_first=newest_first) + total = gallery.count() + + return { + "images": [img.to_dict() for img in images], + "total": total, + "limit": limit, + "offset": offset, + } + + +@router.get("/{image_id}") +def get_image(image_id: str) -> FileResponse: + """Get an image file by ID.""" + gallery = get_gallery() + image = gallery.get_image(image_id) + + if not image: + raise HTTPException(status_code=404, detail="Image not found") + + return FileResponse( + path=image.path, + media_type="image/png", + filename=image.path.name, + ) + + +@router.get("/{image_id}/meta") +def get_image_metadata(image_id: str) -> dict[str, Any]: + """Get metadata for an image.""" + gallery = get_gallery() + image = gallery.get_image(image_id) + + if not image: + raise HTTPException(status_code=404, detail="Image not found") + + metadata = gallery.get_metadata(image_id) or {} + return { + "id": image_id, + "path": str(image.path), + "created_at": image.created_at, + "metadata": metadata, + } + + +@router.post("/{image_id}/edit") +def edit_image_metadata(image_id: str, updates: MetadataUpdate) -> dict[str, Any]: + """Update metadata for an image.""" + gallery = get_gallery() + + # Build update dict from non-None values + update_dict: dict[str, Any] = {} + if updates.tags is not None: + update_dict["tags"] = updates.tags + if updates.notes is not None: + update_dict["notes"] = updates.notes + if updates.rating is not None: + update_dict["rating"] = updates.rating + if updates.favorite is not None: + update_dict["favorite"] = updates.favorite + + result = gallery.update_metadata(image_id, update_dict) + if result is None: + raise HTTPException(status_code=404, detail="Image not found") + + return {"id": image_id, "metadata": result} + + +@router.delete("/{image_id}") +def delete_image(image_id: str) -> dict[str, Any]: + """Delete an image and its metadata.""" + gallery = get_gallery() + deleted = gallery.delete_image(image_id) + + if not deleted: + raise HTTPException(status_code=404, detail="Image not found") + + return {"deleted": True, "id": image_id} + + +@router.get("/stats/summary") +def gallery_stats() -> dict[str, Any]: + """Get gallery statistics.""" + gallery = get_gallery() + return { + "total_images": gallery.count(), + "gallery_dir": str(gallery.gallery_dir), + } + + +def create_gallery_router() -> APIRouter: + """Return the gallery API router.""" + return router diff --git a/tensors/server/generate_routes.py b/tensors/server/generate_routes.py new file mode 100644 index 0000000..863c8ed --- /dev/null +++ b/tensors/server/generate_routes.py @@ -0,0 +1,199 @@ +"""FastAPI route handlers for image generation with gallery integration.""" + +from __future__ import annotations + +import base64 +import logging +import time +from typing import TYPE_CHECKING, Any + +import httpx +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel as PydanticBaseModel +from pydantic import Field + +from tensors.server.gallery import Gallery + +if TYPE_CHECKING: + from tensors.server.process import ProcessManager + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Request/Response Models +# ============================================================================= + + +class GenerateRequest(PydanticBaseModel): + """Request body for image generation.""" + + prompt: str + negative_prompt: str = "" + width: int = Field(default=512, ge=64, le=2048) + height: int = Field(default=512, ge=64, le=2048) + steps: int = Field(default=20, ge=1, le=150) + cfg_scale: float = Field(default=7.0, ge=0, le=30) + seed: int = -1 + sampler_name: str = "" + scheduler: str = "" + batch_size: int = Field(default=1, ge=1, le=16) + save_to_gallery: bool = True + return_base64: bool = False + + +# ============================================================================= +# Helper Functions +# ============================================================================= + + +def _build_sd_request(req: GenerateRequest) -> dict[str, Any]: + """Build request body for sd-server.""" + body: dict[str, Any] = { + "prompt": req.prompt, + "negative_prompt": req.negative_prompt, + "width": req.width, + "height": req.height, + "steps": req.steps, + "cfg_scale": req.cfg_scale, + "seed": req.seed, + "batch_size": req.batch_size, + } + if req.sampler_name: + body["sampler_name"] = req.sampler_name + if req.scheduler: + body["scheduler"] = req.scheduler + return body + + +def _parse_info(info: Any) -> dict[str, Any]: + """Parse info from sd-server response.""" + if isinstance(info, str): + import json # noqa: PLC0415 + + try: + return dict(json.loads(info)) + except json.JSONDecodeError: + return {"raw": info} + return info if isinstance(info, dict) else {} + + +def _process_image( + img_b64: str, + index: int, + seed: int, + req: GenerateRequest, + gallery: Gallery, + model: str | None, +) -> dict[str, Any]: + """Process a single generated image.""" + image_bytes = base64.b64decode(img_b64) + image_info: dict[str, Any] = {"index": index, "seed": seed} + + if req.save_to_gallery: + metadata = { + "prompt": req.prompt, + "negative_prompt": req.negative_prompt, + "width": req.width, + "height": req.height, + "steps": req.steps, + "cfg_scale": req.cfg_scale, + "sampler": req.sampler_name, + "scheduler": req.scheduler, + "model": model, + "generated_at": time.time(), + } + gallery_img = gallery.save_image(image_bytes, metadata=metadata, seed=seed) + image_info["id"] = gallery_img.id + image_info["path"] = str(gallery_img.path) + + if req.return_base64: + image_info["base64"] = img_b64 + + return image_info + + +def _check_server_running(pm: ProcessManager) -> None: + """Check if sd-server is running, raise HTTPException if not.""" + if pm.proc is None or pm.proc.poll() is not None: + raise HTTPException(status_code=503, detail="sd-server is not running") + if pm.config is None: + raise HTTPException(status_code=503, detail="sd-server not configured") + + +# ============================================================================= +# Router Factory +# ============================================================================= + + +def create_generate_router(pm: ProcessManager) -> APIRouter: + """Build a router with /api/generate endpoint.""" + router = APIRouter(prefix="/api", tags=["generate"]) + gallery = Gallery() + + @router.post("/generate") + async def generate(req: GenerateRequest) -> dict[str, Any]: + """Generate images with gallery integration.""" + _check_server_running(pm) + assert pm.config is not None # Verified by _check_server_running + + body = _build_sd_request(req) + url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/txt2img" + + try: + async with httpx.AsyncClient(timeout=300) as client: + response = await client.post(url, json=body) + response.raise_for_status() + result = response.json() + except httpx.HTTPError as e: + logger.exception("Generation failed") + raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e + + images_data = result.get("images", []) + info = _parse_info(result.get("info", {})) + all_seeds = info.get("all_seeds", [req.seed] * len(images_data)) + + output_images = [ + _process_image(img_b64, i, all_seeds[i] if i < len(all_seeds) else req.seed + i, req, gallery, pm.config.model) + for i, img_b64 in enumerate(images_data) + ] + + return { + "images": output_images, + "parameters": result.get("parameters", body), + "info": info, + "saved_to_gallery": req.save_to_gallery, + "total": len(output_images), + } + + @router.get("/samplers") + async def list_samplers() -> dict[str, Any]: + """List available samplers from sd-server.""" + _check_server_running(pm) + assert pm.config is not None + url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/samplers" + + try: + async with httpx.AsyncClient(timeout=30) as client: + response = await client.get(url) + response.raise_for_status() + return {"samplers": response.json()} + except httpx.HTTPError as e: + raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e + + @router.get("/schedulers") + async def list_schedulers() -> dict[str, Any]: + """List available schedulers from sd-server.""" + _check_server_running(pm) + assert pm.config is not None + url = f"http://127.0.0.1:{pm.config.port}/sdapi/v1/schedulers" + + try: + async with httpx.AsyncClient(timeout=30) as client: + response = await client.get(url) + response.raise_for_status() + return {"schedulers": response.json()} + except httpx.HTTPError as e: + raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e + + return router diff --git a/tensors/server/models_routes.py b/tensors/server/models_routes.py new file mode 100644 index 0000000..47251c5 --- /dev/null +++ b/tensors/server/models_routes.py @@ -0,0 +1,171 @@ +"""FastAPI route handlers for model management endpoints.""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from fastapi import APIRouter, HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel as PydanticBaseModel + +from tensors.config import MODELS_DIR + +if TYPE_CHECKING: + from tensors.server.process import ProcessManager + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Request/Response Models +# ============================================================================= + + +class SwitchModelRequest(PydanticBaseModel): + """Request body for switching models.""" + + model: str # Path to model file + + +# ============================================================================= +# Helper Functions +# ============================================================================= + + +def scan_models(directory: Path, extensions: tuple[str, ...] = (".safetensors", ".gguf")) -> list[dict[str, Any]]: + """Scan directory for model files.""" + models: list[dict[str, Any]] = [] + + if not directory.exists(): + return models + + for ext in extensions: + for path in directory.rglob(f"*{ext}"): + stat = path.stat() + models.append( + { + "name": path.stem, + "path": str(path), + "filename": path.name, + "size_mb": round(stat.st_size / (1024 * 1024), 2), + "modified": stat.st_mtime, + } + ) + + # Sort by name + models.sort(key=lambda x: x["name"].lower()) + return models + + +def scan_loras(directory: Path | None = None) -> list[dict[str, Any]]: + """Scan for LoRA files.""" + lora_dir = directory or MODELS_DIR / "loras" + return scan_models(lora_dir, extensions=(".safetensors",)) + + +def scan_checkpoints(directory: Path | None = None) -> list[dict[str, Any]]: + """Scan for checkpoint files.""" + checkpoint_dir = directory or MODELS_DIR / "checkpoints" + return scan_models(checkpoint_dir, extensions=(".safetensors", ".gguf")) + + +# ============================================================================= +# Router Factory +# ============================================================================= + + +def create_models_router(pm: ProcessManager) -> APIRouter: + """Build a router with /api/models/* endpoints.""" + router = APIRouter(prefix="/api/models", tags=["models"]) + + @router.get("") + def list_models() -> dict[str, Any]: + """List available checkpoint models.""" + checkpoints = scan_checkpoints() + return { + "models": checkpoints, + "total": len(checkpoints), + } + + @router.get("/active") + def get_active_model() -> dict[str, Any]: + """Get information about the currently loaded model.""" + status = pm.status() + config = pm.config + + if config is None: + return { + "loaded": False, + "model": None, + "status": status.get("status"), + } + + return { + "loaded": status.get("status") == "running", + "model": config.model, + "pid": status.get("pid"), + "port": config.port, + "status": status.get("status"), + } + + @router.post("/switch") + async def switch_model(req: SwitchModelRequest) -> JSONResponse: + """Switch to a different model (hot reload).""" + model_path = Path(req.model) + + # Validate model exists + if not model_path.exists(): + raise HTTPException(status_code=400, detail=f"Model not found: {req.model}") + + # Use existing reload logic + from tensors.server.models import ServerConfig # noqa: PLC0415 + + new_config = ServerConfig( + model=req.model, + port=pm.config.port if pm.config else 1234, + args=pm.config.args if pm.config else [], + ) + + pm.stop() + pm.start(new_config) + ready = await pm.wait_ready() + + if not ready: + return JSONResponse( + {"error": "sd-server failed to become ready", "model": req.model}, + status_code=503, + ) + + return JSONResponse( + { + "ok": True, + "model": req.model, + "pid": pm.proc.pid if pm.proc else None, + } + ) + + @router.get("/loras") + def list_loras() -> dict[str, Any]: + """List available LoRA files.""" + loras = scan_loras() + return { + "loras": loras, + "total": len(loras), + } + + @router.get("/scan") + def scan_all_models() -> dict[str, Any]: + """Scan all model directories.""" + checkpoints = scan_checkpoints() + loras = scan_loras() + + return { + "checkpoints": checkpoints, + "loras": loras, + "total_checkpoints": len(checkpoints), + "total_loras": len(loras), + } + + return router