Files
OSIT-AE-API-FastAPI/tests/unit/test_unit_websockets_v3_router.py

195 lines
7.1 KiB
Python

import sys
import os
import asyncio
import json
import unittest
from unittest.mock import MagicMock, AsyncMock, patch
# Add project root to path
sys.path.append(os.getcwd())
# Mock app.config BEFORE imports to prevent attempt to load real settings
mock_config = MagicMock()
mock_config.settings = MagicMock()
mock_config.settings.REDIS = {'server': 'localhost', 'port': 6379}
sys.modules["app.config"] = mock_config
# Mock DB and circular/heavy imports before any app module is loaded
sys.modules["app.db_sql"] = MagicMock()
sys.modules["app.lib_sql_core"] = MagicMock()
sys.modules["app.db_connection"] = MagicMock()
sys.modules["app.log"] = MagicMock()
# Provide a real AccountContext but stub out the heavy lib_general_v3 imports
from app.models.auth_models import AccountContext
mock_lib_general_v3 = MagicMock()
mock_lib_general_v3.AccountContext = AccountContext
mock_lib_general_v3.get_account_context_optional = MagicMock(return_value=AccountContext(
account_id=1, account_id_random="test_account_id", auth_method="api_key"
))
sys.modules["app.lib_general_v3"] = mock_lib_general_v3
sys.modules["app.routers.dependencies_v3"] = MagicMock()
sys.modules["app.lib_jwt"] = MagicMock()
from app.routers.websockets_v3 import v3_ws_endpoint
# Shared AccountContext fixture used in all test calls
MOCK_ACCOUNT = AccountContext(account_id=1, account_id_random="test_account_id", auth_method="api_key")
class TestWSV3Router(unittest.TestCase):
@patch('app.routers.websockets_v3.ws_manager_v3')
def test_v3_ws_endpoint_logic(self, mock_manager):
"""
Tests the core logic of the V3 WebSocket endpoint, ensuring
Redis subscription and bidirectional message handling are initiated.
"""
# 1. Setup WebSocket Mock
mock_ws = AsyncMock()
# 2. Setup Redis PubSub Mock
mock_pubsub = MagicMock()
mock_pubsub.subscribe = AsyncMock()
mock_pubsub.unsubscribe = AsyncMock()
mock_pubsub.close = AsyncMock()
mock_message = {
'type': 'message',
'data': json.dumps({
"version": "3",
"msg_type": "msg",
"target": "group",
"from_id": "other_client",
"msg": "Hello from Redis",
"payload": {},
"sent_at": "2026-01-30T12:00:00Z"
})
}
# Signal to coordinate loops
msg_delivered = asyncio.Event()
# Counters to break the 'while True' loops in the endpoint
get_msg_count = 0
recv_json_count = 0
async def mock_get_message(*args, **kwargs):
nonlocal get_msg_count
get_msg_count += 1
if get_msg_count == 1:
msg_delivered.set()
return mock_message
await asyncio.sleep(0.05)
# Raise CancelledError to terminate the loop cleanly
raise asyncio.CancelledError("Terminate sender loop")
mock_pubsub.get_message = mock_get_message
mock_redis = MagicMock()
mock_redis.pubsub.return_value = mock_pubsub
# 3. Setup Manager Mock
mock_manager.get_redis = AsyncMock(return_value=mock_redis)
mock_manager.update_presence = AsyncMock()
mock_manager.publish_message = AsyncMock()
mock_manager.get_channel_names.return_value = ["ws:group:test"]
# Mock incoming websocket message
async def mock_receive_json():
nonlocal recv_json_count
recv_json_count += 1
if recv_json_count == 1:
# Wait until the sender loop has processed the Redis message
await msg_delivered.wait()
return {
"msg_type": "msg",
"target": "group",
"msg": "Client A saying hi"
}
await asyncio.sleep(0.05)
# Raise CancelledError to terminate the loop cleanly
raise asyncio.CancelledError("Terminate receiver loop")
mock_ws.receive_json.side_effect = mock_receive_json
# 4. Run the endpoint logic
async def run_endpoint():
try:
await asyncio.wait_for(
v3_ws_endpoint(mock_ws, "test_group", "client_a", account=MOCK_ACCOUNT),
timeout=0.5
)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
except Exception as e:
if "Terminate" not in str(e):
raise
asyncio.run(run_endpoint())
# 5. Verifications
mock_ws.accept.assert_called_once()
mock_manager.update_presence.assert_any_call("client_a", "test_group", online=True)
mock_ws.send_text.assert_called()
print("✅ WebSocket Router unit logic verified.")
@patch('app.routers.websockets_v3.ws_manager_v3')
def test_heartbeat_refreshes_presence(self, mock_manager):
"""
Ensures that a heartbeat message triggers a presence TTL refresh
in addition to being published normally.
"""
mock_ws = AsyncMock()
mock_pubsub = MagicMock()
mock_pubsub.subscribe = AsyncMock()
mock_pubsub.unsubscribe = AsyncMock()
mock_pubsub.close = AsyncMock()
async def mock_get_message(*args, **kwargs):
await asyncio.sleep(0.05)
raise asyncio.CancelledError("Terminate sender loop")
mock_pubsub.get_message = mock_get_message
mock_redis = MagicMock()
mock_redis.pubsub.return_value = mock_pubsub
recv_count = 0
async def mock_receive_json():
nonlocal recv_count
recv_count += 1
if recv_count == 1:
# Send a heartbeat
return {"msg_type": "heartbeat", "target": "echo"}
await asyncio.sleep(0.05)
raise asyncio.CancelledError("Terminate receiver loop")
mock_ws.receive_json.side_effect = mock_receive_json
mock_manager.get_redis = AsyncMock(return_value=mock_redis)
mock_manager.update_presence = AsyncMock()
mock_manager.publish_message = AsyncMock()
mock_manager.get_channel_names.return_value = ["ws:group:test"]
async def run_endpoint():
try:
await asyncio.wait_for(
v3_ws_endpoint(mock_ws, "test_group", "client_hb", account=MOCK_ACCOUNT),
timeout=0.5
)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
except Exception as e:
if "Terminate" not in str(e):
raise
asyncio.run(run_endpoint())
# Presence should be refreshed: once on connect, once on heartbeat receipt
presence_calls = mock_manager.update_presence.call_args_list
online_calls = [c for c in presence_calls if c.kwargs.get('online') is True or (len(c.args) >= 3 and c.args[2] is True)]
self.assertGreaterEqual(len(online_calls), 2, "Expected at least 2 presence refresh calls (connect + heartbeat)")
print("✅ Heartbeat presence refresh verified.")