Add shared OAuth authentication routes
Support cross-domain auth for tensors-web with return_url parameter. New endpoints: /auth/login, /auth/github, /auth/callback, /auth/verify Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from scalar_fastapi import get_scalar_api_reference
|
from scalar_fastapi import get_scalar_api_reference
|
||||||
|
|
||||||
from tensors.config import get_server_api_key
|
from tensors.config import get_server_api_key
|
||||||
|
from tensors.server.auth_routes import create_auth_router
|
||||||
from tensors.server.civitai_routes import create_civitai_router
|
from tensors.server.civitai_routes import create_civitai_router
|
||||||
from tensors.server.comfyui_routes import create_comfyui_router
|
from tensors.server.comfyui_routes import create_comfyui_router
|
||||||
from tensors.server.db_routes import create_db_router
|
from tensors.server.db_routes import create_db_router
|
||||||
@@ -63,6 +64,9 @@ def create_app() -> FastAPI:
|
|||||||
title="tensors API",
|
title="tensors API",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Shared OAuth auth (no API key required)
|
||||||
|
app.include_router(create_auth_router())
|
||||||
|
|
||||||
# ComfyUI proxy (handles its own session auth)
|
# ComfyUI proxy (handles its own session auth)
|
||||||
app.include_router(create_comfyui_router())
|
app.include_router(create_comfyui_router())
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,423 @@
|
|||||||
|
"""Shared OAuth authentication for tensors apps."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import os
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
import urllib.parse
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import APIRouter, Cookie, HTTPException, Query, Request, status
|
||||||
|
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse, Response
|
||||||
|
|
||||||
|
router = APIRouter(tags=["Auth"])
|
||||||
|
|
||||||
|
# Config from environment
|
||||||
|
SESSION_SECRET = os.environ.get("SESSION_SECRET", "tensors-comfyui-secret-change-me")
|
||||||
|
SESSION_MAX_AGE = 86400 * 7 # 7 days
|
||||||
|
|
||||||
|
# GitHub OAuth config
|
||||||
|
GITHUB_CLIENT_ID = os.environ.get("GITHUB_CLIENT_ID", "")
|
||||||
|
GITHUB_CLIENT_SECRET = os.environ.get("GITHUB_CLIENT_SECRET", "")
|
||||||
|
GITHUB_ALLOWED_USERS = os.environ.get("GITHUB_ALLOWED_USERS", "").split(",")
|
||||||
|
|
||||||
|
# Allowed redirect URLs (for security)
|
||||||
|
ALLOWED_REDIRECT_HOSTS = [
|
||||||
|
"tensors.saiden.dev",
|
||||||
|
"localhost",
|
||||||
|
"127.0.0.1",
|
||||||
|
]
|
||||||
|
|
||||||
|
# OAuth state storage (in-memory, short-lived)
|
||||||
|
# Format: state -> (timestamp, return_url)
|
||||||
|
_oauth_states: dict[str, tuple[float, str | None]] = {}
|
||||||
|
|
||||||
|
_SESSION_TOKEN_PARTS = 3
|
||||||
|
|
||||||
|
|
||||||
|
def _create_session_token(username: str) -> str:
|
||||||
|
"""Create a signed session token."""
|
||||||
|
expires = int(time.time()) + SESSION_MAX_AGE
|
||||||
|
data = f"{username}:{expires}"
|
||||||
|
signature = hmac.new(SESSION_SECRET.encode(), data.encode(), hashlib.sha256).hexdigest()[:32]
|
||||||
|
return f"{data}:{signature}"
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_session_token(token: str | None) -> str | None:
|
||||||
|
"""Verify a session token. Returns username if valid, None otherwise."""
|
||||||
|
if not token:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parts = token.split(":")
|
||||||
|
if len(parts) != _SESSION_TOKEN_PARTS:
|
||||||
|
return None
|
||||||
|
username, expires_str, signature = parts
|
||||||
|
expires = int(expires_str)
|
||||||
|
if time.time() > expires:
|
||||||
|
return None
|
||||||
|
data = f"{username}:{expires_str}"
|
||||||
|
expected = hmac.new(SESSION_SECRET.encode(), data.encode(), hashlib.sha256).hexdigest()[:32]
|
||||||
|
if hmac.compare_digest(signature, expected):
|
||||||
|
return username
|
||||||
|
return None
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_auth_configured() -> bool:
|
||||||
|
"""Check if GitHub OAuth is configured."""
|
||||||
|
return bool(GITHUB_CLIENT_ID and GITHUB_CLIENT_SECRET)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_valid_redirect_url(url: str | None) -> bool:
|
||||||
|
"""Check if redirect URL is allowed."""
|
||||||
|
if not url:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
parsed = urllib.parse.urlparse(url)
|
||||||
|
host = parsed.hostname or ""
|
||||||
|
return host in ALLOWED_REDIRECT_HOSTS or host.endswith(".saiden.dev")
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup_old_states() -> None:
|
||||||
|
"""Remove OAuth states older than 10 minutes."""
|
||||||
|
cutoff = time.time() - 600
|
||||||
|
for state in list(_oauth_states.keys()):
|
||||||
|
if _oauth_states[state][0] < cutoff:
|
||||||
|
del _oauth_states[state]
|
||||||
|
|
||||||
|
|
||||||
|
LOGIN_PAGE_HTML = """
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Tensors - Login</title>
|
||||||
|
<style>
|
||||||
|
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||||
|
body {
|
||||||
|
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
|
||||||
|
background: linear-gradient(135deg, #1a1a2e 0%, #16213e 50%, #0f0f23 100%);
|
||||||
|
min-height: 100vh;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
color: #e0e0e0;
|
||||||
|
}
|
||||||
|
.login-container {
|
||||||
|
background: rgba(30, 30, 46, 0.95);
|
||||||
|
border: 1px solid rgba(255, 255, 255, 0.1);
|
||||||
|
border-radius: 16px;
|
||||||
|
padding: 40px;
|
||||||
|
width: 100%;
|
||||||
|
max-width: 400px;
|
||||||
|
box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.5);
|
||||||
|
}
|
||||||
|
.logo {
|
||||||
|
text-align: center;
|
||||||
|
margin-bottom: 32px;
|
||||||
|
}
|
||||||
|
.logo h1 {
|
||||||
|
font-size: 28px;
|
||||||
|
font-weight: 600;
|
||||||
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||||
|
-webkit-background-clip: text;
|
||||||
|
-webkit-text-fill-color: transparent;
|
||||||
|
background-clip: text;
|
||||||
|
}
|
||||||
|
.logo p {
|
||||||
|
color: #888;
|
||||||
|
font-size: 14px;
|
||||||
|
margin-top: 8px;
|
||||||
|
}
|
||||||
|
.github-btn {
|
||||||
|
width: 100%;
|
||||||
|
padding: 14px;
|
||||||
|
font-size: 16px;
|
||||||
|
font-weight: 600;
|
||||||
|
border: none;
|
||||||
|
border-radius: 8px;
|
||||||
|
background: #24292f;
|
||||||
|
color: #fff;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: transform 0.2s, box-shadow 0.2s, background 0.2s;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
gap: 10px;
|
||||||
|
text-decoration: none;
|
||||||
|
}
|
||||||
|
.github-btn:hover {
|
||||||
|
background: #32383f;
|
||||||
|
transform: translateY(-2px);
|
||||||
|
box-shadow: 0 10px 20px -10px rgba(0, 0, 0, 0.5);
|
||||||
|
}
|
||||||
|
.github-btn svg { width: 20px; height: 20px; fill: currentColor; }
|
||||||
|
.error {
|
||||||
|
background: rgba(239, 68, 68, 0.1);
|
||||||
|
border: 1px solid rgba(239, 68, 68, 0.3);
|
||||||
|
color: #f87171;
|
||||||
|
padding: 12px 16px;
|
||||||
|
border-radius: 8px;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
font-size: 14px;
|
||||||
|
}
|
||||||
|
.footer {
|
||||||
|
text-align: center;
|
||||||
|
margin-top: 24px;
|
||||||
|
font-size: 12px;
|
||||||
|
color: #666;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="login-container">
|
||||||
|
<div class="logo">
|
||||||
|
<h1>Tensors</h1>
|
||||||
|
<p>Sign in to continue</p>
|
||||||
|
</div>
|
||||||
|
{{ERROR}}
|
||||||
|
<a href="{{AUTH_URL}}" class="github-btn">
|
||||||
|
<svg viewBox="0 0 16 16" aria-hidden="true">
|
||||||
|
<path d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0016 8c0-4.42-3.58-8-8-8z"/>
|
||||||
|
</svg>
|
||||||
|
Sign in with GitHub
|
||||||
|
</a>
|
||||||
|
<div class="footer">
|
||||||
|
Powered by tensors
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/auth/login")
|
||||||
|
async def login_page(
|
||||||
|
return_url: str | None = Query(None, description="URL to redirect after login"),
|
||||||
|
error: str | None = Query(None),
|
||||||
|
) -> HTMLResponse:
|
||||||
|
"""Show login page."""
|
||||||
|
error_html = ""
|
||||||
|
if error:
|
||||||
|
error_html = f'<div class="error">{error}</div>'
|
||||||
|
|
||||||
|
# Build auth URL with return_url
|
||||||
|
auth_url = "/auth/github"
|
||||||
|
if return_url and _is_valid_redirect_url(return_url):
|
||||||
|
auth_url += f"?return_url={urllib.parse.quote(return_url)}"
|
||||||
|
|
||||||
|
html = LOGIN_PAGE_HTML.replace("{{ERROR}}", error_html).replace("{{AUTH_URL}}", auth_url)
|
||||||
|
return HTMLResponse(content=html)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/auth/github")
|
||||||
|
async def github_auth(
|
||||||
|
request: Request,
|
||||||
|
return_url: str | None = Query(None, description="URL to redirect after login"),
|
||||||
|
) -> Response:
|
||||||
|
"""Redirect to GitHub OAuth."""
|
||||||
|
if not _is_auth_configured():
|
||||||
|
return RedirectResponse(
|
||||||
|
url="/auth/login?error=GitHub+OAuth+not+configured",
|
||||||
|
status_code=status.HTTP_303_SEE_OTHER,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not _is_valid_redirect_url(return_url):
|
||||||
|
return RedirectResponse(
|
||||||
|
url="/auth/login?error=Invalid+redirect+URL",
|
||||||
|
status_code=status.HTTP_303_SEE_OTHER,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate state for CSRF protection, store return_url
|
||||||
|
state = secrets.token_urlsafe(32)
|
||||||
|
_oauth_states[state] = (time.time(), return_url)
|
||||||
|
_cleanup_old_states()
|
||||||
|
|
||||||
|
# Build callback URL
|
||||||
|
callback_url = str(request.url_for("github_callback"))
|
||||||
|
|
||||||
|
# Build GitHub OAuth URL
|
||||||
|
params = {
|
||||||
|
"client_id": GITHUB_CLIENT_ID,
|
||||||
|
"redirect_uri": callback_url,
|
||||||
|
"scope": "read:user",
|
||||||
|
"state": state,
|
||||||
|
}
|
||||||
|
query = "&".join(f"{k}={v}" for k, v in params.items())
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"https://github.com/login/oauth/authorize?{query}",
|
||||||
|
status_code=status.HTTP_303_SEE_OTHER,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/auth/callback")
|
||||||
|
async def github_callback(
|
||||||
|
request: Request,
|
||||||
|
code: str | None = None,
|
||||||
|
state: str | None = None,
|
||||||
|
) -> Response:
|
||||||
|
"""Handle GitHub OAuth callback."""
|
||||||
|
# Verify state and get return_url
|
||||||
|
if not state or state not in _oauth_states:
|
||||||
|
return RedirectResponse(
|
||||||
|
url="/auth/login?error=Invalid+OAuth+state",
|
||||||
|
status_code=status.HTTP_303_SEE_OTHER,
|
||||||
|
)
|
||||||
|
_, return_url = _oauth_states.pop(state)
|
||||||
|
|
||||||
|
if not code:
|
||||||
|
error_url = "/auth/login?error=No+authorization+code"
|
||||||
|
if return_url:
|
||||||
|
error_url += f"&return_url={urllib.parse.quote(return_url)}"
|
||||||
|
return RedirectResponse(url=error_url, status_code=status.HTTP_303_SEE_OTHER)
|
||||||
|
|
||||||
|
# Exchange code for access token
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
token_response = await client.post(
|
||||||
|
"https://github.com/login/oauth/access_token",
|
||||||
|
data={
|
||||||
|
"client_id": GITHUB_CLIENT_ID,
|
||||||
|
"client_secret": GITHUB_CLIENT_SECRET,
|
||||||
|
"code": code,
|
||||||
|
},
|
||||||
|
headers={"Accept": "application/json"},
|
||||||
|
)
|
||||||
|
token_data = token_response.json()
|
||||||
|
|
||||||
|
if "error" in token_data:
|
||||||
|
error_msg = token_data.get("error_description", "OAuth+error")
|
||||||
|
error_url = f"/auth/login?error={error_msg}"
|
||||||
|
if return_url:
|
||||||
|
error_url += f"&return_url={urllib.parse.quote(return_url)}"
|
||||||
|
return RedirectResponse(url=error_url, status_code=status.HTTP_303_SEE_OTHER)
|
||||||
|
|
||||||
|
access_token = token_data.get("access_token")
|
||||||
|
if not access_token:
|
||||||
|
return RedirectResponse(
|
||||||
|
url="/auth/login?error=No+access+token",
|
||||||
|
status_code=status.HTTP_303_SEE_OTHER,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get user info
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
user_response = await client.get(
|
||||||
|
"https://api.github.com/user",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
"Accept": "application/vnd.github+json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
user_data = user_response.json()
|
||||||
|
|
||||||
|
username = user_data.get("login", "")
|
||||||
|
if not username:
|
||||||
|
return RedirectResponse(
|
||||||
|
url="/auth/login?error=Could+not+get+GitHub+username",
|
||||||
|
status_code=status.HTTP_303_SEE_OTHER,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if user is allowed
|
||||||
|
allowed = [u.strip().lower() for u in GITHUB_ALLOWED_USERS if u.strip()]
|
||||||
|
if allowed and username.lower() not in allowed:
|
||||||
|
return RedirectResponse(
|
||||||
|
url="/auth/login?error=User+not+authorized",
|
||||||
|
status_code=status.HTTP_303_SEE_OTHER,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create session token
|
||||||
|
token = _create_session_token(username)
|
||||||
|
|
||||||
|
# Redirect to return_url with token, or show success
|
||||||
|
if return_url:
|
||||||
|
# Add token to return URL
|
||||||
|
separator = "&" if "?" in return_url else "?"
|
||||||
|
redirect_url = f"{return_url}{separator}token={token}"
|
||||||
|
return RedirectResponse(url=redirect_url, status_code=status.HTTP_303_SEE_OTHER)
|
||||||
|
|
||||||
|
# No return_url - set cookie and show success
|
||||||
|
response = RedirectResponse(url="/auth/success", status_code=status.HTTP_303_SEE_OTHER)
|
||||||
|
response.set_cookie(
|
||||||
|
key="tensors_session",
|
||||||
|
value=token,
|
||||||
|
max_age=SESSION_MAX_AGE,
|
||||||
|
httponly=True,
|
||||||
|
samesite="lax",
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/auth/verify")
|
||||||
|
async def verify_token(
|
||||||
|
token: str | None = Query(None),
|
||||||
|
tensors_session: str | None = Cookie(default=None),
|
||||||
|
) -> JSONResponse:
|
||||||
|
"""Verify a session token. Returns user info if valid."""
|
||||||
|
# Check token from query or cookie
|
||||||
|
check_token = token or tensors_session
|
||||||
|
|
||||||
|
username = _verify_session_token(check_token)
|
||||||
|
if username:
|
||||||
|
return JSONResponse({"valid": True, "username": username})
|
||||||
|
return JSONResponse({"valid": False}, status_code=status.HTTP_401_UNAUTHORIZED)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/auth/success")
|
||||||
|
async def auth_success() -> HTMLResponse:
|
||||||
|
"""Show success page after login."""
|
||||||
|
html = """
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Login Successful</title>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||||
|
background: linear-gradient(135deg, #1a1a2e 0%, #16213e 50%, #0f0f23 100%);
|
||||||
|
min-height: 100vh;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
color: #e0e0e0;
|
||||||
|
}
|
||||||
|
.container {
|
||||||
|
text-align: center;
|
||||||
|
background: rgba(30, 30, 46, 0.95);
|
||||||
|
padding: 40px;
|
||||||
|
border-radius: 16px;
|
||||||
|
border: 1px solid rgba(255, 255, 255, 0.1);
|
||||||
|
}
|
||||||
|
h1 { color: #4ade80; margin-bottom: 16px; }
|
||||||
|
a { color: #667eea; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<h1>Login Successful!</h1>
|
||||||
|
<p>You are now authenticated.</p>
|
||||||
|
<p style="margin-top: 16px;"><a href="/docs">Go to API Docs</a></p>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
return HTMLResponse(content=html)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/auth/logout")
|
||||||
|
async def logout(return_url: str | None = Query(None)) -> Response:
|
||||||
|
"""Clear session and redirect."""
|
||||||
|
redirect_to = return_url if _is_valid_redirect_url(return_url) else "/auth/login"
|
||||||
|
response = RedirectResponse(url=redirect_to, status_code=status.HTTP_303_SEE_OTHER)
|
||||||
|
response.delete_cookie("tensors_session")
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def create_auth_router() -> APIRouter:
|
||||||
|
"""Return the auth router."""
|
||||||
|
return router
|
||||||
Reference in New Issue
Block a user