diff --git a/tensors/config.py b/tensors/config.py index 9321f4a..9d56cd5 100644 --- a/tensors/config.py +++ b/tensors/config.py @@ -396,3 +396,32 @@ def get_sd_server_api_key() -> str | None: return str(key) return None + + +# ============================================================================ +# Tensors Server API Key +# ============================================================================ + + +def get_server_api_key() -> str | None: + """Get the tensors server API key for authentication. + + Resolution order: + 1. TENSORS_API_KEY environment variable + 2. config.toml [server].api_key + 3. None (no authentication required) + """ + # Check environment variable first + env_key = os.environ.get("TENSORS_API_KEY") + if env_key: + return env_key + + # Check config file + config = load_config() + server_config = config.get("server", {}) + if isinstance(server_config, dict): + key = server_config.get("api_key") + if key: + return str(key) + + return None diff --git a/tensors/server/__init__.py b/tensors/server/__init__.py index ced16cf..015d44d 100644 --- a/tensors/server/__init__.py +++ b/tensors/server/__init__.py @@ -6,9 +6,10 @@ import logging from contextlib import asynccontextmanager from typing import TYPE_CHECKING -from fastapi import FastAPI +from fastapi import Depends, FastAPI from scalar_fastapi import get_scalar_api_reference +from tensors.config import get_server_api_key from tensors.server.civitai_routes import create_civitai_router from tensors.server.db_routes import create_db_router from tensors.server.download_routes import create_download_router @@ -29,7 +30,11 @@ def create_app() -> FastAPI: @asynccontextmanager async def lifespan(_app: FastAPI) -> AsyncIterator[None]: - logger.info("Tensors server starting") + api_key = get_server_api_key() + if api_key: + logger.info("Tensors server starting (auth enabled)") + else: + logger.info("Tensors server starting (no auth)") yield app = FastAPI( @@ -41,6 +46,7 @@ def create_app() -> FastAPI: redoc_url=None, ) + # Public endpoints (no auth) @app.get("/status") async def status() -> dict[str, str]: return {"status": "ok"} @@ -52,10 +58,13 @@ def create_app() -> FastAPI: title="tensors API", ) - app.include_router(create_civitai_router()) - app.include_router(create_db_router()) - app.include_router(create_gallery_router()) - app.include_router(create_download_router()) + # Protected routers (auth required if configured) + from tensors.server.auth import verify_api_key # noqa: PLC0415 + + app.include_router(create_civitai_router(), dependencies=[Depends(verify_api_key)]) + app.include_router(create_db_router(), dependencies=[Depends(verify_api_key)]) + app.include_router(create_gallery_router(), dependencies=[Depends(verify_api_key)]) + app.include_router(create_download_router(), dependencies=[Depends(verify_api_key)]) return app diff --git a/tensors/server/auth.py b/tensors/server/auth.py new file mode 100644 index 0000000..af0e7a5 --- /dev/null +++ b/tensors/server/auth.py @@ -0,0 +1,51 @@ +"""Authentication for tensors API.""" + +from __future__ import annotations + +from typing import Annotated + +from fastapi import Depends, HTTPException, Security, status +from fastapi.security import APIKeyHeader, APIKeyQuery + +from tensors.config import get_server_api_key + +# API key can be passed via header or query param +api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) +api_key_query = APIKeyQuery(name="api_key", auto_error=False) + + +def verify_api_key( + header_key: Annotated[str | None, Security(api_key_header)] = None, + query_key: Annotated[str | None, Security(api_key_query)] = None, +) -> str | None: + """Verify API key from header or query parameter. + + If no server API key is configured, authentication is disabled. + If configured, the key must match. + """ + server_key = get_server_api_key() + + # No auth required if no key configured + if not server_key: + return None + + # Check header first, then query + provided_key = header_key or query_key + + if not provided_key: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="API key required. Provide via X-API-Key header or api_key query param.", + ) + + if provided_key != server_key: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid API key", + ) + + return provided_key + + +# Dependency for protected routes +RequireAuth = Annotated[str | None, Depends(verify_api_key)]