""" This file contains general utility functions and helpers specifically for API v3. It aims to provide a clean slate for new methods and refactor existing ones from lib_general.py that are relevant to the v3 API, while removing unused or outdated functionalities. """ # Standard library imports import time import logging import jwt from typing import ( Any, Dict, List, Optional, Union, ) # Third-party imports from fastapi import ( APIRouter, Depends, Header, HTTPException, Query, Request, Response, status, ) from pydantic import ( BaseModel, Field, ValidationError, ) # Internal imports (from this project) from app.config import settings logger = logging.getLogger(__name__) def decode_jwt( secret_key: str, token: str, ) -> dict: """ Decodes and validates a JWT token. Ported from lib_general.py to break circular dependencies. """ algorithm = 'HS256' try: decoded_token = jwt.decode(token, secret_key, algorithms=[algorithm]) if decoded_token['eat'] >= time.time(): return decoded_token else: return False except Exception: return None # --- Pydantic Model for Authentication Context --- class AuthContext(BaseModel): account_id: Optional[int] = None account_id_random: Optional[str] = None user_id: Optional[int] = None person_id: Optional[int] = None auth_method: str = 'none' # 'jwt_header', 'jwt_query', 'legacy_header', 'bypass' # Alias for backward compatibility with initial V3 implementation AccountContext = AuthContext # --- Dependency Function for V3 Authentication --- def get_v3_auth_context( request: Request, authorization: Optional[str] = Header(None, description="Bearer "), jwt_query: Optional[str] = Query(None, alias="jwt", description="JWT token for URL-based auth (e.g., file downloads)"), x_account_id: Optional[str] = Header(None, min_length=11, max_length=22, description="Legacy X-Account-ID header"), x_no_account_id: Optional[str] = Header(None, min_length=3, max_length=100, description="Bypass account context header"), ) -> AuthContext: """ Standardized V3 Authentication Dependency. Supports JWT in Authorization header (Bearer) OR 'jwt' query parameter. Falls back to legacy headers for backward compatibility. """ # Defer import to break circular dependency from app.db_sql import redis_lookup_id_random # 1. Check for JWT (Header preferred, then Query for downloads) token = None method = 'none' if authorization and authorization.startswith("Bearer "): token = authorization.split(" ")[1] method = 'jwt_header' elif jwt_query: token = jwt_query method = 'jwt_query' if token: payload = decode_jwt(settings.JWT_KEY, token) if payload: logger.info(f"JWT Validated ({method}). User: {payload.get('user_id')}, Account: {payload.get('account_id')}") return AuthContext( account_id=payload.get('account_id'), account_id_random=payload.get('public_key'), # existing sign_jwt uses public_key for id_random user_id=payload.get('user_id'), person_id=payload.get('person_id'), auth_method=method ) else: logger.warning(f"Invalid or expired JWT provided via {method}") raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired authentication token.") # 2. Legacy / Testing Fallback: x_account_id if x_account_id: if looked_up_id := redis_lookup_id_random(table_name='account', record_id_random=x_account_id): logger.info(f"Authenticated via legacy header: {looked_up_id}") return AuthContext( account_id=looked_up_id, account_id_random=x_account_id, auth_method='legacy_header' ) else: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid X-Account-ID header.") # 3. Bypass Fallback if x_no_account_id: logger.info("Authentication bypassed via X-No-Account-ID") return AuthContext( account_id_random='--- NO ACCOUNT ---', auth_method='bypass' ) # 4. No Auth Found logger.warning("No authentication provided for V3 endpoint.") raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Authentication required. Provide Authorization header or 'jwt' query parameter." ) # --- Legacy wrapper to avoid breaking current V3 code --- def get_account_context( auth: AuthContext = Depends(get_v3_auth_context) ) -> AuthContext: """ Alias for the new auth dependency to maintain compatibility with existing V3 routes. """ return auth # --- Pydantic Model for Pagination --- class PaginationParams(BaseModel): limit: int = 100 # Default limit offset: int = 0 # --- Dependency Function for Pagination --- def get_pagination_params( limit: int = Query(100, ge=0, description="Maximum number of items to return"), offset: int = Query(0, ge=0, description="Number of items to skip (for pagination)"), ) -> PaginationParams: return PaginationParams(limit=limit, offset=offset) # --- Pydantic Model for Status Filtering --- class StatusFilterParams(BaseModel): enabled: str = 'enabled' # 'enabled', 'disabled', 'all' hidden: str = 'not_hidden' # 'hidden', 'not_hidden', 'all' # --- Dependency Function for Status Filtering --- def get_status_filter_params( enabled: str = Query('enabled', description="Filter by object enabled status ('enabled', 'disabled', 'all')"), hidden: str = Query('not_hidden', description="Filter by object hidden status ('hidden', 'not_hidden', 'all')"), ) -> StatusFilterParams: allowed_enabled_values = {'enabled', 'disabled', 'all'} allowed_hidden_values = {'hidden', 'not_hidden', 'all'} if enabled not in allowed_enabled_values: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid value for 'enabled'. Must be one of {list(allowed_enabled_values)}." ) if hidden not in allowed_hidden_values: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid value for 'hidden'. Must be one of {list(allowed_hidden_values)}." ) return StatusFilterParams(enabled=enabled, hidden=hidden) # --- Pydantic Model for Serialization Options --- class SerializationParams(BaseModel): by_alias: bool = True exclude_unset: bool = False exclude_defaults: bool = False # Added based on common_route_params exclude_none: bool = False # Added based on common_route_params # --- Dependency Function for Serialization Options --- def get_serialization_params( by_alias: bool = Query(True, description="Whether to use field aliases for serialization"), exclude_unset: bool = Query(False, description="Whether to exclude unset fields from the response"), exclude_defaults: bool = Query(False, description="Whether to exclude fields with their default values from the response"), exclude_none: bool = Query(False, description="Whether to exclude fields that are None from the response"), ) -> SerializationParams: return SerializationParams( by_alias=by_alias, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, ) # --- Pydantic Model for Delay --- class DelayParams(BaseModel): sleep_time_ms: int = 0 # Raw delay value in ms sleep_time_s: float = 0.0 # Converted to seconds for time.sleep() # --- Dependency Function for Delay --- def get_delay_params( x_delay_ms: Optional[int] = Header(0, alias='X-Delay-ms', description="Delay response for X milliseconds (header)"), delay_ms: Optional[int] = Query(0, description="Delay response for X milliseconds (query parameter)"), ) -> DelayParams: calculated_delay_ms = max(x_delay_ms or 0, delay_ms or 0) return DelayParams(sleep_time_ms=calculated_delay_ms, sleep_time_s=calculated_delay_ms / 1000.0)