💬 Commit message: Update 2026-02-15 06:21:35, 7 files, 1559 lines

📁 Files changed: 7
📝 Lines changed: 1559

  • .coverage
  • cli.py
  • __init__.py
  • conftest.py
  • test_client.py
  • test_generate.py
  • test_server.py
This commit is contained in:
Adam Ladachowski
2026-02-15 06:21:35 +01:00
parent c419e443ae
commit 356d8fd156
7 changed files with 52 additions and 1507 deletions
-20
View File
@@ -7,11 +7,6 @@ import json
import struct
import pytest
import respx
from tensors.generate import SDClient
BASE_URL = "http://127.0.0.1:1234"
# 1x1 red PNG for image response stubs
TINY_PNG = (
@@ -43,18 +38,3 @@ def temp_safetensor(tmp_path):
f.write(header_bytes)
return file_path
@pytest.fixture()
def mock_api():
"""Activate respx mock for the sd-server base URL."""
with respx.mock(base_url=BASE_URL, assert_all_called=False) as rsps:
yield rsps
@pytest.fixture()
def client(mock_api: respx.MockRouter) -> SDClient: # noqa: ARG001
"""SDClient wired to the mocked transport."""
c = SDClient()
yield c # type: ignore[misc]
c.close()
-481
View File
@@ -1,481 +0,0 @@
"""Tests for the TsrClient HTTP client module."""
from __future__ import annotations
import pytest
import respx
from httpx import Response
from tensors.client import TsrClient, TsrClientError
BASE_URL = "http://test-server:8080"
@pytest.fixture
def mock_server():
"""Activate respx mock for the test server."""
with respx.mock(base_url=BASE_URL, assert_all_called=False) as rsps:
yield rsps
@pytest.fixture
def client(mock_server) -> TsrClient: # noqa: ARG001 - mock_server activates respx
"""TsrClient connected to mock server."""
return TsrClient(BASE_URL)
# =============================================================================
# Status Tests
# =============================================================================
class TestStatus:
"""Tests for server status endpoint."""
def test_status_success(self, client: TsrClient, mock_server) -> None:
"""Test getting server status."""
mock_server.get("/status").mock(return_value=Response(200, json={"running": True, "pid": 12345, "model": "/test.gguf"}))
with client:
result = client.status()
assert result["running"] is True
assert result["pid"] == 12345
def test_status_error(self, client: TsrClient, mock_server) -> None:
"""Test handling status error."""
mock_server.get("/status").mock(return_value=Response(503, text="Service unavailable"))
with client, pytest.raises(TsrClientError, match="HTTP 503"):
client.status()
# =============================================================================
# Gallery Tests
# =============================================================================
class TestGalleryImages:
"""Tests for gallery image operations."""
def test_list_images(self, client: TsrClient, mock_server) -> None:
"""Test listing gallery images."""
mock_server.get("/api/images").mock(
return_value=Response(
200,
json={
"images": [
{"id": "123_0", "filename": "123_0.png", "width": 512, "height": 512},
{"id": "124_1", "filename": "124_1.png", "width": 1024, "height": 1024},
],
"total": 2,
},
)
)
with client:
result = client.list_images()
assert len(result["images"]) == 2
assert result["total"] == 2
def test_list_images_with_pagination(self, client: TsrClient, mock_server) -> None:
"""Test listing images with pagination."""
mock_server.get("/api/images", params={"limit": 10, "offset": 5}).mock(
return_value=Response(200, json={"images": [], "total": 100})
)
with client:
result = client.list_images(limit=10, offset=5)
assert result["total"] == 100
def test_get_image_meta(self, client: TsrClient, mock_server) -> None:
"""Test getting image metadata."""
mock_server.get("/api/images/123_0/meta").mock(
return_value=Response(
200,
json={
"id": "123_0",
"path": "/gallery/123_0.png",
"metadata": {"prompt": "test prompt", "seed": 42},
},
)
)
with client:
result = client.get_image_meta("123_0")
assert result["id"] == "123_0"
assert result["metadata"]["prompt"] == "test prompt"
def test_delete_image(self, client: TsrClient, mock_server) -> None:
"""Test deleting an image."""
mock_server.delete("/api/images/123_0").mock(return_value=Response(200, json={"deleted": True, "id": "123_0"}))
with client:
result = client.delete_image("123_0")
assert result["deleted"] is True
def test_edit_image(self, client: TsrClient, mock_server) -> None:
"""Test editing image metadata."""
mock_server.post("/api/images/123_0/edit").mock(
return_value=Response(200, json={"id": "123_0", "metadata": {"tags": ["favorite"], "rating": 5}})
)
with client:
result = client.edit_image("123_0", {"tags": ["favorite"], "rating": 5})
assert result["metadata"]["tags"] == ["favorite"]
def test_download_image(self, client: TsrClient, mock_server) -> None:
"""Test downloading image bytes."""
image_bytes = b"\x89PNG test image data"
mock_server.get("/api/images/123_0").mock(return_value=Response(200, content=image_bytes))
with client:
result = client.download_image("123_0")
assert result == image_bytes
# =============================================================================
# Models Tests
# =============================================================================
class TestModels:
"""Tests for model management operations."""
def test_list_models(self, client: TsrClient, mock_server) -> None:
"""Test listing available models."""
mock_server.get("/api/models").mock(
return_value=Response(
200,
json={
"models": [
{"name": "sdxl_base", "path": "/models/sdxl_base.safetensors"},
{"name": "pony_v6", "path": "/models/pony_v6.safetensors"},
],
"active": "/models/sdxl_base.safetensors",
},
)
)
with client:
result = client.list_models()
assert len(result["models"]) == 2
assert result["active"] == "/models/sdxl_base.safetensors"
def test_get_active_model(self, client: TsrClient, mock_server) -> None:
"""Test getting active model."""
mock_server.get("/api/models/active").mock(return_value=Response(200, json={"model": "/models/sdxl_base.safetensors"}))
with client:
result = client.get_active_model()
assert result["model"] == "/models/sdxl_base.safetensors"
def test_switch_model(self, client: TsrClient, mock_server) -> None:
"""Test switching model."""
mock_server.post("/api/models/switch").mock(
return_value=Response(200, json={"status": "ok", "model": "/models/pony_v6.safetensors"})
)
with client:
result = client.switch_model("/models/pony_v6.safetensors")
assert result["status"] == "ok"
def test_list_loras(self, client: TsrClient, mock_server) -> None:
"""Test listing LoRAs."""
mock_server.get("/api/models/loras").mock(
return_value=Response(
200,
json={
"loras": [
{"name": "detail_tweaker", "path": "/loras/detail_tweaker.safetensors"},
]
},
)
)
with client:
result = client.list_loras()
assert len(result["loras"]) == 1
def test_scan_models(self, client: TsrClient, mock_server) -> None:
"""Test scanning models."""
mock_server.get("/api/models/scan").mock(return_value=Response(200, json={"scanned": 5}))
with client:
result = client.scan_models()
assert result["scanned"] == 5
# =============================================================================
# Generation Tests
# =============================================================================
class TestGeneration:
"""Tests for image generation."""
def test_generate(self, client: TsrClient, mock_server) -> None:
"""Test generating an image."""
mock_server.post("/api/generate").mock(
return_value=Response(
200,
json={
"images": [{"id": "999_42", "seed": 42}],
"parameters": {"prompt": "test prompt", "seed": 42},
},
)
)
with client:
result = client.generate(
prompt="test prompt",
width=512,
height=512,
seed=42,
)
assert len(result["images"]) == 1
assert result["images"][0]["seed"] == 42
def test_generate_with_all_params(self, client: TsrClient, mock_server) -> None:
"""Test generation with all parameters."""
mock_server.post("/api/generate").mock(return_value=Response(200, json={"images": []}))
with client:
result = client.generate(
prompt="detailed test prompt",
negative_prompt="bad quality",
width=1024,
height=1024,
steps=30,
cfg_scale=5.5,
seed=12345,
sampler_name="DPM++ 2M",
scheduler="karras",
batch_size=2,
save_to_gallery=False,
return_base64=True,
)
assert "images" in result
def test_list_samplers(self, client: TsrClient, mock_server) -> None:
"""Test listing samplers."""
mock_server.get("/api/samplers").mock(return_value=Response(200, json={"samplers": ["Euler", "DPM++ 2M", "Euler a"]}))
with client:
result = client.list_samplers()
assert "samplers" in result
def test_list_schedulers(self, client: TsrClient, mock_server) -> None:
"""Test listing schedulers."""
mock_server.get("/api/schedulers").mock(
return_value=Response(200, json={"schedulers": ["simple", "karras", "sgm_uniform"]})
)
with client:
result = client.list_schedulers()
assert "schedulers" in result
# =============================================================================
# Download Tests
# =============================================================================
class TestDownload:
"""Tests for CivitAI download operations."""
def test_start_download_by_version(self, client: TsrClient, mock_server) -> None:
"""Test starting download by version ID."""
mock_server.post("/api/download").mock(
return_value=Response(200, json={"download_id": "abc123", "status": "started", "version_id": 12345})
)
with client:
result = client.start_download(version_id=12345)
assert result["download_id"] == "abc123"
def test_start_download_by_hash(self, client: TsrClient, mock_server) -> None:
"""Test starting download by hash."""
mock_server.post("/api/download").mock(return_value=Response(200, json={"download_id": "def456", "status": "started"}))
with client:
result = client.start_download(hash_val="ABC123DEF456")
assert result["status"] == "started"
def test_get_download_status(self, client: TsrClient, mock_server) -> None:
"""Test getting download status."""
mock_server.get("/api/download/status/abc123").mock(
return_value=Response(200, json={"download_id": "abc123", "status": "downloading", "progress": 0.5})
)
with client:
result = client.get_download_status("abc123")
assert result["progress"] == 0.5
def test_list_downloads(self, client: TsrClient, mock_server) -> None:
"""Test listing active downloads."""
mock_server.get("/api/download/active").mock(
return_value=Response(200, json={"downloads": [{"id": "abc123", "progress": 0.75}]})
)
with client:
result = client.list_downloads()
assert len(result["downloads"]) == 1
# =============================================================================
# Database Tests
# =============================================================================
class TestDatabase:
"""Tests for database operations."""
def test_db_list_files(self, client: TsrClient, mock_server) -> None:
"""Test listing local files."""
mock_server.get("/api/db/files").mock(
return_value=Response(200, json=[{"id": 1, "file_path": "/models/test.safetensors", "sha256": "abc123"}])
)
with client:
result = client.db_list_files()
assert len(result) == 1
assert result[0]["sha256"] == "abc123"
def test_db_search_models(self, client: TsrClient, mock_server) -> None:
"""Test searching cached models."""
mock_server.get("/api/db/models").mock(
return_value=Response(200, json=[{"civitai_id": 12345, "name": "Test Model", "type": "LORA"}])
)
with client:
result = client.db_search_models(query="Test", model_type="LORA")
assert len(result) == 1
assert result[0]["name"] == "Test Model"
def test_db_get_model(self, client: TsrClient, mock_server) -> None:
"""Test getting cached model."""
mock_server.get("/api/db/models/12345").mock(
return_value=Response(200, json={"civitai_id": 12345, "name": "Test Model", "type": "Checkpoint"})
)
with client:
result = client.db_get_model(12345)
assert result["name"] == "Test Model"
def test_db_get_triggers(self, client: TsrClient, mock_server) -> None:
"""Test getting trigger words."""
mock_server.get("/api/db/triggers/12345").mock(return_value=Response(200, json=["trigger1", "trigger2"]))
with client:
result = client.db_get_triggers(version_id=12345)
assert result == ["trigger1", "trigger2"]
def test_db_stats(self, client: TsrClient, mock_server) -> None:
"""Test getting database stats."""
mock_server.get("/api/db/stats").mock(
return_value=Response(200, json={"local_files": 10, "models": 5, "model_versions": 15})
)
with client:
result = client.db_stats()
assert result["local_files"] == 10
def test_db_scan(self, client: TsrClient, mock_server) -> None:
"""Test scanning directory."""
mock_server.post("/api/db/scan").mock(return_value=Response(200, json={"scanned": 3, "files": []}))
with client:
result = client.db_scan("/models")
assert result["scanned"] == 3
def test_db_link(self, client: TsrClient, mock_server) -> None:
"""Test linking files to CivitAI."""
mock_server.post("/api/db/link").mock(return_value=Response(200, json={"linked": 2}))
with client:
result = client.db_link()
assert result["linked"] == 2
def test_db_cache(self, client: TsrClient, mock_server) -> None:
"""Test caching model data."""
mock_server.post("/api/db/cache").mock(return_value=Response(200, json={"model_id": 12345, "cached": True}))
with client:
result = client.db_cache(12345)
assert result["cached"] is True
# =============================================================================
# Error Handling Tests
# =============================================================================
class TestErrorHandling:
"""Tests for error handling."""
def test_http_error(self, client: TsrClient, mock_server) -> None:
"""Test HTTP error handling."""
mock_server.get("/api/images").mock(return_value=Response(500, text="Internal server error"))
with client, pytest.raises(TsrClientError, match="HTTP 500"):
client.list_images()
def test_not_found_error(self, client: TsrClient, mock_server) -> None:
"""Test 404 error handling."""
mock_server.get("/api/images/nonexistent/meta").mock(return_value=Response(404, json={"detail": "Image not found"}))
with client, pytest.raises(TsrClientError, match="HTTP 404"):
client.get_image_meta("nonexistent")
# =============================================================================
# Context Manager Tests
# =============================================================================
class TestContextManager:
"""Tests for context manager usage."""
def test_context_manager(self, mock_server) -> None:
"""Test client works as context manager."""
mock_server.get("/status").mock(return_value=Response(200, json={"running": True}))
with TsrClient(BASE_URL) as client:
result = client.status()
assert result["running"] is True
def test_client_without_context(self, mock_server) -> None:
"""Test client works without context manager."""
mock_server.get("/status").mock(return_value=Response(200, json={"running": True}))
client = TsrClient(BASE_URL)
result = client.status()
assert result["running"] is True
-303
View File
@@ -1,303 +0,0 @@
"""Tests for tensors.generate package."""
from __future__ import annotations
import base64
import json
from pathlib import Path
import httpx
import pytest
import respx
from tensors.generate import SDClient
from tensors.generate._http import HttpTransport
from tensors.generate.params import Img2ImgParams, Txt2ImgParams
from tensors.generate.util import save_images, to_b64
from tests.conftest import BASE_URL, TINY_PNG, TINY_PNG_B64
# ── util ──────────────────────────────────────────────────────────────
class TestToB64:
def test_bytes_input(self):
raw = b"hello"
assert to_b64(raw) == base64.b64encode(raw).decode()
def test_file_path(self, tmp_path: Path):
f = tmp_path / "img.png"
f.write_bytes(b"\x89PNG")
result = to_b64(str(f))
assert base64.b64decode(result) == b"\x89PNG"
def test_pathlib_path(self, tmp_path: Path):
f = tmp_path / "img.png"
f.write_bytes(b"data")
result = to_b64(f)
assert base64.b64decode(result) == b"data"
def test_passthrough_string(self):
b64 = base64.b64encode(b"already").decode()
assert to_b64(b64) == b64
def test_unsupported_type(self):
with pytest.raises(TypeError, match="unsupported image type"):
to_b64(12345) # type: ignore[arg-type]
class TestSaveImages:
def test_saves_files(self, tmp_path: Path):
images = [b"img0", b"img1", b"img2"]
paths = save_images(images, str(tmp_path), prefix="test")
assert len(paths) == 3
for i, p in enumerate(paths):
assert p.name == f"test_{i:04d}.png"
assert p.read_bytes() == images[i]
def test_creates_directory(self, tmp_path: Path):
out = tmp_path / "sub" / "dir"
save_images([b"x"], str(out))
assert (out / "output_0000.png").exists()
# ── params ────────────────────────────────────────────────────────────
class TestTxt2ImgParams:
def test_minimal_body(self):
p = Txt2ImgParams(prompt="a cat")
body = p.to_body()
assert body["prompt"] == "a cat"
assert body["width"] == 512
assert body["height"] == 512
assert body["steps"] == 20
assert body["seed"] == -1
assert "sampler_name" not in body
assert "scheduler" not in body
assert "clip_skip" not in body
assert "lora" not in body
def test_optional_fields_included(self):
p = Txt2ImgParams(
prompt="test",
sampler_name="euler_a",
scheduler="karras",
clip_skip=2,
lora=[{"path": "x.safetensors", "multiplier": 0.5}],
)
body = p.to_body()
assert body["sampler_name"] == "euler_a"
assert body["scheduler"] == "karras"
assert body["clip_skip"] == 2
assert len(body["lora"]) == 1
class TestImg2ImgParams:
def test_minimal_body(self, tmp_path: Path):
img = tmp_path / "init.png"
img.write_bytes(b"\x89PNG")
p = Img2ImgParams(prompt="paint it", init_image=str(img))
body = p.to_body()
assert body["prompt"] == "paint it"
assert body["denoising_strength"] == 0.75
decoded = base64.b64decode(body["init_images"][0])
assert decoded == b"\x89PNG"
assert "width" not in body
assert "height" not in body
assert "mask" not in body
def test_all_optional_fields(self, tmp_path: Path):
img = tmp_path / "init.png"
img.write_bytes(b"img")
mask = tmp_path / "mask.png"
mask.write_bytes(b"mask")
extra = tmp_path / "extra.png"
extra.write_bytes(b"extra")
p = Img2ImgParams(
prompt="test",
init_image=str(img),
mask=str(mask),
width=768,
height=768,
inpainting_mask_invert=True,
sampler_name="euler",
scheduler="simple",
clip_skip=1,
lora=[{"path": "a.gguf", "multiplier": 1.0}],
extra_images=[str(extra)],
)
body = p.to_body()
assert body["width"] == 768
assert body["mask"]
assert body["inpainting_mask_invert"] == 1
assert body["sampler_name"] == "euler"
assert len(body["extra_images"]) == 1
# ── _http ─────────────────────────────────────────────────────────────
class TestHttpTransport:
def test_get_success(self):
with respx.mock(base_url=BASE_URL) as rsps:
rsps.get("/test").respond(json={"ok": True})
t = HttpTransport(BASE_URL)
assert t.get("/test") == {"ok": True}
t.close()
def test_post_success(self):
with respx.mock(base_url=BASE_URL) as rsps:
rsps.post("/gen").respond(json={"images": []})
t = HttpTransport(BASE_URL)
assert t.post("/gen", {"prompt": "x"}) == {"images": []}
t.close()
def test_get_http_error(self):
with respx.mock(base_url=BASE_URL) as rsps:
rsps.get("/bad").respond(status_code=404, text="not found")
t = HttpTransport(BASE_URL)
with pytest.raises(httpx.HTTPStatusError):
t.get("/bad")
t.close()
def test_post_http_error(self):
with respx.mock(base_url=BASE_URL) as rsps:
rsps.post("/bad").respond(status_code=500, text="error")
t = HttpTransport(BASE_URL)
with pytest.raises(httpx.HTTPStatusError):
t.post("/bad", {})
t.close()
def test_get_connection_error(self):
with respx.mock(base_url=BASE_URL) as rsps:
rsps.get("/fail").mock(side_effect=httpx.ConnectError("refused"))
t = HttpTransport(BASE_URL)
with pytest.raises(httpx.ConnectError):
t.get("/fail")
t.close()
# ── info ──────────────────────────────────────────────────────────────
class TestInfoAPI:
def test_models(self, mock_api: respx.MockRouter, client: SDClient):
mock_api.get("/v1/models").respond(json={"data": [{"id": "sd-cpp-local", "object": "model", "owned_by": "local"}]})
result = client.info.models()
assert len(result) == 1
assert result[0]["id"] == "sd-cpp-local"
def test_sd_models(self, mock_api: respx.MockRouter, client: SDClient):
mock_api.get("/sdapi/v1/sd-models").respond(
json=[{"title": "sdxl", "model_name": "sdxl", "filename": "sdxl.safetensors"}]
)
result = client.info.sd_models()
assert result[0]["title"] == "sdxl"
def test_options(self, mock_api: respx.MockRouter, client: SDClient):
mock_api.get("/sdapi/v1/options").respond(
json={
"samples_format": "png",
"sd_model_checkpoint": "v1-5",
}
)
result = client.info.options()
assert result["sd_model_checkpoint"] == "v1-5"
def test_loras(self, mock_api: respx.MockRouter, client: SDClient):
mock_api.get("/sdapi/v1/loras").respond(
json=[
{"name": "style", "path": "style.safetensors"},
]
)
result = client.info.loras()
assert len(result) == 1
assert result[0]["name"] == "style"
def test_samplers(self, mock_api: respx.MockRouter, client: SDClient):
mock_api.get("/sdapi/v1/samplers").respond(
json=[
{"name": "euler", "aliases": ["euler"], "options": {}},
{"name": "euler_a", "aliases": ["euler_a"], "options": {}},
]
)
result = client.info.samplers()
assert result == ["euler", "euler_a"]
def test_schedulers(self, mock_api: respx.MockRouter, client: SDClient):
mock_api.get("/sdapi/v1/schedulers").respond(
json=[
{"name": "discrete", "label": "discrete"},
{"name": "karras", "label": "karras"},
]
)
result = client.info.schedulers()
assert result == ["discrete", "karras"]
# ── generation ────────────────────────────────────────────────────────
class TestTxt2Img:
def test_returns_decoded_images(self, mock_api: respx.MockRouter, client: SDClient):
mock_api.post("/sdapi/v1/txt2img").respond(
json={
"images": [TINY_PNG_B64],
"parameters": {},
"info": "",
}
)
images = client.generate.txt2img(Txt2ImgParams(prompt="a cat"))
assert len(images) == 1
assert images[0] == TINY_PNG
def test_multiple_images(self, mock_api: respx.MockRouter, client: SDClient):
mock_api.post("/sdapi/v1/txt2img").respond(
json={
"images": [TINY_PNG_B64, TINY_PNG_B64, TINY_PNG_B64],
"parameters": {},
"info": "",
}
)
params = Txt2ImgParams(prompt="cats", batch_size=3)
images = client.generate.txt2img(params)
assert len(images) == 3
def test_sends_correct_body(self, mock_api: respx.MockRouter, client: SDClient):
route = mock_api.post("/sdapi/v1/txt2img").respond(
json={
"images": [TINY_PNG_B64],
"parameters": {},
"info": "",
}
)
params = Txt2ImgParams(
prompt="hello",
width=768,
height=768,
steps=30,
sampler_name="euler_a",
)
client.generate.txt2img(params)
sent = json.loads(route.calls[0].request.content)
assert sent["prompt"] == "hello"
assert sent["width"] == 768
assert sent["sampler_name"] == "euler_a"
class TestImg2Img:
def test_returns_decoded_images(self, mock_api: respx.MockRouter, client: SDClient, tmp_path: Path):
mock_api.post("/sdapi/v1/img2img").respond(
json={
"images": [TINY_PNG_B64],
"parameters": {},
"info": "",
}
)
img = tmp_path / "init.png"
img.write_bytes(TINY_PNG)
params = Img2ImgParams(prompt="paint", init_image=str(img))
images = client.generate.img2img(params)
assert len(images) == 1
assert images[0] == TINY_PNG
+8 -82
View File
@@ -1,12 +1,8 @@
"""Tests for tensors.server package (FastAPI sd-server proxy wrapper)."""
"""Tests for tensors.server package (gallery and CivitAI management)."""
from __future__ import annotations
from unittest.mock import AsyncMock
import httpx
import pytest
import respx
from fastapi.testclient import TestClient
from tensors.server import create_app
@@ -14,87 +10,17 @@ from tensors.server import create_app
@pytest.fixture()
def api() -> TestClient:
"""Create test client with mock sd-server URL."""
return TestClient(create_app(sd_server_url="http://mock-sd-server:1234"))
"""Create test client."""
return TestClient(create_app())
class TestStatus:
@respx.mock
def test_status_when_backend_reachable(self) -> None:
"""Test status endpoint when sd-server is reachable."""
respx.get("http://mock-sd-server:1234/").mock(return_value=httpx.Response(200))
with TestClient(create_app(sd_server_url="http://mock-sd-server:1234")) as client:
r = client.get("/status")
assert r.status_code == 200
data = r.json()
assert data["status"] == "ok"
assert data["sd_server_url"] == "http://mock-sd-server:1234"
@respx.mock
def test_status_when_backend_unreachable(self) -> None:
"""Test status endpoint when sd-server is not reachable."""
respx.get("http://mock-sd-server:1234/").mock(side_effect=httpx.ConnectError("Connection refused"))
with TestClient(create_app(sd_server_url="http://mock-sd-server:1234")) as client:
r = client.get("/status")
assert r.status_code == 200
data = r.json()
assert data["status"] == "error"
assert "Connection refused" in data["error"]
class TestProxy:
def test_proxy_forwards_request(self, api: TestClient) -> None:
"""Test proxy forwards GET requests to backend."""
upstream_response = httpx.Response(
200,
json={"data": [{"id": "model-1"}]},
headers={"content-type": "application/json"},
)
mock_client = AsyncMock()
mock_client.request.return_value = upstream_response
api.app.state.client = mock_client # type: ignore[attr-defined]
api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined]
r = api.get("/v1/models")
def test_status_ok(self, api: TestClient) -> None:
"""Test status endpoint returns ok."""
r = api.get("/status")
assert r.status_code == 200
assert r.json() == {"data": [{"id": "model-1"}]}
mock_client.request.assert_called_once()
def test_proxy_forwards_post_with_body(self, api: TestClient) -> None:
"""Test proxy forwards POST requests with body."""
upstream_response = httpx.Response(200, json={"ok": True})
mock_client = AsyncMock()
mock_client.request.return_value = upstream_response
api.app.state.client = mock_client # type: ignore[attr-defined]
api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined]
r = api.post("/sdapi/v1/txt2img", json={"prompt": "hello"})
assert r.status_code == 200
mock_client.request.assert_called_once()
def test_proxy_503_on_connect_error(self, api: TestClient) -> None:
"""Test proxy returns 503 when backend is unreachable."""
mock_client = AsyncMock()
mock_client.request.side_effect = httpx.ConnectError("Connection refused")
api.app.state.client = mock_client # type: ignore[attr-defined]
api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined]
r = api.get("/v1/models")
assert r.status_code == 503
assert "Cannot connect" in r.json()["error"]
def test_proxy_504_on_timeout(self, api: TestClient) -> None:
"""Test proxy returns 504 on timeout."""
mock_client = AsyncMock()
mock_client.request.side_effect = httpx.TimeoutException("Timeout")
api.app.state.client = mock_client # type: ignore[attr-defined]
api.app.state.sd_server_url = "http://mock-sd-server:1234" # type: ignore[attr-defined]
r = api.get("/v1/models")
assert r.status_code == 504
assert "Timeout" in r.json()["error"]
data = r.json()
assert data["status"] == "ok"
# =============================================================================