import datetime import json import logging from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel, Field import redis.asyncio as redis from app.config import settings log = logging.getLogger(__name__) # --- Models --- class WS_Message_V3(BaseModel): """ Standardized message schema for WebSockets V3. """ version: str = "3" msg_type: str = Field(..., description="'msg', 'cmd', 'heartbeat', 'presence'") target: str = Field(..., description="'direct', 'group', 'broadcast', 'echo'") from_id: str = Field(..., description="client_id_random of the sender") to_id: Optional[str] = Field(None, description="target client_id_random (for direct messages)") group_id: Optional[str] = Field(None, description="target group_id_random (for group messages)") cmd: Optional[str] = Field(None, description="Specific command string (e.g., 'RELOAD', 'OPEN_FILE')") msg: Optional[str] = Field(None, description="Human-readable message content") payload: Dict[str, Any] = Field(default_factory=dict, description="Flexible JSON data payload") sent_at: datetime.datetime = Field(default_factory=lambda: datetime.datetime.now(datetime.timezone.utc)) class Config: json_encoders = { datetime.datetime: lambda v: v.isoformat() } # --- Manager --- class WS_Manager_V3: """ Manages Redis Granular Pub/Sub and Presence for WebSockets V3. """ def __init__(self, redis_db: int = 6): self.redis_db = redis_db self.redis_url = f"redis://{settings.REDIS['server']}:{settings.REDIS['port']}" self._redis_conn: Optional[redis.Redis] = None async def get_redis(self) -> redis.Redis: """Lazy-loaded async Redis connection.""" if self._redis_conn is None: log.info(f"WS V3: Connecting to Redis DB {self.redis_db}") self._redis_conn = redis.Redis.from_url( self.redis_url, db=self.redis_db, encoding='utf-8', decode_responses=True ) return self._redis_conn def get_channel_names(self, client_id: str, group_id: Optional[str] = None) -> List[str]: """ Generates the list of Redis channels a client should subscribe to. """ channels = [ f"ws:client:{client_id}", # Direct messages "ws:broadcast" # System-wide messages ] if group_id: channels.append(f"ws:group:{group_id}") # Group messages return channels async def update_presence(self, client_id: str, group_id: str, online: bool = True): """ Tracks which clients are online in which groups using Redis Sets. """ r = await self.get_redis() key = f"ws:presence:{group_id}" if online: await r.sadd(key, client_id) await r.expire(key, 3600) # Auto-expire in 1 hour if not refreshed else: await r.srem(key, client_id) async def get_online_clients(self, group_id: str) -> List[str]: """Returns list of online client IDs in a group.""" r = await self.get_redis() return await r.smembers(f"ws:presence:{group_id}") async def publish_message(self, message: WS_Message_V3): """ Publishes a structured message to the correct granular Redis channel. """ r = await self.get_redis() channel = "" if message.target == "direct": if not message.to_id: log.warning("WS V3: Attempted direct publish without to_id") return channel = f"ws:client:{message.to_id}" elif message.target == "group": if not message.group_id: log.warning("WS V3: Attempted group publish without group_id") return channel = f"ws:group:{message.group_id}" elif message.target == "broadcast": channel = "ws:broadcast" elif message.target == "echo": channel = f"ws:client:{message.from_id}" if channel: log.debug(f"WS V3: Publishing to {channel}") await r.publish(channel, message.json()) # Global instance ws_manager_v3 = WS_Manager_V3()