Starting the slow migration to FastAPI...

This commit is contained in:
Scott Idem
2020-11-10 17:04:12 -05:00
parent 6bee9c19cf
commit 8433960d0d
12 changed files with 524 additions and 163 deletions

1
.gitignore vendored
View File

@@ -127,6 +127,7 @@ config.py
.directory
tmp/
temp/
log/
development/
myapp/files/
myapp/file_distribution/

View File

@@ -1,5 +1,7 @@
gunicorn
uvicorn
fastapi[all]
SQLAlchemy
mysqlclient
redis
aioredis

View File

@@ -1,25 +1,36 @@
import secrets
from datetime import timedelta
from app.config import settings
from .log import *
from sqlalchemy import create_engine, text
from sqlalchemy.exc import IntegrityError, OperationalError
#from app import db
#from sqlalchemy.ext.declarative import declarative_base
#from sqlalchemy.orm import sessionmaker, session
AMS_DB_SERVER = 'linode.oneskyit.com'
AMS_DB_PORT = '3306' # default = 3306
AMS_DB_NAME = 'aether_dev' #onesky_ams_dev
AMS_DB_USERNAME = 'onesky_aether'
AMS_DB_PASSWORD = '$onesky.Aether.2020'
db_uri = settings.SQLALCHEMY_DATABASE_URI
connection_string = 'mysql://'+AMS_DB_USERNAME+':'+AMS_DB_PASSWORD+'@'+AMS_DB_SERVER+'/'+AMS_DB_NAME
connection_string = db_uri
engine = create_engine(name_or_url=connection_string, pool_size=10, pool_recycle=120, pool_pre_ping=True, echo=True, echo_pool=True, isolation_level='READ COMMITTED')
# NOTE: The default isolation_level is 'REPEATABLE READ'. This can sometimes not show updated data.
#SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
#Base = declarative_base()
db = engine.connect()
# Dependency
#def get_db():
#db = SessionLocal()
#try:
#yield db
#finally:
#db.close()
# Insert a new record with values given.
def sql_insert(table_name=None, record=None, sql=None, data=None, id_random_length=None):
print('** sql_insert() ***')
@@ -104,7 +115,10 @@ def sql_insert(table_name=None, record=None, sql=None, data=None, id_random_leng
# NOTE: Select records using custom SQL SELECT statements.
def sql_select(sql=None, data=None, table_name=None, record_id=None, record_id_random=None, field_name=None, field_value=None, as_list=False):
print('*** sql_select() ***')
log.setLevel(logging.DEBUG) # DEBUG, INFO, WARN, WARNING, ERROR, EXCEPTION, CRITICAL
log.debug(locals())
custom_sql = None
if record_id and table_name:
sql = text(
@@ -130,6 +144,8 @@ def sql_select(sql=None, data=None, table_name=None, record_id=None, record_id_r
WHERE `"""+table_name+"""`."""+field_name+""" = :field_value
"""
)
data = {}
data[field_name] = field_value
elif table_name:
sql = text(
"""
@@ -138,42 +154,48 @@ def sql_select(sql=None, data=None, table_name=None, record_id=None, record_id_r
"""
)
elif sql:
print('SQL found')
log.info('SQL found')
custom_sql = True
sql = text(sql)
else:
print('One or more required fields are missing')
log.warn('One or more required fields are missing')
return False
try:
#if record_id or record_id_random:
#result = db.execute(sql, record_id=record_id, record_id_random=record_id_random)
#elif field_name and field_value:
#result = db.execute(sql, field_value=field_value)
#elif sql and data:
#result = db.execute(sql, data)
print('Executing SQL...')
result = db.execute(sql, data=data, record_id=record_id, record_id_random=record_id_random, table_name=table_name, field_name=field_name, field_value=field_value)
if not custom_sql:
log.info('Executing a simple SQL select with no extra data dict...')
result = db.execute(sql, record_id=record_id, record_id_random=record_id_random, table_name=table_name, field_name=field_name, field_value=field_value)
elif custom_sql and data:
log.info('Executing a custom SQL select and including the data dict...')
result = db.execute(sql, data)
elif custom_sql:
log.info('Executing a custom SQL select with no extra data dict...')
result = db.execute(sql)
except Exception as e:
print('*** An exception happened. ***')
print(repr(e))
print('***')
print(str(e))
print('^^^ exception ^^^')
log.error('*** An exception happened. ***')
log.error(repr(e))
log.error('***')
log.error(str(e))
log.error('^^^ exception ^^^')
return False
else:
if result.rowcount == 1 and as_list:
print('Single as list')
log.info('Found one record. Returning as a list.')
record = dict(result.fetchone())
return [record]
elif result.rowcount == 1 and not as_list:
print('Single as single')
log.info('Found one record. Returning as a dict.')
#record = result.fetchone()
record = dict(result.fetchone())
return record
elif result.rowcount > 1:
print('List as list')
log.info('Found more than one record. Returning as a list of dicts.')
#records = result.fetchall()
records = [dict(u) for u in result.fetchall()]
return records
elif as_list:
log.info('No records found. Returning as a list.')
return [None]
else:
return False
log.info('No records found. Returning None.')
return None

View File

@@ -1,12 +1,13 @@
import redis
from datetime import datetime, time, timedelta
from fastapi import APIRouter, Depends, Header, HTTPException, status
from pydantic import BaseModel, EmailStr, Field
from typing import Dict, List, Optional, Set, Union
from .log import *
from .db import *
#router = APIRouter()
#import app
async def get_token_header(x_token: str = Header(...)):
if x_token != 'fake-super-secret-token':
@@ -14,8 +15,27 @@ async def get_token_header(x_token: str = Header(...)):
async def get_account_header(x_account_id: str = Header(...)):
log.setLevel(logging.DEBUG) # DEBUG, INFO, WARN, WARNING, ERROR, EXCEPTION, CRITICAL
log.debug(locals())
print('get_account_header(): '+x_account_id)
return x_account_id
if len(x_account_id):
log.info('The x-account-id header has a value.')
if account_id := redis_lookup_id_random(table_name='account', record_id_random=x_account_id):
log.setLevel(logging.DEBUG)
log.info('Found the account_id with the account_id_random value: '+x_account_id)
account = { 'id': account_id, 'id_random': x_account_id }
else:
log.setLevel(logging.DEBUG)
log.info('The x-account-id was invalid and not empty...')
#raise HTTPException(status_code=500)
raise HTTPException(status_code=400) # or 404?
#return False
elif x_account_id == '':
log.info('The x-account-id header was empty.')
account = { 'id': None, 'id_random': None }
return account
#Add the processing time to the response header.
@@ -50,3 +70,46 @@ async def get_account_header(x_account_id: str = Header(...)):
#async def get_account_header(x_account_id: str = Header(...)):
#print('get_account_header(): '+x_account_id+'z9999z')
#return x_account_id+'z9999z'
# Attempt to look up id_random key
# If success then return the id number
# If not success and there is a table_name then check the database table passed
# If found in database table then store in Redis
def redis_lookup_id_random(record_id_random=None, table_name=None):
log.setLevel(logging.DEBUG) # DEBUG, INFO, WARN, WARNING, ERROR, EXCEPTION, CRITICAL
log.debug(locals())
r = redis.Redis(host='localhost', port=6379, db=7, password=None, decode_responses=True)
key_name = 'record_id:'+record_id_random
record_id = r.get(key_name)
#print('Record ID? '+str(record_id))
if record_id:
print('TTL for: '+key_name+' : '+str(record_id)+' is '+str(r.ttl(key_name))+' seconds')
return record_id
elif table_name:
data = { 'id_random': record_id_random }
sql = """
SELECT id
FROM `"""+table_name+"""` AS `table`
WHERE table.id_random = :id_random
"""
if select_results := sql_select(table_name=table_name, record_id_random=record_id_random): # sql_select(sql=sql, data=data)
#print('Record ID random found: '+str(select_results['id']))
record_id = select_results['id']
r.setex(key_name, timedelta(minutes=2), value=record_id)
return record_id
else:
#print('Record ID random was not found')
return None
else:
print('Missing table_name to select from for id_random')
return False
#return False

View File

@@ -16,7 +16,18 @@ from sqlalchemy.exc import IntegrityError, OperationalError
from . import config
from .lib_general import *
from .log import *
from .routers import items, users, websockets
# Import the routers here first:
from .routers import items, journals, users, websockets
# TEST TEST TEST
print('**** Calling db.py ... ****')
#from .db import engine, SessionLocal, Base
from .db import db
print('**** Called db.py ****')
# TEST TEST TEST
#log = logging.getLogger('root')
#log.setLevel(logging.ERROR) # DEBUG > INFO > WARNING > ERROR > CRITICAL
@@ -36,14 +47,7 @@ def get_settings():
app.mount('/static', StaticFiles(directory='static'), name='static')
app.include_router(
users.router,
prefix='/user',
tags=['Users'],
#dependencies=[Depends(get_token_header)],
#dependencies=[Depends(get_account_header)],
#responses={404: {'description': 'Not found'}},
)
# Set up each route once the router has been imported
app.include_router(
items.router,
prefix='/item',
@@ -51,9 +55,26 @@ app.include_router(
#dependencies=[Depends(get_token_header)],
#responses={404: {'description': 'Not found'}},
)
app.include_router(
journals.router,
prefix='/journal',
tags=['Journals'],
#dependencies=[Depends(get_token_header)],
#dependencies=[Depends(get_account_header)],
#responses={404: {'description': 'Not found'}},
)
app.include_router(
users.router,
prefix='/user',
tags=['Users'],
#dependencies=[Depends(get_token_header)],
#dependencies=[Depends(get_account_header)],
#responses={404: {'description': 'Not found'}},
)
app.include_router(
websockets.router,
#prefix='/item',
#prefix='/websocket',
tags=['Websockets'],
#dependencies=[Depends(get_token_header)],
#responses={404: {'description': 'Not found'}},
@@ -82,6 +103,24 @@ app.add_middleware(
# END: CORS
@app.on_event('startup')
async def startup():
log.setLevel(logging.INFO) # DEBUG, INFO, WARN, WARNING, ERROR, EXCEPTION, CRITICAL
log.debug(locals())
log.info('FastAPI app is starting up...')
#await database.connect()
@app.on_event('shutdown')
async def shutdown():
log.setLevel(logging.INFO) # DEBUG, INFO, WARN, WARNING, ERROR, EXCEPTION, CRITICAL
log.debug(locals())
log.info('FastAPI app is shutting down...')
#await database.disconnect()
#Add the processing time to the response header.
@app.middleware('http')
async def add_process_time_header(request: Request, call_next):
@@ -95,7 +134,10 @@ async def add_process_time_header(request: Request, call_next):
@app.get('/', tags=['Default'])
async def get_root():
print(config.settings.APP_NAME)
log.setLevel(logging.INFO) # DEBUG, INFO, WARN, WARNING, ERROR, EXCEPTION, CRITICAL
log.debug(locals())
log.info(config.settings.APP_NAME)
log.setLevel(logging.DEBUG)
@@ -118,3 +160,42 @@ async def get_root():
print('^^^')
return {'hello': 'This is the Aether API using FastAPI.'}
# ### TEST TEST TEST ### #
@app.get('/quick_test', tags=['Default'])
async def quick_test():
log.setLevel(logging.DEBUG) # DEBUG, INFO, WARN, WARNING, ERROR, EXCEPTION, CRITICAL
log.debug(locals())
log.info('Getting all accounts...')
sql = text(
"""
SELECT *
FROM `account`
"""
)
try:
result = db.execute(sql)
except Exception as e:
log.error('*** An exception happened. ***')
log.error(repr(e))
log.error('***')
log.error(str(e))
log.error('^^^ exception ^^^')
else:
if result.rowcount:
records = result.fetchall()
log.debug(records)
else:
log.warning('Something went wrong.')
log.info('Got the account list')
response = {}
response['hello'] = 'This is the Aether API using FastAPI.'
response['data'] = records
return response
# ### TEST TEST TEST ### #

View File

@@ -1,7 +1,13 @@
from datetime import datetime, time, timedelta
from fastapi import APIRouter, HTTPException, status
from pydantic import BaseModel, EmailStr, Field
from typing import Dict, List, Optional, Set, Union
from ..lib_general import *
from ..log import *
from app.config import settings
from app.db import *
router = APIRouter()

View File

@@ -1,24 +0,0 @@
from fastapi import APIRouter, HTTPException
router = APIRouter()
@router.get("/")
async def read_items():
return [{"name": "Item Foo"}, {"name": "item Bar"}]
@router.get("/{item_id}")
async def read_item(item_id: str):
return {"name": "Fake Specific Item", "item_id": item_id}
@router.put(
"/{item_id}",
tags=["custom"],
responses={403: {"description": "Operation forbidden"}},
)
async def update_item(item_id: str):
if item_id != "foo":
raise HTTPException(status_code=403, detail="You can only update the item: foo")
return {"item_id": item_id, "name": "The Fighters"}

View File

@@ -0,0 +1,40 @@
from datetime import datetime, time, timedelta
from pydantic import BaseModel, EmailStr, Field
from typing import Dict, List, Optional, Set, Union
class JournalBase(BaseModel):
#id_random: str = None # This should not be None. It is required.
#id_random: str = Field(None, example='iyOrkTnHEuyYUNeePbEdIg', min_length=11, max_length=22)
account_id_random: str = None # This should not be None. It is required.
user_id_random: str = Field(None, example='iyOrkTnHEuyYUNeePbEdIg', min_length=11, max_length=22)
default_private: Optional[bool] = None
default_public: Optional[bool] = None
default_personal: Optional[bool] = None
default_professional: Optional[bool] = None
private_passcode: str = Field(None, example='my passcode', min_length=3, max_length=20)
title: str = Field(None, example='The Journal Title', min_length=3, max_length=200)
summary: Optional[str] = None
hide: Optional[bool] = None
status: Optional[int] = None
archive_on: Optional[datetime] = None
archive: Optional[bool] = None
priority: Optional[bool] = None
sort: Optional[int] = None
group: Optional[str] = None
notes: Optional[str] = None
class JournalIn(JournalBase):
id_random: str = Field(None, example='iyOrkTnHEuyYUNeePbEdIg', min_length=11, max_length=22)
class JournalOut(JournalBase):
id_random: str = Field(None, example='iyOrkTnHEuyYUNeePbEdIg', min_length=11, max_length=22)
created_on: datetime
update_on: Optional[datetime] = None

159
app/routers/journals.py Normal file
View File

@@ -0,0 +1,159 @@
from datetime import datetime, time, timedelta
from fastapi import APIRouter, Depends, Header, HTTPException, status
from pydantic import BaseModel, EmailStr, Field
from typing import Dict, List, Optional, Set, Union
from ..lib_general import *
from ..log import *
from app.config import settings
from app.db import *
from .journal_models import *
router = APIRouter()
@router.post(
"/",
response_model=JournalOut,
response_model_exclude_unset=True,
summary='Create a new journal account',
status_code=status.HTTP_201_CREATED
)
async def create_journal(journal: JournalIn, x_account_id: str = Header(...)):
"""
Create a new journal account
"""
journal = dict(journal)
table_name = 'journal'
# Look up the journal['account_id_random'] and match to a record ID from Redis
if account_id := redis_lookup_id_random(table_name='account', record_id_random=x_account_id):
journal['account_id'] = account_id
journal.pop('account_id_random')
else:
print('Something went wrong with the id_random lookup.')
raise HTTPException(status_code=500)
if result := sql_insert(table_name=table_name, record=journal, id_random_length=16):
print(type(result))
if type(result) == int: # isinstance(result, int):
# Select the new record to return as a response.
if new_journal := dict(sql_select(table_name=table_name, record_id=result)):
return new_journal
else:
print('New journal record was not found.')
raise HTTPException(status_code=400)
else:
print('There is likely a duplicate record. A new record was not created.')
raise HTTPException(status_code=400)
else:
print('No journal record was not created')
raise HTTPException(status_code=400)
#@router.patch('/{id_random}', response_model=JournalOut, dependencies=[Depends(get_account_header)])
#async def update_journal(id_random: str, journal: JournalIn, x_account_id: str = Header(...)):
#async def update_journal(id_random: str, journal: JournalIn):
@router.patch(
'/{id_random}',
response_model=JournalOut,
summary='Update a journal account'
)
async def update_journal(id_random: str, journal: JournalIn, x_account_id: str = Depends(get_account_header)):
"""
Update a journal account
"""
journal = {}
journal['id_random'] = id_random
journal['account_id_random'] = x_account_id
journal['title'] = 'tit'
journal['summary'] = 'sum'
#journal['created_on'] = datetime.now()
journal['default_private'] = True
journal['default_public'] = False
journal['default_personal'] = False
journal['default_professional'] = False
return journal
@router.delete('/{id_random}', response_model=bool)
async def delete_journal(id_random: str, x_account_id: str = Depends(get_account_header)):
"""
Delete a journal account
"""
return True
return False
@router.get('/', response_model=List[JournalOut])
@router.get('/list_all', response_model=List[JournalOut])
async def list_journals():
"""
Get a list of journals
"""
log.setLevel(logging.DEBUG)
log.debug(str(locals().keys())+' | '+str(locals().values()))
log.debug(locals())
journals = [{'journalname': 'test.journal.1'}, {'journalname': 'test.journal.2'}, {'journalname': 'Scott.Idem'}]
log.info('Getting all journals...')
sql = """
SELECT *
FROM `journal`
/*WHERE id=1*/
"""
#records = sql_select(sql=sql, as_list=True)
records = sql_select(table_name='v_journal', as_list=True)
if records:
log.info('Got the journal list')
return records
else:
log.info('No journal records found')
raise HTTPException(status_code=404)
@router.get(
'/{journal_id_random}',
response_model=JournalOut,
summary='Get a journal with an id (id_random)'
)
async def get_journal_id(journal_id_random: str, x_account_id: str = Header(...)):
"""
Get a journal with an id (id_random)
"""
log.setLevel(logging.WARN) # DEBUG, INFO, WARN, WARNING, ERROR, EXCEPTION, CRITICAL
log.debug(locals())
if account_id := redis_lookup_id_random(table_name='account', record_id_random=x_account_id):
#journal['account_id'] = account_id
#journal.pop('account_id_random')
pass
else:
log.warning('Something went wrong with the id_random lookup.')
raise HTTPException(status_code=500)
if journal_id := redis_lookup_id_random(table_name='journal', record_id_random=journal_id_random):
#journal['journal_id'] = journal_id
#journal.pop('account_id_random')
pass
else:
log.warning('Something went wrong with the id_random lookup.')
raise HTTPException(status_code=500)
record = sql_select(table_name='v_journal', record_id=journal_id)
if record:
log.info('Got the journal')
return record
else:
log.info('No journal record found')
raise HTTPException(status_code=404)

View File

@@ -7,11 +7,8 @@ from ..lib_general import *
from ..log import *
from app.config import settings
from app.db import *
from app.redis import *
from .user_models import *
#import logging
router = APIRouter()
@@ -66,6 +63,8 @@ async def update_user(id_random: str, user: UserIn, x_account_id: str = Depends(
"""
Update a user account
"""
log.setLevel(logging.DEBUG)
log.debug(locals())
user = {}
user['id_random'] = id_random
@@ -84,6 +83,8 @@ async def delete_user(id_random: str, x_account_id: str = Depends(get_account_he
"""
Delete a user account
"""
log.setLevel(logging.DEBUG)
log.debug(locals())
return True
return False
@@ -91,60 +92,58 @@ async def delete_user(id_random: str, x_account_id: str = Depends(get_account_he
@router.get('/', response_model=List[UserOut])
@router.get('/list_all', response_model=List[UserOut])
async def list_users():
async def list_users(x_account: str = Depends(get_account_header)):
"""
Get a list of users
"""
log.setLevel(logging.DEBUG)
log.debug(str(locals().keys())+' | '+str(locals().values()))
log.debug(locals())
#log.setLevel(logging.INFO)
#log.info(None)
log.setLevel(logging.WARNING)
print('***')
log.debug('This is debug') # 10 DEBUG
log.info('This is info') # 20 INFO
log.warn('This is warn') # 30 WARNING
log.warning('This is a warning') # 30 WARNING
log.error('This is an error') # 40 ERROR
log.exception('This is an exception') # 40 ERROR
log.critical('This is critical') # 50 CRITICAL
users = [{'username': 'test.user.1'}, {'username': 'test.user.2'}, {'username': 'Scott.Idem'}]
print('Getting all users...')
sql = """
SELECT *
FROM `user`
/*WHERE id=1*/
"""
records = sql_select(sql=sql, as_list=True)
#records = sql_select(table_name='user')
if x_account['id']:
log.info('The x-account-id was given and is not empty...')
sql = """
SELECT *
FROM `user`
WHERE account_id = :account_id
"""
records = sql_select(table_name='user', field_name='account_id', field_value=x_account['id'], as_list=True)
elif x_account['id'] is None:
log.info('The x-account-id was given, but is empty...')
sql = """
SELECT *
FROM `user`
"""
records = sql_select(table_name='user', as_list=True)
if records:
print('Got the user list')
log.info('Returning a user list...')
return records
else:
print('No user records found')
log.info('No user records found...')
raise HTTPException(status_code=404)
@router.get('/{username}')
async def get_user_username(username: str, x_account_id: str = Header(...)):
return {'username': username}
async def get_user_username(username: str, x_account: str = Depends(get_account_header)):
log.setLevel(logging.DEBUG)
log.debug(locals())
data = {}
data['username'] = username
#@router.get('/me')
#async def get_user_current():
#user_out: UserOut
if x_account['id']:
sql = """
SELECT *
FROM `user`
WHERE account_id = :account_id AND username=:username
"""
data['account_id'] = x_account['id']
elif x_account['id'] is None:
sql = """
SELECT *
FROM `user`
WHERE (account_id IS NULL OR account_id = "") AND username=:username
"""
record = sql_select(sql=sql, data=data)
#return {'username': 'test.user'}
return record

View File

@@ -1,18 +0,0 @@
from fastapi import APIRouter
router = APIRouter()
@router.get("/users/", tags=["users"])
async def read_users():
return [{"username": "Foo"}, {"username": "Bar"}]
@router.get("/users/me", tags=["users"])
async def read_user_me():
return {"username": "fakecurrentuser"}
@router.get("/users/{username}", tags=["users"])
async def read_user(username: str):
return {"username": username}

View File

@@ -2,6 +2,11 @@ from fastapi import APIRouter, FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from typing import List
from ..lib_general import *
from ..log import *
from app.config import settings
from app.db import *
router = APIRouter()
@@ -23,7 +28,9 @@ html = """
<script>
var client_id = Date.now()
document.querySelector("#ws-id").textContent = client_id;
var ws = new WebSocket(`ws://localhost:5005/ws/${client_id}`);
//var ws = new WebSocket(`ws://localhost:5005/ws/${client_id}`);
var ws = new WebSocket("ws://localhost:8000/ws_redis");
//var ws = new WebSocket("ws://fastapi.localhost/ws_redis");
ws.onmessage = function(event) {
var messages = document.getElementById('messages')
var message = document.createElement('li')
@@ -43,41 +50,64 @@ html = """
"""
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
manager = ConnectionManager()
@router.get("/ws_test")
async def websocket_root():
async def get():
log.setLevel(logging.DEBUG)
log.debug(locals())
return HTMLResponse(html)
@router.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: int):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
await manager.send_personal_message(f"You wrote: {data}", websocket)
await manager.broadcast(f"Client #{client_id} says: {data}")
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast(f"Client #{client_id} left the chat")
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
log.setLevel(logging.DEBUG)
log.debug(locals())
await websocket.accept()
await redis_connector(websocket)
async def redis_connector(
websocket: WebSocket, redis_uri: str = "redis://localhost:6379"
):
log.setLevel(logging.DEBUG)
log.debug(locals())
async def consumer_handler(ws: WebSocket, r):
try:
while True:
message = await ws.receive_text()
if message:
#logging.info(ws)
#logging.info(dir(message))
data = json.loads(message)
#await r.publish("chat:c", message)
#await r.publish("chat:c", str(data['message']))
await r.publish("chat:c", str(data['client_id']))
await r.publish("chat:c", str(data))
except WebSocketDisconnect as exc:
# TODO this needs handling better
logger.error(exc)
async def producer_handler(r, ws: WebSocket):
(channel,) = await r.subscribe("chat:c")
assert isinstance(channel, aioredis.Channel)
try:
while True:
message = await channel.get()
if message:
await ws.send_text(message.decode("utf-8"))
except Exception as exc:
# TODO this needs handling better
logger.error(exc)
redis = await aioredis.create_redis_pool(redis_uri)
consumer_task = consumer_handler(websocket, redis)
producer_task = producer_handler(redis, websocket)
done, pending = await asyncio.wait(
[consumer_task, producer_task], return_when=asyncio.FIRST_COMPLETED,
)
logger.debug(f"Done task: {done}")
for task in pending:
logger.debug(f"Canceling task: {task}")
task.cancel()
redis.close()
await redis.wait_closed()