140 lines
5.0 KiB
Python
140 lines
5.0 KiB
Python
import base64
|
|
import hashlib
|
|
import hmac
|
|
import json
|
|
import time
|
|
from functools import lru_cache
|
|
|
|
import jwt
|
|
from jwt import PyJWKClient, PyJWTError
|
|
from fastapi import HTTPException, Request, status
|
|
|
|
from app.config import Settings
|
|
|
|
SESSION_COOKIE = "miem_admin_session"
|
|
|
|
|
|
def verify_admin(username: str, password: str, settings: Settings) -> bool:
|
|
return hmac.compare_digest(username, settings.admin_username) and hmac.compare_digest(
|
|
password, settings.admin_password
|
|
)
|
|
|
|
|
|
def sign_session(username: str, settings: Settings) -> str:
|
|
payload = base64.urlsafe_b64encode(
|
|
json.dumps({"sub": username, "iat": int(time.time())}, separators=(",", ":")).encode("utf-8")
|
|
).decode("ascii")
|
|
signature = hmac.new(settings.session_secret.encode("utf-8"), payload.encode("ascii"), hashlib.sha256).hexdigest()
|
|
return f"{payload}.{signature}"
|
|
|
|
|
|
def read_session(token: str | None, settings: Settings) -> str | None:
|
|
if not token or "." not in token:
|
|
return None
|
|
payload, signature = token.rsplit(".", 1)
|
|
expected = hmac.new(settings.session_secret.encode("utf-8"), payload.encode("ascii"), hashlib.sha256).hexdigest()
|
|
if not hmac.compare_digest(signature, expected):
|
|
return None
|
|
try:
|
|
data = json.loads(base64.urlsafe_b64decode(payload.encode("ascii")))
|
|
except Exception:
|
|
return None
|
|
return data.get("sub")
|
|
|
|
|
|
def require_admin(request: Request, settings: Settings) -> str:
|
|
username = read_session(request.cookies.get(SESSION_COOKIE), settings)
|
|
if not username:
|
|
raise HTTPException(status_code=status.HTTP_303_SEE_OTHER, headers={"Location": "/admin/login"})
|
|
return username
|
|
|
|
|
|
def require_mcp_auth(request: Request, settings: Settings) -> None:
|
|
auth = request.headers.get("authorization", "")
|
|
if not auth.startswith("Bearer "):
|
|
raise _mcp_unauthorized(settings, "Missing bearer token")
|
|
|
|
token = auth.removeprefix("Bearer ").strip()
|
|
if _mcp_static_token_allowed(settings) and hmac.compare_digest(token, settings.mcp_token):
|
|
return
|
|
if _mcp_oauth_allowed(settings):
|
|
_validate_mcp_oauth_token(token, settings)
|
|
return
|
|
raise _mcp_unauthorized(settings, "Invalid MCP token")
|
|
|
|
|
|
def require_mcp_token(request: Request, settings: Settings) -> None:
|
|
require_mcp_auth(request, settings)
|
|
|
|
|
|
def mcp_protected_resource_metadata(settings: Settings) -> dict:
|
|
authorization_servers = [settings.mcp_oauth_issuer.rstrip("/")] if settings.mcp_oauth_issuer else []
|
|
return {
|
|
"resource": settings.mcp_resource_url,
|
|
"authorization_servers": authorization_servers,
|
|
"bearer_methods_supported": ["header"],
|
|
"scopes_supported": [settings.mcp_oauth_required_scope],
|
|
"resource_documentation": settings.mcp_resource_url,
|
|
}
|
|
|
|
|
|
def _mcp_static_token_allowed(settings: Settings) -> bool:
|
|
return settings.mcp_auth_mode == "token"
|
|
|
|
|
|
def _mcp_oauth_allowed(settings: Settings) -> bool:
|
|
return settings.mcp_auth_mode == "oauth"
|
|
|
|
|
|
def _validate_mcp_oauth_token(token: str, settings: Settings) -> None:
|
|
if not settings.mcp_oauth_issuer or not settings.mcp_oauth_audience or not settings.oauth_jwks_url():
|
|
raise _mcp_unauthorized(settings, "MCP OAuth is not configured")
|
|
try:
|
|
signing_key = _get_mcp_oauth_signing_key(token, settings).key
|
|
claims = jwt.decode(
|
|
token,
|
|
signing_key,
|
|
algorithms=["RS256", "RS384", "RS512", "ES256", "ES384", "ES512"],
|
|
audience=settings.mcp_oauth_audience,
|
|
issuer=settings.mcp_oauth_issuer.rstrip("/"),
|
|
)
|
|
except PyJWTError as exc:
|
|
raise _mcp_unauthorized(settings, "Invalid OAuth access token") from exc
|
|
if not _claims_have_scope(claims, settings.mcp_oauth_required_scope):
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Missing required MCP OAuth scope")
|
|
|
|
|
|
def _claims_have_scope(claims: dict, required_scope: str) -> bool:
|
|
scopes: set[str] = set()
|
|
scope = claims.get("scope")
|
|
if isinstance(scope, str):
|
|
scopes.update(scope.split())
|
|
scp = claims.get("scp")
|
|
if isinstance(scp, str):
|
|
scopes.update(scp.split())
|
|
elif isinstance(scp, list):
|
|
scopes.update(str(item) for item in scp)
|
|
return required_scope in scopes
|
|
|
|
|
|
@lru_cache(maxsize=16)
|
|
def _get_jwk_client(jwks_url: str) -> PyJWKClient:
|
|
return PyJWKClient(jwks_url)
|
|
|
|
|
|
def _get_mcp_oauth_signing_key(token: str, settings: Settings):
|
|
return _get_jwk_client(settings.oauth_jwks_url()).get_signing_key_from_jwt(token)
|
|
|
|
|
|
def _mcp_unauthorized(settings: Settings, detail: str) -> HTTPException:
|
|
headers = {}
|
|
if _mcp_oauth_allowed(settings):
|
|
headers["WWW-Authenticate"] = f'Bearer resource_metadata="{_mcp_metadata_url(settings)}"'
|
|
return HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=detail, headers=headers)
|
|
|
|
|
|
def _mcp_metadata_url(settings: Settings) -> str:
|
|
resource_url = settings.mcp_resource_url.rstrip("/")
|
|
base_url = resource_url[: -len("/mcp")] if resource_url.endswith("/mcp") else resource_url
|
|
return f"{base_url}/.well-known/oauth-protected-resource"
|