Work on websockets end points and management

This commit is contained in:
Scott Idem
2023-03-30 19:27:39 -04:00
parent 224aaed969
commit cff165d9d9
5 changed files with 160 additions and 77 deletions

View File

@@ -12,6 +12,7 @@ class Settings(BaseSettings):
AETHER_CFG['id'] = 0 AETHER_CFG['id'] = 0
# AETHER_CFG['api_id'] = 0 # NOT CURRENTLY NEED OR USED # AETHER_CFG['api_id'] = 0 # NOT CURRENTLY NEED OR USED
JWT_KEY = '' # 22 characters; super secret Aether JWT signing key
# APP_NAME: str = "Aether API (FastAPI)" # APP_NAME: str = "Aether API (FastAPI)"
# SUPER_EMAIL: EmailStr = 'Aether.Super@oneskyit.com' # SUPER_EMAIL: EmailStr = 'Aether.Super@oneskyit.com'

View File

@@ -123,6 +123,9 @@ class Common_Route_Params:
@logger_reset # This breaks things for some reason when the function is async. Do not use async def common_route_params()! @logger_reset # This breaks things for some reason when the function is async. Do not use async def common_route_params()!
def common_route_params( def common_route_params(
x_account_id: str = Header(..., min_length=11, max_length=22), x_account_id: str = Header(..., min_length=11, max_length=22),
# x_aether_api_key: Optional[str] = Header(..., min_length=11, max_length=22),
# x_aether_api_token: Optional[str] = Header(..., min_length=11, max_length=22),
# x_aether_jwt_token: Optional[str] = Header(..., min_length=11, max_length=50),
enabled: str = 'enabled', # all, enabled, disabled enabled: str = 'enabled', # all, enabled, disabled
limit: int = 100, limit: int = 100,
offset: int = 0, offset: int = 0,
@@ -246,12 +249,14 @@ def verify_secure_hash_string(string: str, string_hash: str):
@logger_reset @logger_reset
def sign_jwt( def sign_jwt(
secret_key: str, # Secret/Private/Password secret_key: str, # Secret/Private/Password
public_key: str, # Will be part of the token. Use to look up secret when verifying.
ttl: int = 60, # Default to 60 seconds ttl: int = 60, # Default to 60 seconds
max_renew: int = 0, # Default to 0 max_renew: int = 0, # Default to 0
public_key: str = None, # Will be part of the token. Use to look up secret when verifying.???
account_id: str = None, account_id: str = None,
person_id: str = None, person_id: str = None,
user_id: str = None, user_id: str = None,
json_str: str = None,
b64_str: str = None,
) -> Dict[str, str]: ) -> Dict[str, str]:
log.setLevel(logging.WARNING) # DEBUG, INFO, WARNING, ERROR, EXCEPTION, CRITICAL log.setLevel(logging.WARNING) # DEBUG, INFO, WARNING, ERROR, EXCEPTION, CRITICAL
log.debug(locals()) log.debug(locals())
@@ -264,6 +269,8 @@ def sign_jwt(
'account_id': account_id, 'account_id': account_id,
'person_id': person_id, 'person_id': person_id,
'user_id': user_id, 'user_id': user_id,
'json_str': json_str,
'b64_str': b64_str,
} }
secret = secret_key secret = secret_key
algorithm = 'HS256' algorithm = 'HS256'

View File

@@ -17,7 +17,15 @@ router = APIRouter()
# ### BEGIN ### API API ### request_jwt() ### # ### BEGIN ### API API ### request_jwt() ###
# Generate JWT using associated API private key # This can be used to generate JWTs for various purposes:
# * for end client browser API access
# * for proof of sign in
# * newer/better version of sign in by URL
# Generate (sign) JWT using Aether platform super secret key or x_aether_signing_key sort of secret key if passed. The Aether platform super secret JWT signing key must be used API access token
# If x_aether_api_key is passed then set higher TTL
# If old and valid x_aether_api_jwt_token is passed then decode and decrease TTL by 1
# Updated 2023-03-24
# Verify JWT using the API public key's associated API private key # Verify JWT using the API public key's associated API private key
# API server or trusted app can generate JWTs # API server or trusted app can generate JWTs
# JWT contains: # JWT contains:
@@ -33,14 +41,27 @@ router = APIRouter()
# Updated 2021-07-14 # Updated 2021-07-14
@router.get('/request_jwt', response_model=Resp_Body_Base) @router.get('/request_jwt', response_model=Resp_Body_Base)
async def request_jwt( async def request_jwt(
x_aether_api_secret_key: Optional[str] = Header(None, min_length=22, max_length=22), # If passed then can also set TTL x_aether_signing_key: Optional[str] = Header(None, min_length=22, max_length=22), # The (secret) signing key. Keep safe!!! If passed then use to sign JWT. Otherwise need to get from system/environment.
x_aether_api_public_key: Optional[str] = Header(None, min_length=22, max_length=22), # Used to look up the API secret if not given
x_aether_api_token: Optional[str] = Header(None), # Token given to client by an API key holder (short max TTL) # x_aether_secret_key: Optional[str] = Header(None, min_length=22, max_length=22), # The Aether secret key. Keep safe!!! If passed then can also set TTL
account_id: str = None,
session_id: str = None, # End client (web browser) x_aether_api_key: Optional[str] = Header(None, min_length=22, max_length=22), # The client side API key. This should be kept secret by the client. If passed then store with JWT and can set TTL.
client_id: str = None, # End client (web browser)
person_id: str = None, # x_aether_api_public_key: Optional[str] = Header(None, min_length=22, max_length=22), # Used to look up the API secret if not given
user_id: str = None,
x_aether_jwt: Optional[str] = Header(None), # A JWT that was created and given to client browser or server in the past. It may or may not be valid. If the x_aether_signing_key was not passed, then assume it was signed with the Aether super secret key.
account_id: str = None, # Handle this different because it is special
json_str: str = None, # This is what should be stored
b64_str: str = None, # This is what should be stored
# I would like payload to be a dict, but then we have to use POST instead of GET...
# Maybe base64 encode and decode?
# session_id: str = None, # End client (web browser)
# client_id: str = None, # End client (web browser)
# person_id: str = None,
# user_id: str = None,
max_ttl: int = 300, # Number of seconds to live. Only use if given the API secret key. max_ttl: int = 300, # Number of seconds to live. Only use if given the API secret key.
# Seconds: 3600 = 1 hr; 300 = 5 min # Seconds: 3600 = 1 hr; 300 = 5 min
max_renew: int = 5, # Decrease count by 1 until 0 if only sent a current API token. max_renew: int = 5, # Decrease count by 1 until 0 if only sent a current API token.
@@ -49,16 +70,43 @@ async def request_jwt(
log.setLevel(logging.WARNING) # DEBUG, INFO, WARNING, ERROR, EXCEPTION, CRITICAL log.setLevel(logging.WARNING) # DEBUG, INFO, WARNING, ERROR, EXCEPTION, CRITICAL
log.debug(locals()) log.debug(locals())
if x_aether_api_secret_key or x_aether_api_token: pass # One of these is required
if account_id or json_str or b64_str: pass
else: return mk_resp(data=False, status_code=400, response=response) # Bad Request else: return mk_resp(data=False, status_code=400, response=response) # Bad Request
if not x_aether_api_secret_key: max_ttl = 300 # Override any max_ttl if no API secret # Possible overrides and checks go here
if not x_aether_api_secret_key: max_renew = 5 # Override any max_rewnew if no API secret if x_aether_signing_key: pass
elif x_aether_api_key:
# Override any if for API JWT???
max_ttl = 3600
max_renew = 5
# if not x_aether_secret_key: max_renew = 5 # Override any max_rewnew if no API secret
# api_secret_key = x_aether_secret_key
api_secret_key = x_aether_api_secret_key signing_key = None
if x_aether_signing_key:
signing_key = x_aether_signing_key
elif settings.JWT_KEY:
signing_key = settings.JWT_KEY
else:
log.error('No key found to sign the JWT with!')
return mk_resp(data=False, status_code=400, response=response) # Bad Request
if x_aether_api_secret_key:
log.debug(f'Contains a value in x_aether_api_secret_key: {x_aether_api_secret_key}') payload = {}
payload['account_id'] = account_id
payload['json_str'] = json_str
payload['b64_str'] = b64_str
token = sign_jwt(secret_key=signing_key, public_key=x_aether_api_key, ttl=max_ttl, max_renew=max_renew, **payload)
response_data = { 'jwt': token }
return mk_resp(data=response_data)
if x_aether_secret_key:
log.debug(f'Contains a value in x_aether_secret_key: {x_aether_secret_key}')
table_name_select = 'api_key' table_name_select = 'api_key'
field_name = 'secret_key' field_name = 'secret_key'
@@ -125,7 +173,7 @@ async def request_jwt(
payload['user_id'] = user_id payload['user_id'] = user_id
token = sign_jwt(secret_key=api_secret_key, public_key=api_public_key, ttl=max_ttl, max_renew=max_renew, **payload) token = sign_jwt(secret_key=api_secret_key, public_key=api_public_key, ttl=max_ttl, max_renew=max_renew, **payload)
response_data = { 'api_access_jwt': token } response_data = { 'jwt': token }
return mk_resp(data=response_data) return mk_resp(data=response_data)
# ### END ### API API ### request_jwt() ### # ### END ### API API ### request_jwt() ###

View File

@@ -20,7 +20,10 @@ router = APIRouter()
# ### BEGIN ### API Hosted File ### directory_check() ### # ### BEGIN ### API Hosted File ### directory_check() ###
# Updated 2022-08-09 # This can be used to clean up the hosted_files directory. Currently it only looks for hashed files in the root, but that is kind of useless now. 2023-03-28
# This needs to be updated to delete orphan files (no records in the DB (dev, test, prod)). Careful...
# I also need to clean up the DB side if there is no file in the hosted_files directory. Less concerning?
# Updated 2023-03-28
@router.get('/directory_check', response_model=Resp_Body_Base) @router.get('/directory_check', response_model=Resp_Body_Base)
async def directory_check( async def directory_check(
rm_orphan: bool = False, rm_orphan: bool = False,
@@ -30,11 +33,6 @@ async def directory_check(
log.setLevel(logging.INFO) # DEBUG, INFO, WARNING, ERROR, EXCEPTION, CRITICAL log.setLevel(logging.INFO) # DEBUG, INFO, WARNING, ERROR, EXCEPTION, CRITICAL
log.debug(locals()) log.debug(locals())
# print('HERE HERE HERE')
# return mk_resp(data=True, response=commons.response, status_message='HERE HERE HERE The hosted file directory check.')
# ### Orphan file: ### Delete file from server # ### Orphan file: ### Delete file from server
hosted_files_path = settings.FILES_PATH['hosted_files_root'] hosted_files_path = settings.FILES_PATH['hosted_files_root']
# hosted_files_path = '/home/scott/tmp/hosted_files_dev/' # hosted_files_path = '/home/scott/tmp/hosted_files_dev/'
@@ -49,12 +47,17 @@ async def directory_check(
log.info('Path exists! Going to get a list of files...') log.info('Path exists! Going to get a list of files...')
directory_list = os.listdir(full_directory_path) directory_list = os.listdir(full_directory_path)
count = 0
result_list = [] result_list = []
for directory_item in directory_list: for directory_item in directory_list:
if count >= 100: break
file_path_w_item = os.path.join(full_directory_path, directory_item) file_path_w_item = os.path.join(full_directory_path, directory_item)
# log.info(f'Full file path with directory item: {file_path_w_item}') # log.info(f'Full file path with directory item: {file_path_w_item}')
log.info(f'Checking directory item: {directory_item}') # log.info(f'Checking directory item: {directory_item}')
if os.path.isfile(file_path_w_item): if os.path.isfile(file_path_w_item):
# ### Found file ###
# log.debug(f'File: {directory_item}') # log.debug(f'File: {directory_item}')
# result_list.append(file_path_w_item) # result_list.append(file_path_w_item)
@@ -63,66 +66,82 @@ async def directory_check(
log.warning(f'Not a hashed file! File: {directory_item}') log.warning(f'Not a hashed file! File: {directory_item}')
continue continue
if lookup_file_hash_result := lookup_file_hash(file_hash=directory_item.replace('.file', '')): log.info(f'Hosted hashed file found: {directory_item}')
# log.info('DB record found') result_list.append(file_path_w_item)
# result_list.append(file_path_w_item)
pass
else:
log.warning(f'Hosted File record not found!!! File: {directory_item}')
result_list.append(file_path_w_item)
if rm_orphan:
log.info('Going remove the hosted file from server...')
try: # Create a subdirectory with the first 2 characters of the hash
# log.warning('DELETE') full_subdirectory_path = os.path.join(full_directory_path, directory_item[:2])
pathlib.Path(file_path_w_item).unlink() log.info(f'Making directory: {full_subdirectory_path}')
# continue os.makedirs(full_subdirectory_path, exist_ok=True)
except OSError as e:
log.error("Error: %s : %s" % (file_path, e.strerror)) # Move the file to the subdirectory
# return False log.info(f'Moving to: {full_subdirectory_path}')
continue shutil.move(os.path.join(full_directory_path, directory_item), os.path.join(full_subdirectory_path, directory_item))
# if lookup_file_hash_result := lookup_file_hash(file_hash=directory_item.replace('.file', '')):
# log.info('DB record found')
# # result_list.append(file_path_w_item)
# pass
# else:
# log.warning(f'Hosted File record not found!!! File: {directory_item}')
# result_list.append(file_path_w_item)
# if rm_orphan:
# log.info('Going remove the hosted file from server...')
# try:
# # log.warning('DELETE')
# pathlib.Path(file_path_w_item).unlink()
# # continue
# except OSError as e:
# log.error("Error: %s : %s" % (file_path, e.strerror))
# # return False
# continue
else: else:
# ### Found directory ###
# continue
# log.debug(f'Directory: {directory_item}') # log.debug(f'Directory: {directory_item}')
# pass # pass
log.info('Subdirectory Path exists! Going to get a list of files...') log.info('Subdirectory Path exists! Going to get a list of files... [LATER]')
full_subdirectory_path = os.path.join(full_directory_path, directory_item) # full_subdirectory_path = os.path.join(full_directory_path, directory_item)
subdirectory_list = os.listdir(full_subdirectory_path) # subdirectory_list = os.listdir(full_subdirectory_path)
subdirectory_result_list = [] # subdirectory_result_list = []
for subdirectory_item in subdirectory_list: # for subdirectory_item in subdirectory_list:
file_path_w_item = os.path.join(full_subdirectory_path, subdirectory_item) # file_path_w_item = os.path.join(full_subdirectory_path, subdirectory_item)
# log.info(f'Full file path with directory item: {file_path_w_item}') # # log.info(f'Full file path with directory item: {file_path_w_item}')
log.info(f'Checking subdirectory item: {subdirectory_item}') # log.info(f'Checking subdirectory item: {subdirectory_item}')
if os.path.isfile(file_path_w_item): # if os.path.isfile(file_path_w_item):
# log.debug(f'File: {subdirectory_item}') # # log.debug(f'File: {subdirectory_item}')
# subdirectory_result_list.append(file_path_w_item) # # subdirectory_result_list.append(file_path_w_item)
if '.file' in subdirectory_item: pass # if '.file' in subdirectory_item: pass
else: # else:
log.warning(f'Not a hashed file! File: {subdirectory_item}') # log.warning(f'Not a hashed file! File: {subdirectory_item}')
continue # continue
if lookup_file_hash_result := lookup_file_hash(file_hash=subdirectory_item.replace('.file', '')): # if lookup_file_hash_result := lookup_file_hash(file_hash=subdirectory_item.replace('.file', '')):
# log.info('DB record found') # # log.info('DB record found')
# subdirectory_result_list.append(file_path_w_item) # # subdirectory_result_list.append(file_path_w_item)
pass # pass
else: # else:
log.warning(f'Hosted File record not found!!! File: {subdirectory_item}') # log.warning(f'Hosted File record not found!!! File: {subdirectory_item}')
result_list.append(file_path_w_item) # result_list.append(file_path_w_item)
if rm_orphan: # if rm_orphan:
log.info('Going remove the hosted file from server...') # log.info('Going remove the hosted file from server...')
try: # try:
# log.warning('DELETE') # # log.warning('DELETE')
pathlib.Path(file_path_w_item).unlink() # pathlib.Path(file_path_w_item).unlink()
# continue # # continue
except OSError as e: # except OSError as e:
log.error("Error: %s : %s" % (file_path, e.strerror)) # log.error("Error: %s : %s" % (file_path, e.strerror))
# return False # # return False
continue # continue
else: # else:
log.warning(f'Subdirectory: {subdirectory_item}') # log.warning(f'Subdirectory: {subdirectory_item}')
pass # pass
count = count + 1
return mk_resp(data=result_list, response=commons.response, status_message='The hosted file directory check.') return mk_resp(data=result_list, response=commons.response, status_message='The hosted file directory check.')
else: else:

View File

@@ -66,6 +66,7 @@ async def get(response: Response = Response):
class ConnectionManager: class ConnectionManager:
def __init__(self): def __init__(self):
# NOTE: The active_connections list should be in Redis
self.active_connections: List[WebSocket] = [] self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket): async def connect(self, websocket: WebSocket):
@@ -76,6 +77,10 @@ class ConnectionManager:
def disconnect(self, websocket: WebSocket): def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket) self.active_connections.remove(websocket)
# Targets: echo, direct, group, broadcast
# send_ text, bytes, json
# receive_ text, bytes, json
async def echo(self, message: str, websocket: WebSocket): async def echo(self, message: str, websocket: WebSocket):
log.setLevel(logging.DEBUG) log.setLevel(logging.DEBUG)
# log.debug(dir(websocket)) # log.debug(dir(websocket))
@@ -94,7 +99,7 @@ class ConnectionManager:
log.debug(connection) log.debug(connection)
await connection.send_text(message) await connection.send_text(message)
async def group(self, group_id: str, data: str): async def group(self, group_id: str, data: dict):
log.setLevel(logging.DEBUG) log.setLevel(logging.DEBUG)
log.debug(locals()) log.debug(locals())
@@ -205,6 +210,9 @@ async def ws_client_id(
log.debug(f'Command: {cmd}') log.debug(f'Command: {cmd}')
log.debug(f'Message: {msg}') log.debug(f'Message: {msg}')
data['client_id'] = client_id
data['group_id'] = group_id
await manager.group(group_id=group_id, data=data) await manager.group(group_id=group_id, data=data)
# if msg_type: # if msg_type: