💬 Commit message: Update 2026-02-14 22:47:41, 18 files, 494 lines
📁 Files changed: 18 📝 Lines changed: 494 • deploy.md • TASK.md • justfile • deploy.sh • config.py • __init__.py • generate_routes.py • models_routes.py • routes.py • sd_client.py • index-CcuP2dTH.css • index-DmOZ-7Sw.js • index-J_qzb7Jl.js • index-QncGJEyk.css • index.html • client.ts • GenerateView.vue • app.ts
This commit is contained in:
@@ -0,0 +1,23 @@
|
||||
# Deploy Tensors to Junkpile
|
||||
|
||||
Build, deploy, and restart tensors on junkpile with verification.
|
||||
|
||||
Run the deploy script:
|
||||
|
||||
```bash
|
||||
./scripts/deploy.sh
|
||||
```
|
||||
|
||||
## What it does
|
||||
|
||||
1. **Build UI** - Runs `npm run build` in `tensors/server/ui/`
|
||||
2. **Sync code** - Rsyncs Python code and static files to junkpile
|
||||
3. **Restart tensors** - Runs `sudo systemctl restart tensors`
|
||||
4. **Verify tensors** - Checks `/api/models/status` responds
|
||||
5. **Verify sd-server** - Checks sd-server is active
|
||||
6. **Verify external** - Checks `sd-api.saiden.dev` responds
|
||||
|
||||
## Endpoints
|
||||
|
||||
- **UI**: https://tensors.saiden.dev
|
||||
- **API**: https://sd-api.saiden.dev (requires `X-API-Key` header)
|
||||
@@ -0,0 +1,74 @@
|
||||
# Tensors Refactoring Task
|
||||
|
||||
## Goal
|
||||
|
||||
Transform tensors from a "sd-server wrapper" to a **pure client** (CLI + UI) that talks to sd-server API directly.
|
||||
|
||||
- Default API URL: `https://sd-api.saiden.dev`
|
||||
- API key auth via config: `sd_server_api_key`
|
||||
- No wrapping/proxying/restarting of sd-server
|
||||
|
||||
## Current Architecture
|
||||
|
||||
### CLI Commands
|
||||
|
||||
| Command | What it does | Talks to |
|
||||
|---------|-------------|----------|
|
||||
| `tsr info <file>` | Read safetensor metadata, fetch CivitAI info | Local file + CivitAI API |
|
||||
| `tsr search` | Search CivitAI models | CivitAI API |
|
||||
| `tsr get <id>` | Fetch model info from CivitAI | CivitAI API |
|
||||
| `tsr dl <id>` | Download model from CivitAI | CivitAI API |
|
||||
| `tsr config` | Manage config (API keys, remotes) | Local config |
|
||||
| `tsr generate` | Generate images | sd-server API (direct or via wrapper) |
|
||||
| `tsr status` | Show wrapper status | tensors wrapper API |
|
||||
| **`tsr reload`** | **Reload sd-server with new model** | **tensors wrapper API** |
|
||||
| `tsr serve` | Start wrapper API (proxies to sd-server) | Starts FastAPI server |
|
||||
| `tsr db` | Manage local models DB | Local SQLite |
|
||||
| `tsr images` | Manage remote gallery | tensors wrapper API |
|
||||
| `tsr models` | List models on remote | tensors wrapper API |
|
||||
| `tsr remote` | Manage remote server config | Local config |
|
||||
|
||||
### Server (`tsr serve`)
|
||||
|
||||
Starts a FastAPI server that:
|
||||
1. **Proxies** all requests to sd-server (catch-all route)
|
||||
2. **Serves Vue UI** at `/`
|
||||
3. **Adds features**: gallery, CivitAI search, model listing, downloads
|
||||
4. **Has a `/reload` endpoint** that proxies to sd-server
|
||||
|
||||
## What to Remove
|
||||
|
||||
| Remove | Reason |
|
||||
|--------|--------|
|
||||
| `tsr reload` command | sd-server manages its own models |
|
||||
| `/reload` route in server | Same |
|
||||
| `switch_model` in client.py | Same |
|
||||
| Proxy wrapper concept | tensors should call API directly, not proxy |
|
||||
|
||||
## What to Keep/Refactor
|
||||
|
||||
| Keep | Change |
|
||||
|------|--------|
|
||||
| `tsr serve` | Just serve Vue UI, no proxying |
|
||||
| `tsr generate` | Call sd-server API directly with API key |
|
||||
| Vue UI | Call sd-server API directly (already does via `/api/*`) |
|
||||
| `tsr models` | List models from sd-server API directly |
|
||||
|
||||
## Already Done
|
||||
|
||||
- [x] Added `get_sd_server_api_key()` to config.py
|
||||
- [x] Added `sd_server_api_key` to app state in server/__init__.py
|
||||
- [x] Created `sd_client.py` with `get_sd_headers()` helper
|
||||
- [x] Updated `generate_routes.py` to use API key headers
|
||||
- [x] Updated `routes.py` to use API key headers
|
||||
- [x] Updated `models_routes.py` to use API key headers
|
||||
- [x] Created local config at `~/.xdg/tensors/config.toml` with sd-api.saiden.dev
|
||||
|
||||
## Still TODO
|
||||
|
||||
- [ ] Remove `tsr reload` command from cli.py
|
||||
- [ ] Remove `/reload` route if it exists
|
||||
- [ ] Remove `switch_model` from client.py
|
||||
- [ ] Decide: keep `tsr serve` as UI server or remove entirely?
|
||||
- [ ] Update Vue UI to call sd-server API directly (not via `/api/*` proxy)
|
||||
- [ ] Clean up unused wrapper/proxy code
|
||||
@@ -34,3 +34,7 @@ ui-dev:
|
||||
# Build UI for production
|
||||
ui-build:
|
||||
cd tensors/server/ui && npm run build
|
||||
|
||||
# Deploy to junkpile (build, sync, restart, verify)
|
||||
deploy:
|
||||
./scripts/deploy.sh
|
||||
|
||||
Executable
+72
@@ -0,0 +1,72 @@
|
||||
#!/usr/bin/env bash
|
||||
# Deploy tensors to junkpile
|
||||
# Usage: ./scripts/deploy.sh
|
||||
|
||||
set -e
|
||||
|
||||
REMOTE="chi@junkpile"
|
||||
REMOTE_DIR="~/Projects/tensors"
|
||||
LOCAL_DIR="$(cd "$(dirname "$0")/.." && pwd)"
|
||||
|
||||
echo "==> Building UI..."
|
||||
cd "$LOCAL_DIR/tensors/server/ui"
|
||||
npm run build
|
||||
|
||||
echo ""
|
||||
echo "==> Syncing Python code to junkpile..."
|
||||
rsync -av --delete \
|
||||
--exclude='.git' \
|
||||
--exclude='__pycache__' \
|
||||
--exclude='.venv' \
|
||||
--exclude='node_modules' \
|
||||
--exclude='.ruff_cache' \
|
||||
--exclude='.mypy_cache' \
|
||||
--exclude='.pytest_cache' \
|
||||
--exclude='*.egg-info' \
|
||||
"$LOCAL_DIR/tensors/" "$REMOTE:$REMOTE_DIR/tensors/"
|
||||
|
||||
echo ""
|
||||
echo "==> Restarting tensors service..."
|
||||
ssh "$REMOTE" "sudo systemctl restart tensors"
|
||||
|
||||
echo ""
|
||||
echo "==> Waiting for tensors to start..."
|
||||
sleep 2
|
||||
|
||||
echo ""
|
||||
echo "==> Verifying tensors API..."
|
||||
TENSORS_STATUS=$(ssh "$REMOTE" "curl -s localhost:8081/api/models/status" 2>/dev/null)
|
||||
if echo "$TENSORS_STATUS" | grep -q '"active":true'; then
|
||||
echo "✓ tensors API responding"
|
||||
echo " Current model: $(echo "$TENSORS_STATUS" | jq -r '.current_model' | xargs basename)"
|
||||
else
|
||||
echo "✗ tensors API not responding"
|
||||
echo "$TENSORS_STATUS"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "==> Verifying sd-server..."
|
||||
SD_STATUS=$(ssh "$REMOTE" "curl -s localhost:1234/sdapi/v1/sd-models" 2>/dev/null)
|
||||
if echo "$SD_STATUS" | grep -q 'model_name'; then
|
||||
echo "✓ sd-server responding"
|
||||
echo " Models available: $(echo "$SD_STATUS" | jq length)"
|
||||
else
|
||||
echo "✗ sd-server not responding"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "==> Verifying external access..."
|
||||
EXT_STATUS=$(curl -s -H "X-API-Key: v00YKDdHzLmwTLUJ07iMn4umLvcsKa9i" https://sd-api.saiden.dev/sdapi/v1/sd-models 2>/dev/null)
|
||||
if echo "$EXT_STATUS" | grep -q 'model_name'; then
|
||||
echo "✓ sd-api.saiden.dev responding"
|
||||
else
|
||||
echo "✗ sd-api.saiden.dev not responding"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "==> Deploy complete!"
|
||||
echo " UI: https://tensors.saiden.dev"
|
||||
echo " API: https://sd-api.saiden.dev"
|
||||
@@ -274,3 +274,27 @@ def get_sd_server_url() -> str:
|
||||
return str(url)
|
||||
|
||||
return SD_SERVER_DEFAULT_URL
|
||||
|
||||
|
||||
def get_sd_server_api_key() -> str | None:
|
||||
"""Get the sd-server API key.
|
||||
|
||||
Resolution order:
|
||||
1. SD_SERVER_API_KEY environment variable
|
||||
2. config.toml [server].sd_server_api_key
|
||||
3. None (no authentication)
|
||||
"""
|
||||
# Check environment variable first
|
||||
env_key = os.environ.get("SD_SERVER_API_KEY")
|
||||
if env_key:
|
||||
return env_key
|
||||
|
||||
# Check config file
|
||||
config = load_config()
|
||||
server_config = config.get("server", {})
|
||||
if isinstance(server_config, dict):
|
||||
key = server_config.get("sd_server_api_key")
|
||||
if key:
|
||||
return str(key)
|
||||
|
||||
return None
|
||||
|
||||
@@ -12,7 +12,7 @@ from fastapi import FastAPI
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from tensors.config import get_sd_server_url
|
||||
from tensors.config import get_sd_server_api_key, get_sd_server_url
|
||||
from tensors.server.civitai_routes import create_civitai_router
|
||||
from tensors.server.db_routes import create_db_router
|
||||
from tensors.server.download_routes import create_download_router
|
||||
@@ -37,11 +37,15 @@ def create_app(sd_server_url: str | None = None) -> FastAPI:
|
||||
get_sd_server_url() to resolve from env/config.
|
||||
"""
|
||||
backend_url = sd_server_url or get_sd_server_url()
|
||||
api_key = get_sd_server_api_key()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
||||
_app.state.sd_server_url = backend_url
|
||||
_app.state.sd_server_api_key = api_key
|
||||
logger.info(f"Proxying to sd-server at: {backend_url}")
|
||||
if api_key:
|
||||
logger.info("Using API key authentication for sd-server")
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
_app.state.client = client
|
||||
yield
|
||||
|
||||
@@ -13,6 +13,7 @@ from pydantic import BaseModel as PydanticBaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from tensors.server.gallery import Gallery
|
||||
from tensors.server.sd_client import get_sd_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -141,8 +142,9 @@ def create_generate_router() -> APIRouter:
|
||||
url = f"{sd_server_url}/sdapi/v1/txt2img"
|
||||
|
||||
try:
|
||||
headers = get_sd_headers(request)
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
response = await client.post(url, json=body)
|
||||
response = await client.post(url, json=body, headers=headers)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
except httpx.ConnectError as e:
|
||||
@@ -178,8 +180,9 @@ def create_generate_router() -> APIRouter:
|
||||
url = f"{sd_server_url}/sdapi/v1/samplers"
|
||||
|
||||
try:
|
||||
headers = get_sd_headers(request)
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.get(url)
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
return {"samplers": response.json()}
|
||||
except httpx.ConnectError as e:
|
||||
@@ -194,8 +197,9 @@ def create_generate_router() -> APIRouter:
|
||||
url = f"{sd_server_url}/sdapi/v1/schedulers"
|
||||
|
||||
try:
|
||||
headers = get_sd_headers(request)
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.get(url)
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
return {"schedulers": response.json()}
|
||||
except httpx.ConnectError as e:
|
||||
|
||||
@@ -2,19 +2,58 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from tensors.config import MODELS_DIR
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
from tensors.server.sd_client import get_sd_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_HTTP_OK = 200
|
||||
_SD_ENV_FILE = Path("/etc/default/sd-server")
|
||||
|
||||
|
||||
class SwitchModelRequest(BaseModel):
|
||||
"""Request body for switching models."""
|
||||
|
||||
model: str # Model filename or full path
|
||||
|
||||
|
||||
async def _run_command(*args: str) -> tuple[int, str, str]:
|
||||
"""Run a shell command and return (returncode, stdout, stderr)."""
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*args,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await proc.communicate()
|
||||
return proc.returncode or 0, stdout.decode(), stderr.decode()
|
||||
|
||||
|
||||
def _read_env_file() -> dict[str, str]:
|
||||
"""Read the sd-server environment file."""
|
||||
env: dict[str, str] = {}
|
||||
if _SD_ENV_FILE.exists():
|
||||
for raw_line in _SD_ENV_FILE.read_text().splitlines():
|
||||
line = raw_line.strip()
|
||||
if line and not line.startswith("#") and "=" in line:
|
||||
key, _, value = line.partition("=")
|
||||
env[key.strip()] = value.strip()
|
||||
return env
|
||||
|
||||
|
||||
def _write_env_file(env: dict[str, str]) -> str:
|
||||
"""Generate env file content."""
|
||||
lines = ["# sd-server configuration"]
|
||||
for key, value in env.items():
|
||||
lines.append(f"{key}={value}")
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
# Keywords for detecting base model category
|
||||
_SD15_KEYWORDS = ("sd15", "sd1.5", "sd-1.5", "sd_1.5", "1.5", "sd-1-", "v1-5")
|
||||
@@ -110,8 +149,9 @@ def create_models_router() -> APIRouter:
|
||||
|
||||
# Try to get current model from sd-server's options endpoint
|
||||
try:
|
||||
headers = get_sd_headers(request)
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.get(f"{sd_server_url}/sdapi/v1/options")
|
||||
response = await client.get(f"{sd_server_url}/sdapi/v1/options", headers=headers)
|
||||
if response.status_code == _HTTP_OK:
|
||||
options = response.json()
|
||||
model_name = options.get("sd_model_checkpoint")
|
||||
@@ -152,4 +192,67 @@ def create_models_router() -> APIRouter:
|
||||
"total_loras": len(loras),
|
||||
}
|
||||
|
||||
@router.post("/switch")
|
||||
async def switch_model(req: SwitchModelRequest) -> dict[str, Any]:
|
||||
"""Switch sd-server to a different model by updating env and restarting."""
|
||||
# Find the model file
|
||||
checkpoints = scan_checkpoints()
|
||||
model_path: str | None = None
|
||||
|
||||
for cp in checkpoints:
|
||||
if cp["filename"] == req.model or cp["path"] == req.model or cp["name"] == req.model:
|
||||
model_path = cp["path"]
|
||||
break
|
||||
|
||||
if not model_path:
|
||||
raise HTTPException(status_code=404, detail=f"Model not found: {req.model}")
|
||||
|
||||
# Read current env, update SD_MODEL
|
||||
env = _read_env_file()
|
||||
old_model = env.get("SD_MODEL", "")
|
||||
env["SD_MODEL"] = model_path
|
||||
|
||||
# Write new env file via sudo tee
|
||||
new_content = _write_env_file(env)
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"sudo", "tee", str(_SD_ENV_FILE),
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
_, stderr = await proc.communicate(new_content.encode())
|
||||
if proc.returncode != 0:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to write env file: {stderr.decode()}")
|
||||
|
||||
# Restart sd-server
|
||||
returncode, _stdout, stderr = await _run_command("sudo", "systemctl", "restart", "sd-server")
|
||||
if returncode != 0:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to restart sd-server: {stderr}")
|
||||
|
||||
logger.info(f"Switched model from {old_model} to {model_path}")
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"old_model": old_model,
|
||||
"new_model": model_path,
|
||||
"message": "Model switched, sd-server restarting",
|
||||
}
|
||||
|
||||
@router.get("/status")
|
||||
async def sd_server_status() -> dict[str, Any]:
|
||||
"""Get sd-server systemd service status."""
|
||||
_returncode, stdout, _stderr = await _run_command("systemctl", "is-active", "sd-server")
|
||||
is_active = stdout.strip() == "active"
|
||||
|
||||
env = _read_env_file()
|
||||
|
||||
return {
|
||||
"service": "sd-server",
|
||||
"active": is_active,
|
||||
"status": stdout.strip(),
|
||||
"current_model": env.get("SD_MODEL"),
|
||||
"host": env.get("SD_HOST"),
|
||||
"port": env.get("SD_PORT"),
|
||||
}
|
||||
|
||||
return router
|
||||
|
||||
@@ -9,6 +9,8 @@ import httpx
|
||||
from fastapi import APIRouter, Request, Response
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from tensors.server.sd_client import get_sd_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -21,8 +23,9 @@ def create_router() -> APIRouter:
|
||||
"""Check if the external sd-server is reachable."""
|
||||
sd_server_url = request.app.state.sd_server_url
|
||||
try:
|
||||
headers = get_sd_headers(request)
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
r = await client.get(sd_server_url)
|
||||
r = await client.get(sd_server_url, headers=headers)
|
||||
return {
|
||||
"status": "ok",
|
||||
"sd_server_url": sd_server_url,
|
||||
@@ -46,6 +49,8 @@ def create_router() -> APIRouter:
|
||||
body = await request.body()
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
# Add API key if configured
|
||||
headers.update(get_sd_headers(request))
|
||||
client = request.app.state.client
|
||||
|
||||
try:
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
"""HTTP client utilities for sd-server communication."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import Request
|
||||
|
||||
|
||||
def get_sd_headers(request: Request) -> dict[str, str]:
|
||||
"""Get headers for sd-server requests, including API key if configured."""
|
||||
headers: dict[str, str] = {}
|
||||
api_key = getattr(request.app.state, "sd_server_api_key", None)
|
||||
if api_key:
|
||||
headers["X-API-Key"] = api_key
|
||||
return headers
|
||||
|
||||
|
||||
async def sd_get(request: Request, path: str, *, timeout: float = 30) -> httpx.Response:
|
||||
"""Make a GET request to sd-server."""
|
||||
url = f"{request.app.state.sd_server_url}/{path.lstrip('/')}"
|
||||
headers = get_sd_headers(request)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
|
||||
async def sd_post(request: Request, path: str, *, json: dict[str, Any] | None = None, timeout: float = 300) -> httpx.Response:
|
||||
"""Make a POST request to sd-server."""
|
||||
url = f"{request.app.state.sd_server_url}/{path.lstrip('/')}"
|
||||
headers = get_sd_headers(request)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(url, json=json, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -5,8 +5,8 @@
|
||||
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Tensors</title>
|
||||
<script type="module" crossorigin src="/assets/index-J_qzb7Jl.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="/assets/index-QncGJEyk.css">
|
||||
<script type="module" crossorigin src="/assets/index-DmOZ-7Sw.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="/assets/index-CcuP2dTH.css">
|
||||
</head>
|
||||
<body>
|
||||
<div id="app"></div>
|
||||
|
||||
@@ -26,13 +26,26 @@ export async function getActiveModel(): Promise<{ loaded: boolean; model: string
|
||||
return fetchJson('/api/models/active')
|
||||
}
|
||||
|
||||
export async function switchModel(model: string): Promise<{ ok: boolean }> {
|
||||
export async function switchModel(model: string): Promise<{ ok: boolean; old_model: string; new_model: string }> {
|
||||
return fetchJson('/api/models/switch', {
|
||||
method: 'POST',
|
||||
body: JSON.stringify({ model }),
|
||||
})
|
||||
}
|
||||
|
||||
export interface ServerStatus {
|
||||
service: string
|
||||
active: boolean
|
||||
status: string
|
||||
current_model: string | null
|
||||
host: string | null
|
||||
port: string | null
|
||||
}
|
||||
|
||||
export async function getServerStatus(): Promise<ServerStatus> {
|
||||
return fetchJson('/api/models/status')
|
||||
}
|
||||
|
||||
export async function getLoras(): Promise<{ loras: LoRA[]; total: number }> {
|
||||
return fetchJson('/api/models/loras')
|
||||
}
|
||||
|
||||
@@ -9,6 +9,16 @@ const store = useAppStore()
|
||||
const prompt = ref('')
|
||||
const generating = ref(false)
|
||||
|
||||
// Snackbar states
|
||||
const showError = computed({
|
||||
get: () => !!store.switchError,
|
||||
set: () => { store.switchError = null }
|
||||
})
|
||||
const showSuccess = computed({
|
||||
get: () => !store.switchingModel && !!store.switchMessage && !store.switchError,
|
||||
set: () => { store.switchMessage = null }
|
||||
})
|
||||
|
||||
interface ChatMessage {
|
||||
prompt: string
|
||||
params: string
|
||||
@@ -95,6 +105,29 @@ async function generate() {
|
||||
|
||||
<template>
|
||||
<v-container fluid class="fill-height pa-0 d-flex flex-column">
|
||||
<!-- Model switch overlay -->
|
||||
<v-overlay
|
||||
:model-value="store.switchingModel"
|
||||
class="align-center justify-center"
|
||||
persistent
|
||||
>
|
||||
<v-card class="pa-6 text-center" min-width="300">
|
||||
<v-progress-circular indeterminate color="primary" size="48" class="mb-4" />
|
||||
<div class="text-h6">{{ store.switchMessage || 'Switching model...' }}</div>
|
||||
<div class="text-caption text-grey mt-2">sd-server is restarting</div>
|
||||
</v-card>
|
||||
</v-overlay>
|
||||
|
||||
<!-- Error snackbar -->
|
||||
<v-snackbar v-model="showError" color="error" timeout="5000">
|
||||
{{ store.switchError }}
|
||||
</v-snackbar>
|
||||
|
||||
<!-- Success snackbar -->
|
||||
<v-snackbar v-model="showSuccess" color="success" timeout="3000">
|
||||
{{ store.switchMessage }}
|
||||
</v-snackbar>
|
||||
|
||||
<!-- Chat area -->
|
||||
<v-container fluid class="flex-grow-1 overflow-y-auto pa-4">
|
||||
<div v-if="messages.length === 0" class="text-center text-grey mt-16">
|
||||
|
||||
@@ -66,6 +66,8 @@ export const useAppStore = defineStore('app', () => {
|
||||
// Loading states
|
||||
const loadingModels = ref(false)
|
||||
const switchingModel = ref(false)
|
||||
const switchMessage = ref<string | null>(null)
|
||||
const switchError = ref<string | null>(null)
|
||||
|
||||
// Actions
|
||||
async function loadModels() {
|
||||
@@ -93,12 +95,38 @@ export const useAppStore = defineStore('app', () => {
|
||||
if (modelPath === activeModel.value) return
|
||||
|
||||
switchingModel.value = true
|
||||
switchMessage.value = 'Switching model...'
|
||||
switchError.value = null
|
||||
|
||||
try {
|
||||
await api.switchModel(modelPath)
|
||||
activeModel.value = modelPath
|
||||
selectedModel.value = modelPath
|
||||
} catch (error) {
|
||||
const result = await api.switchModel(modelPath)
|
||||
switchMessage.value = 'Restarting sd-server...'
|
||||
|
||||
// Poll for server to come back online (up to 60 seconds)
|
||||
let attempts = 0
|
||||
const maxAttempts = 30
|
||||
while (attempts < maxAttempts) {
|
||||
await new Promise(resolve => setTimeout(resolve, 2000))
|
||||
try {
|
||||
const status = await api.getServerStatus()
|
||||
if (status.active && status.current_model?.includes(modelPath.split('/').pop() || '')) {
|
||||
activeModel.value = result.new_model
|
||||
selectedModel.value = result.new_model
|
||||
switchMessage.value = 'Model switched successfully'
|
||||
setTimeout(() => { switchMessage.value = null }, 3000)
|
||||
return
|
||||
}
|
||||
} catch {
|
||||
// Server still restarting, continue polling
|
||||
}
|
||||
attempts++
|
||||
switchMessage.value = `Waiting for sd-server... (${attempts}/${maxAttempts})`
|
||||
}
|
||||
throw new Error('Timeout waiting for sd-server to restart')
|
||||
} catch (error: any) {
|
||||
console.error('Failed to switch model:', error)
|
||||
switchError.value = error.message || 'Failed to switch model'
|
||||
setTimeout(() => { switchError.value = null }, 5000)
|
||||
throw error
|
||||
} finally {
|
||||
switchingModel.value = false
|
||||
@@ -132,6 +160,8 @@ export const useAppStore = defineStore('app', () => {
|
||||
// Loading states
|
||||
loadingModels,
|
||||
switchingModel,
|
||||
switchMessage,
|
||||
switchError,
|
||||
|
||||
// Actions
|
||||
loadModels,
|
||||
|
||||
Reference in New Issue
Block a user