Files
OSIT-AE-API-FastAPI/app/routers/dependencies_v3.py

190 lines
7.0 KiB
Python

from fastapi import Depends, Header, HTTPException, Query, Response, status
from typing import Optional, Union
import logging
import asyncio
from app.models.auth_models import AccountContext
log = logging.getLogger(__name__)
# --- Account Context Dependencies ---
def get_account_context_optional(
x_account_id: Optional[str] = Header(None, min_length=11, max_length=22),
x_no_account_id: Optional[str] = Header(None, min_length=3, max_length=100),
x_no_account_id_token: Optional[str] = Query(None, alias='jwt', min_length=11, max_length=22),
x_aether_api_key: Optional[str] = Header(None, min_length=11, max_length=22),
) -> AccountContext:
"""
Resolves the account context and enforces API Key validation.
Uses DEFERRED imports to prevent circular dependency at startup.
"""
from app.db_sql import redis_lookup_id_random, sql_select
from datetime import datetime
resolved_account_id = None
resolved_account_id_random = None
auth_method = 'guest'
api_key_authorized = False
# 1. Mandatory Machine Auth (API Key)
# This identifies the script/app, regardless of the user/account context.
if x_aether_api_key:
sql = "SELECT * FROM api_key WHERE (public_key = :key OR secret_key = :key) LIMIT 1"
if api_key_rec := sql_select(sql=sql, data={'key': x_aether_api_key}):
if api_key_rec.get('enable'):
now = datetime.now()
enable_from = api_key_rec.get('enable_from')
enable_to = api_key_rec.get('enable_to')
if (not enable_from or enable_from <= now) and (not enable_to or now <= enable_to):
api_key_authorized = True
else:
log.warning(f"Security: API Key {x_aether_api_key} expired/not yet valid.")
else:
log.warning(f"Security: API Key {x_aether_api_key} is disabled.")
else:
log.warning(f"Security: API Key {x_aether_api_key} not found.")
# 2. Context Resolution (Only if API Key is valid)
if api_key_authorized:
# A. Resolve via Account ID Header
if x_account_id:
resolved_account_id_random = x_account_id
if looked_up_id := redis_lookup_id_random(table_name='account', record_id_random=x_account_id):
resolved_account_id = looked_up_id
auth_method = 'legacy_header'
# B. Resolve via JWT / Token Query Param
elif x_no_account_id_token:
resolved_account_id_random = x_no_account_id_token
if looked_up_id := redis_lookup_id_random(table_name='account', record_id_random=x_no_account_id_token):
resolved_account_id = looked_up_id
auth_method = 'token_query'
# C. Resolve via Administrative Bypass
elif x_no_account_id and x_no_account_id.lower() not in ['false', '0', 'null', 'undefined', 'none', 'no_account_id_here']:
resolved_account_id = None
resolved_account_id_random = '--- NO ACCOUNT ---'
auth_method = 'bypass'
return AccountContext(
account_id=resolved_account_id,
account_id_random=resolved_account_id_random,
auth_method=auth_method,
administrator=(auth_method == 'bypass'),
manager=(auth_method == 'bypass'),
super=(auth_method == 'bypass')
)
def get_account_context(
x_account_id: Optional[str] = Header(None, min_length=11, max_length=22),
x_no_account_id: Optional[str] = Header(None, min_length=3, max_length=100),
x_no_account_id_token: Optional[str] = Query(None, alias='jwt', min_length=11, max_length=22),
x_aether_api_key: Optional[str] = Header(None, min_length=11, max_length=22),
) -> AccountContext:
"""Strict version of account context resolution."""
ctx = get_account_context_optional(x_account_id, x_no_account_id, x_no_account_id_token, x_aether_api_key)
if ctx.auth_method == 'guest':
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail='Account context required.')
return ctx
# --- Shared Pagination & Status Dependencies ---
class PaginationParams:
def __init__(
self,
limit: int = Query(100, ge=0),
offset: int = Query(0, ge=0),
):
self.limit = limit
self.offset = offset
class StatusFilterParams:
def __init__(
self,
enabled: str = Query('enabled'),
hidden: str = Query('not_hidden'),
):
self.enabled = enabled
self.hidden = hidden
class SerializationParams:
def __init__(
self,
by_alias: bool = Query(True),
exclude_unset: bool = Query(False),
exclude_defaults: bool = Query(False),
exclude_none: bool = Query(False),
):
self.by_alias = by_alias
self.exclude_unset = exclude_unset
self.exclude_defaults = exclude_defaults
self.exclude_none = exclude_none
class DelayParams:
def __init__(
self,
x_delay_ms: Optional[int] = Header(0, alias='X-Delay-ms'),
delay_ms: Optional[int] = Query(0),
):
val = max(x_delay_ms or 0, delay_ms or 0)
self.sleep_time_ms = val
self.sleep_time_s = val / 1000.0
def get_account_context(
x_account_id: Optional[str] = Header(None, min_length=11, max_length=22),
x_no_account_id: Optional[str] = Header(None, min_length=3, max_length=100),
x_no_account_id_token: Optional[str] = Query(None, alias='jwt', min_length=11, max_length=22),
x_aether_api_key: Optional[str] = Header(None, min_length=11, max_length=22),
) -> AccountContext:
"""Strict version of account context resolution."""
ctx = get_account_context_optional(x_account_id, x_no_account_id, x_no_account_id_token, x_aether_api_key)
if ctx.auth_method == 'guest':
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail='Account context required.')
return ctx
# --- Shared Pagination & Status Dependencies ---
class PaginationParams:
def __init__(
self,
limit: int = Query(100, ge=0),
offset: int = Query(0, ge=0),
):
self.limit = limit
self.offset = offset
class StatusFilterParams:
def __init__(
self,
enabled: str = Query('enabled'),
hidden: str = Query('not_hidden'),
):
self.enabled = enabled
self.hidden = hidden
class SerializationParams:
def __init__(
self,
by_alias: bool = Query(True),
exclude_unset: bool = Query(False),
exclude_defaults: bool = Query(False),
exclude_none: bool = Query(False),
):
self.by_alias = by_alias
self.exclude_unset = exclude_unset
self.exclude_defaults = exclude_defaults
self.exclude_none = exclude_none
class DelayParams:
def __init__(
self,
x_delay_ms: Optional[int] = Header(0, alias='X-Delay-ms'),
delay_ms: Optional[int] = Query(0),
):
val = max(x_delay_ms or 0, delay_ms or 0)
self.sleep_time_ms = val
self.sleep_time_s = val / 1000.0