207 lines
8.8 KiB
Python
207 lines
8.8 KiB
Python
from fastapi import Depends, Header, HTTPException, Query, Request, Response, status
|
|
from typing import Optional, Union
|
|
import logging
|
|
import asyncio
|
|
|
|
from app.models.auth_models import AccountContext
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
# --- Deprecation Dependency ---
|
|
|
|
class DeprecationParams:
|
|
"""Logs a warning when a legacy/deprecated route is accessed."""
|
|
def __init__(self, request: Request):
|
|
log.warning(f"!!! DEPRECATED ROUTE ACCESSED: [{request.method}] {request.url.path} from {request.client.host if request.client else 'Unknown'}")
|
|
|
|
# --- Account Context Dependencies ---
|
|
|
|
def get_account_context_optional(
|
|
x_account_id: Optional[str] = Header(None, min_length=11, max_length=22),
|
|
x_no_account_id: Optional[str] = Header(None, min_length=3, max_length=100),
|
|
x_no_account_id_token: Optional[str] = Query(None, alias='jwt', min_length=11),
|
|
x_aether_api_key: Optional[str] = Header(None, min_length=11, max_length=22),
|
|
x_aether_api_key_query: Optional[str] = Query(None, alias='api_key', min_length=11, max_length=22),
|
|
) -> AccountContext:
|
|
"""
|
|
Resolves the account context and enforces API Key validation.
|
|
Supports API Key in Header or Query param 'api_key'.
|
|
Uses DEFERRED imports to prevent circular dependency at startup.
|
|
"""
|
|
from app.db_sql import redis_lookup_id_random, sql_select
|
|
from app.lib_jwt import decode_jwt
|
|
from app.config import settings
|
|
from datetime import datetime
|
|
|
|
resolved_account_id = None
|
|
resolved_account_id_random = None
|
|
resolved_token_payload = None
|
|
auth_method = 'guest'
|
|
api_key_authorized = False
|
|
auth_error = None
|
|
|
|
# 1. Mandatory Machine Auth (API Key)
|
|
# Prefer header, fallback to query param
|
|
key_to_check = x_aether_api_key or x_aether_api_key_query
|
|
|
|
if key_to_check:
|
|
sql = "SELECT * FROM api_key WHERE (public_key = :key OR secret_key = :key) LIMIT 1"
|
|
if api_key_results := sql_select(sql=sql, data={'key': key_to_check}):
|
|
# sql_select returns a list when raw SQL is used
|
|
api_key_rec = api_key_results[0] if isinstance(api_key_results, list) else api_key_results
|
|
|
|
if api_key_rec.get('enable'):
|
|
now = datetime.now()
|
|
enable_from = api_key_rec.get('enable_from')
|
|
enable_to = api_key_rec.get('enable_to')
|
|
if (not enable_from or enable_from <= now) and (not enable_to or now <= enable_to):
|
|
api_key_authorized = True
|
|
else:
|
|
auth_error = "API Key expired or not yet valid."
|
|
log.error(f"Security: {auth_error} Key: {key_to_check}")
|
|
else:
|
|
auth_error = "API Key is disabled."
|
|
log.error(f"Security: {auth_error} Key: {key_to_check}")
|
|
else:
|
|
auth_error = "API Key not found or invalid."
|
|
log.error(f"Security: {auth_error} Key: {key_to_check}")
|
|
else:
|
|
auth_error = "Mandatory API Key missing."
|
|
|
|
# 2. Context Resolution (Only if API Key is valid)
|
|
if api_key_authorized:
|
|
# Default to machine auth if no account context is provided
|
|
auth_method = 'api_key'
|
|
auth_error = "Account context required for this operation."
|
|
|
|
# A. Resolve via Account ID Header
|
|
if x_account_id:
|
|
resolved_account_id_random = x_account_id
|
|
if looked_up_id := redis_lookup_id_random(table_name='account', record_id_random=x_account_id):
|
|
resolved_account_id = looked_up_id
|
|
auth_method = 'account_header'
|
|
auth_error = None
|
|
else:
|
|
auth_error = f"Account ID '{x_account_id}' not found."
|
|
|
|
# B. Resolve via JWT / Token Query Param
|
|
elif x_no_account_id_token:
|
|
# Check if it's a real JWT (contains dots)
|
|
if '.' in x_no_account_id_token:
|
|
if decoded := decode_jwt(secret_key=settings.JWT_KEY, token=x_no_account_id_token):
|
|
# Capture the full payload for session context (even for guests)
|
|
resolved_token_payload = decoded
|
|
|
|
# In Aether, JWTs store the RANDOM string IDs to prevent exposure
|
|
resolved_account_id_random = decoded.get('account_id')
|
|
if resolved_account_id_random:
|
|
if looked_up_id := redis_lookup_id_random(table_name='account', record_id_random=resolved_account_id_random):
|
|
resolved_account_id = looked_up_id
|
|
auth_method = 'jwt_token'
|
|
auth_error = None
|
|
else:
|
|
auth_error = f"Account ID '{resolved_account_id_random}' from token not found."
|
|
else:
|
|
# JWT is valid but has no account_id (e.g. platform-wide guest)
|
|
# We keep auth_method as 'jwt_token' but account_id as None.
|
|
auth_method = 'jwt_token'
|
|
auth_error = "Valid token provided, but no account context found in payload."
|
|
else:
|
|
auth_error = "Failed to decode JWT token."
|
|
log.warning(f"Security: {auth_error}")
|
|
|
|
# Legacy Fallback (just a raw random ID string)
|
|
if auth_method in ['guest', 'api_key', 'jwt_token'] and auth_error:
|
|
if looked_up_id := redis_lookup_id_random(table_name='account', record_id_random=x_no_account_id_token):
|
|
resolved_account_id = looked_up_id
|
|
resolved_account_id_random = x_no_account_id_token
|
|
auth_method = 'token_query'
|
|
auth_error = None
|
|
|
|
# C. Resolve via Administrative Bypass
|
|
elif x_no_account_id and x_no_account_id.lower() not in ['false', '0', 'null', 'undefined', 'none', 'no_account_id_here']:
|
|
resolved_account_id = 1
|
|
resolved_account_id_random = '--- NO ACCOUNT ---'
|
|
auth_method = 'bypass'
|
|
auth_error = None
|
|
|
|
log.info(f"V3 Auth: method={auth_method}, authorized={api_key_authorized}, account={resolved_account_id_random}, error={auth_error}")
|
|
|
|
is_admin = (auth_method == 'bypass')
|
|
is_manager = (auth_method == 'bypass')
|
|
is_super = (auth_method == 'bypass')
|
|
|
|
if resolved_token_payload:
|
|
if resolved_token_payload.get('administrator'): is_admin = True
|
|
if resolved_token_payload.get('manager'): is_manager = True
|
|
if resolved_token_payload.get('super'): is_super = True
|
|
|
|
return AccountContext(
|
|
account_id=resolved_account_id,
|
|
account_id_random=resolved_account_id_random,
|
|
auth_method=auth_method,
|
|
administrator=is_admin,
|
|
manager=is_manager,
|
|
super=is_super,
|
|
token_payload=resolved_token_payload,
|
|
auth_error=auth_error
|
|
)
|
|
|
|
def get_account_context(
|
|
x_account_id: Optional[str] = Header(None, min_length=11, max_length=22),
|
|
x_no_account_id: Optional[str] = Header(None, min_length=3, max_length=100),
|
|
x_no_account_id_token: Optional[str] = Query(None, alias='jwt', min_length=11),
|
|
x_aether_api_key: Optional[str] = Header(None, min_length=11, max_length=22),
|
|
x_aether_api_key_query: Optional[str] = Query(None, alias='api_key', min_length=11, max_length=22),
|
|
) -> AccountContext:
|
|
"""Strict version of account context resolution."""
|
|
ctx = get_account_context_optional(x_account_id, x_no_account_id, x_no_account_id_token, x_aether_api_key, x_aether_api_key_query)
|
|
if ctx.auth_method == 'guest' or (ctx.account_id is None and not ctx.super):
|
|
reason = ctx.auth_error or "Account context required."
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=reason)
|
|
return ctx
|
|
|
|
|
|
# --- Shared Pagination & Status Dependencies ---
|
|
|
|
class PaginationParams:
|
|
def __init__(
|
|
self,
|
|
limit: int = Query(100, ge=0),
|
|
offset: int = Query(0, ge=0),
|
|
):
|
|
self.limit = limit
|
|
self.offset = offset
|
|
|
|
class StatusFilterParams:
|
|
def __init__(
|
|
self,
|
|
enabled: str = Query('enabled'),
|
|
hidden: str = Query('not_hidden'),
|
|
):
|
|
self.enabled = enabled
|
|
self.hidden = hidden
|
|
|
|
class SerializationParams:
|
|
def __init__(
|
|
self,
|
|
by_alias: bool = Query(True),
|
|
exclude_unset: bool = Query(False),
|
|
exclude_defaults: bool = Query(False),
|
|
exclude_none: bool = Query(False),
|
|
):
|
|
self.by_alias = by_alias
|
|
self.exclude_unset = exclude_unset
|
|
self.exclude_defaults = exclude_defaults
|
|
self.exclude_none = exclude_none
|
|
|
|
class DelayParams:
|
|
def __init__(
|
|
self,
|
|
x_delay_ms: Optional[int] = Header(0, alias='X-Delay-ms'),
|
|
delay_ms: Optional[int] = Query(0),
|
|
):
|
|
val = max(x_delay_ms or 0, delay_ms or 0)
|
|
self.sleep_time_ms = val
|
|
self.sleep_time_s = val / 1000.0
|