import asyncio import json import logging from typing import Optional from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends from pydantic import ValidationError from app.lib_general_v3 import AccountContext, get_account_context_optional from app.lib_websockets_v3 import WS_Message_V3, ws_manager_v3 log = logging.getLogger(__name__) router = APIRouter() @router.websocket('/ws/group/{group_id}/client/{client_id}') async def v3_ws_endpoint( websocket: WebSocket, group_id: str, client_id: str, account: AccountContext = Depends(get_account_context_optional), ): """ Main V3 WebSocket Endpoint. Uses granular Redis Pub/Sub for efficient message routing. """ # Auth: optional — guests can connect but will be limited by downstream logic. # Pass api_key and jwt as query params since browsers cannot set custom WS headers. log.info(f"WS V3: Client {client_id} connected to group {group_id} (auth={account.auth_method})") await websocket.accept() # 1. Presence & Subscription Setup await ws_manager_v3.update_presence(client_id, group_id, online=True) redis_conn = await ws_manager_v3.get_redis() pubsub = redis_conn.pubsub() channels = ws_manager_v3.get_channel_names(client_id, group_id) await pubsub.subscribe(*channels) # --- Handlers --- async def receiver_handler(): """Handles incoming messages from the client.""" try: while True: data = await websocket.receive_json() try: # Enforce standardized schema # Force from_id and group_id from path for security data['from_id'] = client_id data['group_id'] = group_id message = WS_Message_V3(**data) # Refresh presence TTL on every heartbeat so long-lived clients # don't drop out of the presence set before they disconnect. if message.msg_type == 'heartbeat': await ws_manager_v3.update_presence(client_id, group_id, online=True) await ws_manager_v3.publish_message(message) except ValidationError as ve: log.warning(f"WS V3: Validation error from {client_id}: {ve.json()}") await websocket.send_json({ "error": "Invalid message schema", "details": ve.errors(), "version": "3" }) except WebSocketDisconnect: log.info(f"WS V3: Client {client_id} disconnected (receiver)") raise except Exception as e: log.exception(f"WS V3: Unexpected error in receiver for {client_id}") async def sender_handler(): """Handles outgoing messages from Redis to the client.""" try: while True: # Use a small timeout to allow for clean task cancellation message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1) if message and message['type'] == 'message': # Forward the structured message directly # Redis stores them as JSON strings await websocket.send_text(message['data']) except Exception as e: log.exception(f"WS V3: Unexpected error in sender for {client_id}") # --- Execution Loop --- try: # Run both loops concurrently. If either fails or client disconnects, clean up. # asyncio.wait with FIRST_COMPLETED ensures we don't leave orphan tasks. done, pending = await asyncio.wait( [ asyncio.create_task(receiver_handler()), asyncio.create_task(sender_handler()), ], return_when=asyncio.FIRST_COMPLETED, ) # Cancel remaining task (usually the sender if the receiver caught a disconnect) for task in pending: task.cancel() except Exception as e: log.error(f"WS V3: Loop error for {client_id}: {e}") finally: # 2. Cleanup log.info(f"WS V3: Cleaning up connection for {client_id}") await ws_manager_v3.update_presence(client_id, group_id, online=False) await pubsub.unsubscribe(*channels) await pubsub.close()