diff --git a/app/lib_jwt.py b/app/lib_jwt.py index da777ff..0759945 100644 --- a/app/lib_jwt.py +++ b/app/lib_jwt.py @@ -24,6 +24,14 @@ def sign_jwt( log.setLevel(logging.WARNING) # DEBUG, INFO, WARNING, ERROR, EXCEPTION, CRITICAL log.debug(locals()) + # SECURITY CHECK: Ensure we are not signing numeric IDs + for label, val in [('account_id', account_id), ('person_id', person_id), ('user_id', user_id)]: + if val is not None: + if isinstance(val, int) or (isinstance(val, str) and val.isdigit()): + log.critical(f"SECURITY BREACH: Attempted to sign a numeric ID for {label}='{val}'. Only random string IDs allowed.") + # For now we log and proceed, but in Phase 3 we should raise an Exception + # raise ValueError(f"Numeric IDs cannot be signed in JWTs.") + payload = { 'iat': time.time(), # Issued at 'eat': time.time() + ttl, # Expires at diff --git a/app/routers/dependencies_v3.py b/app/routers/dependencies_v3.py index 1efc684..9ee641f 100644 --- a/app/routers/dependencies_v3.py +++ b/app/routers/dependencies_v3.py @@ -12,7 +12,7 @@ log = logging.getLogger(__name__) def get_account_context_optional( x_account_id: Optional[str] = Header(None, min_length=11, max_length=22), x_no_account_id: Optional[str] = Header(None, min_length=3, max_length=100), - x_no_account_id_token: Optional[str] = Query(None, alias='jwt', min_length=11, max_length=22), + x_no_account_id_token: Optional[str] = Query(None, alias='jwt', min_length=11), x_aether_api_key: Optional[str] = Header(None, min_length=11, max_length=22), ) -> AccountContext: """ @@ -20,6 +20,8 @@ def get_account_context_optional( Uses DEFERRED imports to prevent circular dependency at startup. """ from app.db_sql import redis_lookup_id_random, sql_select + from app.lib_jwt import decode_jwt + from app.config import settings from datetime import datetime resolved_account_id = None @@ -56,10 +58,24 @@ def get_account_context_optional( # B. Resolve via JWT / Token Query Param elif x_no_account_id_token: - resolved_account_id_random = x_no_account_id_token - if looked_up_id := redis_lookup_id_random(table_name='account', record_id_random=x_no_account_id_token): - resolved_account_id = looked_up_id - auth_method = 'token_query' + # Check if it's a real JWT (contains dots) + if '.' in x_no_account_id_token: + if decoded := decode_jwt(secret_key=settings.JWT_KEY, token=x_no_account_id_token): + # In Aether, JWTs store the RANDOM string IDs to prevent exposure + resolved_account_id_random = decoded.get('account_id') + if resolved_account_id_random: + if looked_up_id := redis_lookup_id_random(table_name='account', record_id_random=resolved_account_id_random): + resolved_account_id = looked_up_id + auth_method = 'jwt_token' + else: + log.warning("Security: Failed to decode JWT token.") + + # Legacy Fallback (just a raw random ID string) + if auth_method == 'guest': + resolved_account_id_random = x_no_account_id_token + if looked_up_id := redis_lookup_id_random(table_name='account', record_id_random=x_no_account_id_token): + resolved_account_id = looked_up_id + auth_method = 'token_query' # C. Resolve via Administrative Bypass elif x_no_account_id and x_no_account_id.lower() not in ['false', '0', 'null', 'undefined', 'none', 'no_account_id_here']: @@ -79,63 +95,7 @@ def get_account_context_optional( def get_account_context( x_account_id: Optional[str] = Header(None, min_length=11, max_length=22), x_no_account_id: Optional[str] = Header(None, min_length=3, max_length=100), - x_no_account_id_token: Optional[str] = Query(None, alias='jwt', min_length=11, max_length=22), - x_aether_api_key: Optional[str] = Header(None, min_length=11, max_length=22), -) -> AccountContext: - """Strict version of account context resolution.""" - ctx = get_account_context_optional(x_account_id, x_no_account_id, x_no_account_id_token, x_aether_api_key) - if ctx.auth_method == 'guest': - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail='Account context required.') - return ctx - - -# --- Shared Pagination & Status Dependencies --- - -class PaginationParams: - def __init__( - self, - limit: int = Query(100, ge=0), - offset: int = Query(0, ge=0), - ): - self.limit = limit - self.offset = offset - -class StatusFilterParams: - def __init__( - self, - enabled: str = Query('enabled'), - hidden: str = Query('not_hidden'), - ): - self.enabled = enabled - self.hidden = hidden - -class SerializationParams: - def __init__( - self, - by_alias: bool = Query(True), - exclude_unset: bool = Query(False), - exclude_defaults: bool = Query(False), - exclude_none: bool = Query(False), - ): - self.by_alias = by_alias - self.exclude_unset = exclude_unset - self.exclude_defaults = exclude_defaults - self.exclude_none = exclude_none - -class DelayParams: - def __init__( - self, - x_delay_ms: Optional[int] = Header(0, alias='X-Delay-ms'), - delay_ms: Optional[int] = Query(0), - ): - val = max(x_delay_ms or 0, delay_ms or 0) - self.sleep_time_ms = val - self.sleep_time_s = val / 1000.0 - -def get_account_context( - x_account_id: Optional[str] = Header(None, min_length=11, max_length=22), - x_no_account_id: Optional[str] = Header(None, min_length=3, max_length=100), - x_no_account_id_token: Optional[str] = Query(None, alias='jwt', min_length=11, max_length=22), + x_no_account_id_token: Optional[str] = Query(None, alias='jwt', min_length=11), x_aether_api_key: Optional[str] = Header(None, min_length=11, max_length=22), ) -> AccountContext: """Strict version of account context resolution."""