Update WS router test: wire AccountContext mock, add heartbeat presence refresh test
This commit is contained in:
@@ -14,13 +14,28 @@ mock_config.settings = MagicMock()
|
|||||||
mock_config.settings.REDIS = {'server': 'localhost', 'port': 6379}
|
mock_config.settings.REDIS = {'server': 'localhost', 'port': 6379}
|
||||||
sys.modules["app.config"] = mock_config
|
sys.modules["app.config"] = mock_config
|
||||||
|
|
||||||
# Mock DB related modules to prevent circular imports or DB connection attempts
|
# Mock DB and circular/heavy imports before any app module is loaded
|
||||||
sys.modules["app.db_sql"] = MagicMock()
|
sys.modules["app.db_sql"] = MagicMock()
|
||||||
sys.modules["app.lib_sql_core"] = MagicMock()
|
sys.modules["app.lib_sql_core"] = MagicMock()
|
||||||
sys.modules["app.db_connection"] = MagicMock()
|
sys.modules["app.db_connection"] = MagicMock()
|
||||||
|
sys.modules["app.log"] = MagicMock()
|
||||||
|
|
||||||
|
# Provide a real AccountContext but stub out the heavy lib_general_v3 imports
|
||||||
|
from app.models.auth_models import AccountContext
|
||||||
|
mock_lib_general_v3 = MagicMock()
|
||||||
|
mock_lib_general_v3.AccountContext = AccountContext
|
||||||
|
mock_lib_general_v3.get_account_context_optional = MagicMock(return_value=AccountContext(
|
||||||
|
account_id=1, account_id_random="test_account_id", auth_method="api_key"
|
||||||
|
))
|
||||||
|
sys.modules["app.lib_general_v3"] = mock_lib_general_v3
|
||||||
|
sys.modules["app.routers.dependencies_v3"] = MagicMock()
|
||||||
|
sys.modules["app.lib_jwt"] = MagicMock()
|
||||||
|
|
||||||
from app.routers.websockets_v3 import v3_ws_endpoint
|
from app.routers.websockets_v3 import v3_ws_endpoint
|
||||||
|
|
||||||
|
# Shared AccountContext fixture used in all test calls
|
||||||
|
MOCK_ACCOUNT = AccountContext(account_id=1, account_id_random="test_account_id", auth_method="api_key")
|
||||||
|
|
||||||
class TestWSV3Router(unittest.TestCase):
|
class TestWSV3Router(unittest.TestCase):
|
||||||
|
|
||||||
@patch('app.routers.websockets_v3.ws_manager_v3')
|
@patch('app.routers.websockets_v3.ws_manager_v3')
|
||||||
@@ -31,13 +46,13 @@ class TestWSV3Router(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
# 1. Setup WebSocket Mock
|
# 1. Setup WebSocket Mock
|
||||||
mock_ws = AsyncMock()
|
mock_ws = AsyncMock()
|
||||||
|
|
||||||
# 2. Setup Redis PubSub Mock
|
# 2. Setup Redis PubSub Mock
|
||||||
mock_pubsub = MagicMock()
|
mock_pubsub = MagicMock()
|
||||||
mock_pubsub.subscribe = AsyncMock()
|
mock_pubsub.subscribe = AsyncMock()
|
||||||
mock_pubsub.unsubscribe = AsyncMock()
|
mock_pubsub.unsubscribe = AsyncMock()
|
||||||
mock_pubsub.close = AsyncMock()
|
mock_pubsub.close = AsyncMock()
|
||||||
|
|
||||||
mock_message = {
|
mock_message = {
|
||||||
'type': 'message',
|
'type': 'message',
|
||||||
'data': json.dumps({
|
'data': json.dumps({
|
||||||
@@ -50,10 +65,10 @@ class TestWSV3Router(unittest.TestCase):
|
|||||||
"sent_at": "2026-01-30T12:00:00Z"
|
"sent_at": "2026-01-30T12:00:00Z"
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
# Signal to coordinate loops
|
# Signal to coordinate loops
|
||||||
msg_delivered = asyncio.Event()
|
msg_delivered = asyncio.Event()
|
||||||
|
|
||||||
# Counters to break the 'while True' loops in the endpoint
|
# Counters to break the 'while True' loops in the endpoint
|
||||||
get_msg_count = 0
|
get_msg_count = 0
|
||||||
recv_json_count = 0
|
recv_json_count = 0
|
||||||
@@ -69,10 +84,10 @@ class TestWSV3Router(unittest.TestCase):
|
|||||||
raise asyncio.CancelledError("Terminate sender loop")
|
raise asyncio.CancelledError("Terminate sender loop")
|
||||||
|
|
||||||
mock_pubsub.get_message = mock_get_message
|
mock_pubsub.get_message = mock_get_message
|
||||||
|
|
||||||
mock_redis = MagicMock()
|
mock_redis = MagicMock()
|
||||||
mock_redis.pubsub.return_value = mock_pubsub
|
mock_redis.pubsub.return_value = mock_pubsub
|
||||||
|
|
||||||
# 3. Setup Manager Mock
|
# 3. Setup Manager Mock
|
||||||
mock_manager.get_redis = AsyncMock(return_value=mock_redis)
|
mock_manager.get_redis = AsyncMock(return_value=mock_redis)
|
||||||
mock_manager.update_presence = AsyncMock()
|
mock_manager.update_presence = AsyncMock()
|
||||||
@@ -96,19 +111,17 @@ class TestWSV3Router(unittest.TestCase):
|
|||||||
raise asyncio.CancelledError("Terminate receiver loop")
|
raise asyncio.CancelledError("Terminate receiver loop")
|
||||||
|
|
||||||
mock_ws.receive_json.side_effect = mock_receive_json
|
mock_ws.receive_json.side_effect = mock_receive_json
|
||||||
|
|
||||||
# 4. Run the endpoint logic
|
# 4. Run the endpoint logic
|
||||||
async def run_endpoint():
|
async def run_endpoint():
|
||||||
try:
|
try:
|
||||||
# Execute endpoint with a short timeout
|
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
v3_ws_endpoint(mock_ws, "test_group", "client_a"),
|
v3_ws_endpoint(mock_ws, "test_group", "client_a", account=MOCK_ACCOUNT),
|
||||||
timeout=0.5
|
timeout=0.5
|
||||||
)
|
)
|
||||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Suppress our expected loop-termination messages
|
|
||||||
if "Terminate" not in str(e):
|
if "Terminate" not in str(e):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -117,10 +130,65 @@ class TestWSV3Router(unittest.TestCase):
|
|||||||
# 5. Verifications
|
# 5. Verifications
|
||||||
mock_ws.accept.assert_called_once()
|
mock_ws.accept.assert_called_once()
|
||||||
mock_manager.update_presence.assert_any_call("client_a", "test_group", online=True)
|
mock_manager.update_presence.assert_any_call("client_a", "test_group", online=True)
|
||||||
|
|
||||||
# Verify message from Redis was forwarded to WebSocket
|
|
||||||
mock_ws.send_text.assert_called()
|
mock_ws.send_text.assert_called()
|
||||||
print("✅ WebSocket Router unit logic verified.")
|
print("✅ WebSocket Router unit logic verified.")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
@patch('app.routers.websockets_v3.ws_manager_v3')
|
||||||
unittest.main()
|
def test_heartbeat_refreshes_presence(self, mock_manager):
|
||||||
|
"""
|
||||||
|
Ensures that a heartbeat message triggers a presence TTL refresh
|
||||||
|
in addition to being published normally.
|
||||||
|
"""
|
||||||
|
mock_ws = AsyncMock()
|
||||||
|
|
||||||
|
mock_pubsub = MagicMock()
|
||||||
|
mock_pubsub.subscribe = AsyncMock()
|
||||||
|
mock_pubsub.unsubscribe = AsyncMock()
|
||||||
|
mock_pubsub.close = AsyncMock()
|
||||||
|
|
||||||
|
async def mock_get_message(*args, **kwargs):
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
raise asyncio.CancelledError("Terminate sender loop")
|
||||||
|
|
||||||
|
mock_pubsub.get_message = mock_get_message
|
||||||
|
|
||||||
|
mock_redis = MagicMock()
|
||||||
|
mock_redis.pubsub.return_value = mock_pubsub
|
||||||
|
|
||||||
|
recv_count = 0
|
||||||
|
|
||||||
|
async def mock_receive_json():
|
||||||
|
nonlocal recv_count
|
||||||
|
recv_count += 1
|
||||||
|
if recv_count == 1:
|
||||||
|
# Send a heartbeat
|
||||||
|
return {"msg_type": "heartbeat", "target": "echo"}
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
raise asyncio.CancelledError("Terminate receiver loop")
|
||||||
|
|
||||||
|
mock_ws.receive_json.side_effect = mock_receive_json
|
||||||
|
|
||||||
|
mock_manager.get_redis = AsyncMock(return_value=mock_redis)
|
||||||
|
mock_manager.update_presence = AsyncMock()
|
||||||
|
mock_manager.publish_message = AsyncMock()
|
||||||
|
mock_manager.get_channel_names.return_value = ["ws:group:test"]
|
||||||
|
|
||||||
|
async def run_endpoint():
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
v3_ws_endpoint(mock_ws, "test_group", "client_hb", account=MOCK_ACCOUNT),
|
||||||
|
timeout=0.5
|
||||||
|
)
|
||||||
|
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
if "Terminate" not in str(e):
|
||||||
|
raise
|
||||||
|
|
||||||
|
asyncio.run(run_endpoint())
|
||||||
|
|
||||||
|
# Presence should be refreshed: once on connect, once on heartbeat receipt
|
||||||
|
presence_calls = mock_manager.update_presence.call_args_list
|
||||||
|
online_calls = [c for c in presence_calls if c.kwargs.get('online') is True or (len(c.args) >= 3 and c.args[2] is True)]
|
||||||
|
self.assertGreaterEqual(len(online_calls), 2, "Expected at least 2 presence refresh calls (connect + heartbeat)")
|
||||||
|
print("✅ Heartbeat presence refresh verified.")
|
||||||
|
|||||||
Reference in New Issue
Block a user