💬 Commit message: Update 2026-02-15 06:15:02, 11 files, 1240 lines

📁 Files changed: 11
📝 Lines changed: 1240

  • client.py
  • __init__.py
  • _http.py
  • generation.py
  • info.py
  • params.py
  • util.py
  • generate_routes.py
  • models_routes.py
  • routes.py
  • sd_client.py
This commit is contained in:
Adam Ladachowski
2026-02-15 06:15:02 +01:00
parent e016c01370
commit c419e443ae
11 changed files with 0 additions and 1240 deletions
-292
View File
@@ -1,292 +0,0 @@
"""HTTP client for remote tsr server API."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import httpx
if TYPE_CHECKING:
from collections.abc import Iterator
from pathlib import Path
class TsrClientError(Exception):
"""Error from TsrClient operations."""
class TsrClient:
"""HTTP client wrapper for tsr server API.
Usage:
with TsrClient("http://junkpile:8080") as client:
images = client.list_images()
result = client.generate("a cat")
"""
def __init__(self, base_url: str, timeout: float = 300.0) -> None:
"""Initialize client with server URL."""
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self._client: httpx.Client | None = None
def __enter__(self) -> TsrClient:
self._client = httpx.Client(base_url=self.base_url, timeout=self.timeout)
return self
def __exit__(self, *exc: object) -> None:
if self._client:
self._client.close()
self._client = None
@property
def client(self) -> httpx.Client:
"""Get the HTTP client, creating if needed."""
if self._client is None:
self._client = httpx.Client(base_url=self.base_url, timeout=self.timeout)
return self._client
def _get(self, path: str, params: dict[str, Any] | None = None) -> Any:
"""Make GET request."""
try:
resp = self.client.get(path, params=params)
resp.raise_for_status()
return resp.json()
except httpx.HTTPStatusError as e:
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
except httpx.RequestError as e:
raise TsrClientError(f"Request failed: {e}") from e
def _post(self, path: str, json: dict[str, Any] | None = None) -> Any:
"""Make POST request."""
try:
resp = self.client.post(path, json=json)
resp.raise_for_status()
return resp.json()
except httpx.HTTPStatusError as e:
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
except httpx.RequestError as e:
raise TsrClientError(f"Request failed: {e}") from e
def _delete(self, path: str) -> Any:
"""Make DELETE request."""
try:
resp = self.client.delete(path)
resp.raise_for_status()
return resp.json()
except httpx.HTTPStatusError as e:
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
except httpx.RequestError as e:
raise TsrClientError(f"Request failed: {e}") from e
# =========================================================================
# Server Status
# =========================================================================
def status(self) -> dict[str, Any]:
"""Get server status."""
return dict(self._get("/status"))
# =========================================================================
# Gallery / Images
# =========================================================================
def list_images(self, limit: int = 50, offset: int = 0) -> dict[str, Any]:
"""List images in gallery."""
return dict(self._get("/api/images", params={"limit": limit, "offset": offset}))
def get_image_meta(self, image_id: str) -> dict[str, Any]:
"""Get metadata for an image."""
return dict(self._get(f"/api/images/{image_id}/meta"))
def delete_image(self, image_id: str) -> dict[str, Any]:
"""Delete an image."""
return dict(self._delete(f"/api/images/{image_id}"))
def edit_image(self, image_id: str, updates: dict[str, Any]) -> dict[str, Any]:
"""Update image metadata."""
return dict(self._post(f"/api/images/{image_id}/edit", json=updates))
def download_image(self, image_id: str) -> bytes:
"""Download image file bytes."""
try:
resp = self.client.get(f"/api/images/{image_id}")
resp.raise_for_status()
return resp.content
except httpx.HTTPStatusError as e:
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
except httpx.RequestError as e:
raise TsrClientError(f"Request failed: {e}") from e
# =========================================================================
# Models
# =========================================================================
def list_models(self) -> dict[str, Any]:
"""List available models."""
return dict(self._get("/api/models"))
def get_active_model(self) -> dict[str, Any]:
"""Get currently active model."""
return dict(self._get("/api/models/active"))
def switch_model(self, model_path: str) -> dict[str, Any]:
"""Switch to a different model."""
return dict(self._post("/api/models/switch", json={"model": model_path}))
def list_loras(self) -> dict[str, Any]:
"""List available LoRAs."""
return dict(self._get("/api/models/loras"))
def scan_models(self) -> dict[str, Any]:
"""Scan model directories."""
return dict(self._get("/api/models/scan"))
# =========================================================================
# Generation
# =========================================================================
def generate(
self,
prompt: str,
negative_prompt: str = "",
width: int = 512,
height: int = 512,
steps: int = 20,
cfg_scale: float = 7.0,
seed: int = -1,
sampler_name: str = "",
scheduler: str = "",
batch_size: int = 1,
save_to_gallery: bool = True,
return_base64: bool = False,
) -> dict[str, Any]:
"""Generate images."""
body = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"width": width,
"height": height,
"steps": steps,
"cfg_scale": cfg_scale,
"seed": seed,
"sampler_name": sampler_name,
"scheduler": scheduler,
"batch_size": batch_size,
"save_to_gallery": save_to_gallery,
"return_base64": return_base64,
}
return dict(self._post("/api/generate", json=body))
def list_samplers(self) -> dict[str, Any]:
"""List available samplers."""
return dict(self._get("/api/samplers"))
def list_schedulers(self) -> dict[str, Any]:
"""List available schedulers."""
return dict(self._get("/api/schedulers"))
# =========================================================================
# Download
# =========================================================================
def start_download(
self,
version_id: int | None = None,
model_id: int | None = None,
hash_val: str | None = None,
output_dir: str | None = None,
) -> dict[str, Any]:
"""Start a model download from CivitAI."""
body: dict[str, Any] = {}
if version_id:
body["version_id"] = version_id
if model_id:
body["model_id"] = model_id
if hash_val:
body["hash"] = hash_val
if output_dir:
body["output_dir"] = output_dir
return dict(self._post("/api/download", json=body))
def get_download_status(self, download_id: str) -> dict[str, Any]:
"""Get download status."""
return dict(self._get(f"/api/download/status/{download_id}"))
def list_downloads(self) -> dict[str, Any]:
"""List active downloads."""
return dict(self._get("/api/download/active"))
# =========================================================================
# Database
# =========================================================================
def db_list_files(self) -> list[dict[str, Any]]:
"""List local files in database."""
return list(self._get("/api/db/files"))
def db_search_models(
self,
query: str | None = None,
model_type: str | None = None,
base_model: str | None = None,
limit: int = 20,
) -> list[dict[str, Any]]:
"""Search cached models."""
params: dict[str, Any] = {"limit": limit}
if query:
params["query"] = query
if model_type:
params["type"] = model_type
if base_model:
params["base"] = base_model
return list(self._get("/api/db/models", params=params))
def db_get_model(self, civitai_id: int) -> dict[str, Any]:
"""Get cached model by CivitAI ID."""
return dict(self._get(f"/api/db/models/{civitai_id}"))
def db_get_triggers(self, file_path: str | None = None, version_id: int | None = None) -> list[str]:
"""Get trigger words."""
if version_id:
return list(self._get(f"/api/db/triggers/{version_id}"))
if file_path:
return list(self._get("/api/db/triggers", params={"file_path": file_path}))
return []
def db_stats(self) -> dict[str, Any]:
"""Get database statistics."""
return dict(self._get("/api/db/stats"))
def db_scan(self, directory: str) -> dict[str, Any]:
"""Scan directory for safetensor files."""
return dict(self._post("/api/db/scan", json={"directory": directory}))
def db_link(self) -> dict[str, Any]:
"""Link unlinked files to CivitAI."""
return dict(self._post("/api/db/link"))
def db_cache(self, model_id: int) -> dict[str, Any]:
"""Cache CivitAI model data."""
return dict(self._post("/api/db/cache", json={"model_id": model_id}))
# =========================================================================
# Streaming Downloads
# =========================================================================
def stream_image(self, image_id: str) -> Iterator[bytes]:
"""Stream image download in chunks."""
try:
with self.client.stream("GET", f"/api/images/{image_id}") as resp:
resp.raise_for_status()
yield from resp.iter_bytes(chunk_size=1024 * 64)
except httpx.HTTPStatusError as e:
raise TsrClientError(f"HTTP {e.response.status_code}") from e
except httpx.RequestError as e:
raise TsrClientError(f"Request failed: {e}") from e
def save_image_to(self, image_id: str, dest: Path) -> Path:
"""Download and save image to file."""
content = self.download_image(image_id)
dest.write_bytes(content)
return dest
-43
View File
@@ -1,43 +0,0 @@
"""sd-server Python client — modular, httpx-based."""
from __future__ import annotations
from typing import Any
from tensors.generate._http import HttpTransport
from tensors.generate.generation import GenerationAPI
from tensors.generate.info import InfoAPI
from tensors.generate.params import Img2ImgParams, Txt2ImgParams
from tensors.generate.util import save_images
__all__ = [
"Img2ImgParams",
"SDClient",
"Txt2ImgParams",
"save_images",
]
class SDClient:
"""Composite client for sd-server.
Usage::
with SDClient() as c:
c.info.models()
images = c.generate.txt2img(Txt2ImgParams(prompt="a cat"))
"""
def __init__(self, host: str = "127.0.0.1", port: int = 1234) -> None:
self._http = HttpTransport(f"http://{host}:{port}")
self.info = InfoAPI(self._http)
self.generate = GenerationAPI(self._http)
def close(self) -> None:
self._http.close()
def __enter__(self) -> SDClient:
return self
def __exit__(self, *exc: Any) -> None:
self.close()
-46
View File
@@ -1,46 +0,0 @@
"""HTTP transport layer wrapping httpx."""
from __future__ import annotations
import logging
from typing import Any
import httpx
logger = logging.getLogger(__name__)
class HttpTransport:
def __init__(self, base_url: str, timeout: float = 300.0) -> None:
self._client = httpx.Client(base_url=base_url, timeout=timeout)
logger.debug("transport ready: %s", base_url)
def get(self, path: str) -> Any:
logger.debug("GET %s", path)
try:
r = self._client.get(path)
r.raise_for_status()
except httpx.HTTPStatusError as e:
logger.error("GET %s%d: %s", path, e.response.status_code, e.response.text[:200])
raise
except httpx.RequestError as e:
logger.error("GET %s connection failed: %s", path, e)
raise
return r.json()
def post(self, path: str, json: dict[str, Any]) -> Any:
logger.debug("POST %s", path)
try:
r = self._client.post(path, json=json)
r.raise_for_status()
except httpx.HTTPStatusError as e:
logger.error("POST %s%d: %s", path, e.response.status_code, e.response.text[:200])
raise
except httpx.RequestError as e:
logger.error("POST %s connection failed: %s", path, e)
raise
return r.json()
def close(self) -> None:
self._client.close()
logger.debug("transport closed")
-34
View File
@@ -1,34 +0,0 @@
"""Image generation endpoints."""
from __future__ import annotations
import base64
import logging
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from tensors.generate._http import HttpTransport
from tensors.generate.params import Img2ImgParams, Txt2ImgParams
logger = logging.getLogger(__name__)
class GenerationAPI:
def __init__(self, http: HttpTransport) -> None:
self._http = http
def txt2img(self, params: Txt2ImgParams) -> list[bytes]:
"""Generate images from text prompt."""
logger.info("txt2img: '%s' %dx%d steps=%d", params.prompt[:60], params.width, params.height, params.steps)
data = self._http.post("/sdapi/v1/txt2img", params.to_body())
images = [base64.b64decode(img) for img in data["images"]]
logger.info("txt2img: got %d image(s)", len(images))
return images
def img2img(self, params: Img2ImgParams) -> list[bytes]:
"""Generate images from image + text prompt."""
logger.info("img2img: '%s' strength=%.2f steps=%d", params.prompt[:60], params.denoising_strength, params.steps)
data = self._http.post("/sdapi/v1/img2img", params.to_body())
images = [base64.b64decode(img) for img in data["images"]]
logger.info("img2img: got %d image(s)", len(images))
return images
-42
View File
@@ -1,42 +0,0 @@
"""Model and server info endpoints."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from tensors.generate._http import HttpTransport
logger = logging.getLogger(__name__)
class InfoAPI:
def __init__(self, http: HttpTransport) -> None:
self._http = http
def models(self) -> list[dict[str, Any]]:
"""List loaded models (OpenAI /v1/models)."""
return self._http.get("/v1/models")["data"] # type: ignore[no-any-return]
def sd_models(self) -> list[dict[str, Any]]:
"""Detailed model info (sdapi)."""
return self._http.get("/sdapi/v1/sd-models") # type: ignore[no-any-return]
def options(self) -> dict[str, Any]:
"""Current server options."""
return self._http.get("/sdapi/v1/options") # type: ignore[no-any-return]
def loras(self) -> list[dict[str, Any]]:
"""Available LoRAs from --lora-model-dir."""
result: list[dict[str, Any]] = self._http.get("/sdapi/v1/loras")
logger.info("found %d lora(s)", len(result))
return result
def samplers(self) -> list[str]:
"""Available sampler names."""
return [s["name"] for s in self._http.get("/sdapi/v1/samplers")]
def schedulers(self) -> list[str]:
"""Available scheduler names."""
return [s["name"] for s in self._http.get("/sdapi/v1/schedulers")]
-100
View File
@@ -1,100 +0,0 @@
"""Generation parameter dataclasses."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from pathlib import Path
from tensors.generate.util import to_b64
@dataclass
class Txt2ImgParams:
prompt: str
negative_prompt: str = ""
width: int = 512
height: int = 512
steps: int = 20
cfg_scale: float = 7.0
seed: int = -1
batch_size: int = 1
sampler_name: str = ""
scheduler: str = ""
clip_skip: int = -1
lora: list[dict[str, Any]] | None = None
def to_body(self) -> dict[str, Any]:
body = {
"prompt": self.prompt,
"negative_prompt": self.negative_prompt,
"width": self.width,
"height": self.height,
"steps": self.steps,
"cfg_scale": self.cfg_scale,
"seed": self.seed,
"batch_size": self.batch_size,
}
if self.sampler_name:
body["sampler_name"] = self.sampler_name
if self.scheduler:
body["scheduler"] = self.scheduler
if self.clip_skip > 0:
body["clip_skip"] = self.clip_skip
if self.lora:
body["lora"] = self.lora
return body
@dataclass
class Img2ImgParams:
prompt: str
init_image: str | bytes | Path
negative_prompt: str = ""
width: int = -1
height: int = -1
steps: int = 20
cfg_scale: float = 7.0
denoising_strength: float = 0.75
seed: int = -1
batch_size: int = 1
sampler_name: str = ""
scheduler: str = ""
clip_skip: int = -1
mask: str | bytes | Path | None = None
inpainting_mask_invert: bool = False
lora: list[dict[str, Any]] | None = None
extra_images: list[str | bytes | Path] = field(default_factory=list)
def to_body(self) -> dict[str, Any]:
body: dict[str, Any] = {
"prompt": self.prompt,
"negative_prompt": self.negative_prompt,
"steps": self.steps,
"cfg_scale": self.cfg_scale,
"denoising_strength": self.denoising_strength,
"seed": self.seed,
"batch_size": self.batch_size,
"init_images": [to_b64(self.init_image)],
}
if self.width > 0:
body["width"] = self.width
if self.height > 0:
body["height"] = self.height
if self.mask is not None:
body["mask"] = to_b64(self.mask)
if self.inpainting_mask_invert:
body["inpainting_mask_invert"] = 1
if self.sampler_name:
body["sampler_name"] = self.sampler_name
if self.scheduler:
body["scheduler"] = self.scheduler
if self.clip_skip > 0:
body["clip_skip"] = self.clip_skip
if self.lora:
body["lora"] = self.lora
if self.extra_images:
body["extra_images"] = [to_b64(img) for img in self.extra_images]
return body
-37
View File
@@ -1,37 +0,0 @@
"""Utility functions for image encoding and file I/O."""
import base64
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
def to_b64(image: str | bytes | Path) -> str:
"""Convert a file path, raw bytes, or base64 string to base64."""
if isinstance(image, (str, Path)):
path = Path(image)
if path.exists():
logger.debug("encoding file: %s", path)
return base64.b64encode(path.read_bytes()).decode()
return str(image)
if isinstance(image, bytes):
return base64.b64encode(image).decode()
raise TypeError(f"unsupported image type: {type(image)}")
def save_images(
images: list[bytes],
output_dir: str = ".",
prefix: str = "output",
) -> list[Path]:
"""Write raw PNG bytes to numbered files. Returns saved paths."""
out = Path(output_dir)
out.mkdir(parents=True, exist_ok=True)
paths = []
for i, data in enumerate(images):
path = out / f"{prefix}_{i:04d}.png"
path.write_bytes(data)
logger.info("saved: %s", path)
paths.append(path)
return paths
-216
View File
@@ -1,216 +0,0 @@
"""FastAPI route handlers for image generation with gallery integration."""
from __future__ import annotations
import base64
import logging
import time
from pathlib import Path
from typing import Any
import httpx
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel as PydanticBaseModel
from pydantic import Field
from tensors.server.gallery import Gallery
from tensors.server.sd_client import get_sd_headers
logger = logging.getLogger(__name__)
# =============================================================================
# Request/Response Models
# =============================================================================
class LoraConfig(PydanticBaseModel):
"""LoRA configuration for sd-server."""
path: str
multiplier: float = Field(default=1.0, ge=0.0, le=2.0)
class GenerateRequest(PydanticBaseModel):
"""Request body for image generation."""
prompt: str
negative_prompt: str = ""
width: int = Field(default=512, ge=64, le=2048)
height: int = Field(default=512, ge=64, le=2048)
steps: int = Field(default=20, ge=1, le=150)
cfg_scale: float = Field(default=7.0, ge=0, le=30)
seed: int = -1
sampler_name: str = ""
scheduler: str = ""
batch_size: int = Field(default=1, ge=1, le=16)
save_to_gallery: bool = True
return_base64: bool = False
lora: LoraConfig | None = None
# =============================================================================
# Helper Functions
# =============================================================================
def _build_sd_request(req: GenerateRequest) -> dict[str, Any]:
"""Build request body for sd-server."""
prompt = req.prompt
# sd-server expects LoRA in prompt as <lora:name:weight> syntax
if req.lora:
lora_name = Path(req.lora.path).stem
lora_tag = f"<lora:{lora_name}:{req.lora.multiplier}>"
prompt = f"{prompt} {lora_tag}"
body: dict[str, Any] = {
"prompt": prompt,
"negative_prompt": req.negative_prompt,
"width": req.width,
"height": req.height,
"steps": req.steps,
"cfg_scale": req.cfg_scale,
"seed": req.seed,
"batch_size": req.batch_size,
}
if req.sampler_name:
body["sampler_name"] = req.sampler_name
if req.scheduler:
body["scheduler"] = req.scheduler
return body
def _parse_info(info: Any) -> dict[str, Any]:
"""Parse info from sd-server response."""
if isinstance(info, str):
import json # noqa: PLC0415
try:
return dict(json.loads(info))
except json.JSONDecodeError:
return {"raw": info}
return info if isinstance(info, dict) else {}
def _process_image(
img_b64: str,
index: int,
seed: int,
req: GenerateRequest,
gallery: Gallery,
model: str | None,
) -> dict[str, Any]:
"""Process a single generated image."""
image_bytes = base64.b64decode(img_b64)
image_info: dict[str, Any] = {"index": index, "seed": seed}
if req.save_to_gallery:
metadata = {
"prompt": req.prompt,
"negative_prompt": req.negative_prompt,
"width": req.width,
"height": req.height,
"steps": req.steps,
"cfg_scale": req.cfg_scale,
"sampler": req.sampler_name,
"scheduler": req.scheduler,
"model": model,
"lora": req.lora.path if req.lora else None,
"lora_weight": req.lora.multiplier if req.lora else None,
"generated_at": time.time(),
}
gallery_img = gallery.save_image(image_bytes, metadata=metadata, seed=seed)
image_info["id"] = gallery_img.id
image_info["path"] = str(gallery_img.path)
if req.return_base64:
image_info["base64"] = img_b64
return image_info
# =============================================================================
# Router Factory
# =============================================================================
def create_generate_router() -> APIRouter: # noqa: PLR0915
"""Build a router with /api/generate endpoint."""
router = APIRouter(prefix="/api", tags=["generate"])
gallery = Gallery()
@router.post("/generate")
async def generate(request: Request, req: GenerateRequest) -> dict[str, Any]:
"""Generate images with gallery integration."""
sd_server_url = request.app.state.sd_server_url
body = _build_sd_request(req)
url = f"{sd_server_url}/sdapi/v1/txt2img"
try:
headers = get_sd_headers(request)
async with httpx.AsyncClient(timeout=300) as client:
response = await client.post(url, json=body, headers=headers)
response.raise_for_status()
result = response.json()
except httpx.ConnectError as e:
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
except httpx.HTTPError as e:
logger.exception("Generation failed")
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
images_data = result.get("images", [])
info = _parse_info(result.get("info", {}))
all_seeds = info.get("all_seeds", [req.seed] * len(images_data))
# Get model info from sd-server response if available
model_name = info.get("sd_model_name") or info.get("model")
output_images = [
_process_image(img_b64, i, all_seeds[i] if i < len(all_seeds) else req.seed + i, req, gallery, model_name)
for i, img_b64 in enumerate(images_data)
]
return {
"images": output_images,
"parameters": result.get("parameters", body),
"info": info,
"saved_to_gallery": req.save_to_gallery,
"total": len(output_images),
}
@router.get("/samplers")
async def list_samplers(request: Request) -> dict[str, Any]:
"""List available samplers from sd-server."""
sd_server_url = request.app.state.sd_server_url
url = f"{sd_server_url}/sdapi/v1/samplers"
try:
headers = get_sd_headers(request)
async with httpx.AsyncClient(timeout=30) as client:
response = await client.get(url, headers=headers)
response.raise_for_status()
return {"samplers": response.json()}
except httpx.ConnectError as e:
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
except httpx.HTTPError as e:
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
@router.get("/schedulers")
async def list_schedulers(request: Request) -> dict[str, Any]:
"""List available schedulers from sd-server."""
sd_server_url = request.app.state.sd_server_url
url = f"{sd_server_url}/sdapi/v1/schedulers"
try:
headers = get_sd_headers(request)
async with httpx.AsyncClient(timeout=30) as client:
response = await client.get(url, headers=headers)
response.raise_for_status()
return {"schedulers": response.json()}
except httpx.ConnectError as e:
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
except httpx.HTTPError as e:
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
return router
-311
View File
@@ -1,311 +0,0 @@
"""FastAPI route handlers for model management endpoints."""
from __future__ import annotations
import asyncio
import logging
from pathlib import Path
from typing import Any
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel
from tensors.config import MODELS_DIR
from tensors.db import Database
from tensors.server.sd_client import get_sd_headers
logger = logging.getLogger(__name__)
_HTTP_OK = 200
_SD_ENV_FILE = Path("/etc/default/sd-server")
class SwitchModelRequest(BaseModel):
"""Request body for switching models."""
model: str # Model filename or full path
async def _run_command(*args: str) -> tuple[int, str, str]:
"""Run a shell command and return (returncode, stdout, stderr)."""
proc = await asyncio.create_subprocess_exec(
*args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
return proc.returncode or 0, stdout.decode(), stderr.decode()
def _read_env_file() -> dict[str, str]:
"""Read the sd-server environment file."""
env: dict[str, str] = {}
if _SD_ENV_FILE.exists():
for raw_line in _SD_ENV_FILE.read_text().splitlines():
line = raw_line.strip()
if line and not line.startswith("#") and "=" in line:
key, _, value = line.partition("=")
env[key.strip()] = value.strip()
return env
def _write_env_file(env: dict[str, str]) -> str:
"""Generate env file content."""
lines = ["# sd-server configuration"]
for key, value in env.items():
lines.append(f"{key}={value}")
return "\n".join(lines) + "\n"
# Keywords for detecting base model category
_SD15_KEYWORDS = ("sd15", "sd1.5", "sd-1.5", "sd_1.5", "1.5", "sd-1-", "v1-5")
_LARGE_KEYWORDS = ("sdxl", "xl", "pony", "illustrious", "ilust", "noob", "animagine")
def _detect_model_category(name: str) -> str:
"""Detect model category from filename. Returns 'sd15' or 'large'."""
name_lower = name.lower()
# Check SD 1.5 keywords first
for kw in _SD15_KEYWORDS:
if kw in name_lower:
return "sd15"
# Check large model keywords
for kw in _LARGE_KEYWORDS:
if kw in name_lower:
return "large"
# Default to large (SDXL/Pony/Illustrious are more common now)
return "large"
# =============================================================================
# Helper Functions
# =============================================================================
def scan_models(directory: Path, extensions: tuple[str, ...] = (".safetensors", ".gguf")) -> list[dict[str, Any]]:
"""Scan directory for model files."""
models: list[dict[str, Any]] = []
if not directory.exists():
return models
for ext in extensions:
for path in directory.rglob(f"*{ext}"):
stat = path.stat()
name = path.stem
models.append(
{
"name": name,
"path": str(path),
"filename": path.name,
"size_mb": round(stat.st_size / (1024 * 1024), 2),
"modified": stat.st_mtime,
"category": _detect_model_category(name),
}
)
# Sort by name
models.sort(key=lambda x: x["name"].lower())
return models
def _enrich_with_metadata(models: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Enrich model data with CivitAI metadata from database."""
try:
with Database() as db:
db.init_schema()
for model in models:
file_path = model.get("path", "")
file_info = db.get_local_file_by_path(file_path)
if file_info and file_info.get("civitai_model_id"):
# Add human-readable name
model["display_name"] = file_info.get("model_name") or model["name"]
model["base_model"] = file_info.get("base_model")
model["model_type"] = file_info.get("model_type")
model["civitai_model_id"] = file_info.get("civitai_model_id")
model["civitai_version_id"] = file_info.get("civitai_version_id")
# Get thumbnail from version images
version_id = file_info.get("civitai_version_id")
if version_id:
cur = db.conn.cursor()
cur.execute(
"""
SELECT url FROM version_images
WHERE version_id = (SELECT id FROM model_versions WHERE civitai_id = ?)
ORDER BY id LIMIT 1
""",
(version_id,),
)
row = cur.fetchone()
if row:
model["thumbnail_url"] = row[0]
# Get trigger words
triggers = db.get_triggers_by_version(version_id) if version_id else []
model["triggers"] = triggers[:5] # Limit to first 5
else:
model["display_name"] = model["name"]
except Exception as e:
logger.warning("Failed to enrich models with metadata: %s", e)
# Fallback: just use filename as display name
for model in models:
if "display_name" not in model:
model["display_name"] = model["name"]
return models
def scan_loras(directory: Path | None = None) -> list[dict[str, Any]]:
"""Scan for LoRA files."""
lora_dir = directory or MODELS_DIR / "loras"
return scan_models(lora_dir, extensions=(".safetensors", ".gguf"))
def scan_checkpoints(directory: Path | None = None) -> list[dict[str, Any]]:
"""Scan for checkpoint files."""
checkpoint_dir = directory or MODELS_DIR / "checkpoints"
return scan_models(checkpoint_dir, extensions=(".safetensors", ".gguf"))
# =============================================================================
# Router Factory
# =============================================================================
def create_models_router() -> APIRouter: # noqa: PLR0915
"""Build a router with /api/models/* endpoints."""
router = APIRouter(prefix="/api/models", tags=["models"])
@router.get("")
def list_models() -> dict[str, Any]:
"""List available checkpoint models with metadata."""
checkpoints = scan_checkpoints()
checkpoints = _enrich_with_metadata(checkpoints)
return {
"models": checkpoints,
"total": len(checkpoints),
}
@router.get("/active")
async def get_active_model(request: Request) -> dict[str, Any]:
"""Get information about the currently loaded model from sd-server."""
import httpx # noqa: PLC0415
sd_server_url = request.app.state.sd_server_url
# Try to get current model from sd-server's options endpoint
try:
headers = get_sd_headers(request)
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get(f"{sd_server_url}/sdapi/v1/options", headers=headers)
if response.status_code == _HTTP_OK:
options = response.json()
model_name = options.get("sd_model_checkpoint")
return {
"loaded": True,
"model": model_name,
"sd_server_url": sd_server_url,
}
except httpx.HTTPError:
pass
return {
"loaded": False,
"model": None,
"sd_server_url": sd_server_url,
"error": "Cannot connect to sd-server",
}
@router.get("/loras")
def list_loras() -> dict[str, Any]:
"""List available LoRA files with metadata."""
loras = scan_loras()
loras = _enrich_with_metadata(loras)
return {
"loras": loras,
"total": len(loras),
}
@router.get("/scan")
def scan_all_models() -> dict[str, Any]:
"""Scan all model directories."""
checkpoints = scan_checkpoints()
loras = scan_loras()
return {
"checkpoints": checkpoints,
"loras": loras,
"total_checkpoints": len(checkpoints),
"total_loras": len(loras),
}
@router.post("/switch")
async def switch_model(req: SwitchModelRequest) -> dict[str, Any]:
"""Switch sd-server to a different model by updating env and restarting."""
# Find the model file
checkpoints = scan_checkpoints()
model_path: str | None = None
for cp in checkpoints:
if cp["filename"] == req.model or cp["path"] == req.model or cp["name"] == req.model:
model_path = cp["path"]
break
if not model_path:
raise HTTPException(status_code=404, detail=f"Model not found: {req.model}")
# Read current env, update SD_MODEL
env = _read_env_file()
old_model = env.get("SD_MODEL", "")
env["SD_MODEL"] = model_path
# Write new env file via sudo tee
new_content = _write_env_file(env)
proc = await asyncio.create_subprocess_exec(
"sudo", "tee", str(_SD_ENV_FILE),
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
_, tee_stderr = await proc.communicate(new_content.encode())
if proc.returncode != 0:
raise HTTPException(status_code=500, detail=f"Failed to write env file: {tee_stderr.decode()}")
# Restart sd-server
returncode, _stdout, restart_stderr = await _run_command("sudo", "systemctl", "restart", "sd-server")
if returncode != 0:
raise HTTPException(status_code=500, detail=f"Failed to restart sd-server: {restart_stderr}")
logger.info(f"Switched model from {old_model} to {model_path}")
return {
"ok": True,
"old_model": old_model,
"new_model": model_path,
"message": "Model switched, sd-server restarting",
}
@router.get("/status")
async def sd_server_status() -> dict[str, Any]:
"""Get sd-server systemd service status."""
_returncode, stdout, _stderr = await _run_command("systemctl", "is-active", "sd-server")
is_active = stdout.strip() == "active"
env = _read_env_file()
return {
"service": "sd-server",
"active": is_active,
"status": stdout.strip(),
"current_model": env.get("SD_MODEL"),
"host": env.get("SD_HOST"),
"port": env.get("SD_PORT"),
}
return router
-80
View File
@@ -1,80 +0,0 @@
"""FastAPI route handlers for the sd-server wrapper API."""
from __future__ import annotations
import logging
from typing import Any
import httpx
from fastapi import APIRouter, Request, Response
from fastapi.responses import JSONResponse, StreamingResponse
from tensors.server.sd_client import get_sd_headers
logger = logging.getLogger(__name__)
def create_router() -> APIRouter:
"""Build a router with /status and catch-all proxy."""
router = APIRouter()
@router.get("/status")
async def status(request: Request) -> dict[str, Any]:
"""Check if the external sd-server is reachable."""
sd_server_url = request.app.state.sd_server_url
try:
headers = get_sd_headers(request)
async with httpx.AsyncClient(timeout=5) as client:
r = await client.get(sd_server_url, headers=headers)
return {
"status": "ok",
"sd_server_url": sd_server_url,
"sd_server_status": r.status_code,
}
except httpx.HTTPError as e:
return {
"status": "error",
"sd_server_url": sd_server_url,
"error": str(e),
}
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"])
async def proxy(request: Request, path: str) -> Response:
"""Proxy all requests to the external sd-server."""
sd_server_url = request.app.state.sd_server_url
url = f"{sd_server_url}/{path}"
if request.url.query:
url = f"{url}?{request.url.query}"
body = await request.body()
headers = dict(request.headers)
headers.pop("host", None)
# Add API key if configured
headers.update(get_sd_headers(request))
client = request.app.state.client
try:
upstream = await client.request(
method=request.method,
url=url,
headers=headers,
content=body,
timeout=300,
)
return StreamingResponse(
content=upstream.iter_bytes(),
status_code=upstream.status_code,
headers=dict(upstream.headers),
)
except httpx.ConnectError:
return JSONResponse(
{"error": f"Cannot connect to sd-server at {sd_server_url}"},
status_code=503,
)
except httpx.TimeoutException:
return JSONResponse(
{"error": f"Timeout connecting to sd-server at {sd_server_url}"},
status_code=504,
)
return router
-39
View File
@@ -1,39 +0,0 @@
"""HTTP client utilities for sd-server communication."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import httpx
if TYPE_CHECKING:
from fastapi import Request
def get_sd_headers(request: Request) -> dict[str, str]:
"""Get headers for sd-server requests, including API key if configured."""
headers: dict[str, str] = {}
api_key = getattr(request.app.state, "sd_server_api_key", None)
if api_key:
headers["X-API-Key"] = api_key
return headers
async def sd_get(request: Request, path: str, *, timeout: float = 30) -> httpx.Response:
"""Make a GET request to sd-server."""
url = f"{request.app.state.sd_server_url}/{path.lstrip('/')}"
headers = get_sd_headers(request)
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.get(url, headers=headers)
response.raise_for_status()
return response
async def sd_post(request: Request, path: str, *, json: dict[str, Any] | None = None, timeout: float = 300) -> httpx.Response:
"""Make a POST request to sd-server."""
url = f"{request.app.state.sd_server_url}/{path.lstrip('/')}"
headers = get_sd_headers(request)
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(url, json=json, headers=headers)
response.raise_for_status()
return response