💬 Commit message: Update 2026-02-15 18:22:20, 3 files, 101 lines
📁 Files changed: 3 📝 Lines changed: 101 • config.py • __init__.py • auth.py
This commit is contained in:
@@ -396,3 +396,32 @@ def get_sd_server_api_key() -> str | None:
|
|||||||
return str(key)
|
return str(key)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tensors Server API Key
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def get_server_api_key() -> str | None:
|
||||||
|
"""Get the tensors server API key for authentication.
|
||||||
|
|
||||||
|
Resolution order:
|
||||||
|
1. TENSORS_API_KEY environment variable
|
||||||
|
2. config.toml [server].api_key
|
||||||
|
3. None (no authentication required)
|
||||||
|
"""
|
||||||
|
# Check environment variable first
|
||||||
|
env_key = os.environ.get("TENSORS_API_KEY")
|
||||||
|
if env_key:
|
||||||
|
return env_key
|
||||||
|
|
||||||
|
# Check config file
|
||||||
|
config = load_config()
|
||||||
|
server_config = config.get("server", {})
|
||||||
|
if isinstance(server_config, dict):
|
||||||
|
key = server_config.get("api_key")
|
||||||
|
if key:
|
||||||
|
return str(key)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|||||||
@@ -6,9 +6,10 @@ import logging
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import Depends, FastAPI
|
||||||
from scalar_fastapi import get_scalar_api_reference
|
from scalar_fastapi import get_scalar_api_reference
|
||||||
|
|
||||||
|
from tensors.config import get_server_api_key
|
||||||
from tensors.server.civitai_routes import create_civitai_router
|
from tensors.server.civitai_routes import create_civitai_router
|
||||||
from tensors.server.db_routes import create_db_router
|
from tensors.server.db_routes import create_db_router
|
||||||
from tensors.server.download_routes import create_download_router
|
from tensors.server.download_routes import create_download_router
|
||||||
@@ -29,7 +30,11 @@ def create_app() -> FastAPI:
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
||||||
logger.info("Tensors server starting")
|
api_key = get_server_api_key()
|
||||||
|
if api_key:
|
||||||
|
logger.info("Tensors server starting (auth enabled)")
|
||||||
|
else:
|
||||||
|
logger.info("Tensors server starting (no auth)")
|
||||||
yield
|
yield
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
@@ -41,6 +46,7 @@ def create_app() -> FastAPI:
|
|||||||
redoc_url=None,
|
redoc_url=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Public endpoints (no auth)
|
||||||
@app.get("/status")
|
@app.get("/status")
|
||||||
async def status() -> dict[str, str]:
|
async def status() -> dict[str, str]:
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
@@ -52,10 +58,13 @@ def create_app() -> FastAPI:
|
|||||||
title="tensors API",
|
title="tensors API",
|
||||||
)
|
)
|
||||||
|
|
||||||
app.include_router(create_civitai_router())
|
# Protected routers (auth required if configured)
|
||||||
app.include_router(create_db_router())
|
from tensors.server.auth import verify_api_key # noqa: PLC0415
|
||||||
app.include_router(create_gallery_router())
|
|
||||||
app.include_router(create_download_router())
|
app.include_router(create_civitai_router(), dependencies=[Depends(verify_api_key)])
|
||||||
|
app.include_router(create_db_router(), dependencies=[Depends(verify_api_key)])
|
||||||
|
app.include_router(create_gallery_router(), dependencies=[Depends(verify_api_key)])
|
||||||
|
app.include_router(create_download_router(), dependencies=[Depends(verify_api_key)])
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,51 @@
|
|||||||
|
"""Authentication for tensors API."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, Security, status
|
||||||
|
from fastapi.security import APIKeyHeader, APIKeyQuery
|
||||||
|
|
||||||
|
from tensors.config import get_server_api_key
|
||||||
|
|
||||||
|
# API key can be passed via header or query param
|
||||||
|
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||||
|
api_key_query = APIKeyQuery(name="api_key", auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_api_key(
|
||||||
|
header_key: Annotated[str | None, Security(api_key_header)] = None,
|
||||||
|
query_key: Annotated[str | None, Security(api_key_query)] = None,
|
||||||
|
) -> str | None:
|
||||||
|
"""Verify API key from header or query parameter.
|
||||||
|
|
||||||
|
If no server API key is configured, authentication is disabled.
|
||||||
|
If configured, the key must match.
|
||||||
|
"""
|
||||||
|
server_key = get_server_api_key()
|
||||||
|
|
||||||
|
# No auth required if no key configured
|
||||||
|
if not server_key:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check header first, then query
|
||||||
|
provided_key = header_key or query_key
|
||||||
|
|
||||||
|
if not provided_key:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="API key required. Provide via X-API-Key header or api_key query param.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if provided_key != server_key:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Invalid API key",
|
||||||
|
)
|
||||||
|
|
||||||
|
return provided_key
|
||||||
|
|
||||||
|
|
||||||
|
# Dependency for protected routes
|
||||||
|
RequireAuth = Annotated[str | None, Depends(verify_api_key)]
|
||||||
Reference in New Issue
Block a user