from fastapi import APIRouter, FastAPI, Response, WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse from pydantic import BaseModel, EmailStr, Field from typing import Dict, List, Optional, Set, Union import aioredis, asyncio, datetime, hashlib, json, os, pathlib, shutil, time from app.lib_general import log, logging, common_route_params, Common_Route_Params, common_route_params_min, Common_Route_Params_Min from app.config import settings from app.db_sql import sql_insert, sql_update, sql_insert_or_update, sql_select, sql_delete, redis_lookup_id_random router = APIRouter() html = """ Chat

WebSocket Chat

Your ID:

""" @router.get("/ws_test") async def get(response: Response = Response): log.setLevel(logging.DEBUG) log.debug(locals()) return HTMLResponse(html) @router.websocket("/ws/{client_id}") async def websocket_endpoint( websocket: WebSocket, client_id: int, response: Response = Response, ): log.setLevel(logging.DEBUG) log.debug(locals()) log.info('Root of ws. Waiting to accept a websocket and then the redis_connector') await websocket.accept() await redis_connector(websocket) async def redis_connector( websocket: WebSocket, redis_url: str = "redis://localhost:6379", ): log.setLevel(logging.DEBUG) log.debug(locals()) async def consumer_handler(ws: WebSocket, r): try: while True: message = await ws.receive_text() if message: #logging.info(ws) #logging.info(dir(message)) data = json.loads(message) #await r.publish("chat:c", message) #await r.publish("chat:c", str(data['message'])) await r.publish("chat:c", str(data['client_id'])) await r.publish("chat:c", str(data)) except WebSocketDisconnect as exc: # TODO this needs handling better log.error(exc) async def producer_handler(r, ws: WebSocket): (channel,) = await r.subscribe("chat:c") assert isinstance(channel, aioredis.Channel) try: while True: message = await channel.get() if message: await ws.send_text(message.decode("utf-8")) except Exception as exc: # TODO this needs handling better log.error(exc) # redis = await aioredis.create_pool(redis_url) # Redis client bound to pool of connections (auto-reconnecting). redis = aioredis.from_url( redis_url, encoding="utf-8", decode_responses=True ) consumer_task = consumer_handler(websocket, redis) producer_task = producer_handler(redis, websocket) done, pending = await asyncio.wait( [consumer_task, producer_task], return_when=asyncio.FIRST_COMPLETED, ) log.debug(f"Done task: {done}") for task in pending: log.debug(f"Canceling task: {task}") task.cancel() await redis.close() # await redis.wait_closed()