diff --git a/tests/unit/test_unit_websockets_v3_router.py b/tests/unit/test_unit_websockets_v3_router.py index 1badf4e..f23e9e1 100644 --- a/tests/unit/test_unit_websockets_v3_router.py +++ b/tests/unit/test_unit_websockets_v3_router.py @@ -14,13 +14,28 @@ mock_config.settings = MagicMock() mock_config.settings.REDIS = {'server': 'localhost', 'port': 6379} sys.modules["app.config"] = mock_config -# Mock DB related modules to prevent circular imports or DB connection attempts +# Mock DB and circular/heavy imports before any app module is loaded sys.modules["app.db_sql"] = MagicMock() sys.modules["app.lib_sql_core"] = MagicMock() sys.modules["app.db_connection"] = MagicMock() +sys.modules["app.log"] = MagicMock() + +# Provide a real AccountContext but stub out the heavy lib_general_v3 imports +from app.models.auth_models import AccountContext +mock_lib_general_v3 = MagicMock() +mock_lib_general_v3.AccountContext = AccountContext +mock_lib_general_v3.get_account_context_optional = MagicMock(return_value=AccountContext( + account_id=1, account_id_random="test_account_id", auth_method="api_key" +)) +sys.modules["app.lib_general_v3"] = mock_lib_general_v3 +sys.modules["app.routers.dependencies_v3"] = MagicMock() +sys.modules["app.lib_jwt"] = MagicMock() from app.routers.websockets_v3 import v3_ws_endpoint +# Shared AccountContext fixture used in all test calls +MOCK_ACCOUNT = AccountContext(account_id=1, account_id_random="test_account_id", auth_method="api_key") + class TestWSV3Router(unittest.TestCase): @patch('app.routers.websockets_v3.ws_manager_v3') @@ -31,13 +46,13 @@ class TestWSV3Router(unittest.TestCase): """ # 1. Setup WebSocket Mock mock_ws = AsyncMock() - + # 2. Setup Redis PubSub Mock mock_pubsub = MagicMock() mock_pubsub.subscribe = AsyncMock() mock_pubsub.unsubscribe = AsyncMock() mock_pubsub.close = AsyncMock() - + mock_message = { 'type': 'message', 'data': json.dumps({ @@ -50,10 +65,10 @@ class TestWSV3Router(unittest.TestCase): "sent_at": "2026-01-30T12:00:00Z" }) } - + # Signal to coordinate loops msg_delivered = asyncio.Event() - + # Counters to break the 'while True' loops in the endpoint get_msg_count = 0 recv_json_count = 0 @@ -69,10 +84,10 @@ class TestWSV3Router(unittest.TestCase): raise asyncio.CancelledError("Terminate sender loop") mock_pubsub.get_message = mock_get_message - + mock_redis = MagicMock() mock_redis.pubsub.return_value = mock_pubsub - + # 3. Setup Manager Mock mock_manager.get_redis = AsyncMock(return_value=mock_redis) mock_manager.update_presence = AsyncMock() @@ -96,19 +111,17 @@ class TestWSV3Router(unittest.TestCase): raise asyncio.CancelledError("Terminate receiver loop") mock_ws.receive_json.side_effect = mock_receive_json - + # 4. Run the endpoint logic async def run_endpoint(): try: - # Execute endpoint with a short timeout await asyncio.wait_for( - v3_ws_endpoint(mock_ws, "test_group", "client_a"), + v3_ws_endpoint(mock_ws, "test_group", "client_a", account=MOCK_ACCOUNT), timeout=0.5 ) except (asyncio.TimeoutError, asyncio.CancelledError): pass except Exception as e: - # Suppress our expected loop-termination messages if "Terminate" not in str(e): raise @@ -117,10 +130,65 @@ class TestWSV3Router(unittest.TestCase): # 5. Verifications mock_ws.accept.assert_called_once() mock_manager.update_presence.assert_any_call("client_a", "test_group", online=True) - - # Verify message from Redis was forwarded to WebSocket mock_ws.send_text.assert_called() print("✅ WebSocket Router unit logic verified.") -if __name__ == "__main__": - unittest.main() + @patch('app.routers.websockets_v3.ws_manager_v3') + def test_heartbeat_refreshes_presence(self, mock_manager): + """ + Ensures that a heartbeat message triggers a presence TTL refresh + in addition to being published normally. + """ + mock_ws = AsyncMock() + + mock_pubsub = MagicMock() + mock_pubsub.subscribe = AsyncMock() + mock_pubsub.unsubscribe = AsyncMock() + mock_pubsub.close = AsyncMock() + + async def mock_get_message(*args, **kwargs): + await asyncio.sleep(0.05) + raise asyncio.CancelledError("Terminate sender loop") + + mock_pubsub.get_message = mock_get_message + + mock_redis = MagicMock() + mock_redis.pubsub.return_value = mock_pubsub + + recv_count = 0 + + async def mock_receive_json(): + nonlocal recv_count + recv_count += 1 + if recv_count == 1: + # Send a heartbeat + return {"msg_type": "heartbeat", "target": "echo"} + await asyncio.sleep(0.05) + raise asyncio.CancelledError("Terminate receiver loop") + + mock_ws.receive_json.side_effect = mock_receive_json + + mock_manager.get_redis = AsyncMock(return_value=mock_redis) + mock_manager.update_presence = AsyncMock() + mock_manager.publish_message = AsyncMock() + mock_manager.get_channel_names.return_value = ["ws:group:test"] + + async def run_endpoint(): + try: + await asyncio.wait_for( + v3_ws_endpoint(mock_ws, "test_group", "client_hb", account=MOCK_ACCOUNT), + timeout=0.5 + ) + except (asyncio.TimeoutError, asyncio.CancelledError): + pass + except Exception as e: + if "Terminate" not in str(e): + raise + + asyncio.run(run_endpoint()) + + # Presence should be refreshed: once on connect, once on heartbeat receipt + presence_calls = mock_manager.update_presence.call_args_list + online_calls = [c for c in presence_calls if c.kwargs.get('online') is True or (len(c.args) >= 3 and c.args[2] is True)] + self.assertGreaterEqual(len(online_calls), 2, "Expected at least 2 presence refresh calls (connect + heartbeat)") + print("✅ Heartbeat presence refresh verified.")