Starting the slow migration to FastAPI...

This commit is contained in:
Scott Idem
2020-11-10 17:04:12 -05:00
parent 6bee9c19cf
commit 8433960d0d
12 changed files with 524 additions and 163 deletions

1
.gitignore vendored
View File

@@ -127,6 +127,7 @@ config.py
.directory .directory
tmp/ tmp/
temp/ temp/
log/
development/ development/
myapp/files/ myapp/files/
myapp/file_distribution/ myapp/file_distribution/

View File

@@ -1,5 +1,7 @@
gunicorn
uvicorn uvicorn
fastapi[all] fastapi[all]
SQLAlchemy SQLAlchemy
mysqlclient mysqlclient
redis redis
aioredis

View File

@@ -1,25 +1,36 @@
import secrets import secrets
from datetime import timedelta
from app.config import settings from app.config import settings
from .log import *
from sqlalchemy import create_engine, text from sqlalchemy import create_engine, text
from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.exc import IntegrityError, OperationalError
#from app import db #from sqlalchemy.ext.declarative import declarative_base
#from sqlalchemy.orm import sessionmaker, session
AMS_DB_SERVER = 'linode.oneskyit.com' db_uri = settings.SQLALCHEMY_DATABASE_URI
AMS_DB_PORT = '3306' # default = 3306
AMS_DB_NAME = 'aether_dev' #onesky_ams_dev
AMS_DB_USERNAME = 'onesky_aether'
AMS_DB_PASSWORD = '$onesky.Aether.2020'
connection_string = db_uri
connection_string = 'mysql://'+AMS_DB_USERNAME+':'+AMS_DB_PASSWORD+'@'+AMS_DB_SERVER+'/'+AMS_DB_NAME
engine = create_engine(name_or_url=connection_string, pool_size=10, pool_recycle=120, pool_pre_ping=True, echo=True, echo_pool=True, isolation_level='READ COMMITTED') engine = create_engine(name_or_url=connection_string, pool_size=10, pool_recycle=120, pool_pre_ping=True, echo=True, echo_pool=True, isolation_level='READ COMMITTED')
# NOTE: The default isolation_level is 'REPEATABLE READ'. This can sometimes not show updated data. # NOTE: The default isolation_level is 'REPEATABLE READ'. This can sometimes not show updated data.
#SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
#Base = declarative_base()
db = engine.connect() db = engine.connect()
# Dependency
#def get_db():
#db = SessionLocal()
#try:
#yield db
#finally:
#db.close()
# Insert a new record with values given. # Insert a new record with values given.
def sql_insert(table_name=None, record=None, sql=None, data=None, id_random_length=None): def sql_insert(table_name=None, record=None, sql=None, data=None, id_random_length=None):
print('** sql_insert() ***') print('** sql_insert() ***')
@@ -104,7 +115,10 @@ def sql_insert(table_name=None, record=None, sql=None, data=None, id_random_leng
# NOTE: Select records using custom SQL SELECT statements. # NOTE: Select records using custom SQL SELECT statements.
def sql_select(sql=None, data=None, table_name=None, record_id=None, record_id_random=None, field_name=None, field_value=None, as_list=False): def sql_select(sql=None, data=None, table_name=None, record_id=None, record_id_random=None, field_name=None, field_value=None, as_list=False):
print('*** sql_select() ***') log.setLevel(logging.DEBUG) # DEBUG, INFO, WARN, WARNING, ERROR, EXCEPTION, CRITICAL
log.debug(locals())
custom_sql = None
if record_id and table_name: if record_id and table_name:
sql = text( sql = text(
@@ -130,6 +144,8 @@ def sql_select(sql=None, data=None, table_name=None, record_id=None, record_id_r
WHERE `"""+table_name+"""`."""+field_name+""" = :field_value WHERE `"""+table_name+"""`."""+field_name+""" = :field_value
""" """
) )
data = {}
data[field_name] = field_value
elif table_name: elif table_name:
sql = text( sql = text(
""" """
@@ -138,42 +154,48 @@ def sql_select(sql=None, data=None, table_name=None, record_id=None, record_id_r
""" """
) )
elif sql: elif sql:
print('SQL found') log.info('SQL found')
custom_sql = True
sql = text(sql) sql = text(sql)
else: else:
print('One or more required fields are missing') log.warn('One or more required fields are missing')
return False return False
try: try:
#if record_id or record_id_random: if not custom_sql:
#result = db.execute(sql, record_id=record_id, record_id_random=record_id_random) log.info('Executing a simple SQL select with no extra data dict...')
#elif field_name and field_value: result = db.execute(sql, record_id=record_id, record_id_random=record_id_random, table_name=table_name, field_name=field_name, field_value=field_value)
#result = db.execute(sql, field_value=field_value) elif custom_sql and data:
#elif sql and data: log.info('Executing a custom SQL select and including the data dict...')
#result = db.execute(sql, data) result = db.execute(sql, data)
print('Executing SQL...') elif custom_sql:
result = db.execute(sql, data=data, record_id=record_id, record_id_random=record_id_random, table_name=table_name, field_name=field_name, field_value=field_value) log.info('Executing a custom SQL select with no extra data dict...')
result = db.execute(sql)
except Exception as e: except Exception as e:
print('*** An exception happened. ***') log.error('*** An exception happened. ***')
print(repr(e)) log.error(repr(e))
print('***') log.error('***')
print(str(e)) log.error(str(e))
print('^^^ exception ^^^') log.error('^^^ exception ^^^')
return False return False
else: else:
if result.rowcount == 1 and as_list: if result.rowcount == 1 and as_list:
print('Single as list') log.info('Found one record. Returning as a list.')
record = dict(result.fetchone()) record = dict(result.fetchone())
return [record] return [record]
elif result.rowcount == 1 and not as_list: elif result.rowcount == 1 and not as_list:
print('Single as single') log.info('Found one record. Returning as a dict.')
#record = result.fetchone() #record = result.fetchone()
record = dict(result.fetchone()) record = dict(result.fetchone())
return record return record
elif result.rowcount > 1: elif result.rowcount > 1:
print('List as list') log.info('Found more than one record. Returning as a list of dicts.')
#records = result.fetchall() #records = result.fetchall()
records = [dict(u) for u in result.fetchall()] records = [dict(u) for u in result.fetchall()]
return records return records
elif as_list:
log.info('No records found. Returning as a list.')
return [None]
else: else:
return False log.info('No records found. Returning None.')
return None

View File

@@ -1,12 +1,13 @@
import redis
from datetime import datetime, time, timedelta from datetime import datetime, time, timedelta
from fastapi import APIRouter, Depends, Header, HTTPException, status from fastapi import APIRouter, Depends, Header, HTTPException, status
from pydantic import BaseModel, EmailStr, Field from pydantic import BaseModel, EmailStr, Field
from typing import Dict, List, Optional, Set, Union from typing import Dict, List, Optional, Set, Union
from .log import *
from .db import *
#router = APIRouter()
#import app
async def get_token_header(x_token: str = Header(...)): async def get_token_header(x_token: str = Header(...)):
if x_token != 'fake-super-secret-token': if x_token != 'fake-super-secret-token':
@@ -14,8 +15,27 @@ async def get_token_header(x_token: str = Header(...)):
async def get_account_header(x_account_id: str = Header(...)): async def get_account_header(x_account_id: str = Header(...)):
log.setLevel(logging.DEBUG) # DEBUG, INFO, WARN, WARNING, ERROR, EXCEPTION, CRITICAL
log.debug(locals())
print('get_account_header(): '+x_account_id) print('get_account_header(): '+x_account_id)
return x_account_id
if len(x_account_id):
log.info('The x-account-id header has a value.')
if account_id := redis_lookup_id_random(table_name='account', record_id_random=x_account_id):
log.setLevel(logging.DEBUG)
log.info('Found the account_id with the account_id_random value: '+x_account_id)
account = { 'id': account_id, 'id_random': x_account_id }
else:
log.setLevel(logging.DEBUG)
log.info('The x-account-id was invalid and not empty...')
#raise HTTPException(status_code=500)
raise HTTPException(status_code=400) # or 404?
#return False
elif x_account_id == '':
log.info('The x-account-id header was empty.')
account = { 'id': None, 'id_random': None }
return account
#Add the processing time to the response header. #Add the processing time to the response header.
@@ -50,3 +70,46 @@ async def get_account_header(x_account_id: str = Header(...)):
#async def get_account_header(x_account_id: str = Header(...)): #async def get_account_header(x_account_id: str = Header(...)):
#print('get_account_header(): '+x_account_id+'z9999z') #print('get_account_header(): '+x_account_id+'z9999z')
#return x_account_id+'z9999z' #return x_account_id+'z9999z'
# Attempt to look up id_random key
# If success then return the id number
# If not success and there is a table_name then check the database table passed
# If found in database table then store in Redis
def redis_lookup_id_random(record_id_random=None, table_name=None):
log.setLevel(logging.DEBUG) # DEBUG, INFO, WARN, WARNING, ERROR, EXCEPTION, CRITICAL
log.debug(locals())
r = redis.Redis(host='localhost', port=6379, db=7, password=None, decode_responses=True)
key_name = 'record_id:'+record_id_random
record_id = r.get(key_name)
#print('Record ID? '+str(record_id))
if record_id:
print('TTL for: '+key_name+' : '+str(record_id)+' is '+str(r.ttl(key_name))+' seconds')
return record_id
elif table_name:
data = { 'id_random': record_id_random }
sql = """
SELECT id
FROM `"""+table_name+"""` AS `table`
WHERE table.id_random = :id_random
"""
if select_results := sql_select(table_name=table_name, record_id_random=record_id_random): # sql_select(sql=sql, data=data)
#print('Record ID random found: '+str(select_results['id']))
record_id = select_results['id']
r.setex(key_name, timedelta(minutes=2), value=record_id)
return record_id
else:
#print('Record ID random was not found')
return None
else:
print('Missing table_name to select from for id_random')
return False
#return False

View File

@@ -16,7 +16,18 @@ from sqlalchemy.exc import IntegrityError, OperationalError
from . import config from . import config
from .lib_general import * from .lib_general import *
from .log import * from .log import *
from .routers import items, users, websockets
# Import the routers here first:
from .routers import items, journals, users, websockets
# TEST TEST TEST
print('**** Calling db.py ... ****')
#from .db import engine, SessionLocal, Base
from .db import db
print('**** Called db.py ****')
# TEST TEST TEST
#log = logging.getLogger('root') #log = logging.getLogger('root')
#log.setLevel(logging.ERROR) # DEBUG > INFO > WARNING > ERROR > CRITICAL #log.setLevel(logging.ERROR) # DEBUG > INFO > WARNING > ERROR > CRITICAL
@@ -36,14 +47,7 @@ def get_settings():
app.mount('/static', StaticFiles(directory='static'), name='static') app.mount('/static', StaticFiles(directory='static'), name='static')
app.include_router( # Set up each route once the router has been imported
users.router,
prefix='/user',
tags=['Users'],
#dependencies=[Depends(get_token_header)],
#dependencies=[Depends(get_account_header)],
#responses={404: {'description': 'Not found'}},
)
app.include_router( app.include_router(
items.router, items.router,
prefix='/item', prefix='/item',
@@ -51,9 +55,26 @@ app.include_router(
#dependencies=[Depends(get_token_header)], #dependencies=[Depends(get_token_header)],
#responses={404: {'description': 'Not found'}}, #responses={404: {'description': 'Not found'}},
) )
app.include_router(
journals.router,
prefix='/journal',
tags=['Journals'],
#dependencies=[Depends(get_token_header)],
#dependencies=[Depends(get_account_header)],
#responses={404: {'description': 'Not found'}},
)
app.include_router(
users.router,
prefix='/user',
tags=['Users'],
#dependencies=[Depends(get_token_header)],
#dependencies=[Depends(get_account_header)],
#responses={404: {'description': 'Not found'}},
)
app.include_router( app.include_router(
websockets.router, websockets.router,
#prefix='/item', #prefix='/websocket',
tags=['Websockets'], tags=['Websockets'],
#dependencies=[Depends(get_token_header)], #dependencies=[Depends(get_token_header)],
#responses={404: {'description': 'Not found'}}, #responses={404: {'description': 'Not found'}},
@@ -82,6 +103,24 @@ app.add_middleware(
# END: CORS # END: CORS
@app.on_event('startup')
async def startup():
log.setLevel(logging.INFO) # DEBUG, INFO, WARN, WARNING, ERROR, EXCEPTION, CRITICAL
log.debug(locals())
log.info('FastAPI app is starting up...')
#await database.connect()
@app.on_event('shutdown')
async def shutdown():
log.setLevel(logging.INFO) # DEBUG, INFO, WARN, WARNING, ERROR, EXCEPTION, CRITICAL
log.debug(locals())
log.info('FastAPI app is shutting down...')
#await database.disconnect()
#Add the processing time to the response header. #Add the processing time to the response header.
@app.middleware('http') @app.middleware('http')
async def add_process_time_header(request: Request, call_next): async def add_process_time_header(request: Request, call_next):
@@ -95,7 +134,10 @@ async def add_process_time_header(request: Request, call_next):
@app.get('/', tags=['Default']) @app.get('/', tags=['Default'])
async def get_root(): async def get_root():
print(config.settings.APP_NAME) log.setLevel(logging.INFO) # DEBUG, INFO, WARN, WARNING, ERROR, EXCEPTION, CRITICAL
log.debug(locals())
log.info(config.settings.APP_NAME)
log.setLevel(logging.DEBUG) log.setLevel(logging.DEBUG)
@@ -118,3 +160,42 @@ async def get_root():
print('^^^') print('^^^')
return {'hello': 'This is the Aether API using FastAPI.'} return {'hello': 'This is the Aether API using FastAPI.'}
# ### TEST TEST TEST ### #
@app.get('/quick_test', tags=['Default'])
async def quick_test():
log.setLevel(logging.DEBUG) # DEBUG, INFO, WARN, WARNING, ERROR, EXCEPTION, CRITICAL
log.debug(locals())
log.info('Getting all accounts...')
sql = text(
"""
SELECT *
FROM `account`
"""
)
try:
result = db.execute(sql)
except Exception as e:
log.error('*** An exception happened. ***')
log.error(repr(e))
log.error('***')
log.error(str(e))
log.error('^^^ exception ^^^')
else:
if result.rowcount:
records = result.fetchall()
log.debug(records)
else:
log.warning('Something went wrong.')
log.info('Got the account list')
response = {}
response['hello'] = 'This is the Aether API using FastAPI.'
response['data'] = records
return response
# ### TEST TEST TEST ### #

View File

@@ -1,7 +1,13 @@
from datetime import datetime, time, timedelta
from fastapi import APIRouter, HTTPException, status from fastapi import APIRouter, HTTPException, status
from pydantic import BaseModel, EmailStr, Field from pydantic import BaseModel, EmailStr, Field
from typing import Dict, List, Optional, Set, Union from typing import Dict, List, Optional, Set, Union
from ..lib_general import *
from ..log import *
from app.config import settings
from app.db import *
router = APIRouter() router = APIRouter()

View File

@@ -1,24 +0,0 @@
from fastapi import APIRouter, HTTPException
router = APIRouter()
@router.get("/")
async def read_items():
return [{"name": "Item Foo"}, {"name": "item Bar"}]
@router.get("/{item_id}")
async def read_item(item_id: str):
return {"name": "Fake Specific Item", "item_id": item_id}
@router.put(
"/{item_id}",
tags=["custom"],
responses={403: {"description": "Operation forbidden"}},
)
async def update_item(item_id: str):
if item_id != "foo":
raise HTTPException(status_code=403, detail="You can only update the item: foo")
return {"item_id": item_id, "name": "The Fighters"}

View File

@@ -0,0 +1,40 @@
from datetime import datetime, time, timedelta
from pydantic import BaseModel, EmailStr, Field
from typing import Dict, List, Optional, Set, Union
class JournalBase(BaseModel):
#id_random: str = None # This should not be None. It is required.
#id_random: str = Field(None, example='iyOrkTnHEuyYUNeePbEdIg', min_length=11, max_length=22)
account_id_random: str = None # This should not be None. It is required.
user_id_random: str = Field(None, example='iyOrkTnHEuyYUNeePbEdIg', min_length=11, max_length=22)
default_private: Optional[bool] = None
default_public: Optional[bool] = None
default_personal: Optional[bool] = None
default_professional: Optional[bool] = None
private_passcode: str = Field(None, example='my passcode', min_length=3, max_length=20)
title: str = Field(None, example='The Journal Title', min_length=3, max_length=200)
summary: Optional[str] = None
hide: Optional[bool] = None
status: Optional[int] = None
archive_on: Optional[datetime] = None
archive: Optional[bool] = None
priority: Optional[bool] = None
sort: Optional[int] = None
group: Optional[str] = None
notes: Optional[str] = None
class JournalIn(JournalBase):
id_random: str = Field(None, example='iyOrkTnHEuyYUNeePbEdIg', min_length=11, max_length=22)
class JournalOut(JournalBase):
id_random: str = Field(None, example='iyOrkTnHEuyYUNeePbEdIg', min_length=11, max_length=22)
created_on: datetime
update_on: Optional[datetime] = None

159
app/routers/journals.py Normal file
View File

@@ -0,0 +1,159 @@
from datetime import datetime, time, timedelta
from fastapi import APIRouter, Depends, Header, HTTPException, status
from pydantic import BaseModel, EmailStr, Field
from typing import Dict, List, Optional, Set, Union
from ..lib_general import *
from ..log import *
from app.config import settings
from app.db import *
from .journal_models import *
router = APIRouter()
@router.post(
"/",
response_model=JournalOut,
response_model_exclude_unset=True,
summary='Create a new journal account',
status_code=status.HTTP_201_CREATED
)
async def create_journal(journal: JournalIn, x_account_id: str = Header(...)):
"""
Create a new journal account
"""
journal = dict(journal)
table_name = 'journal'
# Look up the journal['account_id_random'] and match to a record ID from Redis
if account_id := redis_lookup_id_random(table_name='account', record_id_random=x_account_id):
journal['account_id'] = account_id
journal.pop('account_id_random')
else:
print('Something went wrong with the id_random lookup.')
raise HTTPException(status_code=500)
if result := sql_insert(table_name=table_name, record=journal, id_random_length=16):
print(type(result))
if type(result) == int: # isinstance(result, int):
# Select the new record to return as a response.
if new_journal := dict(sql_select(table_name=table_name, record_id=result)):
return new_journal
else:
print('New journal record was not found.')
raise HTTPException(status_code=400)
else:
print('There is likely a duplicate record. A new record was not created.')
raise HTTPException(status_code=400)
else:
print('No journal record was not created')
raise HTTPException(status_code=400)
#@router.patch('/{id_random}', response_model=JournalOut, dependencies=[Depends(get_account_header)])
#async def update_journal(id_random: str, journal: JournalIn, x_account_id: str = Header(...)):
#async def update_journal(id_random: str, journal: JournalIn):
@router.patch(
'/{id_random}',
response_model=JournalOut,
summary='Update a journal account'
)
async def update_journal(id_random: str, journal: JournalIn, x_account_id: str = Depends(get_account_header)):
"""
Update a journal account
"""
journal = {}
journal['id_random'] = id_random
journal['account_id_random'] = x_account_id
journal['title'] = 'tit'
journal['summary'] = 'sum'
#journal['created_on'] = datetime.now()
journal['default_private'] = True
journal['default_public'] = False
journal['default_personal'] = False
journal['default_professional'] = False
return journal
@router.delete('/{id_random}', response_model=bool)
async def delete_journal(id_random: str, x_account_id: str = Depends(get_account_header)):
"""
Delete a journal account
"""
return True
return False
@router.get('/', response_model=List[JournalOut])
@router.get('/list_all', response_model=List[JournalOut])
async def list_journals():
"""
Get a list of journals
"""
log.setLevel(logging.DEBUG)
log.debug(str(locals().keys())+' | '+str(locals().values()))
log.debug(locals())
journals = [{'journalname': 'test.journal.1'}, {'journalname': 'test.journal.2'}, {'journalname': 'Scott.Idem'}]
log.info('Getting all journals...')
sql = """
SELECT *
FROM `journal`
/*WHERE id=1*/
"""
#records = sql_select(sql=sql, as_list=True)
records = sql_select(table_name='v_journal', as_list=True)
if records:
log.info('Got the journal list')
return records
else:
log.info('No journal records found')
raise HTTPException(status_code=404)
@router.get(
'/{journal_id_random}',
response_model=JournalOut,
summary='Get a journal with an id (id_random)'
)
async def get_journal_id(journal_id_random: str, x_account_id: str = Header(...)):
"""
Get a journal with an id (id_random)
"""
log.setLevel(logging.WARN) # DEBUG, INFO, WARN, WARNING, ERROR, EXCEPTION, CRITICAL
log.debug(locals())
if account_id := redis_lookup_id_random(table_name='account', record_id_random=x_account_id):
#journal['account_id'] = account_id
#journal.pop('account_id_random')
pass
else:
log.warning('Something went wrong with the id_random lookup.')
raise HTTPException(status_code=500)
if journal_id := redis_lookup_id_random(table_name='journal', record_id_random=journal_id_random):
#journal['journal_id'] = journal_id
#journal.pop('account_id_random')
pass
else:
log.warning('Something went wrong with the id_random lookup.')
raise HTTPException(status_code=500)
record = sql_select(table_name='v_journal', record_id=journal_id)
if record:
log.info('Got the journal')
return record
else:
log.info('No journal record found')
raise HTTPException(status_code=404)

View File

@@ -7,11 +7,8 @@ from ..lib_general import *
from ..log import * from ..log import *
from app.config import settings from app.config import settings
from app.db import * from app.db import *
from app.redis import *
from .user_models import * from .user_models import *
#import logging
router = APIRouter() router = APIRouter()
@@ -66,6 +63,8 @@ async def update_user(id_random: str, user: UserIn, x_account_id: str = Depends(
""" """
Update a user account Update a user account
""" """
log.setLevel(logging.DEBUG)
log.debug(locals())
user = {} user = {}
user['id_random'] = id_random user['id_random'] = id_random
@@ -84,6 +83,8 @@ async def delete_user(id_random: str, x_account_id: str = Depends(get_account_he
""" """
Delete a user account Delete a user account
""" """
log.setLevel(logging.DEBUG)
log.debug(locals())
return True return True
return False return False
@@ -91,60 +92,58 @@ async def delete_user(id_random: str, x_account_id: str = Depends(get_account_he
@router.get('/', response_model=List[UserOut]) @router.get('/', response_model=List[UserOut])
@router.get('/list_all', response_model=List[UserOut]) @router.get('/list_all', response_model=List[UserOut])
async def list_users(): async def list_users(x_account: str = Depends(get_account_header)):
""" """
Get a list of users Get a list of users
""" """
log.setLevel(logging.DEBUG) log.setLevel(logging.DEBUG)
log.debug(str(locals().keys())+' | '+str(locals().values()))
log.debug(locals()) log.debug(locals())
#log.setLevel(logging.INFO) if x_account['id']:
#log.info(None) log.info('The x-account-id was given and is not empty...')
sql = """
log.setLevel(logging.WARNING) SELECT *
FROM `user`
print('***') WHERE account_id = :account_id
log.debug('This is debug') # 10 DEBUG """
log.info('This is info') # 20 INFO records = sql_select(table_name='user', field_name='account_id', field_value=x_account['id'], as_list=True)
log.warn('This is warn') # 30 WARNING elif x_account['id'] is None:
log.warning('This is a warning') # 30 WARNING log.info('The x-account-id was given, but is empty...')
log.error('This is an error') # 40 ERROR sql = """
log.exception('This is an exception') # 40 ERROR SELECT *
log.critical('This is critical') # 50 CRITICAL FROM `user`
"""
records = sql_select(table_name='user', as_list=True)
users = [{'username': 'test.user.1'}, {'username': 'test.user.2'}, {'username': 'Scott.Idem'}]
print('Getting all users...')
sql = """
SELECT *
FROM `user`
/*WHERE id=1*/
"""
records = sql_select(sql=sql, as_list=True)
#records = sql_select(table_name='user')
if records: if records:
print('Got the user list') log.info('Returning a user list...')
return records return records
else: else:
print('No user records found') log.info('No user records found...')
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@router.get('/{username}') @router.get('/{username}')
async def get_user_username(username: str, x_account_id: str = Header(...)): async def get_user_username(username: str, x_account: str = Depends(get_account_header)):
return {'username': username} log.setLevel(logging.DEBUG)
log.debug(locals())
data = {}
data['username'] = username
#@router.get('/me') if x_account['id']:
#async def get_user_current(): sql = """
#user_out: UserOut SELECT *
FROM `user`
WHERE account_id = :account_id AND username=:username
"""
data['account_id'] = x_account['id']
elif x_account['id'] is None:
sql = """
SELECT *
FROM `user`
WHERE (account_id IS NULL OR account_id = "") AND username=:username
"""
record = sql_select(sql=sql, data=data)
#return {'username': 'test.user'} return record

View File

@@ -1,18 +0,0 @@
from fastapi import APIRouter
router = APIRouter()
@router.get("/users/", tags=["users"])
async def read_users():
return [{"username": "Foo"}, {"username": "Bar"}]
@router.get("/users/me", tags=["users"])
async def read_user_me():
return {"username": "fakecurrentuser"}
@router.get("/users/{username}", tags=["users"])
async def read_user(username: str):
return {"username": username}

View File

@@ -2,6 +2,11 @@ from fastapi import APIRouter, FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from typing import List from typing import List
from ..lib_general import *
from ..log import *
from app.config import settings
from app.db import *
router = APIRouter() router = APIRouter()
@@ -23,7 +28,9 @@ html = """
<script> <script>
var client_id = Date.now() var client_id = Date.now()
document.querySelector("#ws-id").textContent = client_id; document.querySelector("#ws-id").textContent = client_id;
var ws = new WebSocket(`ws://localhost:5005/ws/${client_id}`); //var ws = new WebSocket(`ws://localhost:5005/ws/${client_id}`);
var ws = new WebSocket("ws://localhost:8000/ws_redis");
//var ws = new WebSocket("ws://fastapi.localhost/ws_redis");
ws.onmessage = function(event) { ws.onmessage = function(event) {
var messages = document.getElementById('messages') var messages = document.getElementById('messages')
var message = document.createElement('li') var message = document.createElement('li')
@@ -43,41 +50,64 @@ html = """
""" """
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
manager = ConnectionManager()
@router.get("/ws_test") @router.get("/ws_test")
async def websocket_root(): async def get():
log.setLevel(logging.DEBUG)
log.debug(locals())
return HTMLResponse(html) return HTMLResponse(html)
@router.websocket("/ws/{client_id}") @router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, client_id: int): async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket) log.setLevel(logging.DEBUG)
try: log.debug(locals())
while True: await websocket.accept()
data = await websocket.receive_text() await redis_connector(websocket)
await manager.send_personal_message(f"You wrote: {data}", websocket)
await manager.broadcast(f"Client #{client_id} says: {data}")
except WebSocketDisconnect: async def redis_connector(
manager.disconnect(websocket) websocket: WebSocket, redis_uri: str = "redis://localhost:6379"
await manager.broadcast(f"Client #{client_id} left the chat") ):
log.setLevel(logging.DEBUG)
log.debug(locals())
async def consumer_handler(ws: WebSocket, r):
try:
while True:
message = await ws.receive_text()
if message:
#logging.info(ws)
#logging.info(dir(message))
data = json.loads(message)
#await r.publish("chat:c", message)
#await r.publish("chat:c", str(data['message']))
await r.publish("chat:c", str(data['client_id']))
await r.publish("chat:c", str(data))
except WebSocketDisconnect as exc:
# TODO this needs handling better
logger.error(exc)
async def producer_handler(r, ws: WebSocket):
(channel,) = await r.subscribe("chat:c")
assert isinstance(channel, aioredis.Channel)
try:
while True:
message = await channel.get()
if message:
await ws.send_text(message.decode("utf-8"))
except Exception as exc:
# TODO this needs handling better
logger.error(exc)
redis = await aioredis.create_redis_pool(redis_uri)
consumer_task = consumer_handler(websocket, redis)
producer_task = producer_handler(redis, websocket)
done, pending = await asyncio.wait(
[consumer_task, producer_task], return_when=asyncio.FIRST_COMPLETED,
)
logger.debug(f"Done task: {done}")
for task in pending:
logger.debug(f"Canceling task: {task}")
task.cancel()
redis.close()
await redis.wait_closed()