380 lines
14 KiB
Python
380 lines
14 KiB
Python
from fastapi import APIRouter, FastAPI, Response, WebSocket, WebSocketDisconnect
|
|
from fastapi.responses import HTMLResponse
|
|
from pydantic import BaseModel, EmailStr, Field
|
|
from typing import Dict, List, Optional, Set, Union
|
|
import redis, asyncio, base64, datetime, hashlib, json, os, pathlib, shutil, time
|
|
|
|
from app.lib_general import log, logging, common_route_params, Common_Route_Params, common_route_params_min, Common_Route_Params_Min
|
|
from app.config import settings
|
|
from app.db_sql import sql_insert, sql_update, sql_insert_or_update, sql_select, sql_delete, redis_lookup_id_random
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
html = """
|
|
<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<title>Chat</title>
|
|
</head>
|
|
<body>
|
|
<h1>WebSocket Chat</h1>
|
|
<h2>Your ID: <span id="ws-id"></span></h2>
|
|
<form action="" onsubmit="sendMessage(event)">
|
|
<input type="text" id="messageText" autocomplete="off"/>
|
|
<button>Send</button>
|
|
</form>
|
|
<ul id='messages'>
|
|
</ul>
|
|
<script>
|
|
var client_id = Date.now();
|
|
document.querySelector("#ws-id").textContent = client_id;
|
|
var ws = new WebSocket(`ws://fastapi.localhost:8080/ws/${client_id}`);
|
|
// var ws = new WebSocket(`ws://localhost:5005/ws/${client_id}`);
|
|
// var ws = new WebSocket("ws://localhost:8000/ws");
|
|
// var ws = new WebSocket("ws://fastapi.localhost/ws");
|
|
ws.onmessage = function(event) {
|
|
var messages = document.getElementById('messages');
|
|
var message = document.createElement('li');
|
|
var content = document.createTextNode(event.data);
|
|
message.appendChild(content);
|
|
messages.appendChild(message);
|
|
};
|
|
function sendMessage(event) {
|
|
console.log('*** sendMessage() ***');
|
|
var input = document.getElementById("messageText");
|
|
var data = { 'client_id': client_id, 'message': input.value };
|
|
var data_json_str = JSON.stringify(data);
|
|
ws.send(data_json_str);
|
|
|
|
input.value = '';
|
|
event.preventDefault();
|
|
}
|
|
</script>
|
|
</body>
|
|
</html>
|
|
"""
|
|
|
|
|
|
@router.get("/ws_test")
|
|
async def get(response: Response = Response):
|
|
log.setLevel(logging.DEBUG)
|
|
log.debug(locals())
|
|
return HTMLResponse(html)
|
|
|
|
|
|
|
|
class ConnectionManager:
|
|
def __init__(self):
|
|
# NOTE: The active_connections list should be in Redis
|
|
self.active_connections: List[WebSocket] = []
|
|
|
|
async def connect(self, websocket: WebSocket):
|
|
await websocket.accept()
|
|
log.info('WS connect')
|
|
self.active_connections.append(websocket)
|
|
log.debug(self.active_connections)
|
|
|
|
def disconnect(self, websocket: WebSocket):
|
|
log.info('WS disconnect')
|
|
self.active_connections.remove(websocket)
|
|
|
|
|
|
# Targets: echo, direct, group, broadcast
|
|
# send_ text, bytes, json
|
|
# receive_ text, bytes, json
|
|
async def echo(self, message: str, websocket: WebSocket):
|
|
log.setLevel(logging.DEBUG)
|
|
# log.debug(dir(websocket))
|
|
log.debug(vars(websocket))
|
|
log.debug(websocket.url)
|
|
log.debug(websocket.client)
|
|
log.debug(websocket.client_state)
|
|
log.debug(websocket.headers['sec-websocket-key'])
|
|
# log.debug(base64.decode(bytes(websocket.headers['sec-websocket-key']), 'utf-8'))
|
|
await websocket.send_text(message)
|
|
|
|
async def direct(self, from_client_id: str, to_client_id: str, data: dict):
|
|
log.setLevel(logging.DEBUG)
|
|
for connection in self.active_connections:
|
|
log.debug(vars(connection))
|
|
log.debug(connection)
|
|
await connection.send_text(message)
|
|
|
|
async def group(self, group_id: str, data: dict):
|
|
log.setLevel(logging.INFO)
|
|
log.debug(locals())
|
|
|
|
for connection in self.active_connections:
|
|
log.debug(vars(connection))
|
|
# websocket.path_params.get('client_id')
|
|
# if connection.scope.get('path') == group_id:
|
|
if connection.path_params.get('group_id') == group_id:
|
|
log.info('Found matching Group ID')
|
|
await connection.send_json(data)
|
|
|
|
# NOTE: Same as group, but no filter based on path
|
|
async def broadcast(self, message: str):
|
|
log.setLevel(logging.INFO)
|
|
log.debug(locals())
|
|
|
|
for connection in self.active_connections:
|
|
log.debug(vars(connection))
|
|
await connection.send_text(message)
|
|
|
|
async def send_personal_message(self, message: str, websocket: WebSocket):
|
|
log.setLevel(logging.DEBUG)
|
|
# log.debug(dir(websocket))
|
|
log.debug(vars(websocket))
|
|
log.debug(websocket.url)
|
|
log.debug(websocket.client)
|
|
log.debug(websocket.client_state)
|
|
log.debug(websocket.headers['sec-websocket-key'])
|
|
# log.debug(base64.decode(bytes(websocket.headers['sec-websocket-key']), 'utf-8'))
|
|
await websocket.send_text(message)
|
|
|
|
|
|
|
|
manager = ConnectionManager()
|
|
|
|
|
|
# Endpoints for???
|
|
# /room/<id> (just a group of clients; for a related group like a poster presenter or session room)
|
|
# /client/<id> (for one specific client/browser; something specific to a browser???)
|
|
# /person/<id> (for one specific person; handles send and receiving their messages)
|
|
|
|
|
|
@router.websocket('/ws/client/{client_id}')
|
|
async def ws_client_id(
|
|
websocket: WebSocket,
|
|
client_id: str,
|
|
):
|
|
await manager.connect(websocket)
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_json() # Returns dict
|
|
# log.debug(data)
|
|
|
|
# echo (echo message), dm (direct message), group (group message), all (broadcast message to all), cmd, group_cmd(?)
|
|
msg_type = data.get('type')
|
|
cmd = data.get('cmd')
|
|
msg = data.get('msg')
|
|
to_client_id = data.get('to_client_id')
|
|
to_group_id = data.get('to_group_id')
|
|
|
|
log.setLevel(logging.INFO)
|
|
log.info(f'Client ID: {client_id}; Type: {msg_type};')
|
|
log.debug(f'Command: {cmd}')
|
|
log.debug(f'Message: {msg}')
|
|
log.debug(f'To Client ID: {to_client_id}')
|
|
log.debug(f'To Group ID: {to_group_id}')
|
|
|
|
if msg_type:
|
|
if msg_type == 'echo':
|
|
await manager.echo(f'Echo: {data}', websocket)
|
|
elif msg_type == 'dm':
|
|
await manager.direct(from_client_id=client_id, to_client_id=to_client_id, data=data)
|
|
elif msg_type == 'group':
|
|
await manager.broadcast(f'Group: {data}')
|
|
elif msg_type == 'all':
|
|
await manager.broadcast(f'All: {data}')
|
|
else:
|
|
await manager.broadcast(f'Unknown: {data}')
|
|
else:
|
|
await manager.broadcast(f'MSG: {data}')
|
|
|
|
except WebSocketDisconnect:
|
|
manager.disconnect(websocket)
|
|
# await manager.broadcast(f'Client #{client_id} left')
|
|
|
|
|
|
@router.websocket('/ws/group/{group_id}/client/{client_id}')
|
|
async def ws_client_id(
|
|
websocket: WebSocket,
|
|
group_id: str,
|
|
client_id: str,
|
|
):
|
|
await manager.connect(websocket)
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_json() # Returns dict
|
|
# log.debug(data)
|
|
|
|
# group_path_id = f'/ws/group/{group_id}'
|
|
# client_id = data.get('client_id')
|
|
# echo (echo message), dm (direct message), group (group message), all (broadcast message to all), cmd, group_cmd(?)
|
|
msg_type = data.get('type')
|
|
cmd = data.get('cmd')
|
|
msg = data.get('msg')
|
|
|
|
log.setLevel(logging.INFO)
|
|
log.info(f'Group ID: {group_id}; Client ID: {client_id}; Type: {msg_type};')
|
|
log.debug(f'Command: {cmd}')
|
|
log.debug(f'Message: {msg}')
|
|
|
|
data['client_id'] = client_id
|
|
data['group_id'] = group_id
|
|
|
|
await manager.group(group_id=group_id, data=data)
|
|
|
|
# if msg_type:
|
|
# if msg_type == 'echo':
|
|
# await manager.echo(f'Echo: {data}', websocket)
|
|
# elif msg_type == 'dm':
|
|
# await manager.direct(f'DM: {msg}')
|
|
# elif msg_type == 'group':
|
|
# await manager.group(group_id=group_id, f'Group: {data}')
|
|
# elif msg_type == 'all':
|
|
# await manager.broadcast(f'All: {data}')
|
|
# elif msg_type == 'cmd':
|
|
# await manager.broadcast(f'Command: {data}')
|
|
# else:
|
|
# await manager.broadcast(f'Unknown: {data}')
|
|
# else:
|
|
# await manager.broadcast(f'MSG: {data}')
|
|
|
|
except WebSocketDisconnect:
|
|
manager.disconnect(websocket)
|
|
# await manager.broadcast(f'Client #{client_id} left')
|
|
|
|
|
|
|
|
@router.websocket('/ws/{client_id}')
|
|
async def ws_id(
|
|
websocket: WebSocket,
|
|
client_id: int,
|
|
):
|
|
await manager.connect(websocket)
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_json()
|
|
log.debug(data)
|
|
data_dict = data
|
|
# data_dict = json.loads(data)
|
|
log.debug(data_dict['client_id'])
|
|
await manager.send_personal_message(f'You wrote: {data}', websocket)
|
|
await manager.broadcast(f'Client #{client_id} says: {data}')
|
|
except WebSocketDisconnect:
|
|
manager.disconnect(websocket)
|
|
await manager.broadcast(f'Client #{client_id} left')
|
|
|
|
|
|
|
|
# @router.websocket('/ws/room/{room_id}')
|
|
# async def ws_room_id(
|
|
# websocket: WebSocket,
|
|
# room_id: str,
|
|
# ):
|
|
# await manager.connect(websocket)
|
|
# await manager.broadcast(f'Welcome to room "{room_id}"!')
|
|
# try:
|
|
# while True:
|
|
# data = await websocket.receive_json()
|
|
# log.debug(data)
|
|
# data_dict = data
|
|
# # data_dict = json.loads(data)
|
|
# log.debug(data_dict['client_id'])
|
|
# client_id = data_dict['client_id']
|
|
# await manager.send_personal_message(f'You wrote: {data}', websocket)
|
|
# await manager.broadcast(f'Client #{client_id} says: {data}')
|
|
# except WebSocketDisconnect:
|
|
# manager.disconnect(websocket)
|
|
# await manager.broadcast(f'Client left the room')
|
|
|
|
|
|
# # NOTE: WARNING NOTE: WARNING NOTE: WARNING NOTE: WARNING NOTE: WARNING NOTE: WARNING NOTE: WARNING
|
|
# # time.sleep(3.5) # NOTE: WARNING NOTE: WARNING NOTE: WARNING NOTE: WARNING NOTE: WARNING NOTE: WARNING
|
|
# # NOTE: WARNING NOTE: WARNING NOTE: WARNING NOTE: WARNING NOTE: WARNING NOTE: WARNING NOTE: WARNING
|
|
|
|
|
|
# @router.websocket('/ws/looping')
|
|
# async def ws_looping(
|
|
# websocket: WebSocket,
|
|
# ):
|
|
# await manager.connect(websocket)
|
|
# # await manager.broadcast(f'Welcome to looping')
|
|
# try:
|
|
# while True:
|
|
# # NOTE: WARNING NOTE: WARNING NOTE: WARNING NOTE: WARNING NOTE: WARNING NOTE: WARNING NOTE: WARNING
|
|
# # await time.sleep(3.5) # NOTE: WARNING NOTE: WARNING NOTE: WARNING NOTE: WARNING NOTE: WARNING NOTE: WARNING
|
|
# # NOTE: WARNING NOTE: WARNING NOTE: WARNING NOTE: WARNING NOTE: WARNING NOTE: WARNING NOTE: WARNING
|
|
# # data = await websocket.receive_json()
|
|
# # log.debug(data)
|
|
# # data_dict = data
|
|
# # data_dict = json.loads(data)
|
|
# # log.debug(data_dict['client_id'])
|
|
# # await manager.send_personal_message(f'You wrote: {data}', websocket)
|
|
# await manager.broadcast(f'Loop!!!')
|
|
# except WebSocketDisconnect:
|
|
# manager.disconnect(websocket)
|
|
# await manager.broadcast(f'Client left looping')
|
|
|
|
|
|
# @router.websocket("/ws/{client_id}")
|
|
# async def websocket_endpoint(
|
|
# websocket: WebSocket,
|
|
# client_id: int,
|
|
# response: Response = Response,
|
|
# ):
|
|
# log.setLevel(logging.DEBUG)
|
|
# log.debug(locals())
|
|
|
|
# log.info('Root of ws. Waiting to accept a websocket and then the redis_connector')
|
|
|
|
# await websocket.accept()
|
|
# await redis_connector(websocket)
|
|
|
|
|
|
async def redis_connector(
|
|
websocket: WebSocket,
|
|
redis_url: str = f"redis://{settings.REDIS['server']}:{settings.REDIS['port']}",
|
|
):
|
|
log.setLevel(logging.DEBUG)
|
|
log.debug(locals())
|
|
|
|
async def consumer_handler(ws: WebSocket, r):
|
|
try:
|
|
while True:
|
|
message = await ws.receive_text()
|
|
if message:
|
|
#logging.info(ws)
|
|
#logging.info(dir(message))
|
|
data = json.loads(message)
|
|
#await r.publish("chat:c", message)
|
|
#await r.publish("chat:c", str(data['message']))
|
|
await r.publish("chat:c", str(data['client_id']))
|
|
await r.publish("chat:c", str(data))
|
|
except WebSocketDisconnect as exc:
|
|
# TODO this needs handling better
|
|
log.error(exc)
|
|
|
|
async def producer_handler(r, ws: WebSocket):
|
|
(channel,) = await r.subscribe("chat:c")
|
|
assert isinstance(channel, redis.Channel)
|
|
try:
|
|
while True:
|
|
message = await channel.get()
|
|
if message:
|
|
await ws.send_text(message.decode("utf-8"))
|
|
except Exception as exc:
|
|
# TODO this needs handling better
|
|
log.error(exc)
|
|
|
|
# redis = await redis.create_pool(redis_url)
|
|
# Redis client bound to pool of connections (auto-reconnecting).
|
|
redis = redis.from_url(
|
|
redis_url, encoding="utf-8", decode_responses=True
|
|
)
|
|
|
|
consumer_task = consumer_handler(websocket, redis)
|
|
producer_task = producer_handler(redis, websocket)
|
|
done, pending = await asyncio.wait(
|
|
[consumer_task, producer_task], return_when=asyncio.FIRST_COMPLETED,
|
|
)
|
|
log.debug(f"Done task: {done}")
|
|
for task in pending:
|
|
log.debug(f"Canceling task: {task}")
|
|
task.cancel()
|
|
await redis.close()
|
|
# await redis.wait_closed()
|