119 lines
4.3 KiB
Python
119 lines
4.3 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Optional
|
|
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
|
|
from pydantic import ValidationError
|
|
|
|
from app.lib_general_v3 import AccountContext, get_account_context_optional
|
|
from app.lib_websockets_v3 import WS_Message_V3, ws_manager_v3
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
@router.websocket('/ws/group/{group_id}/client/{client_id}')
|
|
async def v3_ws_endpoint(
|
|
websocket: WebSocket,
|
|
group_id: str,
|
|
client_id: str,
|
|
account: AccountContext = Depends(get_account_context_optional),
|
|
):
|
|
"""
|
|
Main V3 WebSocket Endpoint.
|
|
Uses granular Redis Pub/Sub for efficient message routing.
|
|
"""
|
|
# Auth: optional — guests can connect but will be limited by downstream logic.
|
|
# Pass api_key and jwt as query params since browsers cannot set custom WS headers.
|
|
log.info(f"WS V3: Client {client_id} connected to group {group_id} (auth={account.auth_method})")
|
|
await websocket.accept()
|
|
|
|
# 1. Presence & Subscription Setup
|
|
await ws_manager_v3.update_presence(client_id, group_id, online=True)
|
|
|
|
redis_conn = await ws_manager_v3.get_redis()
|
|
pubsub = redis_conn.pubsub()
|
|
|
|
channels = ws_manager_v3.get_channel_names(client_id, group_id)
|
|
await pubsub.subscribe(*channels)
|
|
|
|
# --- Handlers ---
|
|
|
|
async def receiver_handler():
|
|
"""Handles incoming messages from the client."""
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_json()
|
|
|
|
try:
|
|
# Enforce standardized schema
|
|
# Force from_id and group_id from path for security
|
|
data['from_id'] = client_id
|
|
data['group_id'] = group_id
|
|
|
|
message = WS_Message_V3(**data)
|
|
|
|
# Refresh presence TTL on every heartbeat so long-lived clients
|
|
# don't drop out of the presence set before they disconnect.
|
|
if message.msg_type == 'heartbeat':
|
|
await ws_manager_v3.update_presence(client_id, group_id, online=True)
|
|
|
|
await ws_manager_v3.publish_message(message)
|
|
|
|
except ValidationError as ve:
|
|
log.warning(f"WS V3: Validation error from {client_id}: {ve.json()}")
|
|
await websocket.send_json({
|
|
"error": "Invalid message schema",
|
|
"details": ve.errors(),
|
|
"version": "3"
|
|
})
|
|
|
|
except WebSocketDisconnect:
|
|
log.info(f"WS V3: Client {client_id} disconnected (receiver)")
|
|
raise
|
|
except Exception as e:
|
|
log.exception(f"WS V3: Unexpected error in receiver for {client_id}")
|
|
|
|
async def sender_handler():
|
|
"""Handles outgoing messages from Redis to the client."""
|
|
try:
|
|
while True:
|
|
# Use a small timeout to allow for clean task cancellation
|
|
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
|
|
|
|
if message and message['type'] == 'message':
|
|
# Forward the structured message directly
|
|
# Redis stores them as JSON strings
|
|
await websocket.send_text(message['data'])
|
|
|
|
except Exception as e:
|
|
log.exception(f"WS V3: Unexpected error in sender for {client_id}")
|
|
|
|
# --- Execution Loop ---
|
|
|
|
try:
|
|
# Run both loops concurrently. If either fails or client disconnects, clean up.
|
|
# asyncio.wait with FIRST_COMPLETED ensures we don't leave orphan tasks.
|
|
done, pending = await asyncio.wait(
|
|
[
|
|
asyncio.create_task(receiver_handler()),
|
|
asyncio.create_task(sender_handler()),
|
|
],
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
)
|
|
|
|
# Cancel remaining task (usually the sender if the receiver caught a disconnect)
|
|
for task in pending:
|
|
task.cancel()
|
|
|
|
except Exception as e:
|
|
log.error(f"WS V3: Loop error for {client_id}: {e}")
|
|
|
|
finally:
|
|
# 2. Cleanup
|
|
log.info(f"WS V3: Cleaning up connection for {client_id}")
|
|
await ws_manager_v3.update_presence(client_id, group_id, online=False)
|
|
await pubsub.unsubscribe(*channels)
|
|
await pubsub.close()
|