💬 Commit message: Update 2026-02-15 00:07:11, 20 files, 231 lines
📁 Files changed: 20 📝 Lines changed: 231 • .coverage • models.db • screenshot.png • tensors-deployed.png • tensors-dropdown.png • tensors-final.png • tensors-fixed.png • tensors-reloaded.png • tensors-ui.png • civitai_routes.py • generate_routes.py • models_routes.py • sd_client.py • index-BQdjF_w0.css • index-CKJOpgtQ.js • index-DmOZ-7Sw.js • index.html • GenerateView.vue • app.ts • index.ts
|
Before Width: | Height: | Size: 48 KiB After Width: | Height: | Size: 55 KiB |
|
After Width: | Height: | Size: 55 KiB |
|
After Width: | Height: | Size: 80 KiB |
|
Before Width: | Height: | Size: 46 KiB After Width: | Height: | Size: 57 KiB |
|
After Width: | Height: | Size: 55 KiB |
|
After Width: | Height: | Size: 55 KiB |
|
After Width: | Height: | Size: 55 KiB |
@@ -10,6 +10,7 @@ from fastapi import APIRouter, Query, Response
|
|||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from tensors.config import CIVITAI_API_BASE, load_api_key
|
from tensors.config import CIVITAI_API_BASE, load_api_key
|
||||||
|
from tensors.db import Database
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -56,6 +57,18 @@ async def search_models(
|
|||||||
response = await client.get(url, params=params, headers=_get_headers(api_key))
|
response = await client.get(url, params=params, headers=_get_headers(api_key))
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result: dict[str, Any] = response.json()
|
result: dict[str, Any] = response.json()
|
||||||
|
|
||||||
|
# Cache all models from search results
|
||||||
|
items = result.get("items", [])
|
||||||
|
if items:
|
||||||
|
try:
|
||||||
|
with Database() as db:
|
||||||
|
db.init_schema()
|
||||||
|
for model_data in items:
|
||||||
|
db.cache_model(model_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to cache search results: %s", e)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
logger.error("CivitAI API error: %s", e.response.status_code)
|
logger.error("CivitAI API error: %s", e.response.status_code)
|
||||||
@@ -67,7 +80,7 @@ async def search_models(
|
|||||||
|
|
||||||
@router.get("/model/{model_id}", response_model=None)
|
@router.get("/model/{model_id}", response_model=None)
|
||||||
async def get_model(model_id: int) -> dict[str, Any] | Response:
|
async def get_model(model_id: int) -> dict[str, Any] | Response:
|
||||||
"""Get model details from CivitAI."""
|
"""Get model details from CivitAI and cache to database."""
|
||||||
api_key = load_api_key()
|
api_key = load_api_key()
|
||||||
url = f"{CIVITAI_API_BASE}/models/{model_id}"
|
url = f"{CIVITAI_API_BASE}/models/{model_id}"
|
||||||
|
|
||||||
@@ -76,6 +89,15 @@ async def get_model(model_id: int) -> dict[str, Any] | Response:
|
|||||||
response = await client.get(url, headers=_get_headers(api_key))
|
response = await client.get(url, headers=_get_headers(api_key))
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result: dict[str, Any] = response.json()
|
result: dict[str, Any] = response.json()
|
||||||
|
|
||||||
|
# Cache the model data to database
|
||||||
|
try:
|
||||||
|
with Database() as db:
|
||||||
|
db.init_schema()
|
||||||
|
db.cache_model(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to cache model %d: %s", model_id, e)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
except httpx.HTTPStatusError:
|
except httpx.HTTPStatusError:
|
||||||
return JSONResponse({"error": "Model not found"}, status_code=404)
|
return JSONResponse({"error": "Model not found"}, status_code=404)
|
||||||
|
|||||||
@@ -129,7 +129,7 @@ def _process_image(
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def create_generate_router() -> APIRouter:
|
def create_generate_router() -> APIRouter: # noqa: PLR0915
|
||||||
"""Build a router with /api/generate endpoint."""
|
"""Build a router with /api/generate endpoint."""
|
||||||
router = APIRouter(prefix="/api", tags=["generate"])
|
router = APIRouter(prefix="/api", tags=["generate"])
|
||||||
gallery = Gallery()
|
gallery = Gallery()
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from fastapi import APIRouter, HTTPException, Request
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from tensors.config import MODELS_DIR
|
from tensors.config import MODELS_DIR
|
||||||
|
from tensors.db import Database
|
||||||
from tensors.server.sd_client import get_sd_headers
|
from tensors.server.sd_client import get_sd_headers
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -110,6 +111,56 @@ def scan_models(directory: Path, extensions: tuple[str, ...] = (".safetensors",
|
|||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
|
def _enrich_with_metadata(models: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Enrich model data with CivitAI metadata from database."""
|
||||||
|
try:
|
||||||
|
with Database() as db:
|
||||||
|
db.init_schema()
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
file_path = model.get("path", "")
|
||||||
|
file_info = db.get_local_file_by_path(file_path)
|
||||||
|
|
||||||
|
if file_info and file_info.get("civitai_model_id"):
|
||||||
|
# Add human-readable name
|
||||||
|
model["display_name"] = file_info.get("model_name") or model["name"]
|
||||||
|
model["base_model"] = file_info.get("base_model")
|
||||||
|
model["model_type"] = file_info.get("model_type")
|
||||||
|
model["civitai_model_id"] = file_info.get("civitai_model_id")
|
||||||
|
model["civitai_version_id"] = file_info.get("civitai_version_id")
|
||||||
|
|
||||||
|
# Get thumbnail from version images
|
||||||
|
version_id = file_info.get("civitai_version_id")
|
||||||
|
if version_id:
|
||||||
|
cur = db.conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT url FROM version_images
|
||||||
|
WHERE version_id = (SELECT id FROM model_versions WHERE civitai_id = ?)
|
||||||
|
ORDER BY id LIMIT 1
|
||||||
|
""",
|
||||||
|
(version_id,),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
if row:
|
||||||
|
model["thumbnail_url"] = row[0]
|
||||||
|
|
||||||
|
# Get trigger words
|
||||||
|
triggers = db.get_triggers_by_version(version_id) if version_id else []
|
||||||
|
model["triggers"] = triggers[:5] # Limit to first 5
|
||||||
|
else:
|
||||||
|
model["display_name"] = model["name"]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to enrich models with metadata: %s", e)
|
||||||
|
# Fallback: just use filename as display name
|
||||||
|
for model in models:
|
||||||
|
if "display_name" not in model:
|
||||||
|
model["display_name"] = model["name"]
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
|
||||||
def scan_loras(directory: Path | None = None) -> list[dict[str, Any]]:
|
def scan_loras(directory: Path | None = None) -> list[dict[str, Any]]:
|
||||||
"""Scan for LoRA files."""
|
"""Scan for LoRA files."""
|
||||||
lora_dir = directory or MODELS_DIR / "loras"
|
lora_dir = directory or MODELS_DIR / "loras"
|
||||||
@@ -127,14 +178,15 @@ def scan_checkpoints(directory: Path | None = None) -> list[dict[str, Any]]:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def create_models_router() -> APIRouter:
|
def create_models_router() -> APIRouter: # noqa: PLR0915
|
||||||
"""Build a router with /api/models/* endpoints."""
|
"""Build a router with /api/models/* endpoints."""
|
||||||
router = APIRouter(prefix="/api/models", tags=["models"])
|
router = APIRouter(prefix="/api/models", tags=["models"])
|
||||||
|
|
||||||
@router.get("")
|
@router.get("")
|
||||||
def list_models() -> dict[str, Any]:
|
def list_models() -> dict[str, Any]:
|
||||||
"""List available checkpoint models."""
|
"""List available checkpoint models with metadata."""
|
||||||
checkpoints = scan_checkpoints()
|
checkpoints = scan_checkpoints()
|
||||||
|
checkpoints = _enrich_with_metadata(checkpoints)
|
||||||
return {
|
return {
|
||||||
"models": checkpoints,
|
"models": checkpoints,
|
||||||
"total": len(checkpoints),
|
"total": len(checkpoints),
|
||||||
@@ -172,8 +224,9 @@ def create_models_router() -> APIRouter:
|
|||||||
|
|
||||||
@router.get("/loras")
|
@router.get("/loras")
|
||||||
def list_loras() -> dict[str, Any]:
|
def list_loras() -> dict[str, Any]:
|
||||||
"""List available LoRA files."""
|
"""List available LoRA files with metadata."""
|
||||||
loras = scan_loras()
|
loras = scan_loras()
|
||||||
|
loras = _enrich_with_metadata(loras)
|
||||||
return {
|
return {
|
||||||
"loras": loras,
|
"loras": loras,
|
||||||
"total": len(loras),
|
"total": len(loras),
|
||||||
@@ -220,14 +273,14 @@ def create_models_router() -> APIRouter:
|
|||||||
stdout=asyncio.subprocess.PIPE,
|
stdout=asyncio.subprocess.PIPE,
|
||||||
stderr=asyncio.subprocess.PIPE,
|
stderr=asyncio.subprocess.PIPE,
|
||||||
)
|
)
|
||||||
_, stderr = await proc.communicate(new_content.encode())
|
_, tee_stderr = await proc.communicate(new_content.encode())
|
||||||
if proc.returncode != 0:
|
if proc.returncode != 0:
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to write env file: {stderr.decode()}")
|
raise HTTPException(status_code=500, detail=f"Failed to write env file: {tee_stderr.decode()}")
|
||||||
|
|
||||||
# Restart sd-server
|
# Restart sd-server
|
||||||
returncode, _stdout, stderr = await _run_command("sudo", "systemctl", "restart", "sd-server")
|
returncode, _stdout, restart_stderr = await _run_command("sudo", "systemctl", "restart", "sd-server")
|
||||||
if returncode != 0:
|
if returncode != 0:
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to restart sd-server: {stderr}")
|
raise HTTPException(status_code=500, detail=f"Failed to restart sd-server: {restart_stderr}")
|
||||||
|
|
||||||
logger.info(f"Switched model from {old_model} to {model_path}")
|
logger.info(f"Switched model from {old_model} to {model_path}")
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,12 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import Request
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
|
||||||
def get_sd_headers(request: Request) -> dict[str, str]:
|
def get_sd_headers(request: Request) -> dict[str, str]:
|
||||||
|
|||||||
@@ -5,8 +5,8 @@
|
|||||||
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
<title>Tensors</title>
|
<title>Tensors</title>
|
||||||
<script type="module" crossorigin src="/assets/index-DmOZ-7Sw.js"></script>
|
<script type="module" crossorigin src="/assets/index-CKJOpgtQ.js"></script>
|
||||||
<link rel="stylesheet" crossorigin href="/assets/index-CcuP2dTH.css">
|
<link rel="stylesheet" crossorigin href="/assets/index-BQdjF_w0.css">
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<div id="app"></div>
|
<div id="app"></div>
|
||||||
|
|||||||
@@ -30,12 +30,22 @@ interface ChatMessage {
|
|||||||
const messages = ref<ChatMessage[]>([])
|
const messages = ref<ChatMessage[]>([])
|
||||||
|
|
||||||
const modelItems = computed(() =>
|
const modelItems = computed(() =>
|
||||||
store.models.map(m => ({ title: m.name, value: m.path }))
|
store.models.map(m => ({
|
||||||
|
title: m.display_name || m.name,
|
||||||
|
value: m.path,
|
||||||
|
thumbnail: m.thumbnail_url,
|
||||||
|
base_model: m.base_model,
|
||||||
|
}))
|
||||||
)
|
)
|
||||||
|
|
||||||
const loraItems = computed(() => [
|
const loraItems = computed(() => [
|
||||||
{ title: 'None', value: '' },
|
{ title: 'None', value: '', thumbnail: null, triggers: [] },
|
||||||
...store.filteredLoras.map(l => ({ title: l.name, value: l.path }))
|
...store.filteredLoras.map(l => ({
|
||||||
|
title: l.display_name || l.name,
|
||||||
|
value: l.path,
|
||||||
|
thumbnail: l.thumbnail_url,
|
||||||
|
triggers: l.triggers || [],
|
||||||
|
}))
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
@@ -187,9 +197,33 @@ async function generate() {
|
|||||||
:disabled="store.switchingModel || generating"
|
:disabled="store.switchingModel || generating"
|
||||||
density="compact"
|
density="compact"
|
||||||
hide-details
|
hide-details
|
||||||
style="width: 200px"
|
style="width: 280px"
|
||||||
@update:model-value="handleModelChange"
|
@update:model-value="handleModelChange"
|
||||||
/>
|
>
|
||||||
|
<template #selection="{ item }">
|
||||||
|
<div class="d-flex align-center ga-2">
|
||||||
|
<v-avatar v-if="item.raw.thumbnail" size="24" rounded="sm">
|
||||||
|
<v-img :src="item.raw.thumbnail" cover />
|
||||||
|
</v-avatar>
|
||||||
|
<v-icon v-else size="24" color="grey">mdi-cube-outline</v-icon>
|
||||||
|
<span class="text-truncate">{{ item.title }}</span>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
<template #item="{ item, props }">
|
||||||
|
<v-list-item v-bind="props" :title="undefined">
|
||||||
|
<template #prepend>
|
||||||
|
<v-avatar v-if="item.raw.thumbnail" size="32" rounded="sm" class="mr-3">
|
||||||
|
<v-img :src="item.raw.thumbnail" cover />
|
||||||
|
</v-avatar>
|
||||||
|
<v-icon v-else size="32" color="grey" class="mr-3">mdi-cube-outline</v-icon>
|
||||||
|
</template>
|
||||||
|
<v-list-item-title>{{ item.title }}</v-list-item-title>
|
||||||
|
<v-list-item-subtitle v-if="item.raw.base_model" class="text-caption">
|
||||||
|
{{ item.raw.base_model }}
|
||||||
|
</v-list-item-subtitle>
|
||||||
|
</v-list-item>
|
||||||
|
</template>
|
||||||
|
</v-select>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="d-flex align-center ga-2">
|
<div class="d-flex align-center ga-2">
|
||||||
@@ -200,8 +234,32 @@ async function generate() {
|
|||||||
:disabled="generating"
|
:disabled="generating"
|
||||||
density="compact"
|
density="compact"
|
||||||
hide-details
|
hide-details
|
||||||
style="width: 150px"
|
style="width: 200px"
|
||||||
/>
|
>
|
||||||
|
<template #selection="{ item }">
|
||||||
|
<div class="d-flex align-center ga-2">
|
||||||
|
<v-avatar v-if="item.raw.thumbnail" size="24" rounded="sm">
|
||||||
|
<v-img :src="item.raw.thumbnail" cover />
|
||||||
|
</v-avatar>
|
||||||
|
<v-icon v-else size="24" color="grey">mdi-shimmer</v-icon>
|
||||||
|
<span class="text-truncate">{{ item.title }}</span>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
<template #item="{ item, props }">
|
||||||
|
<v-list-item v-bind="props" :title="undefined">
|
||||||
|
<template #prepend>
|
||||||
|
<v-avatar v-if="item.raw.thumbnail" size="32" rounded="sm" class="mr-3">
|
||||||
|
<v-img :src="item.raw.thumbnail" cover />
|
||||||
|
</v-avatar>
|
||||||
|
<v-icon v-else size="32" color="grey" class="mr-3">mdi-shimmer</v-icon>
|
||||||
|
</template>
|
||||||
|
<v-list-item-title>{{ item.title }}</v-list-item-title>
|
||||||
|
<v-list-item-subtitle v-if="item.raw.triggers?.length" class="text-caption text-truncate">
|
||||||
|
{{ item.raw.triggers.slice(0, 2).join(', ') }}
|
||||||
|
</v-list-item-subtitle>
|
||||||
|
</v-list-item>
|
||||||
|
</template>
|
||||||
|
</v-select>
|
||||||
<v-text-field
|
<v-text-field
|
||||||
v-model.number="store.loraWeight"
|
v-model.number="store.loraWeight"
|
||||||
type="number"
|
type="number"
|
||||||
|
|||||||
@@ -82,7 +82,13 @@ export const useAppStore = defineStore('app', () => {
|
|||||||
loras.value = lorasRes.loras
|
loras.value = lorasRes.loras
|
||||||
activeModel.value = activeRes.model
|
activeModel.value = activeRes.model
|
||||||
if (activeRes.model) {
|
if (activeRes.model) {
|
||||||
selectedModel.value = activeRes.model
|
// Find full path for the active model (API returns model name without extension, v-select uses path)
|
||||||
|
const activeModelPath = modelsRes.models.find(m =>
|
||||||
|
m.name === activeRes.model ||
|
||||||
|
m.filename === activeRes.model ||
|
||||||
|
m.filename.startsWith(activeRes.model + '.')
|
||||||
|
)?.path
|
||||||
|
selectedModel.value = activeModelPath || activeRes.model
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Failed to load models:', error)
|
console.error('Failed to load models:', error)
|
||||||
|
|||||||
@@ -5,6 +5,14 @@ export interface Model {
|
|||||||
size_mb: number
|
size_mb: number
|
||||||
modified: number
|
modified: number
|
||||||
category: 'sd15' | 'large'
|
category: 'sd15' | 'large'
|
||||||
|
// Enriched from CivitAI metadata
|
||||||
|
display_name?: string
|
||||||
|
base_model?: string
|
||||||
|
model_type?: string
|
||||||
|
civitai_model_id?: number
|
||||||
|
civitai_version_id?: number
|
||||||
|
thumbnail_url?: string
|
||||||
|
triggers?: string[]
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface LoRA {
|
export interface LoRA {
|
||||||
@@ -14,6 +22,14 @@ export interface LoRA {
|
|||||||
size_mb: number
|
size_mb: number
|
||||||
modified: number
|
modified: number
|
||||||
category: 'sd15' | 'large'
|
category: 'sd15' | 'large'
|
||||||
|
// Enriched from CivitAI metadata
|
||||||
|
display_name?: string
|
||||||
|
base_model?: string
|
||||||
|
model_type?: string
|
||||||
|
civitai_model_id?: number
|
||||||
|
civitai_version_id?: number
|
||||||
|
thumbnail_url?: string
|
||||||
|
triggers?: string[]
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface GeneratedImage {
|
export interface GeneratedImage {
|
||||||
|
|||||||