diff --git a/tensors/client.py b/tensors/client.py deleted file mode 100644 index 74493a4..0000000 --- a/tensors/client.py +++ /dev/null @@ -1,292 +0,0 @@ -"""HTTP client for remote tsr server API.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any - -import httpx - -if TYPE_CHECKING: - from collections.abc import Iterator - from pathlib import Path - - -class TsrClientError(Exception): - """Error from TsrClient operations.""" - - -class TsrClient: - """HTTP client wrapper for tsr server API. - - Usage: - with TsrClient("http://junkpile:8080") as client: - images = client.list_images() - result = client.generate("a cat") - """ - - def __init__(self, base_url: str, timeout: float = 300.0) -> None: - """Initialize client with server URL.""" - self.base_url = base_url.rstrip("/") - self.timeout = timeout - self._client: httpx.Client | None = None - - def __enter__(self) -> TsrClient: - self._client = httpx.Client(base_url=self.base_url, timeout=self.timeout) - return self - - def __exit__(self, *exc: object) -> None: - if self._client: - self._client.close() - self._client = None - - @property - def client(self) -> httpx.Client: - """Get the HTTP client, creating if needed.""" - if self._client is None: - self._client = httpx.Client(base_url=self.base_url, timeout=self.timeout) - return self._client - - def _get(self, path: str, params: dict[str, Any] | None = None) -> Any: - """Make GET request.""" - try: - resp = self.client.get(path, params=params) - resp.raise_for_status() - return resp.json() - except httpx.HTTPStatusError as e: - raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e - except httpx.RequestError as e: - raise TsrClientError(f"Request failed: {e}") from e - - def _post(self, path: str, json: dict[str, Any] | None = None) -> Any: - """Make POST request.""" - try: - resp = self.client.post(path, json=json) - resp.raise_for_status() - return resp.json() - except httpx.HTTPStatusError as e: - raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e - except httpx.RequestError as e: - raise TsrClientError(f"Request failed: {e}") from e - - def _delete(self, path: str) -> Any: - """Make DELETE request.""" - try: - resp = self.client.delete(path) - resp.raise_for_status() - return resp.json() - except httpx.HTTPStatusError as e: - raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e - except httpx.RequestError as e: - raise TsrClientError(f"Request failed: {e}") from e - - # ========================================================================= - # Server Status - # ========================================================================= - - def status(self) -> dict[str, Any]: - """Get server status.""" - return dict(self._get("/status")) - - # ========================================================================= - # Gallery / Images - # ========================================================================= - - def list_images(self, limit: int = 50, offset: int = 0) -> dict[str, Any]: - """List images in gallery.""" - return dict(self._get("/api/images", params={"limit": limit, "offset": offset})) - - def get_image_meta(self, image_id: str) -> dict[str, Any]: - """Get metadata for an image.""" - return dict(self._get(f"/api/images/{image_id}/meta")) - - def delete_image(self, image_id: str) -> dict[str, Any]: - """Delete an image.""" - return dict(self._delete(f"/api/images/{image_id}")) - - def edit_image(self, image_id: str, updates: dict[str, Any]) -> dict[str, Any]: - """Update image metadata.""" - return dict(self._post(f"/api/images/{image_id}/edit", json=updates)) - - def download_image(self, image_id: str) -> bytes: - """Download image file bytes.""" - try: - resp = self.client.get(f"/api/images/{image_id}") - resp.raise_for_status() - return resp.content - except httpx.HTTPStatusError as e: - raise TsrClientError(f"HTTP {e.response.status_code}: {e.response.text}") from e - except httpx.RequestError as e: - raise TsrClientError(f"Request failed: {e}") from e - - # ========================================================================= - # Models - # ========================================================================= - - def list_models(self) -> dict[str, Any]: - """List available models.""" - return dict(self._get("/api/models")) - - def get_active_model(self) -> dict[str, Any]: - """Get currently active model.""" - return dict(self._get("/api/models/active")) - - def switch_model(self, model_path: str) -> dict[str, Any]: - """Switch to a different model.""" - return dict(self._post("/api/models/switch", json={"model": model_path})) - - def list_loras(self) -> dict[str, Any]: - """List available LoRAs.""" - return dict(self._get("/api/models/loras")) - - def scan_models(self) -> dict[str, Any]: - """Scan model directories.""" - return dict(self._get("/api/models/scan")) - - # ========================================================================= - # Generation - # ========================================================================= - - def generate( - self, - prompt: str, - negative_prompt: str = "", - width: int = 512, - height: int = 512, - steps: int = 20, - cfg_scale: float = 7.0, - seed: int = -1, - sampler_name: str = "", - scheduler: str = "", - batch_size: int = 1, - save_to_gallery: bool = True, - return_base64: bool = False, - ) -> dict[str, Any]: - """Generate images.""" - body = { - "prompt": prompt, - "negative_prompt": negative_prompt, - "width": width, - "height": height, - "steps": steps, - "cfg_scale": cfg_scale, - "seed": seed, - "sampler_name": sampler_name, - "scheduler": scheduler, - "batch_size": batch_size, - "save_to_gallery": save_to_gallery, - "return_base64": return_base64, - } - return dict(self._post("/api/generate", json=body)) - - def list_samplers(self) -> dict[str, Any]: - """List available samplers.""" - return dict(self._get("/api/samplers")) - - def list_schedulers(self) -> dict[str, Any]: - """List available schedulers.""" - return dict(self._get("/api/schedulers")) - - # ========================================================================= - # Download - # ========================================================================= - - def start_download( - self, - version_id: int | None = None, - model_id: int | None = None, - hash_val: str | None = None, - output_dir: str | None = None, - ) -> dict[str, Any]: - """Start a model download from CivitAI.""" - body: dict[str, Any] = {} - if version_id: - body["version_id"] = version_id - if model_id: - body["model_id"] = model_id - if hash_val: - body["hash"] = hash_val - if output_dir: - body["output_dir"] = output_dir - return dict(self._post("/api/download", json=body)) - - def get_download_status(self, download_id: str) -> dict[str, Any]: - """Get download status.""" - return dict(self._get(f"/api/download/status/{download_id}")) - - def list_downloads(self) -> dict[str, Any]: - """List active downloads.""" - return dict(self._get("/api/download/active")) - - # ========================================================================= - # Database - # ========================================================================= - - def db_list_files(self) -> list[dict[str, Any]]: - """List local files in database.""" - return list(self._get("/api/db/files")) - - def db_search_models( - self, - query: str | None = None, - model_type: str | None = None, - base_model: str | None = None, - limit: int = 20, - ) -> list[dict[str, Any]]: - """Search cached models.""" - params: dict[str, Any] = {"limit": limit} - if query: - params["query"] = query - if model_type: - params["type"] = model_type - if base_model: - params["base"] = base_model - return list(self._get("/api/db/models", params=params)) - - def db_get_model(self, civitai_id: int) -> dict[str, Any]: - """Get cached model by CivitAI ID.""" - return dict(self._get(f"/api/db/models/{civitai_id}")) - - def db_get_triggers(self, file_path: str | None = None, version_id: int | None = None) -> list[str]: - """Get trigger words.""" - if version_id: - return list(self._get(f"/api/db/triggers/{version_id}")) - if file_path: - return list(self._get("/api/db/triggers", params={"file_path": file_path})) - return [] - - def db_stats(self) -> dict[str, Any]: - """Get database statistics.""" - return dict(self._get("/api/db/stats")) - - def db_scan(self, directory: str) -> dict[str, Any]: - """Scan directory for safetensor files.""" - return dict(self._post("/api/db/scan", json={"directory": directory})) - - def db_link(self) -> dict[str, Any]: - """Link unlinked files to CivitAI.""" - return dict(self._post("/api/db/link")) - - def db_cache(self, model_id: int) -> dict[str, Any]: - """Cache CivitAI model data.""" - return dict(self._post("/api/db/cache", json={"model_id": model_id})) - - # ========================================================================= - # Streaming Downloads - # ========================================================================= - - def stream_image(self, image_id: str) -> Iterator[bytes]: - """Stream image download in chunks.""" - try: - with self.client.stream("GET", f"/api/images/{image_id}") as resp: - resp.raise_for_status() - yield from resp.iter_bytes(chunk_size=1024 * 64) - except httpx.HTTPStatusError as e: - raise TsrClientError(f"HTTP {e.response.status_code}") from e - except httpx.RequestError as e: - raise TsrClientError(f"Request failed: {e}") from e - - def save_image_to(self, image_id: str, dest: Path) -> Path: - """Download and save image to file.""" - content = self.download_image(image_id) - dest.write_bytes(content) - return dest diff --git a/tensors/generate/__init__.py b/tensors/generate/__init__.py deleted file mode 100644 index 5741ef3..0000000 --- a/tensors/generate/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -"""sd-server Python client — modular, httpx-based.""" - -from __future__ import annotations - -from typing import Any - -from tensors.generate._http import HttpTransport -from tensors.generate.generation import GenerationAPI -from tensors.generate.info import InfoAPI -from tensors.generate.params import Img2ImgParams, Txt2ImgParams -from tensors.generate.util import save_images - -__all__ = [ - "Img2ImgParams", - "SDClient", - "Txt2ImgParams", - "save_images", -] - - -class SDClient: - """Composite client for sd-server. - - Usage:: - - with SDClient() as c: - c.info.models() - images = c.generate.txt2img(Txt2ImgParams(prompt="a cat")) - """ - - def __init__(self, host: str = "127.0.0.1", port: int = 1234) -> None: - self._http = HttpTransport(f"http://{host}:{port}") - self.info = InfoAPI(self._http) - self.generate = GenerationAPI(self._http) - - def close(self) -> None: - self._http.close() - - def __enter__(self) -> SDClient: - return self - - def __exit__(self, *exc: Any) -> None: - self.close() diff --git a/tensors/generate/_http.py b/tensors/generate/_http.py deleted file mode 100644 index ea4337d..0000000 --- a/tensors/generate/_http.py +++ /dev/null @@ -1,46 +0,0 @@ -"""HTTP transport layer wrapping httpx.""" - -from __future__ import annotations - -import logging -from typing import Any - -import httpx - -logger = logging.getLogger(__name__) - - -class HttpTransport: - def __init__(self, base_url: str, timeout: float = 300.0) -> None: - self._client = httpx.Client(base_url=base_url, timeout=timeout) - logger.debug("transport ready: %s", base_url) - - def get(self, path: str) -> Any: - logger.debug("GET %s", path) - try: - r = self._client.get(path) - r.raise_for_status() - except httpx.HTTPStatusError as e: - logger.error("GET %s → %d: %s", path, e.response.status_code, e.response.text[:200]) - raise - except httpx.RequestError as e: - logger.error("GET %s connection failed: %s", path, e) - raise - return r.json() - - def post(self, path: str, json: dict[str, Any]) -> Any: - logger.debug("POST %s", path) - try: - r = self._client.post(path, json=json) - r.raise_for_status() - except httpx.HTTPStatusError as e: - logger.error("POST %s → %d: %s", path, e.response.status_code, e.response.text[:200]) - raise - except httpx.RequestError as e: - logger.error("POST %s connection failed: %s", path, e) - raise - return r.json() - - def close(self) -> None: - self._client.close() - logger.debug("transport closed") diff --git a/tensors/generate/generation.py b/tensors/generate/generation.py deleted file mode 100644 index acc8b89..0000000 --- a/tensors/generate/generation.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Image generation endpoints.""" - -from __future__ import annotations - -import base64 -import logging -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from tensors.generate._http import HttpTransport - from tensors.generate.params import Img2ImgParams, Txt2ImgParams - -logger = logging.getLogger(__name__) - - -class GenerationAPI: - def __init__(self, http: HttpTransport) -> None: - self._http = http - - def txt2img(self, params: Txt2ImgParams) -> list[bytes]: - """Generate images from text prompt.""" - logger.info("txt2img: '%s' %dx%d steps=%d", params.prompt[:60], params.width, params.height, params.steps) - data = self._http.post("/sdapi/v1/txt2img", params.to_body()) - images = [base64.b64decode(img) for img in data["images"]] - logger.info("txt2img: got %d image(s)", len(images)) - return images - - def img2img(self, params: Img2ImgParams) -> list[bytes]: - """Generate images from image + text prompt.""" - logger.info("img2img: '%s' strength=%.2f steps=%d", params.prompt[:60], params.denoising_strength, params.steps) - data = self._http.post("/sdapi/v1/img2img", params.to_body()) - images = [base64.b64decode(img) for img in data["images"]] - logger.info("img2img: got %d image(s)", len(images)) - return images diff --git a/tensors/generate/info.py b/tensors/generate/info.py deleted file mode 100644 index abcb7ae..0000000 --- a/tensors/generate/info.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Model and server info endpoints.""" - -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from tensors.generate._http import HttpTransport - -logger = logging.getLogger(__name__) - - -class InfoAPI: - def __init__(self, http: HttpTransport) -> None: - self._http = http - - def models(self) -> list[dict[str, Any]]: - """List loaded models (OpenAI /v1/models).""" - return self._http.get("/v1/models")["data"] # type: ignore[no-any-return] - - def sd_models(self) -> list[dict[str, Any]]: - """Detailed model info (sdapi).""" - return self._http.get("/sdapi/v1/sd-models") # type: ignore[no-any-return] - - def options(self) -> dict[str, Any]: - """Current server options.""" - return self._http.get("/sdapi/v1/options") # type: ignore[no-any-return] - - def loras(self) -> list[dict[str, Any]]: - """Available LoRAs from --lora-model-dir.""" - result: list[dict[str, Any]] = self._http.get("/sdapi/v1/loras") - logger.info("found %d lora(s)", len(result)) - return result - - def samplers(self) -> list[str]: - """Available sampler names.""" - return [s["name"] for s in self._http.get("/sdapi/v1/samplers")] - - def schedulers(self) -> list[str]: - """Available scheduler names.""" - return [s["name"] for s in self._http.get("/sdapi/v1/schedulers")] diff --git a/tensors/generate/params.py b/tensors/generate/params.py deleted file mode 100644 index be18ce8..0000000 --- a/tensors/generate/params.py +++ /dev/null @@ -1,100 +0,0 @@ -"""Generation parameter dataclasses.""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from pathlib import Path - -from tensors.generate.util import to_b64 - - -@dataclass -class Txt2ImgParams: - prompt: str - negative_prompt: str = "" - width: int = 512 - height: int = 512 - steps: int = 20 - cfg_scale: float = 7.0 - seed: int = -1 - batch_size: int = 1 - sampler_name: str = "" - scheduler: str = "" - clip_skip: int = -1 - lora: list[dict[str, Any]] | None = None - - def to_body(self) -> dict[str, Any]: - body = { - "prompt": self.prompt, - "negative_prompt": self.negative_prompt, - "width": self.width, - "height": self.height, - "steps": self.steps, - "cfg_scale": self.cfg_scale, - "seed": self.seed, - "batch_size": self.batch_size, - } - if self.sampler_name: - body["sampler_name"] = self.sampler_name - if self.scheduler: - body["scheduler"] = self.scheduler - if self.clip_skip > 0: - body["clip_skip"] = self.clip_skip - if self.lora: - body["lora"] = self.lora - return body - - -@dataclass -class Img2ImgParams: - prompt: str - init_image: str | bytes | Path - negative_prompt: str = "" - width: int = -1 - height: int = -1 - steps: int = 20 - cfg_scale: float = 7.0 - denoising_strength: float = 0.75 - seed: int = -1 - batch_size: int = 1 - sampler_name: str = "" - scheduler: str = "" - clip_skip: int = -1 - mask: str | bytes | Path | None = None - inpainting_mask_invert: bool = False - lora: list[dict[str, Any]] | None = None - extra_images: list[str | bytes | Path] = field(default_factory=list) - - def to_body(self) -> dict[str, Any]: - body: dict[str, Any] = { - "prompt": self.prompt, - "negative_prompt": self.negative_prompt, - "steps": self.steps, - "cfg_scale": self.cfg_scale, - "denoising_strength": self.denoising_strength, - "seed": self.seed, - "batch_size": self.batch_size, - "init_images": [to_b64(self.init_image)], - } - if self.width > 0: - body["width"] = self.width - if self.height > 0: - body["height"] = self.height - if self.mask is not None: - body["mask"] = to_b64(self.mask) - if self.inpainting_mask_invert: - body["inpainting_mask_invert"] = 1 - if self.sampler_name: - body["sampler_name"] = self.sampler_name - if self.scheduler: - body["scheduler"] = self.scheduler - if self.clip_skip > 0: - body["clip_skip"] = self.clip_skip - if self.lora: - body["lora"] = self.lora - if self.extra_images: - body["extra_images"] = [to_b64(img) for img in self.extra_images] - return body diff --git a/tensors/generate/util.py b/tensors/generate/util.py deleted file mode 100644 index 4014115..0000000 --- a/tensors/generate/util.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Utility functions for image encoding and file I/O.""" - -import base64 -import logging -from pathlib import Path - -logger = logging.getLogger(__name__) - - -def to_b64(image: str | bytes | Path) -> str: - """Convert a file path, raw bytes, or base64 string to base64.""" - if isinstance(image, (str, Path)): - path = Path(image) - if path.exists(): - logger.debug("encoding file: %s", path) - return base64.b64encode(path.read_bytes()).decode() - return str(image) - if isinstance(image, bytes): - return base64.b64encode(image).decode() - raise TypeError(f"unsupported image type: {type(image)}") - - -def save_images( - images: list[bytes], - output_dir: str = ".", - prefix: str = "output", -) -> list[Path]: - """Write raw PNG bytes to numbered files. Returns saved paths.""" - out = Path(output_dir) - out.mkdir(parents=True, exist_ok=True) - paths = [] - for i, data in enumerate(images): - path = out / f"{prefix}_{i:04d}.png" - path.write_bytes(data) - logger.info("saved: %s", path) - paths.append(path) - return paths diff --git a/tensors/server/generate_routes.py b/tensors/server/generate_routes.py deleted file mode 100644 index 75bc8ec..0000000 --- a/tensors/server/generate_routes.py +++ /dev/null @@ -1,216 +0,0 @@ -"""FastAPI route handlers for image generation with gallery integration.""" - -from __future__ import annotations - -import base64 -import logging -import time -from pathlib import Path -from typing import Any - -import httpx -from fastapi import APIRouter, HTTPException, Request -from pydantic import BaseModel as PydanticBaseModel -from pydantic import Field - -from tensors.server.gallery import Gallery -from tensors.server.sd_client import get_sd_headers - -logger = logging.getLogger(__name__) - - -# ============================================================================= -# Request/Response Models -# ============================================================================= - - -class LoraConfig(PydanticBaseModel): - """LoRA configuration for sd-server.""" - - path: str - multiplier: float = Field(default=1.0, ge=0.0, le=2.0) - - -class GenerateRequest(PydanticBaseModel): - """Request body for image generation.""" - - prompt: str - negative_prompt: str = "" - width: int = Field(default=512, ge=64, le=2048) - height: int = Field(default=512, ge=64, le=2048) - steps: int = Field(default=20, ge=1, le=150) - cfg_scale: float = Field(default=7.0, ge=0, le=30) - seed: int = -1 - sampler_name: str = "" - scheduler: str = "" - batch_size: int = Field(default=1, ge=1, le=16) - save_to_gallery: bool = True - return_base64: bool = False - lora: LoraConfig | None = None - - -# ============================================================================= -# Helper Functions -# ============================================================================= - - -def _build_sd_request(req: GenerateRequest) -> dict[str, Any]: - """Build request body for sd-server.""" - prompt = req.prompt - - # sd-server expects LoRA in prompt as syntax - if req.lora: - lora_name = Path(req.lora.path).stem - lora_tag = f"" - prompt = f"{prompt} {lora_tag}" - - body: dict[str, Any] = { - "prompt": prompt, - "negative_prompt": req.negative_prompt, - "width": req.width, - "height": req.height, - "steps": req.steps, - "cfg_scale": req.cfg_scale, - "seed": req.seed, - "batch_size": req.batch_size, - } - if req.sampler_name: - body["sampler_name"] = req.sampler_name - if req.scheduler: - body["scheduler"] = req.scheduler - return body - - -def _parse_info(info: Any) -> dict[str, Any]: - """Parse info from sd-server response.""" - if isinstance(info, str): - import json # noqa: PLC0415 - - try: - return dict(json.loads(info)) - except json.JSONDecodeError: - return {"raw": info} - return info if isinstance(info, dict) else {} - - -def _process_image( - img_b64: str, - index: int, - seed: int, - req: GenerateRequest, - gallery: Gallery, - model: str | None, -) -> dict[str, Any]: - """Process a single generated image.""" - image_bytes = base64.b64decode(img_b64) - image_info: dict[str, Any] = {"index": index, "seed": seed} - - if req.save_to_gallery: - metadata = { - "prompt": req.prompt, - "negative_prompt": req.negative_prompt, - "width": req.width, - "height": req.height, - "steps": req.steps, - "cfg_scale": req.cfg_scale, - "sampler": req.sampler_name, - "scheduler": req.scheduler, - "model": model, - "lora": req.lora.path if req.lora else None, - "lora_weight": req.lora.multiplier if req.lora else None, - "generated_at": time.time(), - } - gallery_img = gallery.save_image(image_bytes, metadata=metadata, seed=seed) - image_info["id"] = gallery_img.id - image_info["path"] = str(gallery_img.path) - - if req.return_base64: - image_info["base64"] = img_b64 - - return image_info - - -# ============================================================================= -# Router Factory -# ============================================================================= - - -def create_generate_router() -> APIRouter: # noqa: PLR0915 - """Build a router with /api/generate endpoint.""" - router = APIRouter(prefix="/api", tags=["generate"]) - gallery = Gallery() - - @router.post("/generate") - async def generate(request: Request, req: GenerateRequest) -> dict[str, Any]: - """Generate images with gallery integration.""" - sd_server_url = request.app.state.sd_server_url - body = _build_sd_request(req) - url = f"{sd_server_url}/sdapi/v1/txt2img" - - try: - headers = get_sd_headers(request) - async with httpx.AsyncClient(timeout=300) as client: - response = await client.post(url, json=body, headers=headers) - response.raise_for_status() - result = response.json() - except httpx.ConnectError as e: - raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e - except httpx.HTTPError as e: - logger.exception("Generation failed") - raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e - - images_data = result.get("images", []) - info = _parse_info(result.get("info", {})) - all_seeds = info.get("all_seeds", [req.seed] * len(images_data)) - - # Get model info from sd-server response if available - model_name = info.get("sd_model_name") or info.get("model") - - output_images = [ - _process_image(img_b64, i, all_seeds[i] if i < len(all_seeds) else req.seed + i, req, gallery, model_name) - for i, img_b64 in enumerate(images_data) - ] - - return { - "images": output_images, - "parameters": result.get("parameters", body), - "info": info, - "saved_to_gallery": req.save_to_gallery, - "total": len(output_images), - } - - @router.get("/samplers") - async def list_samplers(request: Request) -> dict[str, Any]: - """List available samplers from sd-server.""" - sd_server_url = request.app.state.sd_server_url - url = f"{sd_server_url}/sdapi/v1/samplers" - - try: - headers = get_sd_headers(request) - async with httpx.AsyncClient(timeout=30) as client: - response = await client.get(url, headers=headers) - response.raise_for_status() - return {"samplers": response.json()} - except httpx.ConnectError as e: - raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e - except httpx.HTTPError as e: - raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e - - @router.get("/schedulers") - async def list_schedulers(request: Request) -> dict[str, Any]: - """List available schedulers from sd-server.""" - sd_server_url = request.app.state.sd_server_url - url = f"{sd_server_url}/sdapi/v1/schedulers" - - try: - headers = get_sd_headers(request) - async with httpx.AsyncClient(timeout=30) as client: - response = await client.get(url, headers=headers) - response.raise_for_status() - return {"schedulers": response.json()} - except httpx.ConnectError as e: - raise HTTPException(status_code=503, detail=f"Cannot connect to sd-server: {e}") from e - except httpx.HTTPError as e: - raise HTTPException(status_code=502, detail=f"sd-server error: {e}") from e - - return router diff --git a/tensors/server/models_routes.py b/tensors/server/models_routes.py deleted file mode 100644 index 03cbcd6..0000000 --- a/tensors/server/models_routes.py +++ /dev/null @@ -1,311 +0,0 @@ -"""FastAPI route handlers for model management endpoints.""" - -from __future__ import annotations - -import asyncio -import logging -from pathlib import Path -from typing import Any - -from fastapi import APIRouter, HTTPException, Request -from pydantic import BaseModel - -from tensors.config import MODELS_DIR -from tensors.db import Database -from tensors.server.sd_client import get_sd_headers - -logger = logging.getLogger(__name__) - -_HTTP_OK = 200 -_SD_ENV_FILE = Path("/etc/default/sd-server") - - -class SwitchModelRequest(BaseModel): - """Request body for switching models.""" - - model: str # Model filename or full path - - -async def _run_command(*args: str) -> tuple[int, str, str]: - """Run a shell command and return (returncode, stdout, stderr).""" - proc = await asyncio.create_subprocess_exec( - *args, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - stdout, stderr = await proc.communicate() - return proc.returncode or 0, stdout.decode(), stderr.decode() - - -def _read_env_file() -> dict[str, str]: - """Read the sd-server environment file.""" - env: dict[str, str] = {} - if _SD_ENV_FILE.exists(): - for raw_line in _SD_ENV_FILE.read_text().splitlines(): - line = raw_line.strip() - if line and not line.startswith("#") and "=" in line: - key, _, value = line.partition("=") - env[key.strip()] = value.strip() - return env - - -def _write_env_file(env: dict[str, str]) -> str: - """Generate env file content.""" - lines = ["# sd-server configuration"] - for key, value in env.items(): - lines.append(f"{key}={value}") - return "\n".join(lines) + "\n" - -# Keywords for detecting base model category -_SD15_KEYWORDS = ("sd15", "sd1.5", "sd-1.5", "sd_1.5", "1.5", "sd-1-", "v1-5") -_LARGE_KEYWORDS = ("sdxl", "xl", "pony", "illustrious", "ilust", "noob", "animagine") - - -def _detect_model_category(name: str) -> str: - """Detect model category from filename. Returns 'sd15' or 'large'.""" - name_lower = name.lower() - - # Check SD 1.5 keywords first - for kw in _SD15_KEYWORDS: - if kw in name_lower: - return "sd15" - - # Check large model keywords - for kw in _LARGE_KEYWORDS: - if kw in name_lower: - return "large" - - # Default to large (SDXL/Pony/Illustrious are more common now) - return "large" - - -# ============================================================================= -# Helper Functions -# ============================================================================= - - -def scan_models(directory: Path, extensions: tuple[str, ...] = (".safetensors", ".gguf")) -> list[dict[str, Any]]: - """Scan directory for model files.""" - models: list[dict[str, Any]] = [] - - if not directory.exists(): - return models - - for ext in extensions: - for path in directory.rglob(f"*{ext}"): - stat = path.stat() - name = path.stem - models.append( - { - "name": name, - "path": str(path), - "filename": path.name, - "size_mb": round(stat.st_size / (1024 * 1024), 2), - "modified": stat.st_mtime, - "category": _detect_model_category(name), - } - ) - - # Sort by name - models.sort(key=lambda x: x["name"].lower()) - return models - - -def _enrich_with_metadata(models: list[dict[str, Any]]) -> list[dict[str, Any]]: - """Enrich model data with CivitAI metadata from database.""" - try: - with Database() as db: - db.init_schema() - - for model in models: - file_path = model.get("path", "") - file_info = db.get_local_file_by_path(file_path) - - if file_info and file_info.get("civitai_model_id"): - # Add human-readable name - model["display_name"] = file_info.get("model_name") or model["name"] - model["base_model"] = file_info.get("base_model") - model["model_type"] = file_info.get("model_type") - model["civitai_model_id"] = file_info.get("civitai_model_id") - model["civitai_version_id"] = file_info.get("civitai_version_id") - - # Get thumbnail from version images - version_id = file_info.get("civitai_version_id") - if version_id: - cur = db.conn.cursor() - cur.execute( - """ - SELECT url FROM version_images - WHERE version_id = (SELECT id FROM model_versions WHERE civitai_id = ?) - ORDER BY id LIMIT 1 - """, - (version_id,), - ) - row = cur.fetchone() - if row: - model["thumbnail_url"] = row[0] - - # Get trigger words - triggers = db.get_triggers_by_version(version_id) if version_id else [] - model["triggers"] = triggers[:5] # Limit to first 5 - else: - model["display_name"] = model["name"] - - except Exception as e: - logger.warning("Failed to enrich models with metadata: %s", e) - # Fallback: just use filename as display name - for model in models: - if "display_name" not in model: - model["display_name"] = model["name"] - - return models - - -def scan_loras(directory: Path | None = None) -> list[dict[str, Any]]: - """Scan for LoRA files.""" - lora_dir = directory or MODELS_DIR / "loras" - return scan_models(lora_dir, extensions=(".safetensors", ".gguf")) - - -def scan_checkpoints(directory: Path | None = None) -> list[dict[str, Any]]: - """Scan for checkpoint files.""" - checkpoint_dir = directory or MODELS_DIR / "checkpoints" - return scan_models(checkpoint_dir, extensions=(".safetensors", ".gguf")) - - -# ============================================================================= -# Router Factory -# ============================================================================= - - -def create_models_router() -> APIRouter: # noqa: PLR0915 - """Build a router with /api/models/* endpoints.""" - router = APIRouter(prefix="/api/models", tags=["models"]) - - @router.get("") - def list_models() -> dict[str, Any]: - """List available checkpoint models with metadata.""" - checkpoints = scan_checkpoints() - checkpoints = _enrich_with_metadata(checkpoints) - return { - "models": checkpoints, - "total": len(checkpoints), - } - - @router.get("/active") - async def get_active_model(request: Request) -> dict[str, Any]: - """Get information about the currently loaded model from sd-server.""" - import httpx # noqa: PLC0415 - - sd_server_url = request.app.state.sd_server_url - - # Try to get current model from sd-server's options endpoint - try: - headers = get_sd_headers(request) - async with httpx.AsyncClient(timeout=10) as client: - response = await client.get(f"{sd_server_url}/sdapi/v1/options", headers=headers) - if response.status_code == _HTTP_OK: - options = response.json() - model_name = options.get("sd_model_checkpoint") - return { - "loaded": True, - "model": model_name, - "sd_server_url": sd_server_url, - } - except httpx.HTTPError: - pass - - return { - "loaded": False, - "model": None, - "sd_server_url": sd_server_url, - "error": "Cannot connect to sd-server", - } - - @router.get("/loras") - def list_loras() -> dict[str, Any]: - """List available LoRA files with metadata.""" - loras = scan_loras() - loras = _enrich_with_metadata(loras) - return { - "loras": loras, - "total": len(loras), - } - - @router.get("/scan") - def scan_all_models() -> dict[str, Any]: - """Scan all model directories.""" - checkpoints = scan_checkpoints() - loras = scan_loras() - - return { - "checkpoints": checkpoints, - "loras": loras, - "total_checkpoints": len(checkpoints), - "total_loras": len(loras), - } - - @router.post("/switch") - async def switch_model(req: SwitchModelRequest) -> dict[str, Any]: - """Switch sd-server to a different model by updating env and restarting.""" - # Find the model file - checkpoints = scan_checkpoints() - model_path: str | None = None - - for cp in checkpoints: - if cp["filename"] == req.model or cp["path"] == req.model or cp["name"] == req.model: - model_path = cp["path"] - break - - if not model_path: - raise HTTPException(status_code=404, detail=f"Model not found: {req.model}") - - # Read current env, update SD_MODEL - env = _read_env_file() - old_model = env.get("SD_MODEL", "") - env["SD_MODEL"] = model_path - - # Write new env file via sudo tee - new_content = _write_env_file(env) - proc = await asyncio.create_subprocess_exec( - "sudo", "tee", str(_SD_ENV_FILE), - stdin=asyncio.subprocess.PIPE, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - _, tee_stderr = await proc.communicate(new_content.encode()) - if proc.returncode != 0: - raise HTTPException(status_code=500, detail=f"Failed to write env file: {tee_stderr.decode()}") - - # Restart sd-server - returncode, _stdout, restart_stderr = await _run_command("sudo", "systemctl", "restart", "sd-server") - if returncode != 0: - raise HTTPException(status_code=500, detail=f"Failed to restart sd-server: {restart_stderr}") - - logger.info(f"Switched model from {old_model} to {model_path}") - - return { - "ok": True, - "old_model": old_model, - "new_model": model_path, - "message": "Model switched, sd-server restarting", - } - - @router.get("/status") - async def sd_server_status() -> dict[str, Any]: - """Get sd-server systemd service status.""" - _returncode, stdout, _stderr = await _run_command("systemctl", "is-active", "sd-server") - is_active = stdout.strip() == "active" - - env = _read_env_file() - - return { - "service": "sd-server", - "active": is_active, - "status": stdout.strip(), - "current_model": env.get("SD_MODEL"), - "host": env.get("SD_HOST"), - "port": env.get("SD_PORT"), - } - - return router diff --git a/tensors/server/routes.py b/tensors/server/routes.py deleted file mode 100644 index 647fc74..0000000 --- a/tensors/server/routes.py +++ /dev/null @@ -1,80 +0,0 @@ -"""FastAPI route handlers for the sd-server wrapper API.""" - -from __future__ import annotations - -import logging -from typing import Any - -import httpx -from fastapi import APIRouter, Request, Response -from fastapi.responses import JSONResponse, StreamingResponse - -from tensors.server.sd_client import get_sd_headers - -logger = logging.getLogger(__name__) - - -def create_router() -> APIRouter: - """Build a router with /status and catch-all proxy.""" - router = APIRouter() - - @router.get("/status") - async def status(request: Request) -> dict[str, Any]: - """Check if the external sd-server is reachable.""" - sd_server_url = request.app.state.sd_server_url - try: - headers = get_sd_headers(request) - async with httpx.AsyncClient(timeout=5) as client: - r = await client.get(sd_server_url, headers=headers) - return { - "status": "ok", - "sd_server_url": sd_server_url, - "sd_server_status": r.status_code, - } - except httpx.HTTPError as e: - return { - "status": "error", - "sd_server_url": sd_server_url, - "error": str(e), - } - - @router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]) - async def proxy(request: Request, path: str) -> Response: - """Proxy all requests to the external sd-server.""" - sd_server_url = request.app.state.sd_server_url - url = f"{sd_server_url}/{path}" - if request.url.query: - url = f"{url}?{request.url.query}" - - body = await request.body() - headers = dict(request.headers) - headers.pop("host", None) - # Add API key if configured - headers.update(get_sd_headers(request)) - client = request.app.state.client - - try: - upstream = await client.request( - method=request.method, - url=url, - headers=headers, - content=body, - timeout=300, - ) - return StreamingResponse( - content=upstream.iter_bytes(), - status_code=upstream.status_code, - headers=dict(upstream.headers), - ) - except httpx.ConnectError: - return JSONResponse( - {"error": f"Cannot connect to sd-server at {sd_server_url}"}, - status_code=503, - ) - except httpx.TimeoutException: - return JSONResponse( - {"error": f"Timeout connecting to sd-server at {sd_server_url}"}, - status_code=504, - ) - - return router diff --git a/tensors/server/sd_client.py b/tensors/server/sd_client.py deleted file mode 100644 index 4e728e3..0000000 --- a/tensors/server/sd_client.py +++ /dev/null @@ -1,39 +0,0 @@ -"""HTTP client utilities for sd-server communication.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any - -import httpx - -if TYPE_CHECKING: - from fastapi import Request - - -def get_sd_headers(request: Request) -> dict[str, str]: - """Get headers for sd-server requests, including API key if configured.""" - headers: dict[str, str] = {} - api_key = getattr(request.app.state, "sd_server_api_key", None) - if api_key: - headers["X-API-Key"] = api_key - return headers - - -async def sd_get(request: Request, path: str, *, timeout: float = 30) -> httpx.Response: - """Make a GET request to sd-server.""" - url = f"{request.app.state.sd_server_url}/{path.lstrip('/')}" - headers = get_sd_headers(request) - async with httpx.AsyncClient(timeout=timeout) as client: - response = await client.get(url, headers=headers) - response.raise_for_status() - return response - - -async def sd_post(request: Request, path: str, *, json: dict[str, Any] | None = None, timeout: float = 300) -> httpx.Response: - """Make a POST request to sd-server.""" - url = f"{request.app.state.sd_server_url}/{path.lstrip('/')}" - headers = get_sd_headers(request) - async with httpx.AsyncClient(timeout=timeout) as client: - response = await client.post(url, json=json, headers=headers) - response.raise_for_status() - return response