Security: Implement recursion depth limits and field allowlists for Advanced Search; add reference SQL exports.

This commit is contained in:
Scott Idem
2026-01-02 19:38:37 -05:00
parent 5a4c82e4cb
commit 09ec231303
5 changed files with 211 additions and 9 deletions

View File

@@ -1,7 +1,8 @@
import datetime, json, pytz, random, redis, secrets
from typing import Any, Optional
from typing import Any, List, Optional
from timeit import default_timer as timer
from fastapi import HTTPException
from app.config import settings
from app.log import log, logging, logger_reset
@@ -616,6 +617,7 @@ def sql_select(
or_like_dict: dict|None = None,
and_in_dict_li: dict|None = None,
search_query: Any|None = None, # NEW 2026-01-02 (SearchQuery model)
searchable_fields: List[str]|None = None, # NEW 2026-01-03
fulltext_qry_field_li: list|None = None, # ['field_name_1', 'field_name_2']
fulltext_qry_str: str|None = None, # 'search string'
order_by_li: dict|None = None, # {"the_field_name": "DESC"}
@@ -721,7 +723,7 @@ def sql_select(
sql_search_qry = ''
if search_query:
log.info('Creating partial SQL string for complex SearchQuery.')
sql_search_qry, data_search = sql_search_qry_part(search_query)
sql_search_qry, data_search = sql_search_qry_part(search_query, searchable_fields=searchable_fields)
data = {**data, **data_search}
sql = text(
@@ -828,7 +830,7 @@ def sql_select(
sql_search_qry = ''
if search_query:
log.info('Creating partial SQL string for complex SearchQuery.')
sql_search_qry, data_search = sql_search_qry_part(search_query)
sql_search_qry, data_search = sql_search_qry_part(search_query, searchable_fields=searchable_fields)
data = {**data, **data_search}
# # NOTE: Version 3 of the fulltext search
@@ -2110,16 +2112,19 @@ def sql_limit_offset_part(limit: int, offset: int = 0) -> bool|str:
@logger_reset
def sql_search_qry_part(
search_query: Any, # SearchQuery model instance
searchable_fields: List[str]|None = None, # List of allowed fields
max_depth: int = 5, # Maximum recursion depth
) -> tuple[str, dict]:
"""
Recursively builds a SQL WHERE clause from a SearchQuery model.
Uses unique parameter names to prevent collisions.
Enforces security via field allowlist and recursion depth limits.
"""
log.setLevel(logging.INFO)
log.debug(locals())
data = {}
param_counter = [0] # List used as a reference for unique parameter names
param_counter = [0]
def get_param_name():
param_counter[0] += 1
@@ -2138,7 +2143,10 @@ def sql_search_qry_part(
"is_not_null": "IS NOT NULL"
}
def process_node(query_node) -> str:
def process_node(query_node, current_depth: int) -> str:
if current_depth > max_depth:
raise HTTPException(status_code=400, detail=f"Search query too complex (max depth {max_depth} reached).")
clauses = []
# Process 'query_string' (Standardized Full-Text Search)
@@ -2156,7 +2164,7 @@ def sql_search_qry_part(
and_clauses.append(clause)
data.update(item_data)
else: # Nested SearchQuery
and_clauses.append(f"({process_node(item)})")
and_clauses.append(f"({process_node(item, current_depth + 1)})")
if and_clauses:
clauses.append(f"({' AND '.join(and_clauses)})")
@@ -2169,13 +2177,17 @@ def sql_search_qry_part(
or_clauses.append(clause)
data.update(item_data)
else: # Nested SearchQuery
or_clauses.append(f"({process_node(item)})")
or_clauses.append(f"({process_node(item, current_depth + 1)})")
if or_clauses:
clauses.append(f"({' OR '.join(or_clauses)})")
return ' AND '.join(clauses)
def process_filter(f) -> tuple[str, dict]:
# Field Validation: Check against allowlist
if searchable_fields is not None and f.field not in searchable_fields:
raise HTTPException(status_code=400, detail=f"Searching on field '{f.field}' is not permitted.")
sql_op = operator_map.get(f.op.lower())
if not sql_op:
raise ValueError(f"Unsupported search operator: {f.op}")
@@ -2185,7 +2197,6 @@ def sql_search_qry_part(
clause = f"`{f.field}` {sql_op}"
elif f.op.lower() == 'in':
p_name = get_param_name()
# IN operator requires a tuple or list
clause = f"`{f.field}` IN (:{p_name})"
filter_data[p_name] = f.value
else:
@@ -2196,7 +2207,7 @@ def sql_search_qry_part(
return clause, filter_data
# Initial processing
sql_where = process_node(search_query)
sql_where = process_node(search_query, 1)
if sql_where:
return f"AND ({sql_where})", data
return "", {}