💬 Commit message: Update 2026-02-15 00:07:11, 20 files, 231 lines
📁 Files changed: 20 📝 Lines changed: 231 • .coverage • models.db • screenshot.png • tensors-deployed.png • tensors-dropdown.png • tensors-final.png • tensors-fixed.png • tensors-reloaded.png • tensors-ui.png • civitai_routes.py • generate_routes.py • models_routes.py • sd_client.py • index-BQdjF_w0.css • index-CKJOpgtQ.js • index-DmOZ-7Sw.js • index.html • GenerateView.vue • app.ts • index.ts
This commit is contained in:
@@ -10,6 +10,7 @@ from fastapi import APIRouter, Query, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from tensors.config import CIVITAI_API_BASE, load_api_key
|
||||
from tensors.db import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -56,6 +57,18 @@ async def search_models(
|
||||
response = await client.get(url, params=params, headers=_get_headers(api_key))
|
||||
response.raise_for_status()
|
||||
result: dict[str, Any] = response.json()
|
||||
|
||||
# Cache all models from search results
|
||||
items = result.get("items", [])
|
||||
if items:
|
||||
try:
|
||||
with Database() as db:
|
||||
db.init_schema()
|
||||
for model_data in items:
|
||||
db.cache_model(model_data)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cache search results: %s", e)
|
||||
|
||||
return result
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error("CivitAI API error: %s", e.response.status_code)
|
||||
@@ -67,7 +80,7 @@ async def search_models(
|
||||
|
||||
@router.get("/model/{model_id}", response_model=None)
|
||||
async def get_model(model_id: int) -> dict[str, Any] | Response:
|
||||
"""Get model details from CivitAI."""
|
||||
"""Get model details from CivitAI and cache to database."""
|
||||
api_key = load_api_key()
|
||||
url = f"{CIVITAI_API_BASE}/models/{model_id}"
|
||||
|
||||
@@ -76,6 +89,15 @@ async def get_model(model_id: int) -> dict[str, Any] | Response:
|
||||
response = await client.get(url, headers=_get_headers(api_key))
|
||||
response.raise_for_status()
|
||||
result: dict[str, Any] = response.json()
|
||||
|
||||
# Cache the model data to database
|
||||
try:
|
||||
with Database() as db:
|
||||
db.init_schema()
|
||||
db.cache_model(result)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cache model %d: %s", model_id, e)
|
||||
|
||||
return result
|
||||
except httpx.HTTPStatusError:
|
||||
return JSONResponse({"error": "Model not found"}, status_code=404)
|
||||
|
||||
@@ -129,7 +129,7 @@ def _process_image(
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_generate_router() -> APIRouter:
|
||||
def create_generate_router() -> APIRouter: # noqa: PLR0915
|
||||
"""Build a router with /api/generate endpoint."""
|
||||
router = APIRouter(prefix="/api", tags=["generate"])
|
||||
gallery = Gallery()
|
||||
|
||||
@@ -11,6 +11,7 @@ from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from tensors.config import MODELS_DIR
|
||||
from tensors.db import Database
|
||||
from tensors.server.sd_client import get_sd_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -110,6 +111,56 @@ def scan_models(directory: Path, extensions: tuple[str, ...] = (".safetensors",
|
||||
return models
|
||||
|
||||
|
||||
def _enrich_with_metadata(models: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Enrich model data with CivitAI metadata from database."""
|
||||
try:
|
||||
with Database() as db:
|
||||
db.init_schema()
|
||||
|
||||
for model in models:
|
||||
file_path = model.get("path", "")
|
||||
file_info = db.get_local_file_by_path(file_path)
|
||||
|
||||
if file_info and file_info.get("civitai_model_id"):
|
||||
# Add human-readable name
|
||||
model["display_name"] = file_info.get("model_name") or model["name"]
|
||||
model["base_model"] = file_info.get("base_model")
|
||||
model["model_type"] = file_info.get("model_type")
|
||||
model["civitai_model_id"] = file_info.get("civitai_model_id")
|
||||
model["civitai_version_id"] = file_info.get("civitai_version_id")
|
||||
|
||||
# Get thumbnail from version images
|
||||
version_id = file_info.get("civitai_version_id")
|
||||
if version_id:
|
||||
cur = db.conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT url FROM version_images
|
||||
WHERE version_id = (SELECT id FROM model_versions WHERE civitai_id = ?)
|
||||
ORDER BY id LIMIT 1
|
||||
""",
|
||||
(version_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row:
|
||||
model["thumbnail_url"] = row[0]
|
||||
|
||||
# Get trigger words
|
||||
triggers = db.get_triggers_by_version(version_id) if version_id else []
|
||||
model["triggers"] = triggers[:5] # Limit to first 5
|
||||
else:
|
||||
model["display_name"] = model["name"]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to enrich models with metadata: %s", e)
|
||||
# Fallback: just use filename as display name
|
||||
for model in models:
|
||||
if "display_name" not in model:
|
||||
model["display_name"] = model["name"]
|
||||
|
||||
return models
|
||||
|
||||
|
||||
def scan_loras(directory: Path | None = None) -> list[dict[str, Any]]:
|
||||
"""Scan for LoRA files."""
|
||||
lora_dir = directory or MODELS_DIR / "loras"
|
||||
@@ -127,14 +178,15 @@ def scan_checkpoints(directory: Path | None = None) -> list[dict[str, Any]]:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_models_router() -> APIRouter:
|
||||
def create_models_router() -> APIRouter: # noqa: PLR0915
|
||||
"""Build a router with /api/models/* endpoints."""
|
||||
router = APIRouter(prefix="/api/models", tags=["models"])
|
||||
|
||||
@router.get("")
|
||||
def list_models() -> dict[str, Any]:
|
||||
"""List available checkpoint models."""
|
||||
"""List available checkpoint models with metadata."""
|
||||
checkpoints = scan_checkpoints()
|
||||
checkpoints = _enrich_with_metadata(checkpoints)
|
||||
return {
|
||||
"models": checkpoints,
|
||||
"total": len(checkpoints),
|
||||
@@ -172,8 +224,9 @@ def create_models_router() -> APIRouter:
|
||||
|
||||
@router.get("/loras")
|
||||
def list_loras() -> dict[str, Any]:
|
||||
"""List available LoRA files."""
|
||||
"""List available LoRA files with metadata."""
|
||||
loras = scan_loras()
|
||||
loras = _enrich_with_metadata(loras)
|
||||
return {
|
||||
"loras": loras,
|
||||
"total": len(loras),
|
||||
@@ -220,14 +273,14 @@ def create_models_router() -> APIRouter:
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
_, stderr = await proc.communicate(new_content.encode())
|
||||
_, tee_stderr = await proc.communicate(new_content.encode())
|
||||
if proc.returncode != 0:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to write env file: {stderr.decode()}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to write env file: {tee_stderr.decode()}")
|
||||
|
||||
# Restart sd-server
|
||||
returncode, _stdout, stderr = await _run_command("sudo", "systemctl", "restart", "sd-server")
|
||||
returncode, _stdout, restart_stderr = await _run_command("sudo", "systemctl", "restart", "sd-server")
|
||||
if returncode != 0:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to restart sd-server: {stderr}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to restart sd-server: {restart_stderr}")
|
||||
|
||||
logger.info(f"Switched model from {old_model} to {model_path}")
|
||||
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
from fastapi import Request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import Request
|
||||
|
||||
|
||||
def get_sd_headers(request: Request) -> dict[str, str]:
|
||||
|
||||
+1
-1
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -5,8 +5,8 @@
|
||||
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Tensors</title>
|
||||
<script type="module" crossorigin src="/assets/index-DmOZ-7Sw.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="/assets/index-CcuP2dTH.css">
|
||||
<script type="module" crossorigin src="/assets/index-CKJOpgtQ.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="/assets/index-BQdjF_w0.css">
|
||||
</head>
|
||||
<body>
|
||||
<div id="app"></div>
|
||||
|
||||
@@ -30,12 +30,22 @@ interface ChatMessage {
|
||||
const messages = ref<ChatMessage[]>([])
|
||||
|
||||
const modelItems = computed(() =>
|
||||
store.models.map(m => ({ title: m.name, value: m.path }))
|
||||
store.models.map(m => ({
|
||||
title: m.display_name || m.name,
|
||||
value: m.path,
|
||||
thumbnail: m.thumbnail_url,
|
||||
base_model: m.base_model,
|
||||
}))
|
||||
)
|
||||
|
||||
const loraItems = computed(() => [
|
||||
{ title: 'None', value: '' },
|
||||
...store.filteredLoras.map(l => ({ title: l.name, value: l.path }))
|
||||
{ title: 'None', value: '', thumbnail: null, triggers: [] },
|
||||
...store.filteredLoras.map(l => ({
|
||||
title: l.display_name || l.name,
|
||||
value: l.path,
|
||||
thumbnail: l.thumbnail_url,
|
||||
triggers: l.triggers || [],
|
||||
}))
|
||||
])
|
||||
|
||||
|
||||
@@ -187,9 +197,33 @@ async function generate() {
|
||||
:disabled="store.switchingModel || generating"
|
||||
density="compact"
|
||||
hide-details
|
||||
style="width: 200px"
|
||||
style="width: 280px"
|
||||
@update:model-value="handleModelChange"
|
||||
/>
|
||||
>
|
||||
<template #selection="{ item }">
|
||||
<div class="d-flex align-center ga-2">
|
||||
<v-avatar v-if="item.raw.thumbnail" size="24" rounded="sm">
|
||||
<v-img :src="item.raw.thumbnail" cover />
|
||||
</v-avatar>
|
||||
<v-icon v-else size="24" color="grey">mdi-cube-outline</v-icon>
|
||||
<span class="text-truncate">{{ item.title }}</span>
|
||||
</div>
|
||||
</template>
|
||||
<template #item="{ item, props }">
|
||||
<v-list-item v-bind="props" :title="undefined">
|
||||
<template #prepend>
|
||||
<v-avatar v-if="item.raw.thumbnail" size="32" rounded="sm" class="mr-3">
|
||||
<v-img :src="item.raw.thumbnail" cover />
|
||||
</v-avatar>
|
||||
<v-icon v-else size="32" color="grey" class="mr-3">mdi-cube-outline</v-icon>
|
||||
</template>
|
||||
<v-list-item-title>{{ item.title }}</v-list-item-title>
|
||||
<v-list-item-subtitle v-if="item.raw.base_model" class="text-caption">
|
||||
{{ item.raw.base_model }}
|
||||
</v-list-item-subtitle>
|
||||
</v-list-item>
|
||||
</template>
|
||||
</v-select>
|
||||
</div>
|
||||
|
||||
<div class="d-flex align-center ga-2">
|
||||
@@ -200,8 +234,32 @@ async function generate() {
|
||||
:disabled="generating"
|
||||
density="compact"
|
||||
hide-details
|
||||
style="width: 150px"
|
||||
/>
|
||||
style="width: 200px"
|
||||
>
|
||||
<template #selection="{ item }">
|
||||
<div class="d-flex align-center ga-2">
|
||||
<v-avatar v-if="item.raw.thumbnail" size="24" rounded="sm">
|
||||
<v-img :src="item.raw.thumbnail" cover />
|
||||
</v-avatar>
|
||||
<v-icon v-else size="24" color="grey">mdi-shimmer</v-icon>
|
||||
<span class="text-truncate">{{ item.title }}</span>
|
||||
</div>
|
||||
</template>
|
||||
<template #item="{ item, props }">
|
||||
<v-list-item v-bind="props" :title="undefined">
|
||||
<template #prepend>
|
||||
<v-avatar v-if="item.raw.thumbnail" size="32" rounded="sm" class="mr-3">
|
||||
<v-img :src="item.raw.thumbnail" cover />
|
||||
</v-avatar>
|
||||
<v-icon v-else size="32" color="grey" class="mr-3">mdi-shimmer</v-icon>
|
||||
</template>
|
||||
<v-list-item-title>{{ item.title }}</v-list-item-title>
|
||||
<v-list-item-subtitle v-if="item.raw.triggers?.length" class="text-caption text-truncate">
|
||||
{{ item.raw.triggers.slice(0, 2).join(', ') }}
|
||||
</v-list-item-subtitle>
|
||||
</v-list-item>
|
||||
</template>
|
||||
</v-select>
|
||||
<v-text-field
|
||||
v-model.number="store.loraWeight"
|
||||
type="number"
|
||||
|
||||
@@ -82,7 +82,13 @@ export const useAppStore = defineStore('app', () => {
|
||||
loras.value = lorasRes.loras
|
||||
activeModel.value = activeRes.model
|
||||
if (activeRes.model) {
|
||||
selectedModel.value = activeRes.model
|
||||
// Find full path for the active model (API returns model name without extension, v-select uses path)
|
||||
const activeModelPath = modelsRes.models.find(m =>
|
||||
m.name === activeRes.model ||
|
||||
m.filename === activeRes.model ||
|
||||
m.filename.startsWith(activeRes.model + '.')
|
||||
)?.path
|
||||
selectedModel.value = activeModelPath || activeRes.model
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to load models:', error)
|
||||
|
||||
@@ -5,6 +5,14 @@ export interface Model {
|
||||
size_mb: number
|
||||
modified: number
|
||||
category: 'sd15' | 'large'
|
||||
// Enriched from CivitAI metadata
|
||||
display_name?: string
|
||||
base_model?: string
|
||||
model_type?: string
|
||||
civitai_model_id?: number
|
||||
civitai_version_id?: number
|
||||
thumbnail_url?: string
|
||||
triggers?: string[]
|
||||
}
|
||||
|
||||
export interface LoRA {
|
||||
@@ -14,6 +22,14 @@ export interface LoRA {
|
||||
size_mb: number
|
||||
modified: number
|
||||
category: 'sd15' | 'large'
|
||||
// Enriched from CivitAI metadata
|
||||
display_name?: string
|
||||
base_model?: string
|
||||
model_type?: string
|
||||
civitai_model_id?: number
|
||||
civitai_version_id?: number
|
||||
thumbnail_url?: string
|
||||
triggers?: string[]
|
||||
}
|
||||
|
||||
export interface GeneratedImage {
|
||||
|
||||
Reference in New Issue
Block a user