💬 Commit message: Update 2026-02-15 06:15:02, 11 files, 1240 lines
📁 Files changed: 11 📝 Lines changed: 1240 • client.py • __init__.py • _http.py • generation.py • info.py • params.py • util.py • generate_routes.py • models_routes.py • routes.py • sd_client.py
This commit is contained in:
@@ -1,292 +0,0 @@
|
|||||||
"""HTTP client for remote tsr server API."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from collections.abc import Iterator
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
class TsrClientError(Exception):
|
|
||||||
"""Error from TsrClient operations."""
|
|
||||||
|
|
||||||
|
|
||||||
class TsrClient:
|
|
||||||
"""HTTP client wrapper for tsr server API.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
with TsrClient("http://junkpile:8080") as client:
|
|
||||||
images = client.list_images()
|
|
||||||
result = client.generate("a cat")
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, base_url: str, timeout: float = 300.0) -> None:
|
|
||||||
"""Initialize client with server URL."""
|
|
||||||
self.base_url = base_url.rstrip("/")
|
|
||||||
self.timeout = timeout
|
|
||||||
self._client: httpx.Client | None = None
|
|
||||||
|
|
||||||
def __enter__(self) -> TsrClient:
|
|
||||||
self._client = httpx.Client(base_url=self.base_url, timeout=self.timeout)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *exc: object) -> None:
|
|
||||||
if self._client:
|
|
||||||
self._client.close()
|
|
||||||
self._client = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def client(self) -> httpx.Client:
|
|
||||||
"""Get the HTTP client, creating if needed."""
|
|
||||||
if self._client is None:
|
|
||||||
self._client = httpx.Client(base_url=self.base_url, timeout=self.timeout)
|
|
||||||
return self._client
|
|
||||||
|
|
||||||
def _get(self, path: str, params: dict[str, Any] | None = None) -> Any:
|
|
||||||
"""Make GET request."""
|
|
||||||
try:
|
|
||||||
resp = self.client.get(path, params=params)
|
|
||||||
resp.raise_for_status()
|
|
||||||
return resp.json()
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
raise TsrClientError(f"Request failed: {e}") from e
|
|
||||||
|
|
||||||
def _post(self, path: str, json: dict[str, Any] | None = None) -> Any:
|
|
||||||
"""Make POST request."""
|
|
||||||
try:
|
|
||||||
resp = self.client.post(path, json=json)
|
|
||||||
resp.raise_for_status()
|
|
||||||
return resp.json()
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
raise TsrClientError(f"Request failed: {e}") from e
|
|
||||||
|
|
||||||
def _delete(self, path: str) -> Any:
|
|
||||||
"""Make DELETE request."""
|
|
||||||
try:
|
|
||||||
resp = self.client.delete(path)
|
|
||||||
resp.raise_for_status()
|
|
||||||
return resp.json()
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
raise TsrClientError(f"Request failed: {e}") from e
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Server Status
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def status(self) -> dict[str, Any]:
|
|
||||||
"""Get server status."""
|
|
||||||
return dict(self._get("/status"))
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Gallery / Images
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def list_images(self, limit: int = 50, offset: int = 0) -> dict[str, Any]:
|
|
||||||
"""List images in gallery."""
|
|
||||||
return dict(self._get("/api/images", params={"limit": limit, "offset": offset}))
|
|
||||||
|
|
||||||
def get_image_meta(self, image_id: str) -> dict[str, Any]:
|
|
||||||
"""Get metadata for an image."""
|
|
||||||
return dict(self._get(f"/api/images/{image_id}/meta"))
|
|
||||||
|
|
||||||
def delete_image(self, image_id: str) -> dict[str, Any]:
|
|
||||||
"""Delete an image."""
|
|
||||||
return dict(self._delete(f"/api/images/{image_id}"))
|
|
||||||
|
|
||||||
def edit_image(self, image_id: str, updates: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Update image metadata."""
|
|
||||||
return dict(self._post(f"/api/images/{image_id}/edit", json=updates))
|
|
||||||
|
|
||||||
def download_image(self, image_id: str) -> bytes:
|
|
||||||
"""Download image file bytes."""
|
|
||||||
try:
|
|
||||||
resp = self.client.get(f"/api/images/{image_id}")
|
|
||||||
resp.raise_for_status()
|
|
||||||
return resp.content
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
raise TsrClientError(f"Request failed: {e}") from e
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Models
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def list_models(self) -> dict[str, Any]:
|
|
||||||
"""List available models."""
|
|
||||||
return dict(self._get("/api/models"))
|
|
||||||
|
|
||||||
def get_active_model(self) -> dict[str, Any]:
|
|
||||||
"""Get currently active model."""
|
|
||||||
return dict(self._get("/api/models/active"))
|
|
||||||
|
|
||||||
def switch_model(self, model_path: str) -> dict[str, Any]:
|
|
||||||
"""Switch to a different model."""
|
|
||||||
return dict(self._post("/api/models/switch", json={"model": model_path}))
|
|
||||||
|
|
||||||
def list_loras(self) -> dict[str, Any]:
|
|
||||||
"""List available LoRAs."""
|
|
||||||
return dict(self._get("/api/models/loras"))
|
|
||||||
|
|
||||||
def scan_models(self) -> dict[str, Any]:
|
|
||||||
"""Scan model directories."""
|
|
||||||
return dict(self._get("/api/models/scan"))
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Generation
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
negative_prompt: str = "",
|
|
||||||
width: int = 512,
|
|
||||||
height: int = 512,
|
|
||||||
steps: int = 20,
|
|
||||||
cfg_scale: float = 7.0,
|
|
||||||
seed: int = -1,
|
|
||||||
sampler_name: str = "",
|
|
||||||
scheduler: str = "",
|
|
||||||
batch_size: int = 1,
|
|
||||||
save_to_gallery: bool = True,
|
|
||||||
return_base64: bool = False,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Generate images."""
|
|
||||||
body = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"negative_prompt": negative_prompt,
|
|
||||||
"width": width,
|
|
||||||
"height": height,
|
|
||||||
"steps": steps,
|
|
||||||
"cfg_scale": cfg_scale,
|
|
||||||
"seed": seed,
|
|
||||||
"sampler_name": sampler_name,
|
|
||||||
"scheduler": scheduler,
|
|
||||||
"batch_size": batch_size,
|
|
||||||
"save_to_gallery": save_to_gallery,
|
|
||||||
"return_base64": return_base64,
|
|
||||||
}
|
|
||||||
return dict(self._post("/api/generate", json=body))
|
|
||||||
|
|
||||||
def list_samplers(self) -> dict[str, Any]:
|
|
||||||
"""List available samplers."""
|
|
||||||
return dict(self._get("/api/samplers"))
|
|
||||||
|
|
||||||
def list_schedulers(self) -> dict[str, Any]:
|
|
||||||
"""List available schedulers."""
|
|
||||||
return dict(self._get("/api/schedulers"))
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Download
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def start_download(
|
|
||||||
self,
|
|
||||||
version_id: int | None = None,
|
|
||||||
model_id: int | None = None,
|
|
||||||
hash_val: str | None = None,
|
|
||||||
output_dir: str | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Start a model download from CivitAI."""
|
|
||||||
body: dict[str, Any] = {}
|
|
||||||
if version_id:
|
|
||||||
body["version_id"] = version_id
|
|
||||||
if model_id:
|
|
||||||
body["model_id"] = model_id
|
|
||||||
if hash_val:
|
|
||||||
body["hash"] = hash_val
|
|
||||||
if output_dir:
|
|
||||||
body["output_dir"] = output_dir
|
|
||||||
return dict(self._post("/api/download", json=body))
|
|
||||||
|
|
||||||
def get_download_status(self, download_id: str) -> dict[str, Any]:
|
|
||||||
"""Get download status."""
|
|
||||||
return dict(self._get(f"/api/download/status/{download_id}"))
|
|
||||||
|
|
||||||
def list_downloads(self) -> dict[str, Any]:
|
|
||||||
"""List active downloads."""
|
|
||||||
return dict(self._get("/api/download/active"))
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Database
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def db_list_files(self) -> list[dict[str, Any]]:
|
|
||||||
"""List local files in database."""
|
|
||||||
return list(self._get("/api/db/files"))
|
|
||||||
|
|
||||||
def db_search_models(
|
|
||||||
self,
|
|
||||||
query: str | None = None,
|
|
||||||
model_type: str | None = None,
|
|
||||||
base_model: str | None = None,
|
|
||||||
limit: int = 20,
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""Search cached models."""
|
|
||||||
params: dict[str, Any] = {"limit": limit}
|
|
||||||
if query:
|
|
||||||
params["query"] = query
|
|
||||||
if model_type:
|
|
||||||
params["type"] = model_type
|
|
||||||
if base_model:
|
|
||||||
params["base"] = base_model
|
|
||||||
return list(self._get("/api/db/models", params=params))
|
|
||||||
|
|
||||||
def db_get_model(self, civitai_id: int) -> dict[str, Any]:
|
|
||||||
"""Get cached model by CivitAI ID."""
|
|
||||||
return dict(self._get(f"/api/db/models/{civitai_id}"))
|
|
||||||
|
|
||||||
def db_get_triggers(self, file_path: str | None = None, version_id: int | None = None) -> list[str]:
|
|
||||||
"""Get trigger words."""
|
|
||||||
if version_id:
|
|
||||||
return list(self._get(f"/api/db/triggers/{version_id}"))
|
|
||||||
if file_path:
|
|
||||||
return list(self._get("/api/db/triggers", params={"file_path": file_path}))
|
|
||||||
return []
|
|
||||||
|
|
||||||
def db_stats(self) -> dict[str, Any]:
|
|
||||||
"""Get database statistics."""
|
|
||||||
return dict(self._get("/api/db/stats"))
|
|
||||||
|
|
||||||
def db_scan(self, directory: str) -> dict[str, Any]:
|
|
||||||
"""Scan directory for safetensor files."""
|
|
||||||
return dict(self._post("/api/db/scan", json={"directory": directory}))
|
|
||||||
|
|
||||||
def db_link(self) -> dict[str, Any]:
|
|
||||||
"""Link unlinked files to CivitAI."""
|
|
||||||
return dict(self._post("/api/db/link"))
|
|
||||||
|
|
||||||
def db_cache(self, model_id: int) -> dict[str, Any]:
|
|
||||||
"""Cache CivitAI model data."""
|
|
||||||
return dict(self._post("/api/db/cache", json={"model_id": model_id}))
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Streaming Downloads
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def stream_image(self, image_id: str) -> Iterator[bytes]:
|
|
||||||
"""Stream image download in chunks."""
|
|
||||||
try:
|
|
||||||
with self.client.stream("GET", f"/api/images/{image_id}") as resp:
|
|
||||||
resp.raise_for_status()
|
|
||||||
yield from resp.iter_bytes(chunk_size=1024 * 64)
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
raise TsrClientError(f"HTTP {e.response.status_code}") from e
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
raise TsrClientError(f"Request failed: {e}") from e
|
|
||||||
|
|
||||||
def save_image_to(self, image_id: str, dest: Path) -> Path:
|
|
||||||
"""Download and save image to file."""
|
|
||||||
content = self.download_image(image_id)
|
|
||||||
dest.write_bytes(content)
|
|
||||||
return dest
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
"""sd-server Python client — modular, httpx-based."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from tensors.generate._http import HttpTransport
|
|
||||||
from tensors.generate.generation import GenerationAPI
|
|
||||||
from tensors.generate.info import InfoAPI
|
|
||||||
from tensors.generate.params import Img2ImgParams, Txt2ImgParams
|
|
||||||
from tensors.generate.util import save_images
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Img2ImgParams",
|
|
||||||
"SDClient",
|
|
||||||
"Txt2ImgParams",
|
|
||||||
"save_images",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class SDClient:
|
|
||||||
"""Composite client for sd-server.
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
with SDClient() as c:
|
|
||||||
c.info.models()
|
|
||||||
images = c.generate.txt2img(Txt2ImgParams(prompt="a cat"))
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, host: str = "127.0.0.1", port: int = 1234) -> None:
|
|
||||||
self._http = HttpTransport(f"http://{host}:{port}")
|
|
||||||
self.info = InfoAPI(self._http)
|
|
||||||
self.generate = GenerationAPI(self._http)
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
self._http.close()
|
|
||||||
|
|
||||||
def __enter__(self) -> SDClient:
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *exc: Any) -> None:
|
|
||||||
self.close()
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
"""HTTP transport layer wrapping httpx."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class HttpTransport:
|
|
||||||
def __init__(self, base_url: str, timeout: float = 300.0) -> None:
|
|
||||||
self._client = httpx.Client(base_url=base_url, timeout=timeout)
|
|
||||||
logger.debug("transport ready: %s", base_url)
|
|
||||||
|
|
||||||
def get(self, path: str) -> Any:
|
|
||||||
logger.debug("GET %s", path)
|
|
||||||
try:
|
|
||||||
r = self._client.get(path)
|
|
||||||
r.raise_for_status()
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
logger.error("GET %s → %d: %s", path, e.response.status_code, e.response.text[:200])
|
|
||||||
raise
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
logger.error("GET %s connection failed: %s", path, e)
|
|
||||||
raise
|
|
||||||
return r.json()
|
|
||||||
|
|
||||||
def post(self, path: str, json: dict[str, Any]) -> Any:
|
|
||||||
logger.debug("POST %s", path)
|
|
||||||
try:
|
|
||||||
r = self._client.post(path, json=json)
|
|
||||||
r.raise_for_status()
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
logger.error("POST %s → %d: %s", path, e.response.status_code, e.response.text[:200])
|
|
||||||
raise
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
logger.error("POST %s connection failed: %s", path, e)
|
|
||||||
raise
|
|
||||||
return r.json()
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
self._client.close()
|
|
||||||
logger.debug("transport closed")
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
"""Image generation endpoints."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import logging
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from tensors.generate._http import HttpTransport
|
|
||||||
from tensors.generate.params import Img2ImgParams, Txt2ImgParams
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class GenerationAPI:
|
|
||||||
def __init__(self, http: HttpTransport) -> None:
|
|
||||||
self._http = http
|
|
||||||
|
|
||||||
def txt2img(self, params: Txt2ImgParams) -> list[bytes]:
|
|
||||||
"""Generate images from text prompt."""
|
|
||||||
logger.info("txt2img: '%s' %dx%d steps=%d", params.prompt[:60], params.width, params.height, params.steps)
|
|
||||||
data = self._http.post("/sdapi/v1/txt2img", params.to_body())
|
|
||||||
images = [base64.b64decode(img) for img in data["images"]]
|
|
||||||
logger.info("txt2img: got %d image(s)", len(images))
|
|
||||||
return images
|
|
||||||
|
|
||||||
def img2img(self, params: Img2ImgParams) -> list[bytes]:
|
|
||||||
"""Generate images from image + text prompt."""
|
|
||||||
logger.info("img2img: '%s' strength=%.2f steps=%d", params.prompt[:60], params.denoising_strength, params.steps)
|
|
||||||
data = self._http.post("/sdapi/v1/img2img", params.to_body())
|
|
||||||
images = [base64.b64decode(img) for img in data["images"]]
|
|
||||||
logger.info("img2img: got %d image(s)", len(images))
|
|
||||||
return images
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
"""Model and server info endpoints."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import TYPE_CHECKING, Any
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from tensors.generate._http import HttpTransport
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class InfoAPI:
|
|
||||||
def __init__(self, http: HttpTransport) -> None:
|
|
||||||
self._http = http
|
|
||||||
|
|
||||||
def models(self) -> list[dict[str, Any]]:
|
|
||||||
"""List loaded models (OpenAI /v1/models)."""
|
|
||||||
return self._http.get("/v1/models")["data"] # type: ignore[no-any-return]
|
|
||||||
|
|
||||||
def sd_models(self) -> list[dict[str, Any]]:
|
|
||||||
"""Detailed model info (sdapi)."""
|
|
||||||
return self._http.get("/sdapi/v1/sd-models") # type: ignore[no-any-return]
|
|
||||||
|
|
||||||
def options(self) -> dict[str, Any]:
|
|
||||||
"""Current server options."""
|
|
||||||
return self._http.get("/sdapi/v1/options") # type: ignore[no-any-return]
|
|
||||||
|
|
||||||
def loras(self) -> list[dict[str, Any]]:
|
|
||||||
"""Available LoRAs from --lora-model-dir."""
|
|
||||||
result: list[dict[str, Any]] = self._http.get("/sdapi/v1/loras")
|
|
||||||
logger.info("found %d lora(s)", len(result))
|
|
||||||
return result
|
|
||||||
|
|
||||||
def samplers(self) -> list[str]:
|
|
||||||
"""Available sampler names."""
|
|
||||||
return [s["name"] for s in self._http.get("/sdapi/v1/samplers")]
|
|
||||||
|
|
||||||
def schedulers(self) -> list[str]:
|
|
||||||
"""Available scheduler names."""
|
|
||||||
return [s["name"] for s in self._http.get("/sdapi/v1/schedulers")]
|
|
||||||
@@ -1,100 +0,0 @@
|
|||||||
"""Generation parameter dataclasses."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import TYPE_CHECKING, Any
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from tensors.generate.util import to_b64
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Txt2ImgParams:
|
|
||||||
prompt: str
|
|
||||||
negative_prompt: str = ""
|
|
||||||
width: int = 512
|
|
||||||
height: int = 512
|
|
||||||
steps: int = 20
|
|
||||||
cfg_scale: float = 7.0
|
|
||||||
seed: int = -1
|
|
||||||
batch_size: int = 1
|
|
||||||
sampler_name: str = ""
|
|
||||||
scheduler: str = ""
|
|
||||||
clip_skip: int = -1
|
|
||||||
lora: list[dict[str, Any]] | None = None
|
|
||||||
|
|
||||||
def to_body(self) -> dict[str, Any]:
|
|
||||||
body = {
|
|
||||||
"prompt": self.prompt,
|
|
||||||
"negative_prompt": self.negative_prompt,
|
|
||||||
"width": self.width,
|
|
||||||
"height": self.height,
|
|
||||||
"steps": self.steps,
|
|
||||||
"cfg_scale": self.cfg_scale,
|
|
||||||
"seed": self.seed,
|
|
||||||
"batch_size": self.batch_size,
|
|
||||||
}
|
|
||||||
if self.sampler_name:
|
|
||||||
body["sampler_name"] = self.sampler_name
|
|
||||||
if self.scheduler:
|
|
||||||
body["scheduler"] = self.scheduler
|
|
||||||
if self.clip_skip > 0:
|
|
||||||
body["clip_skip"] = self.clip_skip
|
|
||||||
if self.lora:
|
|
||||||
body["lora"] = self.lora
|
|
||||||
return body
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Img2ImgParams:
|
|
||||||
prompt: str
|
|
||||||
init_image: str | bytes | Path
|
|
||||||
negative_prompt: str = ""
|
|
||||||
width: int = -1
|
|
||||||
height: int = -1
|
|
||||||
steps: int = 20
|
|
||||||
cfg_scale: float = 7.0
|
|
||||||
denoising_strength: float = 0.75
|
|
||||||
seed: int = -1
|
|
||||||
batch_size: int = 1
|
|
||||||
sampler_name: str = ""
|
|
||||||
scheduler: str = ""
|
|
||||||
clip_skip: int = -1
|
|
||||||
mask: str | bytes | Path | None = None
|
|
||||||
inpainting_mask_invert: bool = False
|
|
||||||
lora: list[dict[str, Any]] | None = None
|
|
||||||
extra_images: list[str | bytes | Path] = field(default_factory=list)
|
|
||||||
|
|
||||||
def to_body(self) -> dict[str, Any]:
|
|
||||||
body: dict[str, Any] = {
|
|
||||||
"prompt": self.prompt,
|
|
||||||
"negative_prompt": self.negative_prompt,
|
|
||||||
"steps": self.steps,
|
|
||||||
"cfg_scale": self.cfg_scale,
|
|
||||||
"denoising_strength": self.denoising_strength,
|
|
||||||
"seed": self.seed,
|
|
||||||
"batch_size": self.batch_size,
|
|
||||||
"init_images": [to_b64(self.init_image)],
|
|
||||||
}
|
|
||||||
if self.width > 0:
|
|
||||||
body["width"] = self.width
|
|
||||||
if self.height > 0:
|
|
||||||
body["height"] = self.height
|
|
||||||
if self.mask is not None:
|
|
||||||
body["mask"] = to_b64(self.mask)
|
|
||||||
if self.inpainting_mask_invert:
|
|
||||||
body["inpainting_mask_invert"] = 1
|
|
||||||
if self.sampler_name:
|
|
||||||
body["sampler_name"] = self.sampler_name
|
|
||||||
if self.scheduler:
|
|
||||||
body["scheduler"] = self.scheduler
|
|
||||||
if self.clip_skip > 0:
|
|
||||||
body["clip_skip"] = self.clip_skip
|
|
||||||
if self.lora:
|
|
||||||
body["lora"] = self.lora
|
|
||||||
if self.extra_images:
|
|
||||||
body["extra_images"] = [to_b64(img) for img in self.extra_images]
|
|
||||||
return body
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
"""Utility functions for image encoding and file I/O."""
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def to_b64(image: str | bytes | Path) -> str:
|
|
||||||
"""Convert a file path, raw bytes, or base64 string to base64."""
|
|
||||||
if isinstance(image, (str, Path)):
|
|
||||||
path = Path(image)
|
|
||||||
if path.exists():
|
|
||||||
logger.debug("encoding file: %s", path)
|
|
||||||
return base64.b64encode(path.read_bytes()).decode()
|
|
||||||
return str(image)
|
|
||||||
if isinstance(image, bytes):
|
|
||||||
return base64.b64encode(image).decode()
|
|
||||||
raise TypeError(f"unsupported image type: {type(image)}")
|
|
||||||
|
|
||||||
|
|
||||||
def save_images(
|
|
||||||
images: list[bytes],
|
|
||||||
output_dir: str = ".",
|
|
||||||
prefix: str = "output",
|
|
||||||
) -> list[Path]:
|
|
||||||
"""Write raw PNG bytes to numbered files. Returns saved paths."""
|
|
||||||
out = Path(output_dir)
|
|
||||||
out.mkdir(parents=True, exist_ok=True)
|
|
||||||
paths = []
|
|
||||||
for i, data in enumerate(images):
|
|
||||||
path = out / f"{prefix}_{i:04d}.png"
|
|
||||||
path.write_bytes(data)
|
|
||||||
logger.info("saved: %s", path)
|
|
||||||
paths.append(path)
|
|
||||||
return paths
|
|
||||||
@@ -1,216 +0,0 @@
|
|||||||
"""FastAPI route handlers for image generation with gallery integration."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
|
||||||
from pydantic import BaseModel as PydanticBaseModel
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from tensors.server.gallery import Gallery
|
|
||||||
from tensors.server.sd_client import get_sd_headers
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Request/Response Models
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class LoraConfig(PydanticBaseModel):
|
|
||||||
"""LoRA configuration for sd-server."""
|
|
||||||
|
|
||||||
path: str
|
|
||||||
multiplier: float = Field(default=1.0, ge=0.0, le=2.0)
|
|
||||||
|
|
||||||
|
|
||||||
class GenerateRequest(PydanticBaseModel):
|
|
||||||
"""Request body for image generation."""
|
|
||||||
|
|
||||||
prompt: str
|
|
||||||
negative_prompt: str = ""
|
|
||||||
width: int = Field(default=512, ge=64, le=2048)
|
|
||||||
height: int = Field(default=512, ge=64, le=2048)
|
|
||||||
steps: int = Field(default=20, ge=1, le=150)
|
|
||||||
cfg_scale: float = Field(default=7.0, ge=0, le=30)
|
|
||||||
seed: int = -1
|
|
||||||
sampler_name: str = ""
|
|
||||||
scheduler: str = ""
|
|
||||||
batch_size: int = Field(default=1, ge=1, le=16)
|
|
||||||
save_to_gallery: bool = True
|
|
||||||
return_base64: bool = False
|
|
||||||
lora: LoraConfig | None = None
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Helper Functions
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def _build_sd_request(req: GenerateRequest) -> dict[str, Any]:
|
|
||||||
"""Build request body for sd-server."""
|
|
||||||
prompt = req.prompt
|
|
||||||
|
|
||||||
# sd-server expects LoRA in prompt as <lora:name:weight> syntax
|
|
||||||
if req.lora:
|
|
||||||
lora_name = Path(req.lora.path).stem
|
|
||||||
lora_tag = f"<lora:{lora_name}:{req.lora.multiplier}>"
|
|
||||||
prompt = f"{prompt} {lora_tag}"
|
|
||||||
|
|
||||||
body: dict[str, Any] = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"negative_prompt": req.negative_prompt,
|
|
||||||
"width": req.width,
|
|
||||||
"height": req.height,
|
|
||||||
"steps": req.steps,
|
|
||||||
"cfg_scale": req.cfg_scale,
|
|
||||||
"seed": req.seed,
|
|
||||||
"batch_size": req.batch_size,
|
|
||||||
}
|
|
||||||
if req.sampler_name:
|
|
||||||
body["sampler_name"] = req.sampler_name
|
|
||||||
if req.scheduler:
|
|
||||||
body["scheduler"] = req.scheduler
|
|
||||||
return body
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_info(info: Any) -> dict[str, Any]:
|
|
||||||
"""Parse info from sd-server response."""
|
|
||||||
if isinstance(info, str):
|
|
||||||
import json # noqa: PLC0415
|
|
||||||
|
|
||||||
try:
|
|
||||||
return dict(json.loads(info))
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return {"raw": info}
|
|
||||||
return info if isinstance(info, dict) else {}
|
|
||||||
|
|
||||||
|
|
||||||
def _process_image(
|
|
||||||
img_b64: str,
|
|
||||||
index: int,
|
|
||||||
seed: int,
|
|
||||||
req: GenerateRequest,
|
|
||||||
gallery: Gallery,
|
|
||||||
model: str | None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Process a single generated image."""
|
|
||||||
image_bytes = base64.b64decode(img_b64)
|
|
||||||
image_info: dict[str, Any] = {"index": index, "seed": seed}
|
|
||||||
|
|
||||||
if req.save_to_gallery:
|
|
||||||
metadata = {
|
|
||||||
"prompt": req.prompt,
|
|
||||||
"negative_prompt": req.negative_prompt,
|
|
||||||
"width": req.width,
|
|
||||||
"height": req.height,
|
|
||||||
"steps": req.steps,
|
|
||||||
"cfg_scale": req.cfg_scale,
|
|
||||||
"sampler": req.sampler_name,
|
|
||||||
"scheduler": req.scheduler,
|
|
||||||
"model": model,
|
|
||||||
"lora": req.lora.path if req.lora else None,
|
|
||||||
"lora_weight": req.lora.multiplier if req.lora else None,
|
|
||||||
"generated_at": time.time(),
|
|
||||||
}
|
|
||||||
gallery_img = gallery.save_image(image_bytes, metadata=metadata, seed=seed)
|
|
||||||
image_info["id"] = gallery_img.id
|
|
||||||
image_info["path"] = str(gallery_img.path)
|
|
||||||
|
|
||||||
if req.return_base64:
|
|
||||||
image_info["base64"] = img_b64
|
|
||||||
|
|
||||||
return image_info
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Router Factory
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def create_generate_router() -> APIRouter: # noqa: PLR0915
|
|
||||||
"""Build a router with /api/generate endpoint."""
|
|
||||||
router = APIRouter(prefix="/api", tags=["generate"])
|
|
||||||
gallery = Gallery()
|
|
||||||
|
|
||||||
@router.post("/generate")
|
|
||||||
async def generate(request: Request, req: GenerateRequest) -> dict[str, Any]:
|
|
||||||
"""Generate images with gallery integration."""
|
|
||||||
sd_server_url = request.app.state.sd_server_url
|
|
||||||
body = _build_sd_request(req)
|
|
||||||
url = f"{sd_server_url}/sdapi/v1/txt2img"
|
|
||||||
|
|
||||||
try:
|
|
||||||
headers = get_sd_headers(request)
|
|
||||||
async with httpx.AsyncClient(timeout=300) as client:
|
|
||||||
response = await client.post(url, json=body, headers=headers)
|
|
||||||
response.raise_for_status()
|
|
||||||
result = response.json()
|
|
||||||
except httpx.ConnectError as e:
|
|
||||||
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
|
|
||||||
except httpx.HTTPError as e:
|
|
||||||
logger.exception("Generation failed")
|
|
||||||
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
|
||||||
|
|
||||||
images_data = result.get("images", [])
|
|
||||||
info = _parse_info(result.get("info", {}))
|
|
||||||
all_seeds = info.get("all_seeds", [req.seed] * len(images_data))
|
|
||||||
|
|
||||||
# Get model info from sd-server response if available
|
|
||||||
model_name = info.get("sd_model_name") or info.get("model")
|
|
||||||
|
|
||||||
output_images = [
|
|
||||||
_process_image(img_b64, i, all_seeds[i] if i < len(all_seeds) else req.seed + i, req, gallery, model_name)
|
|
||||||
for i, img_b64 in enumerate(images_data)
|
|
||||||
]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"images": output_images,
|
|
||||||
"parameters": result.get("parameters", body),
|
|
||||||
"info": info,
|
|
||||||
"saved_to_gallery": req.save_to_gallery,
|
|
||||||
"total": len(output_images),
|
|
||||||
}
|
|
||||||
|
|
||||||
@router.get("/samplers")
|
|
||||||
async def list_samplers(request: Request) -> dict[str, Any]:
|
|
||||||
"""List available samplers from sd-server."""
|
|
||||||
sd_server_url = request.app.state.sd_server_url
|
|
||||||
url = f"{sd_server_url}/sdapi/v1/samplers"
|
|
||||||
|
|
||||||
try:
|
|
||||||
headers = get_sd_headers(request)
|
|
||||||
async with httpx.AsyncClient(timeout=30) as client:
|
|
||||||
response = await client.get(url, headers=headers)
|
|
||||||
response.raise_for_status()
|
|
||||||
return {"samplers": response.json()}
|
|
||||||
except httpx.ConnectError as e:
|
|
||||||
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
|
|
||||||
except httpx.HTTPError as e:
|
|
||||||
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
|
||||||
|
|
||||||
@router.get("/schedulers")
|
|
||||||
async def list_schedulers(request: Request) -> dict[str, Any]:
|
|
||||||
"""List available schedulers from sd-server."""
|
|
||||||
sd_server_url = request.app.state.sd_server_url
|
|
||||||
url = f"{sd_server_url}/sdapi/v1/schedulers"
|
|
||||||
|
|
||||||
try:
|
|
||||||
headers = get_sd_headers(request)
|
|
||||||
async with httpx.AsyncClient(timeout=30) as client:
|
|
||||||
response = await client.get(url, headers=headers)
|
|
||||||
response.raise_for_status()
|
|
||||||
return {"schedulers": response.json()}
|
|
||||||
except httpx.ConnectError as e:
|
|
||||||
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
|
|
||||||
except httpx.HTTPError as e:
|
|
||||||
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
|
||||||
|
|
||||||
return router
|
|
||||||
@@ -1,311 +0,0 @@
|
|||||||
"""FastAPI route handlers for model management endpoints."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from tensors.config import MODELS_DIR
|
|
||||||
from tensors.db import Database
|
|
||||||
from tensors.server.sd_client import get_sd_headers
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_HTTP_OK = 200
|
|
||||||
_SD_ENV_FILE = Path("/etc/default/sd-server")
|
|
||||||
|
|
||||||
|
|
||||||
class SwitchModelRequest(BaseModel):
|
|
||||||
"""Request body for switching models."""
|
|
||||||
|
|
||||||
model: str # Model filename or full path
|
|
||||||
|
|
||||||
|
|
||||||
async def _run_command(*args: str) -> tuple[int, str, str]:
|
|
||||||
"""Run a shell command and return (returncode, stdout, stderr)."""
|
|
||||||
proc = await asyncio.create_subprocess_exec(
|
|
||||||
*args,
|
|
||||||
stdout=asyncio.subprocess.PIPE,
|
|
||||||
stderr=asyncio.subprocess.PIPE,
|
|
||||||
)
|
|
||||||
stdout, stderr = await proc.communicate()
|
|
||||||
return proc.returncode or 0, stdout.decode(), stderr.decode()
|
|
||||||
|
|
||||||
|
|
||||||
def _read_env_file() -> dict[str, str]:
|
|
||||||
"""Read the sd-server environment file."""
|
|
||||||
env: dict[str, str] = {}
|
|
||||||
if _SD_ENV_FILE.exists():
|
|
||||||
for raw_line in _SD_ENV_FILE.read_text().splitlines():
|
|
||||||
line = raw_line.strip()
|
|
||||||
if line and not line.startswith("#") and "=" in line:
|
|
||||||
key, _, value = line.partition("=")
|
|
||||||
env[key.strip()] = value.strip()
|
|
||||||
return env
|
|
||||||
|
|
||||||
|
|
||||||
def _write_env_file(env: dict[str, str]) -> str:
|
|
||||||
"""Generate env file content."""
|
|
||||||
lines = ["# sd-server configuration"]
|
|
||||||
for key, value in env.items():
|
|
||||||
lines.append(f"{key}={value}")
|
|
||||||
return "\n".join(lines) + "\n"
|
|
||||||
|
|
||||||
# Keywords for detecting base model category
|
|
||||||
_SD15_KEYWORDS = ("sd15", "sd1.5", "sd-1.5", "sd_1.5", "1.5", "sd-1-", "v1-5")
|
|
||||||
_LARGE_KEYWORDS = ("sdxl", "xl", "pony", "illustrious", "ilust", "noob", "animagine")
|
|
||||||
|
|
||||||
|
|
||||||
def _detect_model_category(name: str) -> str:
|
|
||||||
"""Detect model category from filename. Returns 'sd15' or 'large'."""
|
|
||||||
name_lower = name.lower()
|
|
||||||
|
|
||||||
# Check SD 1.5 keywords first
|
|
||||||
for kw in _SD15_KEYWORDS:
|
|
||||||
if kw in name_lower:
|
|
||||||
return "sd15"
|
|
||||||
|
|
||||||
# Check large model keywords
|
|
||||||
for kw in _LARGE_KEYWORDS:
|
|
||||||
if kw in name_lower:
|
|
||||||
return "large"
|
|
||||||
|
|
||||||
# Default to large (SDXL/Pony/Illustrious are more common now)
|
|
||||||
return "large"
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Helper Functions
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def scan_models(directory: Path, extensions: tuple[str, ...] = (".safetensors", ".gguf")) -> list[dict[str, Any]]:
|
|
||||||
"""Scan directory for model files."""
|
|
||||||
models: list[dict[str, Any]] = []
|
|
||||||
|
|
||||||
if not directory.exists():
|
|
||||||
return models
|
|
||||||
|
|
||||||
for ext in extensions:
|
|
||||||
for path in directory.rglob(f"*{ext}"):
|
|
||||||
stat = path.stat()
|
|
||||||
name = path.stem
|
|
||||||
models.append(
|
|
||||||
{
|
|
||||||
"name": name,
|
|
||||||
"path": str(path),
|
|
||||||
"filename": path.name,
|
|
||||||
"size_mb": round(stat.st_size / (1024 * 1024), 2),
|
|
||||||
"modified": stat.st_mtime,
|
|
||||||
"category": _detect_model_category(name),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sort by name
|
|
||||||
models.sort(key=lambda x: x["name"].lower())
|
|
||||||
return models
|
|
||||||
|
|
||||||
|
|
||||||
def _enrich_with_metadata(models: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
||||||
"""Enrich model data with CivitAI metadata from database."""
|
|
||||||
try:
|
|
||||||
with Database() as db:
|
|
||||||
db.init_schema()
|
|
||||||
|
|
||||||
for model in models:
|
|
||||||
file_path = model.get("path", "")
|
|
||||||
file_info = db.get_local_file_by_path(file_path)
|
|
||||||
|
|
||||||
if file_info and file_info.get("civitai_model_id"):
|
|
||||||
# Add human-readable name
|
|
||||||
model["display_name"] = file_info.get("model_name") or model["name"]
|
|
||||||
model["base_model"] = file_info.get("base_model")
|
|
||||||
model["model_type"] = file_info.get("model_type")
|
|
||||||
model["civitai_model_id"] = file_info.get("civitai_model_id")
|
|
||||||
model["civitai_version_id"] = file_info.get("civitai_version_id")
|
|
||||||
|
|
||||||
# Get thumbnail from version images
|
|
||||||
version_id = file_info.get("civitai_version_id")
|
|
||||||
if version_id:
|
|
||||||
cur = db.conn.cursor()
|
|
||||||
cur.execute(
|
|
||||||
"""
|
|
||||||
SELECT url FROM version_images
|
|
||||||
WHERE version_id = (SELECT id FROM model_versions WHERE civitai_id = ?)
|
|
||||||
ORDER BY id LIMIT 1
|
|
||||||
""",
|
|
||||||
(version_id,),
|
|
||||||
)
|
|
||||||
row = cur.fetchone()
|
|
||||||
if row:
|
|
||||||
model["thumbnail_url"] = row[0]
|
|
||||||
|
|
||||||
# Get trigger words
|
|
||||||
triggers = db.get_triggers_by_version(version_id) if version_id else []
|
|
||||||
model["triggers"] = triggers[:5] # Limit to first 5
|
|
||||||
else:
|
|
||||||
model["display_name"] = model["name"]
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Failed to enrich models with metadata: %s", e)
|
|
||||||
# Fallback: just use filename as display name
|
|
||||||
for model in models:
|
|
||||||
if "display_name" not in model:
|
|
||||||
model["display_name"] = model["name"]
|
|
||||||
|
|
||||||
return models
|
|
||||||
|
|
||||||
|
|
||||||
def scan_loras(directory: Path | None = None) -> list[dict[str, Any]]:
|
|
||||||
"""Scan for LoRA files."""
|
|
||||||
lora_dir = directory or MODELS_DIR / "loras"
|
|
||||||
return scan_models(lora_dir, extensions=(".safetensors", ".gguf"))
|
|
||||||
|
|
||||||
|
|
||||||
def scan_checkpoints(directory: Path | None = None) -> list[dict[str, Any]]:
|
|
||||||
"""Scan for checkpoint files."""
|
|
||||||
checkpoint_dir = directory or MODELS_DIR / "checkpoints"
|
|
||||||
return scan_models(checkpoint_dir, extensions=(".safetensors", ".gguf"))
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Router Factory
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def create_models_router() -> APIRouter: # noqa: PLR0915
|
|
||||||
"""Build a router with /api/models/* endpoints."""
|
|
||||||
router = APIRouter(prefix="/api/models", tags=["models"])
|
|
||||||
|
|
||||||
@router.get("")
|
|
||||||
def list_models() -> dict[str, Any]:
|
|
||||||
"""List available checkpoint models with metadata."""
|
|
||||||
checkpoints = scan_checkpoints()
|
|
||||||
checkpoints = _enrich_with_metadata(checkpoints)
|
|
||||||
return {
|
|
||||||
"models": checkpoints,
|
|
||||||
"total": len(checkpoints),
|
|
||||||
}
|
|
||||||
|
|
||||||
@router.get("/active")
|
|
||||||
async def get_active_model(request: Request) -> dict[str, Any]:
|
|
||||||
"""Get information about the currently loaded model from sd-server."""
|
|
||||||
import httpx # noqa: PLC0415
|
|
||||||
|
|
||||||
sd_server_url = request.app.state.sd_server_url
|
|
||||||
|
|
||||||
# Try to get current model from sd-server's options endpoint
|
|
||||||
try:
|
|
||||||
headers = get_sd_headers(request)
|
|
||||||
async with httpx.AsyncClient(timeout=10) as client:
|
|
||||||
response = await client.get(f"{sd_server_url}/sdapi/v1/options", headers=headers)
|
|
||||||
if response.status_code == _HTTP_OK:
|
|
||||||
options = response.json()
|
|
||||||
model_name = options.get("sd_model_checkpoint")
|
|
||||||
return {
|
|
||||||
"loaded": True,
|
|
||||||
"model": model_name,
|
|
||||||
"sd_server_url": sd_server_url,
|
|
||||||
}
|
|
||||||
except httpx.HTTPError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return {
|
|
||||||
"loaded": False,
|
|
||||||
"model": None,
|
|
||||||
"sd_server_url": sd_server_url,
|
|
||||||
"error": "Cannot connect to sd-server",
|
|
||||||
}
|
|
||||||
|
|
||||||
@router.get("/loras")
|
|
||||||
def list_loras() -> dict[str, Any]:
|
|
||||||
"""List available LoRA files with metadata."""
|
|
||||||
loras = scan_loras()
|
|
||||||
loras = _enrich_with_metadata(loras)
|
|
||||||
return {
|
|
||||||
"loras": loras,
|
|
||||||
"total": len(loras),
|
|
||||||
}
|
|
||||||
|
|
||||||
@router.get("/scan")
|
|
||||||
def scan_all_models() -> dict[str, Any]:
|
|
||||||
"""Scan all model directories."""
|
|
||||||
checkpoints = scan_checkpoints()
|
|
||||||
loras = scan_loras()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"checkpoints": checkpoints,
|
|
||||||
"loras": loras,
|
|
||||||
"total_checkpoints": len(checkpoints),
|
|
||||||
"total_loras": len(loras),
|
|
||||||
}
|
|
||||||
|
|
||||||
@router.post("/switch")
|
|
||||||
async def switch_model(req: SwitchModelRequest) -> dict[str, Any]:
|
|
||||||
"""Switch sd-server to a different model by updating env and restarting."""
|
|
||||||
# Find the model file
|
|
||||||
checkpoints = scan_checkpoints()
|
|
||||||
model_path: str | None = None
|
|
||||||
|
|
||||||
for cp in checkpoints:
|
|
||||||
if cp["filename"] == req.model or cp["path"] == req.model or cp["name"] == req.model:
|
|
||||||
model_path = cp["path"]
|
|
||||||
break
|
|
||||||
|
|
||||||
if not model_path:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Model not found: {req.model}")
|
|
||||||
|
|
||||||
# Read current env, update SD_MODEL
|
|
||||||
env = _read_env_file()
|
|
||||||
old_model = env.get("SD_MODEL", "")
|
|
||||||
env["SD_MODEL"] = model_path
|
|
||||||
|
|
||||||
# Write new env file via sudo tee
|
|
||||||
new_content = _write_env_file(env)
|
|
||||||
proc = await asyncio.create_subprocess_exec(
|
|
||||||
"sudo", "tee", str(_SD_ENV_FILE),
|
|
||||||
stdin=asyncio.subprocess.PIPE,
|
|
||||||
stdout=asyncio.subprocess.PIPE,
|
|
||||||
stderr=asyncio.subprocess.PIPE,
|
|
||||||
)
|
|
||||||
_, tee_stderr = await proc.communicate(new_content.encode())
|
|
||||||
if proc.returncode != 0:
|
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to write env file: {tee_stderr.decode()}")
|
|
||||||
|
|
||||||
# Restart sd-server
|
|
||||||
returncode, _stdout, restart_stderr = await _run_command("sudo", "systemctl", "restart", "sd-server")
|
|
||||||
if returncode != 0:
|
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to restart sd-server: {restart_stderr}")
|
|
||||||
|
|
||||||
logger.info(f"Switched model from {old_model} to {model_path}")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"ok": True,
|
|
||||||
"old_model": old_model,
|
|
||||||
"new_model": model_path,
|
|
||||||
"message": "Model switched, sd-server restarting",
|
|
||||||
}
|
|
||||||
|
|
||||||
@router.get("/status")
|
|
||||||
async def sd_server_status() -> dict[str, Any]:
|
|
||||||
"""Get sd-server systemd service status."""
|
|
||||||
_returncode, stdout, _stderr = await _run_command("systemctl", "is-active", "sd-server")
|
|
||||||
is_active = stdout.strip() == "active"
|
|
||||||
|
|
||||||
env = _read_env_file()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"service": "sd-server",
|
|
||||||
"active": is_active,
|
|
||||||
"status": stdout.strip(),
|
|
||||||
"current_model": env.get("SD_MODEL"),
|
|
||||||
"host": env.get("SD_HOST"),
|
|
||||||
"port": env.get("SD_PORT"),
|
|
||||||
}
|
|
||||||
|
|
||||||
return router
|
|
||||||
@@ -1,80 +0,0 @@
|
|||||||
"""FastAPI route handlers for the sd-server wrapper API."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
from fastapi import APIRouter, Request, Response
|
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
|
||||||
|
|
||||||
from tensors.server.sd_client import get_sd_headers
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def create_router() -> APIRouter:
|
|
||||||
"""Build a router with /status and catch-all proxy."""
|
|
||||||
router = APIRouter()
|
|
||||||
|
|
||||||
@router.get("/status")
|
|
||||||
async def status(request: Request) -> dict[str, Any]:
|
|
||||||
"""Check if the external sd-server is reachable."""
|
|
||||||
sd_server_url = request.app.state.sd_server_url
|
|
||||||
try:
|
|
||||||
headers = get_sd_headers(request)
|
|
||||||
async with httpx.AsyncClient(timeout=5) as client:
|
|
||||||
r = await client.get(sd_server_url, headers=headers)
|
|
||||||
return {
|
|
||||||
"status": "ok",
|
|
||||||
"sd_server_url": sd_server_url,
|
|
||||||
"sd_server_status": r.status_code,
|
|
||||||
}
|
|
||||||
except httpx.HTTPError as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"sd_server_url": sd_server_url,
|
|
||||||
"error": str(e),
|
|
||||||
}
|
|
||||||
|
|
||||||
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"])
|
|
||||||
async def proxy(request: Request, path: str) -> Response:
|
|
||||||
"""Proxy all requests to the external sd-server."""
|
|
||||||
sd_server_url = request.app.state.sd_server_url
|
|
||||||
url = f"{sd_server_url}/{path}"
|
|
||||||
if request.url.query:
|
|
||||||
url = f"{url}?{request.url.query}"
|
|
||||||
|
|
||||||
body = await request.body()
|
|
||||||
headers = dict(request.headers)
|
|
||||||
headers.pop("host", None)
|
|
||||||
# Add API key if configured
|
|
||||||
headers.update(get_sd_headers(request))
|
|
||||||
client = request.app.state.client
|
|
||||||
|
|
||||||
try:
|
|
||||||
upstream = await client.request(
|
|
||||||
method=request.method,
|
|
||||||
url=url,
|
|
||||||
headers=headers,
|
|
||||||
content=body,
|
|
||||||
timeout=300,
|
|
||||||
)
|
|
||||||
return StreamingResponse(
|
|
||||||
content=upstream.iter_bytes(),
|
|
||||||
status_code=upstream.status_code,
|
|
||||||
headers=dict(upstream.headers),
|
|
||||||
)
|
|
||||||
except httpx.ConnectError:
|
|
||||||
return JSONResponse(
|
|
||||||
{"error": f"Cannot connect to sd-server at {sd_server_url}"},
|
|
||||||
status_code=503,
|
|
||||||
)
|
|
||||||
except httpx.TimeoutException:
|
|
||||||
return JSONResponse(
|
|
||||||
{"error": f"Timeout connecting to sd-server at {sd_server_url}"},
|
|
||||||
status_code=504,
|
|
||||||
)
|
|
||||||
|
|
||||||
return router
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
"""HTTP client utilities for sd-server communication."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from fastapi import Request
|
|
||||||
|
|
||||||
|
|
||||||
def get_sd_headers(request: Request) -> dict[str, str]:
|
|
||||||
"""Get headers for sd-server requests, including API key if configured."""
|
|
||||||
headers: dict[str, str] = {}
|
|
||||||
api_key = getattr(request.app.state, "sd_server_api_key", None)
|
|
||||||
if api_key:
|
|
||||||
headers["X-API-Key"] = api_key
|
|
||||||
return headers
|
|
||||||
|
|
||||||
|
|
||||||
async def sd_get(request: Request, path: str, *, timeout: float = 30) -> httpx.Response:
|
|
||||||
"""Make a GET request to sd-server."""
|
|
||||||
url = f"{request.app.state.sd_server_url}/{path.lstrip('/')}"
|
|
||||||
headers = get_sd_headers(request)
|
|
||||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
||||||
response = await client.get(url, headers=headers)
|
|
||||||
response.raise_for_status()
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
async def sd_post(request: Request, path: str, *, json: dict[str, Any] | None = None, timeout: float = 300) -> httpx.Response:
|
|
||||||
"""Make a POST request to sd-server."""
|
|
||||||
url = f"{request.app.state.sd_server_url}/{path.lstrip('/')}"
|
|
||||||
headers = get_sd_headers(request)
|
|
||||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
||||||
response = await client.post(url, json=json, headers=headers)
|
|
||||||
response.raise_for_status()
|
|
||||||
return response
|
|
||||||
Reference in New Issue
Block a user