Refactor V3 CRUD: Extract helper functions and unify sanitization logic.
- Created app/lib_api_crud_v3.py to house core security, filtering, and sanitization logic. - Implemented reusable sanitize_payload() to generically strip virtual lookup fields (*_id_random) and view-only fields (fields_to_exclude_from_db). - Updated app/routers/api_crud_v3.py to use the new library and consolidated sanitization across all Create/Update endpoints. - Documented Phase 1 completion in documentation/REFACTOR_API_CRUD_V3.md.
This commit is contained in:
131
app/lib_api_crud_v3.py
Normal file
131
app/lib_api_crud_v3.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from typing import Any, Dict, Optional
|
||||
import json
|
||||
import logging
|
||||
|
||||
from app.lib_general_v3 import AccountContext, StatusFilterParams
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
def check_account_access(sql_result: Any, account: AccountContext, obj_name: str = None) -> bool:
|
||||
"""
|
||||
Enforce Multi-Tenant Data Isolation.
|
||||
|
||||
Verifies that the requested record belongs to the authenticated user's account.
|
||||
Returns True if:
|
||||
- User is a Super User or System (Bypass).
|
||||
- The record's `account_id` matches the user's `account_id`.
|
||||
"""
|
||||
if account.super or account.auth_method == 'bypass':
|
||||
return True
|
||||
if not account.account_id:
|
||||
return False
|
||||
|
||||
res_account_id = None
|
||||
if isinstance(sql_result, dict):
|
||||
if obj_name == 'account':
|
||||
res_account_id = sql_result.get('id')
|
||||
else:
|
||||
res_account_id = sql_result.get('account_id')
|
||||
|
||||
if res_account_id is not None and res_account_id != account.account_id:
|
||||
return False
|
||||
return True
|
||||
|
||||
def apply_forced_account_filter(and_qry_dict: Optional[Dict], account: AccountContext, model: Any, obj_name: str) -> Dict:
|
||||
"""
|
||||
Secure Search Filtering.
|
||||
|
||||
Automatically appends an `account_id` filter to database queries to ensure
|
||||
users only retrieve records associated with their own account.
|
||||
"""
|
||||
forced = and_qry_dict or {}
|
||||
if account.super or account.auth_method == 'bypass':
|
||||
return forced
|
||||
|
||||
if obj_name == 'account':
|
||||
forced['id'] = account.account_id
|
||||
elif model and hasattr(model, '__fields__') and 'account_id' in model.__fields__:
|
||||
forced['account_id'] = account.account_id
|
||||
|
||||
return forced
|
||||
|
||||
def filter_order_by(order_by_li: Any, model: Any, table_name: str = None) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
Sanitize Sorting Parameters.
|
||||
|
||||
Prevents SQL injection and logic errors by validating that requested sort columns
|
||||
actually exist in the Pydantic model and/or the database table.
|
||||
"""
|
||||
if not order_by_li or not isinstance(order_by_li, dict) or not model:
|
||||
return order_by_li
|
||||
if not hasattr(model, '__fields__'):
|
||||
return order_by_li
|
||||
|
||||
model_fields = set(model.__fields__.keys())
|
||||
model_fields.update({f.alias for f in model.__fields__.values() if f.alias})
|
||||
filtered = {k: v for k, v in order_by_li.items() if k in model_fields}
|
||||
|
||||
if table_name and filtered:
|
||||
from app.db_sql import db
|
||||
from sqlalchemy import text
|
||||
final_filtered = {}
|
||||
for column in filtered:
|
||||
try:
|
||||
# Lightweight check to see if column exists in SQL
|
||||
db.execute(text(f"SELECT `{column}` FROM `{table_name}` LIMIT 0"))
|
||||
final_filtered[column] = filtered[column]
|
||||
except Exception:
|
||||
pass
|
||||
filtered = final_filtered
|
||||
return filtered
|
||||
|
||||
def get_supported_filters(model: Any, status_filter: StatusFilterParams) -> StatusFilterParams:
|
||||
"""
|
||||
Adaptive Status Filtering.
|
||||
|
||||
Adjusts the default filters (enabled/hidden) based on whether the target object
|
||||
actually supports those concepts (i.e., has those columns).
|
||||
"""
|
||||
if not model or not hasattr(model, "__fields__"):
|
||||
return status_filter
|
||||
# We create a new instance to avoid side effects on the dependency object
|
||||
from app.routers.dependencies_v3 import StatusFilterParams as SF
|
||||
adjusted = SF()
|
||||
adjusted.enabled = status_filter.enabled
|
||||
adjusted.hidden = status_filter.hidden
|
||||
|
||||
if 'enable' not in model.__fields__:
|
||||
adjusted.enabled = 'all'
|
||||
if 'hide' not in model.__fields__:
|
||||
adjusted.hidden = 'all'
|
||||
return adjusted
|
||||
|
||||
def safe_json_loads(json_str: Optional[str]) -> Any:
|
||||
if not json_str or json_str == 'undefined': return None
|
||||
try: return json.loads(json_str)
|
||||
except: return None
|
||||
|
||||
def sanitize_payload(data: dict, model: Any) -> None:
|
||||
"""
|
||||
Sanitizes an input payload before database insertion or update.
|
||||
|
||||
1. Removes virtual lookup fields (ending in `_id_random`) that are used for API
|
||||
convenience but do not exist in the database.
|
||||
2. Removes fields explicitly marked for exclusion in the model's
|
||||
`fields_to_exclude_from_db` ClassVar (e.g., view-only fields).
|
||||
|
||||
Modifies the `data` dictionary in-place.
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
return
|
||||
|
||||
# Filter out virtual _id_random fields (e.g., account_id_random)
|
||||
keys_to_remove = [k for k in data.keys() if k.endswith('_id_random') and k != 'id_random']
|
||||
for k in keys_to_remove:
|
||||
del data[k]
|
||||
|
||||
# Filter out model-specific excluded fields (e.g., view-only fields)
|
||||
if hasattr(model, 'fields_to_exclude_from_db'):
|
||||
for k in model.fields_to_exclude_from_db:
|
||||
if k in data:
|
||||
del data[k]
|
||||
@@ -13,6 +13,10 @@ from app.lib_general_v3 import (
|
||||
PaginationParams, StatusFilterParams,
|
||||
SerializationParams, DelayParams
|
||||
)
|
||||
from app.lib_api_crud_v3 import (
|
||||
check_account_access, apply_forced_account_filter, filter_order_by,
|
||||
get_supported_filters, safe_json_loads, sanitize_payload
|
||||
)
|
||||
from app.models.response_models import *
|
||||
from app.models.api_crud_models import SearchFilter, SearchQuery
|
||||
from app.ae_obj_types_def import obj_type_kv_li
|
||||
@@ -33,108 +37,6 @@ Key Features:
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
def check_account_access(sql_result: Any, account: AccountContext, obj_name: str = None) -> bool:
|
||||
"""
|
||||
Enforce Multi-Tenant Data Isolation.
|
||||
|
||||
Verifies that the requested record belongs to the authenticated user's account.
|
||||
Returns True if:
|
||||
- User is a Super User or System (Bypass).
|
||||
- The record's `account_id` matches the user's `account_id`.
|
||||
"""
|
||||
if account.super or account.auth_method == 'bypass':
|
||||
return True
|
||||
if not account.account_id:
|
||||
return False
|
||||
|
||||
res_account_id = None
|
||||
if isinstance(sql_result, dict):
|
||||
if obj_name == 'account':
|
||||
res_account_id = sql_result.get('id')
|
||||
else:
|
||||
res_account_id = sql_result.get('account_id')
|
||||
|
||||
if res_account_id is not None and res_account_id != account.account_id:
|
||||
return False
|
||||
return True
|
||||
|
||||
def apply_forced_account_filter(and_qry_dict: Optional[Dict], account: AccountContext, model: Any, obj_name: str) -> Dict:
|
||||
"""
|
||||
Secure Search Filtering.
|
||||
|
||||
Automatically appends an `account_id` filter to database queries to ensure
|
||||
users only retrieve records associated with their own account.
|
||||
"""
|
||||
forced = and_qry_dict or {}
|
||||
if account.super or account.auth_method == 'bypass':
|
||||
return forced
|
||||
|
||||
if obj_name == 'account':
|
||||
forced['id'] = account.account_id
|
||||
elif model and hasattr(model, '__fields__') and 'account_id' in model.__fields__:
|
||||
forced['account_id'] = account.account_id
|
||||
|
||||
return forced
|
||||
|
||||
def filter_order_by(order_by_li: Any, model: Any, table_name: str = None) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
Sanitize Sorting Parameters.
|
||||
|
||||
Prevents SQL injection and logic errors by validating that requested sort columns
|
||||
actually exist in the Pydantic model and/or the database table.
|
||||
"""
|
||||
if not order_by_li or not isinstance(order_by_li, dict) or not model:
|
||||
return order_by_li
|
||||
if not hasattr(model, '__fields__'):
|
||||
return order_by_li
|
||||
|
||||
model_fields = set(model.__fields__.keys())
|
||||
model_fields.update({f.alias for f in model.__fields__.values() if f.alias})
|
||||
filtered = {k: v for k, v in order_by_li.items() if k in model_fields}
|
||||
|
||||
if table_name and filtered:
|
||||
from app.db_sql import db
|
||||
from sqlalchemy import text
|
||||
final_filtered = {}
|
||||
for column in filtered:
|
||||
try:
|
||||
# Lightweight check to see if column exists in SQL
|
||||
db.execute(text(f"SELECT `{column}` FROM `{table_name}` LIMIT 0"))
|
||||
final_filtered[column] = filtered[column]
|
||||
except Exception:
|
||||
pass
|
||||
filtered = final_filtered
|
||||
return filtered
|
||||
|
||||
def get_supported_filters(model: Any, status_filter: StatusFilterParams) -> StatusFilterParams:
|
||||
"""
|
||||
Adaptive Status Filtering.
|
||||
|
||||
Adjusts the default filters (enabled/hidden) based on whether the target object
|
||||
actually supports those concepts (i.e., has those columns).
|
||||
"""
|
||||
if not model or not hasattr(model, "__fields__"):
|
||||
return status_filter
|
||||
# We create a new instance to avoid side effects on the dependency object
|
||||
from app.routers.dependencies_v3 import StatusFilterParams as SF
|
||||
adjusted = SF()
|
||||
adjusted.enabled = status_filter.enabled
|
||||
adjusted.hidden = status_filter.hidden
|
||||
|
||||
if 'enable' not in model.__fields__:
|
||||
adjusted.enabled = 'all'
|
||||
if 'hide' not in model.__fields__:
|
||||
adjusted.hidden = 'all'
|
||||
return adjusted
|
||||
|
||||
def safe_json_loads(json_str: Optional[str]) -> Any:
|
||||
if not json_str or json_str == 'undefined': return None
|
||||
try: return json.loads(json_str)
|
||||
except: return None
|
||||
|
||||
|
||||
# --- Routes ---
|
||||
|
||||
@router.get("/health", response_model=Resp_Body_Base)
|
||||
@@ -515,16 +417,8 @@ async def post_obj(
|
||||
|
||||
data_to_insert = validated_obj.dict(exclude_unset=True)
|
||||
|
||||
# Filter out virtual _id_random fields (e.g., account_id_random) that are not in the DB table
|
||||
keys_to_remove = [k for k in data_to_insert.keys() if k.endswith('_id_random') and k != 'id_random']
|
||||
for k in keys_to_remove:
|
||||
del data_to_insert[k]
|
||||
|
||||
# Filter out model-specific excluded fields (e.g., view-only fields like person_full_name in Journal)
|
||||
if hasattr(input_model, 'fields_to_exclude_from_db'):
|
||||
for k in input_model.fields_to_exclude_from_db:
|
||||
if k in data_to_insert:
|
||||
del data_to_insert[k]
|
||||
# Sanitize payload (remove virtual fields and view-only fields)
|
||||
sanitize_payload(data_to_insert, input_model)
|
||||
|
||||
if sql_insert_result := sql_insert(data=data_to_insert, table_name=table_name_insert):
|
||||
new_obj_id = sql_insert_result
|
||||
@@ -585,16 +479,8 @@ async def patch_obj(
|
||||
else:
|
||||
return mk_resp(data=False, status_code=404, response=response, status_message=f"Object with ID '{obj_id}' not found in database.")
|
||||
|
||||
# Filter out virtual _id_random fields (e.g., account_id_random) that are not in the DB table
|
||||
keys_to_remove = [k for k in obj_data.keys() if k.endswith('_id_random') and k != 'id_random']
|
||||
for k in keys_to_remove:
|
||||
del obj_data[k]
|
||||
|
||||
# Filter out model-specific excluded fields (e.g., view-only fields like person_full_name in Journal)
|
||||
if hasattr(input_model, 'fields_to_exclude_from_db'):
|
||||
for k in input_model.fields_to_exclude_from_db:
|
||||
if k in obj_data:
|
||||
del obj_data[k]
|
||||
# Sanitize payload (remove virtual fields and view-only fields)
|
||||
sanitize_payload(obj_data, input_model)
|
||||
|
||||
if sql_update(data=obj_data, table_name=table_name_update, record_id=record_id):
|
||||
if return_obj:
|
||||
@@ -814,6 +700,9 @@ async def post_child_obj(
|
||||
|
||||
data_to_insert = validated_obj.dict(exclude_unset=True)
|
||||
|
||||
# Sanitize payload (remove virtual fields and view-only fields)
|
||||
sanitize_payload(data_to_insert, input_model)
|
||||
|
||||
if sql_insert_result := sql_insert(data=data_to_insert, table_name=table_name_insert):
|
||||
new_obj_id = sql_insert_result
|
||||
new_obj_id_random = get_id_random(record_id=new_obj_id, table_name=child_obj_type)
|
||||
@@ -906,6 +795,9 @@ async def patch_child_obj(
|
||||
else:
|
||||
return mk_resp(data=False, status_code=404, response=response, status_message="Child not found.")
|
||||
|
||||
# Sanitize payload (remove virtual fields and view-only fields)
|
||||
sanitize_payload(obj_data, output_model)
|
||||
|
||||
if sql_update(data=obj_data, table_name=table_name_update, record_id=resolved_child_id):
|
||||
if return_obj:
|
||||
if updated_child := sql_select(table_name=table_name_select, record_id=resolved_child_id):
|
||||
|
||||
35
documentation/REFACTOR_API_CRUD_V3.md
Normal file
35
documentation/REFACTOR_API_CRUD_V3.md
Normal file
@@ -0,0 +1,35 @@
|
||||
# Refactoring Plan: API CRUD V3
|
||||
|
||||
**Goal:** Modularize `app/routers/api_crud_v3.py` to improve maintainability, readability, and reusability. The file currently mixes route definitions, security enforcement, data sanitization, and helper utilities.
|
||||
|
||||
## Phase 1: Extract Helpers & Core Logic (Safest) - COMPLETED
|
||||
**Objective:** Move pure functions and business logic out of the router file.
|
||||
|
||||
1. **Create `app/lib_api_crud_v3.py`**: DONE
|
||||
2. **Update `app/routers/api_crud_v3.py`**: DONE (All endpoints now use `sanitize_payload`).
|
||||
|
||||
## Phase 2: Separate Child/Nested Routes - PLANNED
|
||||
|
||||
1. **Create `app/routers/api_crud_v3_nested.py`**:
|
||||
* Move `get_child_obj_li`
|
||||
* Move `post_child_obj`
|
||||
* Move `get_child_obj`
|
||||
* Move `patch_child_obj`
|
||||
* Move `delete_child_obj`
|
||||
|
||||
2. **Update `app/main.py` (or router inclusion)**:
|
||||
* Ensure the new router is included, OR include it within `api_crud_v3.py` if preferred to keep a single import point.
|
||||
|
||||
## Phase 3: Schema Introspection
|
||||
**Objective:** Isolate database introspection logic.
|
||||
|
||||
1. **Create `app/lib_schema_v3.py` (or similar)**:
|
||||
* Move the logic inside `get_obj_schema` (SQL `DESCRIBE` parsing, Pydantic introspection) to a helper function.
|
||||
|
||||
## Execution Strategy
|
||||
We will execute **Phase 1** first as it provides immediate value (removing code duplication for sanitization) with minimal risk to routing logic.
|
||||
|
||||
### Testing
|
||||
After each move:
|
||||
1. Run `tests/test_v3_router_filtering.py` (requires update to import from new location if we test the lib directly).
|
||||
2. Verify application startup.
|
||||
Reference in New Issue
Block a user