refactor(sql): complete modularization of search builders and ID resolution

This commit is contained in:
Scott Idem
2026-01-06 17:58:34 -05:00
parent 56fe7ed953
commit 868a0060dc
3 changed files with 186 additions and 146 deletions

View File

@@ -20,7 +20,8 @@ from app.lib_sql_search import (
sql_fulltext_qry_part as _sql_fulltext_qry_part, sql_fulltext_qry_part as _sql_fulltext_qry_part,
sql_enable_part as _sql_enable_part, sql_enable_part as _sql_enable_part,
sql_hidden_part as _sql_hidden_part, sql_hidden_part as _sql_hidden_part,
sql_where_qry_part as _sql_where_qry_part sql_where_qry_part as _sql_where_qry_part,
sql_search_qry_part as _sql_search_qry_part
) )
from app.lib_redis_helpers import ( from app.lib_redis_helpers import (
@@ -1768,156 +1769,32 @@ def sql_limit_offset_part(limit: int, offset: int = 0) -> bool|str:
# ### END ### API DB SQL Methods ### sql_limit_offset_part() ### # ### END ### API DB SQL Methods ### sql_limit_offset_part() ###
# ### BEGIN ### API DB SQL Methods ### sql_search_qry_part() ###
# NEW 2026-01-02
# Updated to support complex POST-based searches with recursive logical grouping.
# Updated 2026-01-06 to handle missing default_qry_str column gracefully.
@logger_reset @logger_reset
def sql_search_qry_part( def sql_search_qry_part(
search_query: Any, # SearchQuery model instance search_query: Any, # SearchQuery model instance
searchable_fields: List[str]|None = None, # List of allowed fields searchable_fields: List[str]|None = None, # List of allowed fields
max_depth: int = 5, # Maximum recursion depth max_depth: int = 5, # Maximum recursion depth
table_name: str|None = None, # Target table for schema validation table_name: str|None = None, # Target table for schema validation
) -> tuple[str, dict]: ) -> tuple[str, dict]:
"""
Recursively builds a SQL WHERE clause from a SearchQuery model.
Uses unique parameter names to prevent collisions.
Enforces security via field allowlist and recursion depth limits.
"""
log.setLevel(logging.INFO) log.setLevel(logging.INFO)
log.debug(locals()) log.debug(locals())
data = {}
param_counter = [0]
def get_param_name(): return _sql_search_qry_part(search_query, searchable_fields, max_depth, table_name)
param_counter[0] += 1
return f"sp_{param_counter[0]}"
operator_map = {
"eq": "=",
"ne": "!=",
"gt": ">",
"gte": ">=",
"lt": "<",
"lte": "<=",
"like": "LIKE",
"in": "IN",
"is_null": "IS NULL",
"is_not_null": "IS NOT NULL",
"contains": "LIKE",
"icontains": "LIKE",
"startswith": "LIKE",
"istartswith": "LIKE",
"endswith": "LIKE",
"iendswith": "LIKE"
}
def process_node(query_node, current_depth: int) -> str:
if current_depth > max_depth:
raise HTTPException(status_code=400, detail=f"Search query too complex (max depth {max_depth} reached).")
clauses = []
# Process 'query_string' (Standardized Full-Text Search)
if hasattr(query_node, 'query_string') and query_node.query_string:
if query_node.query_string == '%':
# Wildcard: Skip filtering for this part
pass
else:
# Check if default_qry_str exists in this table/view
use_match = True
if table_name:
try:
db.execute(text(f"SELECT default_qry_str FROM `{table_name}` LIMIT 0"))
except:
use_match = False
else:
use_match = False # Safe default if no table_name
if use_match:
p_name = get_param_name()
clauses.append(f"MATCH( default_qry_str ) AGAINST( :{p_name} IN BOOLEAN MODE )")
data[p_name] = query_node.query_string
elif searchable_fields:
# Fallback: OR LIKE across all searchable fields
like_clauses = []
for field in searchable_fields:
# Skip internal/numeric fields for full-text-like search
if not any(x in field for x in ['_id', 'enable', 'hide', 'priority', 'sort', 'group', 'created_on', 'updated_on']):
f_p_name = get_param_name()
like_clauses.append(f"`{field}` LIKE :{f_p_name}")
data[f_p_name] = f"%{query_node.query_string}%"
if like_clauses:
clauses.append(f"({' OR '.join(like_clauses)})")
# Process 'and' filters
if hasattr(query_node, 'and_filters') and query_node.and_filters:
and_clauses = []
for item in query_node.and_filters:
if hasattr(item, 'field'): # SearchFilter
clause, item_data = process_filter(item)
and_clauses.append(clause)
data.update(item_data)
else: # Nested SearchQuery
and_clauses.append(f"({process_node(item, current_depth + 1)})")
if and_clauses:
clauses.append(f"({' AND '.join(and_clauses)})")
# Process 'or' filters
if hasattr(query_node, 'or_filters') and query_node.or_filters:
or_clauses = []
for item in query_node.or_filters:
if hasattr(item, 'field'): # SearchFilter
clause, item_data = process_filter(item)
or_clauses.append(clause)
data.update(item_data)
else: # Nested SearchQuery
or_clauses.append(f"({process_node(item, current_depth + 1)})")
if or_clauses:
clauses.append(f"({' OR '.join(or_clauses)})")
return ' AND '.join(clauses)
def process_filter(f) -> tuple[str, dict]:
# Field Validation: Check against allowlist
if searchable_fields is not None and f.field not in searchable_fields:
raise HTTPException(status_code=400, detail=f"Searching on field '{f.field}' is not permitted.")
op_lower = f.op.lower()
sql_op = operator_map.get(op_lower)
if not sql_op:
raise HTTPException(status_code=400, detail=f"Unsupported search operator: {f.op}")
filter_data = {}
if op_lower in ['is_null', 'is_not_null']:
clause = f"`{f.field}` {sql_op}"
elif op_lower == 'in':
p_name = get_param_name()
clause = f"`{f.field}` IN (:{p_name})"
filter_data[p_name] = f.value
elif op_lower in ['contains', 'icontains']:
p_name = get_param_name()
clause = f"`{f.field}` LIKE :{p_name}"
filter_data[p_name] = f"%{f.value}%"
elif op_lower in ['startswith', 'istartswith']:
p_name = get_param_name()
clause = f"`{f.field}` LIKE :{p_name}"
filter_data[p_name] = f"{f.value}%"
elif op_lower in ['endswith', 'iendswith']:
p_name = get_param_name()
clause = f"`{f.field}` LIKE :{p_name}"
filter_data[p_name] = f"%{f.value}"
else:
p_name = get_param_name()
clause = f"`{f.field}` {sql_op} :{p_name}"
filter_data[p_name] = f.value
return clause, filter_data
# Initial processing
sql_where = process_node(search_query, 1)
if sql_where:
return f"AND ({sql_where})", data
return "", {}
# ### END ### API DB SQL Methods ### sql_search_qry_part() ###

View File

@@ -87,9 +87,9 @@ def redis_lookup_id_random(
log.debug(f'SQL: SELECT result: {select_results}') log.debug(f'SQL: SELECT result: {select_results}')
if isinstance(select_results, dict): if isinstance(select_results, dict):
log.info(f"""SQL: Found ID Random for: {str(record_id_random)} = {str(select_results.get('id'))}""") log.info(f"""SQL: Found ID Random for: {str(record_id_random)} = {str(select_results.get('id'))}""")
if rid := select_results.get('id'): if record_id := select_results.get('id'):
r.setex(key_name, datetime.timedelta(minutes=minutes), value=rid) r.setex(key_name, datetime.timedelta(minutes=minutes), value=record_id)
return int(rid) return int(record_id)
else: else:
log.error('The SQL result was not what was expected. The ID field was not found.') log.error('The SQL result was not what was expected. The ID field was not found.')
return False return False
@@ -141,4 +141,87 @@ def reset_redis():
"""Flushes the Redis database used for ID caching.""" """Flushes the Redis database used for ID caching."""
r = redis.Redis(host=settings.REDIS['server'], port=settings.REDIS['port'], db=7, password=None, decode_responses=True) r = redis.Redis(host=settings.REDIS['server'], port=settings.REDIS['port'], db=7, password=None, decode_responses=True)
r.flushdb() r.flushdb()
return True return True
def lookup_id_random_pop(
obj_data: dict,
log_lvl: int = logging.WARNING, # DEBUG, INFO, WARNING, ERROR, EXCEPTION, CRITICAL
):
"""
Look up and resolve id_random values to their id
Remove the unneeded *_id_random key from the dict
"""
log.setLevel(log_lvl)
# Common prefixes for ID resolution
id_prefixes = [
'account', 'activity_log', 'address', 'address_location', 'archive',
'contact', 'contact_1', 'contact_2', 'cont_edu_cert', 'cont_edu_cert_person',
'event', 'event_id_random_only', 'event_abstract', 'event_badge',
'event_badge_template', 'event_exhibit', 'event_file', 'event_location',
'event_person', 'event_person_profile', 'event_presentation',
'event_presenter', 'event_registration', 'event_session', 'event_track',
'grant', 'hosted_file', 'journal', 'journal_entry', 'membership_group',
'membership_person_group', 'membership_person', 'membership_type',
'membership_person_type', 'order', 'order_line', 'order_cart',
'order_cart_line', 'organization', 'page', 'person', 'poc_event_person',
'poc_person', 'post', 'product', 'sponsorship', 'sponsorship_cfg',
'site', 'user'
]
for prefix in id_prefixes:
key = f'{prefix}_id_random'
if key in obj_data:
# Table name mapping
table = prefix
if prefix == 'address_location': table = 'address'
elif prefix in ['contact_1', 'contact_2']: table = 'contact'
elif prefix == 'event_id_random_only': table = 'event'
elif prefix == 'poc_event_person': table = 'event_person'
elif prefix == 'poc_person': table = 'person'
resolved_id = redis_lookup_id_random(record_id_random=obj_data[key], table_name=table)
obj_data[f'{prefix if not prefix.endswith("_id_random_only") else prefix[:-15]+"_id_only"}'] = resolved_id
# Special case for event_id_random_only
target_id_key = f'{prefix[:-7]}' if prefix.endswith('_random') else f'{prefix}_id'
if prefix == 'event_id_random_only': target_id_key = 'event_id_only'
obj_data[target_id_key] = resolved_id
obj_data.pop(key)
# Polymorphic links
polymorphic = [
('for_type', 'for_id_random', 'for_id'),
('link_to_type', 'link_to_id_random', 'link_to_id'),
('object_type', 'object_id_random', 'object_id'),
('to_object_type', 'to_object_id_random', 'to_object_id'),
('from_object_type', 'from_object_id_random', 'from_object_id')
]
for type_key, rand_key, id_key in polymorphic:
if type_key in obj_data and rand_key in obj_data:
obj_data[id_key] = redis_lookup_id_random(
record_id_random=obj_data.get(rand_key),
table_name=obj_data.get(type_key)
)
obj_data.pop(rand_key)
return obj_data
def get_account_id_w_for_type_id(
for_type: str, # This is the table name
for_id: int|str,
) -> bool|int|None:
"""Helper to find an account_id associated with an object."""
from app.db_sql import sql_select
log.setLevel(logging.WARNING)
if fid := redis_lookup_id_random(record_id_random=for_id, table_name=for_type):
data = {'for_id': fid}
sql = f"SELECT account_id FROM `{for_type}` WHERE id = :for_id LIMIT 1;"
if result := sql_select(data=data, sql=sql):
return result.get('account_id')
return False
return None

View File

@@ -2,6 +2,8 @@
Modular search builder and query generators for Aether. Modular search builder and query generators for Aether.
""" """
import logging import logging
from typing import Any, List, Optional
from fastapi import HTTPException
from sqlalchemy import text from sqlalchemy import text
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -123,3 +125,81 @@ def sql_where_qry_part(qry_dict_li: list) -> tuple[str, dict]|bool:
data[field] = val data[field] = val
return ' '.join(clauses), data return ' '.join(clauses), data
return False return False
def sql_search_qry_part(
search_query: Any,
searchable_fields: List[str]|None = None,
max_depth: int = 5,
table_name: str|None = None,
) -> tuple[str, dict]:
"""Recursively builds a SQL WHERE clause from a SearchQuery model."""
from app.db_sql import db
data = {}
param_counter = [0]
def get_param_name():
param_counter[0] += 1
return f"sp_{param_counter[0]}"
operator_map = {
"eq": "=", "ne": "!=", "gt": ">", "gte": ">=", "lt": "<", "lte": "<=",
"like": "LIKE", "in": "IN", "is_null": "IS NULL", "is_not_null": "IS NOT NULL",
"contains": "LIKE", "icontains": "LIKE", "startswith": "LIKE", "istartswith": "LIKE",
"endswith": "LIKE", "iendswith": "LIKE"
}
def process_node(query_node, current_depth: int) -> str:
if current_depth > max_depth:
raise HTTPException(status_code=400, detail=f"Search query too complex.")
clauses = []
if hasattr(query_node, 'query_string') and query_node.query_string:
if query_node.query_string == '%': pass
else:
use_match = True
if table_name:
try: db.execute(text(f"SELECT default_qry_str FROM `{table_name}` LIMIT 0"))
except: use_match = False
else: use_match = False
if use_match:
p_name = get_param_name()
clauses.append(f"MATCH( default_qry_str ) AGAINST( :{p_name} IN BOOLEAN MODE )")
data[p_name] = query_node.query_string
elif searchable_fields:
like_clauses = []
for field in searchable_fields:
if not any(x in field for x in ['_id', 'enable', 'hide', 'priority', 'sort', 'group', 'created_on', 'updated_on']):
f_p_name = get_param_name()
like_clauses.append(f"`{field}` LIKE :{f_p_name}")
data[f_p_name] = f"%{query_node.query_string}%"
if like_clauses: clauses.append(f"({' OR '.join(like_clauses)})")
for filter_attr in ['and_filters', 'or_filters']:
if hasattr(query_node, filter_attr) and getattr(query_node, filter_attr):
node_clauses = []
for item in getattr(query_node, filter_attr):
if hasattr(item, 'field'):
clause, item_data = process_filter(item)
node_clauses.append(clause); data.update(item_data)
else: node_clauses.append(f"({process_node(item, current_depth + 1)})")
if node_clauses:
joiner = ' AND ' if 'and' in filter_attr else ' OR '
clauses.append(f"({joiner.join(node_clauses)})")
return ' AND '.join(clauses)
def process_filter(f) -> tuple[str, dict]:
if searchable_fields is not None and f.field not in searchable_fields:
raise HTTPException(status_code=400, detail=f"Unauthorized search field '{f.field}'")
sql_op = operator_map.get(f.op.lower())
if not sql_op: raise HTTPException(status_code=400, detail=f"Unsupported operator: {f.op}")
filter_data = {}
if f.op.lower() in ['is_null', 'is_not_null']: clause = f"`{f.field}` {sql_op}"
else:
p_name = get_param_name()
if f.op.lower() == 'in': clause = f"`{f.field}` IN (:{p_name})"; filter_data[p_name] = f.value
elif f.op.lower() in ['contains', 'icontains']: clause = f"`{f.field}` LIKE :{p_name}"; filter_data[p_name] = f"%{f.value}%"
elif f.op.lower() in ['startswith', 'istartswith']: clause = f"`{f.field}` LIKE :{p_name}"; filter_data[p_name] = f"{f.value}%"
elif f.op.lower() in ['endswith', 'iendswith']: clause = f"`{f.field}` LIKE :{p_name}"; filter_data[p_name] = f"%{f.value}"
else: clause = f"`{f.field}` {sql_op} :{p_name}"; filter_data[p_name] = f.value
return clause, filter_data
sql_where = process_node(search_query, 1)
return (f"AND ({sql_where})", data) if sql_where else ("", {})