diff --git a/monitor/test/test_websock_server.py b/monitor/test/test_websock_server.py index 08f5108..4980aa7 100644 --- a/monitor/test/test_websock_server.py +++ b/monitor/test/test_websock_server.py @@ -2,25 +2,21 @@ import unittest import asyncio -from unittest.mock import * +from unittest.mock import ANY, AsyncMock, MagicMock, patch from monitor.websocket_connections import WebsocketServer class TestWebsocketServer(unittest.TestCase): - @patch('asyncio.set_event_loop') @patch('monitor.websocket_connections.serve') - def test_init(self, mock_websock_serve, mock_set_loop): - loop = MagicMock() + def test_init(self, mock_websock_serve): + loop = asyncio.new_event_loop() logger = MagicMock() - mock_websock_serve.return_value = "0101" - WebsocketServer.connection_handler = MagicMock() server = WebsocketServer(("ip_address", 4512), None, loop, logger) self.assertEqual(server._loop, loop) self.assertIsNone(server._connections) - mock_set_loop.assert_called_once_with(loop) - mock_websock_serve.assert_called_once_with(server.connection_handler, "ip_address", 4512) - loop.run_until_complete.assert_called_once_with("0101") + mock_websock_serve.assert_not_called() + loop.close() @patch('monitor.websocket_connections.serve') def test_connection_handler(self, mock_websock_serve): @@ -56,12 +52,23 @@ def test_connection_handler(self, mock_websock_serve): connection_mock.remove_client.assert_called_once_with("1234", queue) @patch('asyncio.set_event_loop') - @patch('monitor.websocket_connections.serve') + @patch('monitor.websocket_connections.serve', new_callable=AsyncMock) def test_run(self, mock_websock_serve, mock_set_loop): - loop = MagicMock() + loop = asyncio.new_event_loop() logger = MagicMock() - mock_websock_serve.return_value = "0101" + + server_obj = MagicMock() + wait_closed_future = loop.create_future() + wait_closed_future.set_result(None) + server_obj.wait_closed.return_value = wait_closed_future + mock_websock_serve.return_value = server_obj + server = WebsocketServer(("ip_address", 123), None, loop, logger) + + loop.call_later(0.05, loop.stop) server.run() + mock_set_loop.assert_called_with(loop) - loop.run_forever.assert_called_once_with() + mock_websock_serve.assert_awaited_once_with(ANY, "ip_address", 123) + server_obj.close.assert_called_once_with() + loop.close() diff --git a/monitor/websocket_connections.py b/monitor/websocket_connections.py index b519129..47990e5 100644 --- a/monitor/websocket_connections.py +++ b/monitor/websocket_connections.py @@ -147,11 +147,14 @@ def __init__(self, websock_uri, connections, loop, logger): self._connections = connections self._loop = loop self._logger = logger - hostname, port = websock_uri - asyncio.set_event_loop(loop) - start_server = serve(self.connection_handler, hostname, port) - self._server = loop.run_until_complete(start_server) - self._logger.info("websocket server initialized at {}:{}".format(hostname, port)) + self._hostname, self._port = websock_uri + self._server = None + + async def _start_server(self): + self._server = await serve(self.connection_handler, self._hostname, self._port) + self._logger.info( + "websocket server initialized at {}:{}".format(self._hostname, self._port) + ) async def connection_handler(self, websocket): """ @@ -191,4 +194,10 @@ def run(self): nothing. """ asyncio.set_event_loop(self._loop) - self._loop.run_forever() + self._loop.run_until_complete(self._start_server()) + try: + self._loop.run_forever() + finally: + if self._server is not None: + self._server.close() + self._loop.run_until_complete(self._server.wait_closed())