diff --git a/app/lib_api_crud_v3.py b/app/lib_api_crud_v3.py index 1fb1647..3e321a9 100644 --- a/app/lib_api_crud_v3.py +++ b/app/lib_api_crud_v3.py @@ -122,20 +122,36 @@ def safe_json_loads(json_str: Optional[str]) -> Any: try: return json.loads(json_str) except: return None -def sanitize_payload(data: dict, model: Any) -> None: +def sanitize_payload(data: dict, model: Any, ignore_extra: bool = False) -> None: """ Sanitizes an input payload before database insertion or update. - 1. Removes virtual lookup fields (ending in `_id_random`) that are used for API + 1. Resolves virtual lookup fields (`*_id_random`) into their integer database IDs. + 2. Removes virtual lookup fields (ending in `_id_random`) that are used for API convenience but do not exist in the database. - 2. Removes fields explicitly marked for exclusion in the model's + 3. Removes fields explicitly marked for exclusion in the model's `fields_to_exclude_from_db` ClassVar (e.g., view-only fields). + 4. If `ignore_extra` is True, removes all fields NOT present in the model definition. Modifies the `data` dictionary in-place. """ if not isinstance(data, dict): return + from app.db_sql import redis_lookup_id_random + + # Resolve virtual _id_random fields to integer IDs (e.g., account_id_random -> account_id) + # This must happen BEFORE we delete them. + for k, v in list(data.items()): + if k.endswith('_id_random') and k != 'id_random' and v: + target_id_field = k.replace('_id_random', '_id') + # Only resolve if the integer version is missing or null + if not data.get(target_id_field): + obj_type_lookup = k.replace('_id_random', '') + resolved_id = redis_lookup_id_random(record_id_random=v, table_name=obj_type_lookup) + if resolved_id: + data[target_id_field] = resolved_id + # Filter out virtual _id_random fields (e.g., account_id_random) keys_to_remove = [k for k in data.keys() if k.endswith('_id_random') and k != 'id_random'] for k in keys_to_remove: @@ -146,3 +162,15 @@ def sanitize_payload(data: dict, model: Any) -> None: for k in model.fields_to_exclude_from_db: if k in data: del data[k] + + # If permissive mode is on, remove any field not in the Pydantic model + if ignore_extra and model and hasattr(model, '__fields__'): + model_fields = set(model.__fields__.keys()) + # Also check for aliases + for f in model.__fields__.values(): + if f.alias: + model_fields.add(f.alias) + + extra_keys = [k for k in data.keys() if k not in model_fields] + for k in extra_keys: + del data[k] diff --git a/app/models/response_models.py b/app/models/response_models.py index 58979af..e2cb43e 100644 --- a/app/models/response_models.py +++ b/app/models/response_models.py @@ -41,7 +41,7 @@ def mk_resp( status_message: str = '', status_name: str = '', success: bool = True, - details: str = '', + details: Union[None, str, dict, list] = '', include: dict = None, exclude: dict = None, by_alias: bool = True, diff --git a/app/routers/api_crud_v3.py b/app/routers/api_crud_v3.py index 14a81e4..441f396 100644 --- a/app/routers/api_crud_v3.py +++ b/app/routers/api_crud_v3.py @@ -370,6 +370,7 @@ async def post_obj( response: Response, obj_type_l1: str = Path(min_length=2, max_length=50), return_obj: Optional[bool] = True, + x_ae_ignore_extra_fields: Optional[bool] = Header(False), account: AccountContext = Depends(get_account_context), serialization: SerializationParams = Depends(), delay: DelayParams = Depends(), @@ -377,10 +378,10 @@ async def post_obj( """ Create Object. - 1. Validates input against Pydantic model (`mdl_in`). - 2. Injects `account_id` for ownership. - 3. **Sanitizes Payload**: Removes virtual lookup fields (`*_id_random`) and view-only fields (`fields_to_exclude_from_db`) - to prevent "unknown column" errors during insertion. + 1. Injects `account_id` for ownership. + 2. **Sanitizes Payload**: Resolves `*_id_random` -> `*_id`, removes virtual fields, and view-only fields. + - If `x-ae-ignore-extra-fields: true` header is provided, unknown fields are stripped. + 3. Validates input against Pydantic model (`mdl_in`). 4. Returns the created object or just its ID. """ from app.db_sql import sql_insert, get_id_random, sql_select @@ -407,16 +408,20 @@ async def post_obj( elif obj_name == 'account': return mk_resp(data=False, status_code=403, response=response, status_message="Account creation is restricted.") + # Sanitize payload (ID resolution, virtual fields, and optionally extra fields) + sanitize_payload(obj_data, input_model, ignore_extra=x_ae_ignore_extra_fields) + try: validated_obj = input_model(**obj_data) + except ValidationError as e: + # Return structured errors (field -> error message) for UI feedback + structured_errors = {err['loc'][-1]: err['msg'] for err in e.errors()} + return mk_resp(data=False, status_code=400, response=response, status_message="Validation Failed", details=structured_errors) except Exception as e: return mk_resp(data=False, status_code=400, response=response, status_message="Validation Failed", details=str(e)) data_to_insert = validated_obj.dict(exclude_unset=True) - # Sanitize payload (remove virtual fields and view-only fields) - sanitize_payload(data_to_insert, input_model) - if sql_insert_result := sql_insert(data=data_to_insert, table_name=table_name_insert): new_obj_id = sql_insert_result new_obj_id_random = get_id_random(record_id=new_obj_id, table_name=obj_name) @@ -438,6 +443,7 @@ async def patch_obj( obj_type_l1: str = Path(min_length=2, max_length=50), obj_id: str = Path(min_length=11, max_length=22), return_obj: Optional[bool] = True, + x_ae_ignore_extra_fields: Optional[bool] = Header(False), account: AccountContext = Depends(get_account_context), serialization: SerializationParams = Depends(), delay: DelayParams = Depends(), @@ -446,7 +452,8 @@ async def patch_obj( Update Object (Partial). 1. Resolves ID and checks access permissions. - 2. **Sanitizes Payload**: Removes virtual lookup fields and view-only fields. + 2. **Sanitizes Payload**: Resolves `*_id_random` -> `*_id`, removes virtual fields, and view-only fields. + - If `x-ae-ignore-extra-fields: true` header is provided, unknown fields are stripped. 3. Performs SQL UPDATE. """ from app.db_sql import redis_lookup_id_random, sql_select, sql_update @@ -477,8 +484,8 @@ async def patch_obj( else: return mk_resp(data=False, status_code=404, response=response, status_message=f"Object with ID '{obj_id}' not found in database.") - # Sanitize payload (remove virtual fields and view-only fields) - sanitize_payload(obj_data, input_model) + # Sanitize payload (ID resolution, virtual fields, and optionally extra fields) + sanitize_payload(obj_data, input_model, ignore_extra=x_ae_ignore_extra_fields) if sql_update(data=obj_data, table_name=table_name_update, record_id=record_id): if return_obj: diff --git a/tests/verify_feedback_fixes.py b/tests/verify_feedback_fixes.py new file mode 100644 index 0000000..ee991fd --- /dev/null +++ b/tests/verify_feedback_fixes.py @@ -0,0 +1,104 @@ +import sys +import os +from unittest.mock import MagicMock + +# --- Environment Setup --- +sys.modules['redis'] = MagicMock() +sys.modules['sqlalchemy'] = MagicMock() +sys.modules['app.config'] = MagicMock() +sys.modules['html2text'] = MagicMock() +sys.modules['app.log'] = MagicMock() +sys.modules['app.lib_general'] = MagicMock() + +# Mock app.db_sql +mock_db_sql = MagicMock() +# Mock ID resolution: abc -> 123 +mock_db_sql.redis_lookup_id_random.side_effect = lambda record_id_random, table_name: 123 if record_id_random == 'abc' else None +sys.modules['app.db_sql'] = mock_db_sql + +# Add project root to path +sys.path.append(os.getcwd()) + +from app.lib_api_crud_v3 import sanitize_payload +from pydantic import BaseModel, Field, ValidationError +from typing import Optional, List, ClassVar + +class MockModel(BaseModel): + id: Optional[int] + name: str = Field(None, min_length=3) + account_id: Optional[int] + + fields_to_exclude_from_db: ClassVar[List[str]] = ['computed_field'] + +def test_permissive_update(): + print("--- Testing Permissive Update (ignore_extra=True) ---") + payload = { + "name": "Test", + "extra_field": "Should be removed", + "computed_field": "Should be removed" + } + sanitize_payload(payload, MockModel, ignore_extra=True) + print(f"Sanitized Payload: {payload}") + + assert "extra_field" not in payload + assert "computed_field" not in payload + assert payload["name"] == "Test" + print("āœ… Permissive update stripping works.") + +def test_strict_update(): + print("\n--- Testing Strict Update (ignore_extra=False) ---") + payload = { + "name": "Test", + "extra_field": "Should be removed", + "computed_field": "Should be removed" + } + sanitize_payload(payload, MockModel, ignore_extra=False) + print(f"Sanitized Payload: {payload}") + + assert "extra_field" in payload + assert "computed_field" not in payload + print("āœ… Strict update correctly preserves unknown fields (waiting for DB error) but strips excluded fields.") + +def test_id_resolution(): + print("\n--- Testing ID Resolution ---") + payload = { + "name": "Test", + "account_id_random": "abc" + } + sanitize_payload(payload, MockModel) + print(f"Sanitized Payload: {payload}") + + assert payload.get("account_id") == 123 + assert "account_id_random" not in payload + print("āœ… ID resolution (account_id_random -> account_id) works.") + +def test_structured_validation_errors(): + print("\n--- Testing Structured Validation Errors ---") + payload = { + "name": "a" # Too short + } + try: + MockModel(**payload) + except ValidationError as e: + structured_errors = {err['loc'][-1]: err['msg'] for err in e.errors()} + print(f"Structured Errors: {structured_errors}") + assert "name" in structured_errors + # Pydantic 1.x error message + assert "at least 3 characters" in structured_errors["name"] + print("āœ… Structured validation errors work.") + +if __name__ == "__main__": + try: + test_permissive_update() + test_strict_update() + test_id_resolution() + test_structured_validation_errors() + print("\nšŸŽ‰ All local logic tests passed!") + except AssertionError as e: + print(f"\nāŒ Test failed: {e}") + sys.exit(1) + except Exception as e: + print(f"\nšŸ’„ An error occurred: {e}") + import traceback + traceback.print_exc() + sys.exit(1) \ No newline at end of file