226 lines
8.1 KiB
Python
226 lines
8.1 KiB
Python
"""
|
|
This file contains general utility functions and helpers specifically for API v3.
|
|
It aims to provide a clean slate for new methods and refactor existing ones from lib_general.py
|
|
that are relevant to the v3 API, while removing unused or outdated functionalities.
|
|
"""
|
|
|
|
# Standard library imports
|
|
import time
|
|
import logging
|
|
import jwt
|
|
from typing import (
|
|
Any,
|
|
Dict,
|
|
List,
|
|
Optional,
|
|
Union,
|
|
)
|
|
|
|
# Third-party imports
|
|
from fastapi import (
|
|
APIRouter,
|
|
Depends,
|
|
Header,
|
|
HTTPException,
|
|
Query,
|
|
Request,
|
|
Response,
|
|
status,
|
|
)
|
|
from pydantic import (
|
|
BaseModel,
|
|
Field,
|
|
ValidationError,
|
|
)
|
|
|
|
# Internal imports (from this project)
|
|
from app.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def decode_jwt(
|
|
secret_key: str,
|
|
token: str,
|
|
) -> dict:
|
|
"""
|
|
Decodes and validates a JWT token.
|
|
Ported from lib_general.py to break circular dependencies.
|
|
"""
|
|
algorithm = 'HS256'
|
|
try:
|
|
decoded_token = jwt.decode(token, secret_key, algorithms=[algorithm])
|
|
if decoded_token['eat'] >= time.time():
|
|
return decoded_token
|
|
else:
|
|
return False
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
# --- Pydantic Model for Authentication Context ---
|
|
class AuthContext(BaseModel):
|
|
account_id: Optional[int] = None
|
|
account_id_random: Optional[str] = None
|
|
user_id: Optional[int] = None
|
|
person_id: Optional[int] = None
|
|
auth_method: str = 'none' # 'jwt_header', 'jwt_query', 'legacy_header', 'bypass'
|
|
|
|
# Alias for backward compatibility with initial V3 implementation
|
|
AccountContext = AuthContext
|
|
|
|
|
|
# --- Dependency Function for V3 Authentication ---
|
|
def get_v3_auth_context(
|
|
request: Request,
|
|
authorization: Optional[str] = Header(None, description="Bearer <jwt_token>"),
|
|
jwt_query: Optional[str] = Query(None, alias="jwt", description="JWT token for URL-based auth (e.g., file downloads)"),
|
|
x_account_id: Optional[str] = Header(None, min_length=11, max_length=22, description="Legacy X-Account-ID header"),
|
|
x_no_account_id: Optional[str] = Header(None, min_length=3, max_length=100, description="Bypass account context header"),
|
|
) -> AuthContext:
|
|
"""
|
|
Standardized V3 Authentication Dependency.
|
|
Supports JWT in Authorization header (Bearer) OR 'jwt' query parameter.
|
|
Falls back to legacy headers for backward compatibility.
|
|
"""
|
|
# Defer import to break circular dependency
|
|
from app.db_sql import redis_lookup_id_random
|
|
|
|
# 1. Check for JWT (Header preferred, then Query for downloads)
|
|
token = None
|
|
method = 'none'
|
|
|
|
if authorization and authorization.startswith("Bearer "):
|
|
token = authorization.split(" ")[1]
|
|
method = 'jwt_header'
|
|
elif jwt_query:
|
|
token = jwt_query
|
|
method = 'jwt_query'
|
|
|
|
if token:
|
|
payload = decode_jwt(settings.JWT_KEY, token)
|
|
if payload:
|
|
logger.info(f"JWT Validated ({method}). User: {payload.get('user_id')}, Account: {payload.get('account_id')}")
|
|
return AuthContext(
|
|
account_id=payload.get('account_id'),
|
|
account_id_random=payload.get('public_key'), # existing sign_jwt uses public_key for id_random
|
|
user_id=payload.get('user_id'),
|
|
person_id=payload.get('person_id'),
|
|
auth_method=method
|
|
)
|
|
else:
|
|
logger.warning(f"Invalid or expired JWT provided via {method}")
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired authentication token.")
|
|
|
|
# 2. Legacy / Testing Fallback: x_account_id
|
|
if x_account_id:
|
|
if looked_up_id := redis_lookup_id_random(table_name='account', record_id_random=x_account_id):
|
|
logger.info(f"Authenticated via legacy header: {looked_up_id}")
|
|
return AuthContext(
|
|
account_id=looked_up_id,
|
|
account_id_random=x_account_id,
|
|
auth_method='legacy_header'
|
|
)
|
|
else:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid X-Account-ID header.")
|
|
|
|
# 3. Bypass Fallback
|
|
if x_no_account_id:
|
|
logger.info("Authentication bypassed via X-No-Account-ID")
|
|
return AuthContext(
|
|
account_id_random='--- NO ACCOUNT ---',
|
|
auth_method='bypass'
|
|
)
|
|
|
|
# 4. No Auth Found
|
|
logger.warning("No authentication provided for V3 endpoint.")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Authentication required. Provide Authorization header or 'jwt' query parameter."
|
|
)
|
|
|
|
|
|
# --- Legacy wrapper to avoid breaking current V3 code ---
|
|
def get_account_context(
|
|
auth: AuthContext = Depends(get_v3_auth_context)
|
|
) -> AuthContext:
|
|
"""
|
|
Alias for the new auth dependency to maintain compatibility
|
|
with existing V3 routes.
|
|
"""
|
|
return auth
|
|
|
|
|
|
# --- Pydantic Model for Pagination ---
|
|
class PaginationParams(BaseModel):
|
|
limit: int = 100 # Default limit
|
|
offset: int = 0
|
|
|
|
# --- Dependency Function for Pagination ---
|
|
def get_pagination_params(
|
|
limit: int = Query(100, ge=0, description="Maximum number of items to return"),
|
|
offset: int = Query(0, ge=0, description="Number of items to skip (for pagination)"),
|
|
) -> PaginationParams:
|
|
return PaginationParams(limit=limit, offset=offset)
|
|
|
|
|
|
# --- Pydantic Model for Status Filtering ---
|
|
class StatusFilterParams(BaseModel):
|
|
enabled: str = 'enabled' # 'enabled', 'disabled', 'all'
|
|
hidden: str = 'not_hidden' # 'hidden', 'not_hidden', 'all'
|
|
|
|
# --- Dependency Function for Status Filtering ---
|
|
def get_status_filter_params(
|
|
enabled: str = Query('enabled', description="Filter by object enabled status ('enabled', 'disabled', 'all')"),
|
|
hidden: str = Query('not_hidden', description="Filter by object hidden status ('hidden', 'not_hidden', 'all')"),
|
|
) -> StatusFilterParams:
|
|
allowed_enabled_values = {'enabled', 'disabled', 'all'}
|
|
allowed_hidden_values = {'hidden', 'not_hidden', 'all'}
|
|
|
|
if enabled not in allowed_enabled_values:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Invalid value for 'enabled'. Must be one of {list(allowed_enabled_values)}."
|
|
)
|
|
if hidden not in allowed_hidden_values:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Invalid value for 'hidden'. Must be one of {list(allowed_hidden_values)}."
|
|
)
|
|
return StatusFilterParams(enabled=enabled, hidden=hidden)
|
|
|
|
|
|
# --- Pydantic Model for Serialization Options ---
|
|
class SerializationParams(BaseModel):
|
|
by_alias: bool = True
|
|
exclude_unset: bool = False
|
|
exclude_defaults: bool = False # Added based on common_route_params
|
|
exclude_none: bool = False # Added based on common_route_params
|
|
|
|
# --- Dependency Function for Serialization Options ---
|
|
def get_serialization_params(
|
|
by_alias: bool = Query(True, description="Whether to use field aliases for serialization"),
|
|
exclude_unset: bool = Query(False, description="Whether to exclude unset fields from the response"),
|
|
exclude_defaults: bool = Query(False, description="Whether to exclude fields with their default values from the response"),
|
|
exclude_none: bool = Query(False, description="Whether to exclude fields that are None from the response"),
|
|
) -> SerializationParams:
|
|
return SerializationParams(
|
|
by_alias=by_alias,
|
|
exclude_unset=exclude_unset,
|
|
exclude_defaults=exclude_defaults,
|
|
exclude_none=exclude_none,
|
|
)
|
|
|
|
|
|
# --- Pydantic Model for Delay ---
|
|
class DelayParams(BaseModel):
|
|
sleep_time_ms: int = 0 # Raw delay value in ms
|
|
sleep_time_s: float = 0.0 # Converted to seconds for time.sleep()
|
|
|
|
# --- Dependency Function for Delay ---
|
|
def get_delay_params(
|
|
x_delay_ms: Optional[int] = Header(0, alias='X-Delay-ms', description="Delay response for X milliseconds (header)"),
|
|
delay_ms: Optional[int] = Query(0, description="Delay response for X milliseconds (query parameter)"),
|
|
) -> DelayParams:
|
|
calculated_delay_ms = max(x_delay_ms or 0, delay_ms or 0)
|
|
return DelayParams(sleep_time_ms=calculated_delay_ms, sleep_time_s=calculated_delay_ms / 1000.0) |