44import asyncio
55import logging
66import struct
7- from asyncio import IncompleteReadError , StreamReader , StreamWriter
7+ from asyncio import StreamReader , StreamWriter
88from typing import Awaitable
99
1010from google .protobuf .message import DecodeError
@@ -71,7 +71,6 @@ def __init__(self, host: str, uuid: str, loop=None):
7171 self ._reference = None
7272
7373 self ._event_bus : EventBus = None
74- self ._read_task : asyncio .Task = None
7574
7675 self .__sensor_callback_fn : callable = None
7776 self .__alarm_callback_fn : callable = None
@@ -89,43 +88,43 @@ def set_alarm_callback(self, callback: callable):
8988 """Set a callback to be called when an alarm is received."""
9089 self .__alarm_callback_fn = callback
9190
92- async def connect (self , uuid : str ):
91+ async def _connect (self , uuid : str ):
9392 """Connect to the bridge."""
94- await self .disconnect ()
95-
9693 _LOGGER .debug ("Connecting to bridge %s" , self .host )
9794 try :
9895 self ._reader , self ._writer = await asyncio .wait_for (asyncio .open_connection (self .host , self .PORT ), TIMEOUT )
9996 except asyncio .TimeoutError as exc :
100- raise AioComfoConnectTimeout () from exc
97+ _LOGGER .warning ("Timeout while connecting to bridge %s" , self .host )
98+ raise AioComfoConnectTimeout from exc
10199
102100 self ._reference = 1
103101 self ._local_uuid = uuid
104102 self ._event_bus = EventBus ()
105103
106- # We are connected, start the background task
107- self ._read_task = self ._loop .create_task (self ._read_messages ())
104+ async def _read_messages ():
105+ while True :
106+ try :
107+ # Keep processing messages until we are disconnected or shutting down
108+ await self ._process_message ()
108109
109- _LOGGER .debug ("Connected to bridge %s" , self .host )
110+ except asyncio .exceptions .CancelledError :
111+ # We are shutting down. Return to stop the background task
112+ return False
110113
111- async def disconnect ( self ) :
112- """Disconnect from the bridge."""
113- _LOGGER . debug ( "Disconnecting from bridge %s" , self . host )
114+ except AioComfoConnectNotConnected :
115+ # We have been disconnected
116+ raise
114117
115- if self ._read_task :
116- # Cancel the background task
117- self ._read_task .cancel ()
118+ read_task = self ._loop .create_task (_read_messages ())
119+ _LOGGER .debug ("Connected to bridge %s" , self .host )
118120
119- # Wait for background task to finish
120- try :
121- await self ._read_task
122- except asyncio .CancelledError :
123- pass
121+ return read_task
124122
123+ async def _disconnect (self ):
124+ """Disconnect from the bridge."""
125125 if self ._writer :
126126 self ._writer .close ()
127-
128- _LOGGER .debug ("Disconnected from bridge %s" , self .host )
127+ await self ._writer .wait_closed ()
129128
130129 def is_connected (self ) -> bool :
131130 """Returns True if the bridge is connected."""
@@ -135,7 +134,7 @@ async def _send(self, request, request_type, params: dict = None, reply: bool =
135134 """Sends a command and wait for a response if the request is known to return a result."""
136135 # Check if we are actually connected
137136 if not self .is_connected ():
138- raise AioComfoConnectNotConnected ()
137+ raise AioComfoConnectNotConnected
139138
140139 # Construct the message
141140 cmd = zehnder_pb2 .GatewayOperation () # pylint: disable=no-member
@@ -160,6 +159,7 @@ async def _send(self, request, request_type, params: dict = None, reply: bool =
160159 # Send the message
161160 _LOGGER .debug ("TX %s" , message )
162161 self ._writer .write (message .encode ())
162+ await self ._writer .drain ()
163163
164164 # Increase message reference for next message
165165 self ._reference += 1
@@ -168,6 +168,7 @@ async def _send(self, request, request_type, params: dict = None, reply: bool =
168168 return await asyncio .wait_for (fut , TIMEOUT )
169169 except asyncio .TimeoutError as exc :
170170 _LOGGER .warning ("Timeout while waiting for response from bridge" )
171+ await self ._disconnect ()
171172 raise AioComfoConnectTimeout from exc
172173
173174 async def _read (self ) -> Message :
@@ -206,55 +207,51 @@ async def _read(self) -> Message:
206207
207208 return message
208209
209- async def _read_messages (self ):
210- """Receive a message from the bridge."""
211- while self ._read_task .cancelled () is False :
212- try :
213- message = await self ._read ()
214-
215- # pylint: disable=no-member
216- if message .cmd .type == zehnder_pb2 .GatewayOperation .CnRpdoNotificationType :
217- if self .__sensor_callback_fn :
218- self .__sensor_callback_fn (message .msg .pdid , int .from_bytes (message .msg .data , byteorder = "little" , signed = True ))
219- else :
220- _LOGGER .info ("Unhandled CnRpdoNotificationType since no callback is registered." )
210+ async def _process_message (self ):
211+ """Process a message from the bridge."""
212+ try :
213+ message = await self ._read ()
221214
222- elif message .cmd .type == zehnder_pb2 .GatewayOperation .GatewayNotificationType :
223- _LOGGER .debug ("Unhandled GatewayNotificationType" )
215+ # pylint: disable=no-member
216+ if message .cmd .type == zehnder_pb2 .GatewayOperation .CnRpdoNotificationType :
217+ if self .__sensor_callback_fn :
218+ self .__sensor_callback_fn (message .msg .pdid , int .from_bytes (message .msg .data , byteorder = "little" , signed = True ))
219+ else :
220+ _LOGGER .info ("Unhandled CnRpdoNotificationType since no callback is registered." )
224221
225- elif message .cmd .type == zehnder_pb2 .GatewayOperation .CnNodeNotificationType :
226- _LOGGER .debug ("Unhandled CnNodeNotificationType " )
222+ elif message .cmd .type == zehnder_pb2 .GatewayOperation .GatewayNotificationType :
223+ _LOGGER .debug ("Unhandled GatewayNotificationType " )
227224
228- elif message .cmd .type == zehnder_pb2 .GatewayOperation .CnAlarmNotificationType :
229- if self .__alarm_callback_fn :
230- self .__alarm_callback_fn (message .msg .nodeId , message .msg )
231- else :
232- _LOGGER .info ("Unhandled CnAlarmNotificationType since no callback is registered." )
225+ elif message .cmd .type == zehnder_pb2 .GatewayOperation .CnNodeNotificationType :
226+ _LOGGER .debug ("Unhandled CnNodeNotificationType" )
233227
234- elif message .cmd .type == zehnder_pb2 .GatewayOperation .CloseSessionRequestType :
235- _LOGGER .info ("The Bridge has asked us to close the connection." )
236- return # Stop the background task
228+ elif message .cmd .type == zehnder_pb2 .GatewayOperation .CnAlarmNotificationType :
229+ if self .__alarm_callback_fn :
230+ self .__alarm_callback_fn (message .msg .nodeId , message .msg )
231+ else :
232+ _LOGGER .info ("Unhandled CnAlarmNotificationType since no callback is registered." )
237233
238- elif message .cmd .reference :
239- # Emit to the event bus
240- self ._event_bus .emit (message .cmd .reference , message .msg )
234+ elif message .cmd .type == zehnder_pb2 .GatewayOperation .CloseSessionRequestType :
235+ _LOGGER .info ("The Bridge has asked us to close the connection." )
241236
242- else :
243- _LOGGER .warning ("Unhandled message type %s: %s" , message .cmd .type , message )
237+ elif message .cmd .reference :
238+ # Emit to the event bus
239+ self ._event_bus .emit (message .cmd .reference , message .msg )
244240
245- except asyncio . exceptions . CancelledError :
246- return # Stop the background task
241+ else :
242+ _LOGGER . warning ( "Unhandled message type %s: %s" , message . cmd . type , message )
247243
248- except IncompleteReadError :
249- _LOGGER .info ("The connection was closed." )
250- return # Stop the background task
244+ except asyncio .exceptions .IncompleteReadError :
245+ _LOGGER .info ("The connection was closed." )
246+ await self ._disconnect ()
247+ raise AioComfoConnectNotConnected
251248
252- except ComfoConnectError as exc :
253- if exc .message .cmd .reference :
254- self ._event_bus .emit (exc .message .cmd .reference , exc )
249+ except ComfoConnectError as exc :
250+ if exc .message .cmd .reference :
251+ self ._event_bus .emit (exc .message .cmd .reference , exc )
255252
256- except DecodeError as exc :
257- _LOGGER .error ("Failed to decode message: %s" , exc )
253+ except DecodeError as exc :
254+ _LOGGER .error ("Failed to decode message: %s" , exc )
258255
259256 def cmd_start_session (self , take_over : bool = False ) -> Awaitable [Message ]:
260257 """Starts the session on the device by logging in and optionally disconnecting an already existing session."""
0 commit comments