From 8459b57e1b39b5ee37ed6ac36feff294b3b0f79c Mon Sep 17 00:00:00 2001 From: Scott Idem Date: Fri, 9 Jan 2026 16:16:44 -0500 Subject: [PATCH] Refactor V3 CRUD: Extract helper functions and unify sanitization logic. - Created app/lib_api_crud_v3.py to house core security, filtering, and sanitization logic. - Implemented reusable sanitize_payload() to generically strip virtual lookup fields (*_id_random) and view-only fields (fields_to_exclude_from_db). - Updated app/routers/api_crud_v3.py to use the new library and consolidated sanitization across all Create/Update endpoints. - Documented Phase 1 completion in documentation/REFACTOR_API_CRUD_V3.md. --- app/lib_api_crud_v3.py | 131 +++++++++++++++++++++++++ app/routers/api_crud_v3.py | 136 +++----------------------- documentation/REFACTOR_API_CRUD_V3.md | 35 +++++++ 3 files changed, 180 insertions(+), 122 deletions(-) create mode 100644 app/lib_api_crud_v3.py create mode 100644 documentation/REFACTOR_API_CRUD_V3.md diff --git a/app/lib_api_crud_v3.py b/app/lib_api_crud_v3.py new file mode 100644 index 0000000..3ace5c5 --- /dev/null +++ b/app/lib_api_crud_v3.py @@ -0,0 +1,131 @@ +from typing import Any, Dict, Optional +import json +import logging + +from app.lib_general_v3 import AccountContext, StatusFilterParams + +log = logging.getLogger(__name__) + +def check_account_access(sql_result: Any, account: AccountContext, obj_name: str = None) -> bool: + """ + Enforce Multi-Tenant Data Isolation. + + Verifies that the requested record belongs to the authenticated user's account. + Returns True if: + - User is a Super User or System (Bypass). + - The record's `account_id` matches the user's `account_id`. + """ + if account.super or account.auth_method == 'bypass': + return True + if not account.account_id: + return False + + res_account_id = None + if isinstance(sql_result, dict): + if obj_name == 'account': + res_account_id = sql_result.get('id') + else: + res_account_id = sql_result.get('account_id') + + if res_account_id is not None and res_account_id != account.account_id: + return False + return True + +def apply_forced_account_filter(and_qry_dict: Optional[Dict], account: AccountContext, model: Any, obj_name: str) -> Dict: + """ + Secure Search Filtering. + + Automatically appends an `account_id` filter to database queries to ensure + users only retrieve records associated with their own account. + """ + forced = and_qry_dict or {} + if account.super or account.auth_method == 'bypass': + return forced + + if obj_name == 'account': + forced['id'] = account.account_id + elif model and hasattr(model, '__fields__') and 'account_id' in model.__fields__: + forced['account_id'] = account.account_id + + return forced + +def filter_order_by(order_by_li: Any, model: Any, table_name: str = None) -> Optional[Dict[str, str]]: + """ + Sanitize Sorting Parameters. + + Prevents SQL injection and logic errors by validating that requested sort columns + actually exist in the Pydantic model and/or the database table. + """ + if not order_by_li or not isinstance(order_by_li, dict) or not model: + return order_by_li + if not hasattr(model, '__fields__'): + return order_by_li + + model_fields = set(model.__fields__.keys()) + model_fields.update({f.alias for f in model.__fields__.values() if f.alias}) + filtered = {k: v for k, v in order_by_li.items() if k in model_fields} + + if table_name and filtered: + from app.db_sql import db + from sqlalchemy import text + final_filtered = {} + for column in filtered: + try: + # Lightweight check to see if column exists in SQL + db.execute(text(f"SELECT `{column}` FROM `{table_name}` LIMIT 0")) + final_filtered[column] = filtered[column] + except Exception: + pass + filtered = final_filtered + return filtered + +def get_supported_filters(model: Any, status_filter: StatusFilterParams) -> StatusFilterParams: + """ + Adaptive Status Filtering. + + Adjusts the default filters (enabled/hidden) based on whether the target object + actually supports those concepts (i.e., has those columns). + """ + if not model or not hasattr(model, "__fields__"): + return status_filter + # We create a new instance to avoid side effects on the dependency object + from app.routers.dependencies_v3 import StatusFilterParams as SF + adjusted = SF() + adjusted.enabled = status_filter.enabled + adjusted.hidden = status_filter.hidden + + if 'enable' not in model.__fields__: + adjusted.enabled = 'all' + if 'hide' not in model.__fields__: + adjusted.hidden = 'all' + return adjusted + +def safe_json_loads(json_str: Optional[str]) -> Any: + if not json_str or json_str == 'undefined': return None + try: return json.loads(json_str) + except: return None + +def sanitize_payload(data: dict, model: Any) -> None: + """ + Sanitizes an input payload before database insertion or update. + + 1. Removes virtual lookup fields (ending in `_id_random`) that are used for API + convenience but do not exist in the database. + 2. Removes fields explicitly marked for exclusion in the model's + `fields_to_exclude_from_db` ClassVar (e.g., view-only fields). + + Modifies the `data` dictionary in-place. + """ + if not isinstance(data, dict): + return + + # Filter out virtual _id_random fields (e.g., account_id_random) + keys_to_remove = [k for k in data.keys() if k.endswith('_id_random') and k != 'id_random'] + for k in keys_to_remove: + del data[k] + + # Filter out model-specific excluded fields (e.g., view-only fields) + if hasattr(model, 'fields_to_exclude_from_db'): + for k in model.fields_to_exclude_from_db: + if k in data: + del data[k] diff --git a/app/routers/api_crud_v3.py b/app/routers/api_crud_v3.py index ccf71ba..7189f75 100644 --- a/app/routers/api_crud_v3.py +++ b/app/routers/api_crud_v3.py @@ -13,6 +13,10 @@ from app.lib_general_v3 import ( PaginationParams, StatusFilterParams, SerializationParams, DelayParams ) +from app.lib_api_crud_v3 import ( + check_account_access, apply_forced_account_filter, filter_order_by, + get_supported_filters, safe_json_loads, sanitize_payload +) from app.models.response_models import * from app.models.api_crud_models import SearchFilter, SearchQuery from app.ae_obj_types_def import obj_type_kv_li @@ -33,108 +37,6 @@ Key Features: router = APIRouter() -# --- Helpers --- - -def check_account_access(sql_result: Any, account: AccountContext, obj_name: str = None) -> bool: - """ - Enforce Multi-Tenant Data Isolation. - - Verifies that the requested record belongs to the authenticated user's account. - Returns True if: - - User is a Super User or System (Bypass). - - The record's `account_id` matches the user's `account_id`. - """ - if account.super or account.auth_method == 'bypass': - return True - if not account.account_id: - return False - - res_account_id = None - if isinstance(sql_result, dict): - if obj_name == 'account': - res_account_id = sql_result.get('id') - else: - res_account_id = sql_result.get('account_id') - - if res_account_id is not None and res_account_id != account.account_id: - return False - return True - -def apply_forced_account_filter(and_qry_dict: Optional[Dict], account: AccountContext, model: Any, obj_name: str) -> Dict: - """ - Secure Search Filtering. - - Automatically appends an `account_id` filter to database queries to ensure - users only retrieve records associated with their own account. - """ - forced = and_qry_dict or {} - if account.super or account.auth_method == 'bypass': - return forced - - if obj_name == 'account': - forced['id'] = account.account_id - elif model and hasattr(model, '__fields__') and 'account_id' in model.__fields__: - forced['account_id'] = account.account_id - - return forced - -def filter_order_by(order_by_li: Any, model: Any, table_name: str = None) -> Optional[Dict[str, str]]: - """ - Sanitize Sorting Parameters. - - Prevents SQL injection and logic errors by validating that requested sort columns - actually exist in the Pydantic model and/or the database table. - """ - if not order_by_li or not isinstance(order_by_li, dict) or not model: - return order_by_li - if not hasattr(model, '__fields__'): - return order_by_li - - model_fields = set(model.__fields__.keys()) - model_fields.update({f.alias for f in model.__fields__.values() if f.alias}) - filtered = {k: v for k, v in order_by_li.items() if k in model_fields} - - if table_name and filtered: - from app.db_sql import db - from sqlalchemy import text - final_filtered = {} - for column in filtered: - try: - # Lightweight check to see if column exists in SQL - db.execute(text(f"SELECT `{column}` FROM `{table_name}` LIMIT 0")) - final_filtered[column] = filtered[column] - except Exception: - pass - filtered = final_filtered - return filtered - -def get_supported_filters(model: Any, status_filter: StatusFilterParams) -> StatusFilterParams: - """ - Adaptive Status Filtering. - - Adjusts the default filters (enabled/hidden) based on whether the target object - actually supports those concepts (i.e., has those columns). - """ - if not model or not hasattr(model, "__fields__"): - return status_filter - # We create a new instance to avoid side effects on the dependency object - from app.routers.dependencies_v3 import StatusFilterParams as SF - adjusted = SF() - adjusted.enabled = status_filter.enabled - adjusted.hidden = status_filter.hidden - - if 'enable' not in model.__fields__: - adjusted.enabled = 'all' - if 'hide' not in model.__fields__: - adjusted.hidden = 'all' - return adjusted - -def safe_json_loads(json_str: Optional[str]) -> Any: - if not json_str or json_str == 'undefined': return None - try: return json.loads(json_str) - except: return None - - # --- Routes --- @router.get("/health", response_model=Resp_Body_Base) @@ -515,16 +417,8 @@ async def post_obj( data_to_insert = validated_obj.dict(exclude_unset=True) - # Filter out virtual _id_random fields (e.g., account_id_random) that are not in the DB table - keys_to_remove = [k for k in data_to_insert.keys() if k.endswith('_id_random') and k != 'id_random'] - for k in keys_to_remove: - del data_to_insert[k] - - # Filter out model-specific excluded fields (e.g., view-only fields like person_full_name in Journal) - if hasattr(input_model, 'fields_to_exclude_from_db'): - for k in input_model.fields_to_exclude_from_db: - if k in data_to_insert: - del data_to_insert[k] + # Sanitize payload (remove virtual fields and view-only fields) + sanitize_payload(data_to_insert, input_model) if sql_insert_result := sql_insert(data=data_to_insert, table_name=table_name_insert): new_obj_id = sql_insert_result @@ -585,16 +479,8 @@ async def patch_obj( else: return mk_resp(data=False, status_code=404, response=response, status_message=f"Object with ID '{obj_id}' not found in database.") - # Filter out virtual _id_random fields (e.g., account_id_random) that are not in the DB table - keys_to_remove = [k for k in obj_data.keys() if k.endswith('_id_random') and k != 'id_random'] - for k in keys_to_remove: - del obj_data[k] - - # Filter out model-specific excluded fields (e.g., view-only fields like person_full_name in Journal) - if hasattr(input_model, 'fields_to_exclude_from_db'): - for k in input_model.fields_to_exclude_from_db: - if k in obj_data: - del obj_data[k] + # Sanitize payload (remove virtual fields and view-only fields) + sanitize_payload(obj_data, input_model) if sql_update(data=obj_data, table_name=table_name_update, record_id=record_id): if return_obj: @@ -814,6 +700,9 @@ async def post_child_obj( data_to_insert = validated_obj.dict(exclude_unset=True) + # Sanitize payload (remove virtual fields and view-only fields) + sanitize_payload(data_to_insert, input_model) + if sql_insert_result := sql_insert(data=data_to_insert, table_name=table_name_insert): new_obj_id = sql_insert_result new_obj_id_random = get_id_random(record_id=new_obj_id, table_name=child_obj_type) @@ -906,6 +795,9 @@ async def patch_child_obj( else: return mk_resp(data=False, status_code=404, response=response, status_message="Child not found.") + # Sanitize payload (remove virtual fields and view-only fields) + sanitize_payload(obj_data, output_model) + if sql_update(data=obj_data, table_name=table_name_update, record_id=resolved_child_id): if return_obj: if updated_child := sql_select(table_name=table_name_select, record_id=resolved_child_id): diff --git a/documentation/REFACTOR_API_CRUD_V3.md b/documentation/REFACTOR_API_CRUD_V3.md new file mode 100644 index 0000000..8a2440d --- /dev/null +++ b/documentation/REFACTOR_API_CRUD_V3.md @@ -0,0 +1,35 @@ +# Refactoring Plan: API CRUD V3 + +**Goal:** Modularize `app/routers/api_crud_v3.py` to improve maintainability, readability, and reusability. The file currently mixes route definitions, security enforcement, data sanitization, and helper utilities. + +## Phase 1: Extract Helpers & Core Logic (Safest) - COMPLETED +**Objective:** Move pure functions and business logic out of the router file. + +1. **Create `app/lib_api_crud_v3.py`**: DONE +2. **Update `app/routers/api_crud_v3.py`**: DONE (All endpoints now use `sanitize_payload`). + +## Phase 2: Separate Child/Nested Routes - PLANNED + +1. **Create `app/routers/api_crud_v3_nested.py`**: + * Move `get_child_obj_li` + * Move `post_child_obj` + * Move `get_child_obj` + * Move `patch_child_obj` + * Move `delete_child_obj` + +2. **Update `app/main.py` (or router inclusion)**: + * Ensure the new router is included, OR include it within `api_crud_v3.py` if preferred to keep a single import point. + +## Phase 3: Schema Introspection +**Objective:** Isolate database introspection logic. + +1. **Create `app/lib_schema_v3.py` (or similar)**: + * Move the logic inside `get_obj_schema` (SQL `DESCRIBE` parsing, Pydantic introspection) to a helper function. + +## Execution Strategy +We will execute **Phase 1** first as it provides immediate value (removing code duplication for sanitization) with minimal risk to routing logic. + +### Testing +After each move: +1. Run `tests/test_v3_router_filtering.py` (requires update to import from new location if we test the lib directly). +2. Verify application startup.