💬 Commit message: Update 2026-02-15 06:15:02, 11 files, 1240 lines
📁 Files changed: 11 📝 Lines changed: 1240 • client.py • __init__.py • _http.py • generation.py • info.py • params.py • util.py • generate_routes.py • models_routes.py • routes.py • sd_client.py
This commit is contained in:
@@ -1,292 +0,0 @@
|
||||
"""HTTP client for remote tsr server API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TsrClientError(Exception):
|
||||
"""Error from TsrClient operations."""
|
||||
|
||||
|
||||
class TsrClient:
|
||||
"""HTTP client wrapper for tsr server API.
|
||||
|
||||
Usage:
|
||||
with TsrClient("http://junkpile:8080") as client:
|
||||
images = client.list_images()
|
||||
result = client.generate("a cat")
|
||||
"""
|
||||
|
||||
def __init__(self, base_url: str, timeout: float = 300.0) -> None:
|
||||
"""Initialize client with server URL."""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout = timeout
|
||||
self._client: httpx.Client | None = None
|
||||
|
||||
def __enter__(self) -> TsrClient:
|
||||
self._client = httpx.Client(base_url=self.base_url, timeout=self.timeout)
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc: object) -> None:
|
||||
if self._client:
|
||||
self._client.close()
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self) -> httpx.Client:
|
||||
"""Get the HTTP client, creating if needed."""
|
||||
if self._client is None:
|
||||
self._client = httpx.Client(base_url=self.base_url, timeout=self.timeout)
|
||||
return self._client
|
||||
|
||||
def _get(self, path: str, params: dict[str, Any] | None = None) -> Any:
|
||||
"""Make GET request."""
|
||||
try:
|
||||
resp = self.client.get(path, params=params)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
|
||||
except httpx.RequestError as e:
|
||||
raise TsrClientError(f"Request failed: {e}") from e
|
||||
|
||||
def _post(self, path: str, json: dict[str, Any] | None = None) -> Any:
|
||||
"""Make POST request."""
|
||||
try:
|
||||
resp = self.client.post(path, json=json)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
|
||||
except httpx.RequestError as e:
|
||||
raise TsrClientError(f"Request failed: {e}") from e
|
||||
|
||||
def _delete(self, path: str) -> Any:
|
||||
"""Make DELETE request."""
|
||||
try:
|
||||
resp = self.client.delete(path)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
|
||||
except httpx.RequestError as e:
|
||||
raise TsrClientError(f"Request failed: {e}") from e
|
||||
|
||||
# =========================================================================
|
||||
# Server Status
|
||||
# =========================================================================
|
||||
|
||||
def status(self) -> dict[str, Any]:
|
||||
"""Get server status."""
|
||||
return dict(self._get("/status"))
|
||||
|
||||
# =========================================================================
|
||||
# Gallery / Images
|
||||
# =========================================================================
|
||||
|
||||
def list_images(self, limit: int = 50, offset: int = 0) -> dict[str, Any]:
|
||||
"""List images in gallery."""
|
||||
return dict(self._get("/api/images", params={"limit": limit, "offset": offset}))
|
||||
|
||||
def get_image_meta(self, image_id: str) -> dict[str, Any]:
|
||||
"""Get metadata for an image."""
|
||||
return dict(self._get(f"/api/images/{image_id}/meta"))
|
||||
|
||||
def delete_image(self, image_id: str) -> dict[str, Any]:
|
||||
"""Delete an image."""
|
||||
return dict(self._delete(f"/api/images/{image_id}"))
|
||||
|
||||
def edit_image(self, image_id: str, updates: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Update image metadata."""
|
||||
return dict(self._post(f"/api/images/{image_id}/edit", json=updates))
|
||||
|
||||
def download_image(self, image_id: str) -> bytes:
|
||||
"""Download image file bytes."""
|
||||
try:
|
||||
resp = self.client.get(f"/api/images/{image_id}")
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e
|
||||
except httpx.RequestError as e:
|
||||
raise TsrClientError(f"Request failed: {e}") from e
|
||||
|
||||
# =========================================================================
|
||||
# Models
|
||||
# =========================================================================
|
||||
|
||||
def list_models(self) -> dict[str, Any]:
|
||||
"""List available models."""
|
||||
return dict(self._get("/api/models"))
|
||||
|
||||
def get_active_model(self) -> dict[str, Any]:
|
||||
"""Get currently active model."""
|
||||
return dict(self._get("/api/models/active"))
|
||||
|
||||
def switch_model(self, model_path: str) -> dict[str, Any]:
|
||||
"""Switch to a different model."""
|
||||
return dict(self._post("/api/models/switch", json={"model": model_path}))
|
||||
|
||||
def list_loras(self) -> dict[str, Any]:
|
||||
"""List available LoRAs."""
|
||||
return dict(self._get("/api/models/loras"))
|
||||
|
||||
def scan_models(self) -> dict[str, Any]:
|
||||
"""Scan model directories."""
|
||||
return dict(self._get("/api/models/scan"))
|
||||
|
||||
# =========================================================================
|
||||
# Generation
|
||||
# =========================================================================
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
width: int = 512,
|
||||
height: int = 512,
|
||||
steps: int = 20,
|
||||
cfg_scale: float = 7.0,
|
||||
seed: int = -1,
|
||||
sampler_name: str = "",
|
||||
scheduler: str = "",
|
||||
batch_size: int = 1,
|
||||
save_to_gallery: bool = True,
|
||||
return_base64: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Generate images."""
|
||||
body = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"steps": steps,
|
||||
"cfg_scale": cfg_scale,
|
||||
"seed": seed,
|
||||
"sampler_name": sampler_name,
|
||||
"scheduler": scheduler,
|
||||
"batch_size": batch_size,
|
||||
"save_to_gallery": save_to_gallery,
|
||||
"return_base64": return_base64,
|
||||
}
|
||||
return dict(self._post("/api/generate", json=body))
|
||||
|
||||
def list_samplers(self) -> dict[str, Any]:
|
||||
"""List available samplers."""
|
||||
return dict(self._get("/api/samplers"))
|
||||
|
||||
def list_schedulers(self) -> dict[str, Any]:
|
||||
"""List available schedulers."""
|
||||
return dict(self._get("/api/schedulers"))
|
||||
|
||||
# =========================================================================
|
||||
# Download
|
||||
# =========================================================================
|
||||
|
||||
def start_download(
|
||||
self,
|
||||
version_id: int | None = None,
|
||||
model_id: int | None = None,
|
||||
hash_val: str | None = None,
|
||||
output_dir: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Start a model download from CivitAI."""
|
||||
body: dict[str, Any] = {}
|
||||
if version_id:
|
||||
body["version_id"] = version_id
|
||||
if model_id:
|
||||
body["model_id"] = model_id
|
||||
if hash_val:
|
||||
body["hash"] = hash_val
|
||||
if output_dir:
|
||||
body["output_dir"] = output_dir
|
||||
return dict(self._post("/api/download", json=body))
|
||||
|
||||
def get_download_status(self, download_id: str) -> dict[str, Any]:
|
||||
"""Get download status."""
|
||||
return dict(self._get(f"/api/download/status/{download_id}"))
|
||||
|
||||
def list_downloads(self) -> dict[str, Any]:
|
||||
"""List active downloads."""
|
||||
return dict(self._get("/api/download/active"))
|
||||
|
||||
# =========================================================================
|
||||
# Database
|
||||
# =========================================================================
|
||||
|
||||
def db_list_files(self) -> list[dict[str, Any]]:
|
||||
"""List local files in database."""
|
||||
return list(self._get("/api/db/files"))
|
||||
|
||||
def db_search_models(
|
||||
self,
|
||||
query: str | None = None,
|
||||
model_type: str | None = None,
|
||||
base_model: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search cached models."""
|
||||
params: dict[str, Any] = {"limit": limit}
|
||||
if query:
|
||||
params["query"] = query
|
||||
if model_type:
|
||||
params["type"] = model_type
|
||||
if base_model:
|
||||
params["base"] = base_model
|
||||
return list(self._get("/api/db/models", params=params))
|
||||
|
||||
def db_get_model(self, civitai_id: int) -> dict[str, Any]:
|
||||
"""Get cached model by CivitAI ID."""
|
||||
return dict(self._get(f"/api/db/models/{civitai_id}"))
|
||||
|
||||
def db_get_triggers(self, file_path: str | None = None, version_id: int | None = None) -> list[str]:
|
||||
"""Get trigger words."""
|
||||
if version_id:
|
||||
return list(self._get(f"/api/db/triggers/{version_id}"))
|
||||
if file_path:
|
||||
return list(self._get("/api/db/triggers", params={"file_path": file_path}))
|
||||
return []
|
||||
|
||||
def db_stats(self) -> dict[str, Any]:
|
||||
"""Get database statistics."""
|
||||
return dict(self._get("/api/db/stats"))
|
||||
|
||||
def db_scan(self, directory: str) -> dict[str, Any]:
|
||||
"""Scan directory for safetensor files."""
|
||||
return dict(self._post("/api/db/scan", json={"directory": directory}))
|
||||
|
||||
def db_link(self) -> dict[str, Any]:
|
||||
"""Link unlinked files to CivitAI."""
|
||||
return dict(self._post("/api/db/link"))
|
||||
|
||||
def db_cache(self, model_id: int) -> dict[str, Any]:
|
||||
"""Cache CivitAI model data."""
|
||||
return dict(self._post("/api/db/cache", json={"model_id": model_id}))
|
||||
|
||||
# =========================================================================
|
||||
# Streaming Downloads
|
||||
# =========================================================================
|
||||
|
||||
def stream_image(self, image_id: str) -> Iterator[bytes]:
|
||||
"""Stream image download in chunks."""
|
||||
try:
|
||||
with self.client.stream("GET", f"/api/images/{image_id}") as resp:
|
||||
resp.raise_for_status()
|
||||
yield from resp.iter_bytes(chunk_size=1024 * 64)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise TsrClientError(f"HTTP {e.response.status_code}") from e
|
||||
except httpx.RequestError as e:
|
||||
raise TsrClientError(f"Request failed: {e}") from e
|
||||
|
||||
def save_image_to(self, image_id: str, dest: Path) -> Path:
|
||||
"""Download and save image to file."""
|
||||
content = self.download_image(image_id)
|
||||
dest.write_bytes(content)
|
||||
return dest
|
||||
@@ -1,43 +0,0 @@
|
||||
"""sd-server Python client — modular, httpx-based."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from tensors.generate._http import HttpTransport
|
||||
from tensors.generate.generation import GenerationAPI
|
||||
from tensors.generate.info import InfoAPI
|
||||
from tensors.generate.params import Img2ImgParams, Txt2ImgParams
|
||||
from tensors.generate.util import save_images
|
||||
|
||||
__all__ = [
|
||||
"Img2ImgParams",
|
||||
"SDClient",
|
||||
"Txt2ImgParams",
|
||||
"save_images",
|
||||
]
|
||||
|
||||
|
||||
class SDClient:
|
||||
"""Composite client for sd-server.
|
||||
|
||||
Usage::
|
||||
|
||||
with SDClient() as c:
|
||||
c.info.models()
|
||||
images = c.generate.txt2img(Txt2ImgParams(prompt="a cat"))
|
||||
"""
|
||||
|
||||
def __init__(self, host: str = "127.0.0.1", port: int = 1234) -> None:
|
||||
self._http = HttpTransport(f"http://{host}:{port}")
|
||||
self.info = InfoAPI(self._http)
|
||||
self.generate = GenerationAPI(self._http)
|
||||
|
||||
def close(self) -> None:
|
||||
self._http.close()
|
||||
|
||||
def __enter__(self) -> SDClient:
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc: Any) -> None:
|
||||
self.close()
|
||||
@@ -1,46 +0,0 @@
|
||||
"""HTTP transport layer wrapping httpx."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HttpTransport:
|
||||
def __init__(self, base_url: str, timeout: float = 300.0) -> None:
|
||||
self._client = httpx.Client(base_url=base_url, timeout=timeout)
|
||||
logger.debug("transport ready: %s", base_url)
|
||||
|
||||
def get(self, path: str) -> Any:
|
||||
logger.debug("GET %s", path)
|
||||
try:
|
||||
r = self._client.get(path)
|
||||
r.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error("GET %s → %d: %s", path, e.response.status_code, e.response.text[:200])
|
||||
raise
|
||||
except httpx.RequestError as e:
|
||||
logger.error("GET %s connection failed: %s", path, e)
|
||||
raise
|
||||
return r.json()
|
||||
|
||||
def post(self, path: str, json: dict[str, Any]) -> Any:
|
||||
logger.debug("POST %s", path)
|
||||
try:
|
||||
r = self._client.post(path, json=json)
|
||||
r.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error("POST %s → %d: %s", path, e.response.status_code, e.response.text[:200])
|
||||
raise
|
||||
except httpx.RequestError as e:
|
||||
logger.error("POST %s connection failed: %s", path, e)
|
||||
raise
|
||||
return r.json()
|
||||
|
||||
def close(self) -> None:
|
||||
self._client.close()
|
||||
logger.debug("transport closed")
|
||||
@@ -1,34 +0,0 @@
|
||||
"""Image generation endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensors.generate._http import HttpTransport
|
||||
from tensors.generate.params import Img2ImgParams, Txt2ImgParams
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GenerationAPI:
|
||||
def __init__(self, http: HttpTransport) -> None:
|
||||
self._http = http
|
||||
|
||||
def txt2img(self, params: Txt2ImgParams) -> list[bytes]:
|
||||
"""Generate images from text prompt."""
|
||||
logger.info("txt2img: '%s' %dx%d steps=%d", params.prompt[:60], params.width, params.height, params.steps)
|
||||
data = self._http.post("/sdapi/v1/txt2img", params.to_body())
|
||||
images = [base64.b64decode(img) for img in data["images"]]
|
||||
logger.info("txt2img: got %d image(s)", len(images))
|
||||
return images
|
||||
|
||||
def img2img(self, params: Img2ImgParams) -> list[bytes]:
|
||||
"""Generate images from image + text prompt."""
|
||||
logger.info("img2img: '%s' strength=%.2f steps=%d", params.prompt[:60], params.denoising_strength, params.steps)
|
||||
data = self._http.post("/sdapi/v1/img2img", params.to_body())
|
||||
images = [base64.b64decode(img) for img in data["images"]]
|
||||
logger.info("img2img: got %d image(s)", len(images))
|
||||
return images
|
||||
@@ -1,42 +0,0 @@
|
||||
"""Model and server info endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensors.generate._http import HttpTransport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InfoAPI:
|
||||
def __init__(self, http: HttpTransport) -> None:
|
||||
self._http = http
|
||||
|
||||
def models(self) -> list[dict[str, Any]]:
|
||||
"""List loaded models (OpenAI /v1/models)."""
|
||||
return self._http.get("/v1/models")["data"] # type: ignore[no-any-return]
|
||||
|
||||
def sd_models(self) -> list[dict[str, Any]]:
|
||||
"""Detailed model info (sdapi)."""
|
||||
return self._http.get("/sdapi/v1/sd-models") # type: ignore[no-any-return]
|
||||
|
||||
def options(self) -> dict[str, Any]:
|
||||
"""Current server options."""
|
||||
return self._http.get("/sdapi/v1/options") # type: ignore[no-any-return]
|
||||
|
||||
def loras(self) -> list[dict[str, Any]]:
|
||||
"""Available LoRAs from --lora-model-dir."""
|
||||
result: list[dict[str, Any]] = self._http.get("/sdapi/v1/loras")
|
||||
logger.info("found %d lora(s)", len(result))
|
||||
return result
|
||||
|
||||
def samplers(self) -> list[str]:
|
||||
"""Available sampler names."""
|
||||
return [s["name"] for s in self._http.get("/sdapi/v1/samplers")]
|
||||
|
||||
def schedulers(self) -> list[str]:
|
||||
"""Available scheduler names."""
|
||||
return [s["name"] for s in self._http.get("/sdapi/v1/schedulers")]
|
||||
@@ -1,100 +0,0 @@
|
||||
"""Generation parameter dataclasses."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from tensors.generate.util import to_b64
|
||||
|
||||
|
||||
@dataclass
|
||||
class Txt2ImgParams:
|
||||
prompt: str
|
||||
negative_prompt: str = ""
|
||||
width: int = 512
|
||||
height: int = 512
|
||||
steps: int = 20
|
||||
cfg_scale: float = 7.0
|
||||
seed: int = -1
|
||||
batch_size: int = 1
|
||||
sampler_name: str = ""
|
||||
scheduler: str = ""
|
||||
clip_skip: int = -1
|
||||
lora: list[dict[str, Any]] | None = None
|
||||
|
||||
def to_body(self) -> dict[str, Any]:
|
||||
body = {
|
||||
"prompt": self.prompt,
|
||||
"negative_prompt": self.negative_prompt,
|
||||
"width": self.width,
|
||||
"height": self.height,
|
||||
"steps": self.steps,
|
||||
"cfg_scale": self.cfg_scale,
|
||||
"seed": self.seed,
|
||||
"batch_size": self.batch_size,
|
||||
}
|
||||
if self.sampler_name:
|
||||
body["sampler_name"] = self.sampler_name
|
||||
if self.scheduler:
|
||||
body["scheduler"] = self.scheduler
|
||||
if self.clip_skip > 0:
|
||||
body["clip_skip"] = self.clip_skip
|
||||
if self.lora:
|
||||
body["lora"] = self.lora
|
||||
return body
|
||||
|
||||
|
||||
@dataclass
|
||||
class Img2ImgParams:
|
||||
prompt: str
|
||||
init_image: str | bytes | Path
|
||||
negative_prompt: str = ""
|
||||
width: int = -1
|
||||
height: int = -1
|
||||
steps: int = 20
|
||||
cfg_scale: float = 7.0
|
||||
denoising_strength: float = 0.75
|
||||
seed: int = -1
|
||||
batch_size: int = 1
|
||||
sampler_name: str = ""
|
||||
scheduler: str = ""
|
||||
clip_skip: int = -1
|
||||
mask: str | bytes | Path | None = None
|
||||
inpainting_mask_invert: bool = False
|
||||
lora: list[dict[str, Any]] | None = None
|
||||
extra_images: list[str | bytes | Path] = field(default_factory=list)
|
||||
|
||||
def to_body(self) -> dict[str, Any]:
|
||||
body: dict[str, Any] = {
|
||||
"prompt": self.prompt,
|
||||
"negative_prompt": self.negative_prompt,
|
||||
"steps": self.steps,
|
||||
"cfg_scale": self.cfg_scale,
|
||||
"denoising_strength": self.denoising_strength,
|
||||
"seed": self.seed,
|
||||
"batch_size": self.batch_size,
|
||||
"init_images": [to_b64(self.init_image)],
|
||||
}
|
||||
if self.width > 0:
|
||||
body["width"] = self.width
|
||||
if self.height > 0:
|
||||
body["height"] = self.height
|
||||
if self.mask is not None:
|
||||
body["mask"] = to_b64(self.mask)
|
||||
if self.inpainting_mask_invert:
|
||||
body["inpainting_mask_invert"] = 1
|
||||
if self.sampler_name:
|
||||
body["sampler_name"] = self.sampler_name
|
||||
if self.scheduler:
|
||||
body["scheduler"] = self.scheduler
|
||||
if self.clip_skip > 0:
|
||||
body["clip_skip"] = self.clip_skip
|
||||
if self.lora:
|
||||
body["lora"] = self.lora
|
||||
if self.extra_images:
|
||||
body["extra_images"] = [to_b64(img) for img in self.extra_images]
|
||||
return body
|
||||
@@ -1,37 +0,0 @@
|
||||
"""Utility functions for image encoding and file I/O."""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def to_b64(image: str | bytes | Path) -> str:
|
||||
"""Convert a file path, raw bytes, or base64 string to base64."""
|
||||
if isinstance(image, (str, Path)):
|
||||
path = Path(image)
|
||||
if path.exists():
|
||||
logger.debug("encoding file: %s", path)
|
||||
return base64.b64encode(path.read_bytes()).decode()
|
||||
return str(image)
|
||||
if isinstance(image, bytes):
|
||||
return base64.b64encode(image).decode()
|
||||
raise TypeError(f"unsupported image type: {type(image)}")
|
||||
|
||||
|
||||
def save_images(
|
||||
images: list[bytes],
|
||||
output_dir: str = ".",
|
||||
prefix: str = "output",
|
||||
) -> list[Path]:
|
||||
"""Write raw PNG bytes to numbered files. Returns saved paths."""
|
||||
out = Path(output_dir)
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
paths = []
|
||||
for i, data in enumerate(images):
|
||||
path = out / f"{prefix}_{i:04d}.png"
|
||||
path.write_bytes(data)
|
||||
logger.info("saved: %s", path)
|
||||
paths.append(path)
|
||||
return paths
|
||||
@@ -1,216 +0,0 @@
|
||||
"""FastAPI route handlers for image generation with gallery integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from tensors.server.gallery import Gallery
|
||||
from tensors.server.sd_client import get_sd_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Request/Response Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class LoraConfig(PydanticBaseModel):
|
||||
"""LoRA configuration for sd-server."""
|
||||
|
||||
path: str
|
||||
multiplier: float = Field(default=1.0, ge=0.0, le=2.0)
|
||||
|
||||
|
||||
class GenerateRequest(PydanticBaseModel):
|
||||
"""Request body for image generation."""
|
||||
|
||||
prompt: str
|
||||
negative_prompt: str = ""
|
||||
width: int = Field(default=512, ge=64, le=2048)
|
||||
height: int = Field(default=512, ge=64, le=2048)
|
||||
steps: int = Field(default=20, ge=1, le=150)
|
||||
cfg_scale: float = Field(default=7.0, ge=0, le=30)
|
||||
seed: int = -1
|
||||
sampler_name: str = ""
|
||||
scheduler: str = ""
|
||||
batch_size: int = Field(default=1, ge=1, le=16)
|
||||
save_to_gallery: bool = True
|
||||
return_base64: bool = False
|
||||
lora: LoraConfig | None = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _build_sd_request(req: GenerateRequest) -> dict[str, Any]:
|
||||
"""Build request body for sd-server."""
|
||||
prompt = req.prompt
|
||||
|
||||
# sd-server expects LoRA in prompt as <lora:name:weight> syntax
|
||||
if req.lora:
|
||||
lora_name = Path(req.lora.path).stem
|
||||
lora_tag = f"<lora:{lora_name}:{req.lora.multiplier}>"
|
||||
prompt = f"{prompt} {lora_tag}"
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": req.negative_prompt,
|
||||
"width": req.width,
|
||||
"height": req.height,
|
||||
"steps": req.steps,
|
||||
"cfg_scale": req.cfg_scale,
|
||||
"seed": req.seed,
|
||||
"batch_size": req.batch_size,
|
||||
}
|
||||
if req.sampler_name:
|
||||
body["sampler_name"] = req.sampler_name
|
||||
if req.scheduler:
|
||||
body["scheduler"] = req.scheduler
|
||||
return body
|
||||
|
||||
|
||||
def _parse_info(info: Any) -> dict[str, Any]:
|
||||
"""Parse info from sd-server response."""
|
||||
if isinstance(info, str):
|
||||
import json # noqa: PLC0415
|
||||
|
||||
try:
|
||||
return dict(json.loads(info))
|
||||
except json.JSONDecodeError:
|
||||
return {"raw": info}
|
||||
return info if isinstance(info, dict) else {}
|
||||
|
||||
|
||||
def _process_image(
|
||||
img_b64: str,
|
||||
index: int,
|
||||
seed: int,
|
||||
req: GenerateRequest,
|
||||
gallery: Gallery,
|
||||
model: str | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Process a single generated image."""
|
||||
image_bytes = base64.b64decode(img_b64)
|
||||
image_info: dict[str, Any] = {"index": index, "seed": seed}
|
||||
|
||||
if req.save_to_gallery:
|
||||
metadata = {
|
||||
"prompt": req.prompt,
|
||||
"negative_prompt": req.negative_prompt,
|
||||
"width": req.width,
|
||||
"height": req.height,
|
||||
"steps": req.steps,
|
||||
"cfg_scale": req.cfg_scale,
|
||||
"sampler": req.sampler_name,
|
||||
"scheduler": req.scheduler,
|
||||
"model": model,
|
||||
"lora": req.lora.path if req.lora else None,
|
||||
"lora_weight": req.lora.multiplier if req.lora else None,
|
||||
"generated_at": time.time(),
|
||||
}
|
||||
gallery_img = gallery.save_image(image_bytes, metadata=metadata, seed=seed)
|
||||
image_info["id"] = gallery_img.id
|
||||
image_info["path"] = str(gallery_img.path)
|
||||
|
||||
if req.return_base64:
|
||||
image_info["base64"] = img_b64
|
||||
|
||||
return image_info
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Router Factory
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_generate_router() -> APIRouter: # noqa: PLR0915
|
||||
"""Build a router with /api/generate endpoint."""
|
||||
router = APIRouter(prefix="/api", tags=["generate"])
|
||||
gallery = Gallery()
|
||||
|
||||
@router.post("/generate")
|
||||
async def generate(request: Request, req: GenerateRequest) -> dict[str, Any]:
|
||||
"""Generate images with gallery integration."""
|
||||
sd_server_url = request.app.state.sd_server_url
|
||||
body = _build_sd_request(req)
|
||||
url = f"{sd_server_url}/sdapi/v1/txt2img"
|
||||
|
||||
try:
|
||||
headers = get_sd_headers(request)
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
response = await client.post(url, json=body, headers=headers)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
except httpx.ConnectError as e:
|
||||
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
|
||||
except httpx.HTTPError as e:
|
||||
logger.exception("Generation failed")
|
||||
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
||||
|
||||
images_data = result.get("images", [])
|
||||
info = _parse_info(result.get("info", {}))
|
||||
all_seeds = info.get("all_seeds", [req.seed] * len(images_data))
|
||||
|
||||
# Get model info from sd-server response if available
|
||||
model_name = info.get("sd_model_name") or info.get("model")
|
||||
|
||||
output_images = [
|
||||
_process_image(img_b64, i, all_seeds[i] if i < len(all_seeds) else req.seed + i, req, gallery, model_name)
|
||||
for i, img_b64 in enumerate(images_data)
|
||||
]
|
||||
|
||||
return {
|
||||
"images": output_images,
|
||||
"parameters": result.get("parameters", body),
|
||||
"info": info,
|
||||
"saved_to_gallery": req.save_to_gallery,
|
||||
"total": len(output_images),
|
||||
}
|
||||
|
||||
@router.get("/samplers")
|
||||
async def list_samplers(request: Request) -> dict[str, Any]:
|
||||
"""List available samplers from sd-server."""
|
||||
sd_server_url = request.app.state.sd_server_url
|
||||
url = f"{sd_server_url}/sdapi/v1/samplers"
|
||||
|
||||
try:
|
||||
headers = get_sd_headers(request)
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
return {"samplers": response.json()}
|
||||
except httpx.ConnectError as e:
|
||||
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
||||
|
||||
@router.get("/schedulers")
|
||||
async def list_schedulers(request: Request) -> dict[str, Any]:
|
||||
"""List available schedulers from sd-server."""
|
||||
sd_server_url = request.app.state.sd_server_url
|
||||
url = f"{sd_server_url}/sdapi/v1/schedulers"
|
||||
|
||||
try:
|
||||
headers = get_sd_headers(request)
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
return {"schedulers": response.json()}
|
||||
except httpx.ConnectError as e:
|
||||
raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e
|
||||
|
||||
return router
|
||||
@@ -1,311 +0,0 @@
|
||||
"""FastAPI route handlers for model management endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from tensors.config import MODELS_DIR
|
||||
from tensors.db import Database
|
||||
from tensors.server.sd_client import get_sd_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_HTTP_OK = 200
|
||||
_SD_ENV_FILE = Path("/etc/default/sd-server")
|
||||
|
||||
|
||||
class SwitchModelRequest(BaseModel):
|
||||
"""Request body for switching models."""
|
||||
|
||||
model: str # Model filename or full path
|
||||
|
||||
|
||||
async def _run_command(*args: str) -> tuple[int, str, str]:
|
||||
"""Run a shell command and return (returncode, stdout, stderr)."""
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*args,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await proc.communicate()
|
||||
return proc.returncode or 0, stdout.decode(), stderr.decode()
|
||||
|
||||
|
||||
def _read_env_file() -> dict[str, str]:
|
||||
"""Read the sd-server environment file."""
|
||||
env: dict[str, str] = {}
|
||||
if _SD_ENV_FILE.exists():
|
||||
for raw_line in _SD_ENV_FILE.read_text().splitlines():
|
||||
line = raw_line.strip()
|
||||
if line and not line.startswith("#") and "=" in line:
|
||||
key, _, value = line.partition("=")
|
||||
env[key.strip()] = value.strip()
|
||||
return env
|
||||
|
||||
|
||||
def _write_env_file(env: dict[str, str]) -> str:
|
||||
"""Generate env file content."""
|
||||
lines = ["# sd-server configuration"]
|
||||
for key, value in env.items():
|
||||
lines.append(f"{key}={value}")
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
# Keywords for detecting base model category
|
||||
_SD15_KEYWORDS = ("sd15", "sd1.5", "sd-1.5", "sd_1.5", "1.5", "sd-1-", "v1-5")
|
||||
_LARGE_KEYWORDS = ("sdxl", "xl", "pony", "illustrious", "ilust", "noob", "animagine")
|
||||
|
||||
|
||||
def _detect_model_category(name: str) -> str:
|
||||
"""Detect model category from filename. Returns 'sd15' or 'large'."""
|
||||
name_lower = name.lower()
|
||||
|
||||
# Check SD 1.5 keywords first
|
||||
for kw in _SD15_KEYWORDS:
|
||||
if kw in name_lower:
|
||||
return "sd15"
|
||||
|
||||
# Check large model keywords
|
||||
for kw in _LARGE_KEYWORDS:
|
||||
if kw in name_lower:
|
||||
return "large"
|
||||
|
||||
# Default to large (SDXL/Pony/Illustrious are more common now)
|
||||
return "large"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def scan_models(directory: Path, extensions: tuple[str, ...] = (".safetensors", ".gguf")) -> list[dict[str, Any]]:
|
||||
"""Scan directory for model files."""
|
||||
models: list[dict[str, Any]] = []
|
||||
|
||||
if not directory.exists():
|
||||
return models
|
||||
|
||||
for ext in extensions:
|
||||
for path in directory.rglob(f"*{ext}"):
|
||||
stat = path.stat()
|
||||
name = path.stem
|
||||
models.append(
|
||||
{
|
||||
"name": name,
|
||||
"path": str(path),
|
||||
"filename": path.name,
|
||||
"size_mb": round(stat.st_size / (1024 * 1024), 2),
|
||||
"modified": stat.st_mtime,
|
||||
"category": _detect_model_category(name),
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by name
|
||||
models.sort(key=lambda x: x["name"].lower())
|
||||
return models
|
||||
|
||||
|
||||
def _enrich_with_metadata(models: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Enrich model data with CivitAI metadata from database."""
|
||||
try:
|
||||
with Database() as db:
|
||||
db.init_schema()
|
||||
|
||||
for model in models:
|
||||
file_path = model.get("path", "")
|
||||
file_info = db.get_local_file_by_path(file_path)
|
||||
|
||||
if file_info and file_info.get("civitai_model_id"):
|
||||
# Add human-readable name
|
||||
model["display_name"] = file_info.get("model_name") or model["name"]
|
||||
model["base_model"] = file_info.get("base_model")
|
||||
model["model_type"] = file_info.get("model_type")
|
||||
model["civitai_model_id"] = file_info.get("civitai_model_id")
|
||||
model["civitai_version_id"] = file_info.get("civitai_version_id")
|
||||
|
||||
# Get thumbnail from version images
|
||||
version_id = file_info.get("civitai_version_id")
|
||||
if version_id:
|
||||
cur = db.conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT url FROM version_images
|
||||
WHERE version_id = (SELECT id FROM model_versions WHERE civitai_id = ?)
|
||||
ORDER BY id LIMIT 1
|
||||
""",
|
||||
(version_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row:
|
||||
model["thumbnail_url"] = row[0]
|
||||
|
||||
# Get trigger words
|
||||
triggers = db.get_triggers_by_version(version_id) if version_id else []
|
||||
model["triggers"] = triggers[:5] # Limit to first 5
|
||||
else:
|
||||
model["display_name"] = model["name"]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to enrich models with metadata: %s", e)
|
||||
# Fallback: just use filename as display name
|
||||
for model in models:
|
||||
if "display_name" not in model:
|
||||
model["display_name"] = model["name"]
|
||||
|
||||
return models
|
||||
|
||||
|
||||
def scan_loras(directory: Path | None = None) -> list[dict[str, Any]]:
|
||||
"""Scan for LoRA files."""
|
||||
lora_dir = directory or MODELS_DIR / "loras"
|
||||
return scan_models(lora_dir, extensions=(".safetensors", ".gguf"))
|
||||
|
||||
|
||||
def scan_checkpoints(directory: Path | None = None) -> list[dict[str, Any]]:
|
||||
"""Scan for checkpoint files."""
|
||||
checkpoint_dir = directory or MODELS_DIR / "checkpoints"
|
||||
return scan_models(checkpoint_dir, extensions=(".safetensors", ".gguf"))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Router Factory
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_models_router() -> APIRouter: # noqa: PLR0915
|
||||
"""Build a router with /api/models/* endpoints."""
|
||||
router = APIRouter(prefix="/api/models", tags=["models"])
|
||||
|
||||
@router.get("")
|
||||
def list_models() -> dict[str, Any]:
|
||||
"""List available checkpoint models with metadata."""
|
||||
checkpoints = scan_checkpoints()
|
||||
checkpoints = _enrich_with_metadata(checkpoints)
|
||||
return {
|
||||
"models": checkpoints,
|
||||
"total": len(checkpoints),
|
||||
}
|
||||
|
||||
@router.get("/active")
|
||||
async def get_active_model(request: Request) -> dict[str, Any]:
|
||||
"""Get information about the currently loaded model from sd-server."""
|
||||
import httpx # noqa: PLC0415
|
||||
|
||||
sd_server_url = request.app.state.sd_server_url
|
||||
|
||||
# Try to get current model from sd-server's options endpoint
|
||||
try:
|
||||
headers = get_sd_headers(request)
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.get(f"{sd_server_url}/sdapi/v1/options", headers=headers)
|
||||
if response.status_code == _HTTP_OK:
|
||||
options = response.json()
|
||||
model_name = options.get("sd_model_checkpoint")
|
||||
return {
|
||||
"loaded": True,
|
||||
"model": model_name,
|
||||
"sd_server_url": sd_server_url,
|
||||
}
|
||||
except httpx.HTTPError:
|
||||
pass
|
||||
|
||||
return {
|
||||
"loaded": False,
|
||||
"model": None,
|
||||
"sd_server_url": sd_server_url,
|
||||
"error": "Cannot connect to sd-server",
|
||||
}
|
||||
|
||||
@router.get("/loras")
|
||||
def list_loras() -> dict[str, Any]:
|
||||
"""List available LoRA files with metadata."""
|
||||
loras = scan_loras()
|
||||
loras = _enrich_with_metadata(loras)
|
||||
return {
|
||||
"loras": loras,
|
||||
"total": len(loras),
|
||||
}
|
||||
|
||||
@router.get("/scan")
|
||||
def scan_all_models() -> dict[str, Any]:
|
||||
"""Scan all model directories."""
|
||||
checkpoints = scan_checkpoints()
|
||||
loras = scan_loras()
|
||||
|
||||
return {
|
||||
"checkpoints": checkpoints,
|
||||
"loras": loras,
|
||||
"total_checkpoints": len(checkpoints),
|
||||
"total_loras": len(loras),
|
||||
}
|
||||
|
||||
@router.post("/switch")
|
||||
async def switch_model(req: SwitchModelRequest) -> dict[str, Any]:
|
||||
"""Switch sd-server to a different model by updating env and restarting."""
|
||||
# Find the model file
|
||||
checkpoints = scan_checkpoints()
|
||||
model_path: str | None = None
|
||||
|
||||
for cp in checkpoints:
|
||||
if cp["filename"] == req.model or cp["path"] == req.model or cp["name"] == req.model:
|
||||
model_path = cp["path"]
|
||||
break
|
||||
|
||||
if not model_path:
|
||||
raise HTTPException(status_code=404, detail=f"Model not found: {req.model}")
|
||||
|
||||
# Read current env, update SD_MODEL
|
||||
env = _read_env_file()
|
||||
old_model = env.get("SD_MODEL", "")
|
||||
env["SD_MODEL"] = model_path
|
||||
|
||||
# Write new env file via sudo tee
|
||||
new_content = _write_env_file(env)
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"sudo", "tee", str(_SD_ENV_FILE),
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
_, tee_stderr = await proc.communicate(new_content.encode())
|
||||
if proc.returncode != 0:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to write env file: {tee_stderr.decode()}")
|
||||
|
||||
# Restart sd-server
|
||||
returncode, _stdout, restart_stderr = await _run_command("sudo", "systemctl", "restart", "sd-server")
|
||||
if returncode != 0:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to restart sd-server: {restart_stderr}")
|
||||
|
||||
logger.info(f"Switched model from {old_model} to {model_path}")
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"old_model": old_model,
|
||||
"new_model": model_path,
|
||||
"message": "Model switched, sd-server restarting",
|
||||
}
|
||||
|
||||
@router.get("/status")
|
||||
async def sd_server_status() -> dict[str, Any]:
|
||||
"""Get sd-server systemd service status."""
|
||||
_returncode, stdout, _stderr = await _run_command("systemctl", "is-active", "sd-server")
|
||||
is_active = stdout.strip() == "active"
|
||||
|
||||
env = _read_env_file()
|
||||
|
||||
return {
|
||||
"service": "sd-server",
|
||||
"active": is_active,
|
||||
"status": stdout.strip(),
|
||||
"current_model": env.get("SD_MODEL"),
|
||||
"host": env.get("SD_HOST"),
|
||||
"port": env.get("SD_PORT"),
|
||||
}
|
||||
|
||||
return router
|
||||
@@ -1,80 +0,0 @@
|
||||
"""FastAPI route handlers for the sd-server wrapper API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Request, Response
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from tensors.server.sd_client import get_sd_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_router() -> APIRouter:
|
||||
"""Build a router with /status and catch-all proxy."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/status")
|
||||
async def status(request: Request) -> dict[str, Any]:
|
||||
"""Check if the external sd-server is reachable."""
|
||||
sd_server_url = request.app.state.sd_server_url
|
||||
try:
|
||||
headers = get_sd_headers(request)
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
r = await client.get(sd_server_url, headers=headers)
|
||||
return {
|
||||
"status": "ok",
|
||||
"sd_server_url": sd_server_url,
|
||||
"sd_server_status": r.status_code,
|
||||
}
|
||||
except httpx.HTTPError as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"sd_server_url": sd_server_url,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"])
|
||||
async def proxy(request: Request, path: str) -> Response:
|
||||
"""Proxy all requests to the external sd-server."""
|
||||
sd_server_url = request.app.state.sd_server_url
|
||||
url = f"{sd_server_url}/{path}"
|
||||
if request.url.query:
|
||||
url = f"{url}?{request.url.query}"
|
||||
|
||||
body = await request.body()
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
# Add API key if configured
|
||||
headers.update(get_sd_headers(request))
|
||||
client = request.app.state.client
|
||||
|
||||
try:
|
||||
upstream = await client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=body,
|
||||
timeout=300,
|
||||
)
|
||||
return StreamingResponse(
|
||||
content=upstream.iter_bytes(),
|
||||
status_code=upstream.status_code,
|
||||
headers=dict(upstream.headers),
|
||||
)
|
||||
except httpx.ConnectError:
|
||||
return JSONResponse(
|
||||
{"error": f"Cannot connect to sd-server at {sd_server_url}"},
|
||||
status_code=503,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
return JSONResponse(
|
||||
{"error": f"Timeout connecting to sd-server at {sd_server_url}"},
|
||||
status_code=504,
|
||||
)
|
||||
|
||||
return router
|
||||
@@ -1,39 +0,0 @@
|
||||
"""HTTP client utilities for sd-server communication."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import Request
|
||||
|
||||
|
||||
def get_sd_headers(request: Request) -> dict[str, str]:
|
||||
"""Get headers for sd-server requests, including API key if configured."""
|
||||
headers: dict[str, str] = {}
|
||||
api_key = getattr(request.app.state, "sd_server_api_key", None)
|
||||
if api_key:
|
||||
headers["X-API-Key"] = api_key
|
||||
return headers
|
||||
|
||||
|
||||
async def sd_get(request: Request, path: str, *, timeout: float = 30) -> httpx.Response:
|
||||
"""Make a GET request to sd-server."""
|
||||
url = f"{request.app.state.sd_server_url}/{path.lstrip('/')}"
|
||||
headers = get_sd_headers(request)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
|
||||
async def sd_post(request: Request, path: str, *, json: dict[str, Any] | None = None, timeout: float = 300) -> httpx.Response:
|
||||
"""Make a POST request to sd-server."""
|
||||
url = f"{request.app.state.sd_server_url}/{path.lstrip('/')}"
|
||||
headers = get_sd_headers(request)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(url, json=json, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
Reference in New Issue
Block a user