Merge pull request #17 from fdlamotte/awolden/timing-and-event-fixes

Awolden/timing and event fixes
This commit is contained in:
fdlamotte
2025-08-06 08:41:03 +02:00
committed by GitHub
4 changed files with 27 additions and 8 deletions

View File

@@ -49,7 +49,8 @@ class BLEConnection:
if self.client: if self.client:
logger.debug("Using pre-configured BleakClient.") logger.debug("Using pre-configured BleakClient.")
# If a client is already provided, ensure its disconnect callback is set # If a client is already provided, ensure its disconnect callback is set
self.client._disconnected_callback = self.handle_disconnect assert isinstance(self.client, BleakClient)
self.client.set_disconnected_callback(self.handle_disconnect)
self.address = self.client.address self.address = self.client.address
else: else:

View File

@@ -1,4 +1,5 @@
from enum import Enum from enum import Enum
import inspect
import logging import logging
from math import log from math import log
from typing import Any, Dict, Optional, Callable, List, Union from typing import Any, Dict, Optional, Callable, List, Union
@@ -133,6 +134,7 @@ class EventDispatcher:
while self.running: while self.running:
event = await self.queue.get() event = await self.queue.get()
logger.debug(f"Dispatching event: {event.type}, {event.payload}, {event.attributes}") logger.debug(f"Dispatching event: {event.type}, {event.payload}, {event.attributes}")
for subscription in self.subscriptions.copy(): for subscription in self.subscriptions.copy():
# Check if event type matches # Check if event type matches
if subscription.event_type is None or subscription.event_type == event.type: if subscription.event_type is None or subscription.event_type == event.type:
@@ -142,15 +144,24 @@ class EventDispatcher:
if not all(event.attributes.get(key) == value if not all(event.attributes.get(key) == value
for key, value in subscription.attribute_filters.items()): for key, value in subscription.attribute_filters.items()):
continue continue
try:
result = subscription.callback(event.clone()) # Fire the call back asychronously
if asyncio.iscoroutine(result): asyncio.create_task(self._execute_callback(subscription.callback, event.clone()))
await result
except Exception as e:
print(f"Error in event handler: {e}")
self.queue.task_done() self.queue.task_done()
async def _execute_callback(self, callback, event):
"""Execute a callback with proper error handling"""
try:
if asyncio.iscoroutinefunction(callback):
await callback(event)
else:
result = callback(event)
if inspect.iscoroutine(result):
await result
except Exception as e:
logger.error(f"Error in event handler for {event.type}: {e}", exc_info=True)
async def start(self): async def start(self):
if not self.running: if not self.running:
self.running = True self.running = True

View File

@@ -20,6 +20,7 @@ class SerialConnection:
self.inframe = b"" self.inframe = b""
self._disconnect_callback = None self._disconnect_callback = None
self.cx_dly = cx_dly self.cx_dly = cx_dly
self._connected_event = asyncio.Event()
class MCSerialClientProtocol(asyncio.Protocol): class MCSerialClientProtocol(asyncio.Protocol):
def __init__(self, cx): def __init__(self, cx):
@@ -30,12 +31,14 @@ class SerialConnection:
logger.debug('port opened') logger.debug('port opened')
if isinstance(transport, serial_asyncio.SerialTransport) and transport.serial: if isinstance(transport, serial_asyncio.SerialTransport) and transport.serial:
transport.serial.rts = False # You can manipulate Serial object via transport transport.serial.rts = False # You can manipulate Serial object via transport
self.cx._connected_event.set()
def data_received(self, data): def data_received(self, data):
self.cx.handle_rx(data) self.cx.handle_rx(data)
def connection_lost(self, exc): def connection_lost(self, exc):
logger.debug('Serial port closed') logger.debug('Serial port closed')
self.cx._connected_event.clear()
if self.cx._disconnect_callback: if self.cx._disconnect_callback:
asyncio.create_task(self.cx._disconnect_callback("serial_disconnect")) asyncio.create_task(self.cx._disconnect_callback("serial_disconnect"))
@@ -49,12 +52,14 @@ class SerialConnection:
""" """
Connects to the device Connects to the device
""" """
self._connected_event.clear()
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
await serial_asyncio.create_serial_connection( await serial_asyncio.create_serial_connection(
loop, lambda: self.MCSerialClientProtocol(self), loop, lambda: self.MCSerialClientProtocol(self),
self.port, baudrate=self.baudrate) self.port, baudrate=self.baudrate)
await asyncio.sleep(self.cx_dly) # wait for cx to establish await self._connected_event.wait()
logger.info("Serial Connection started") logger.info("Serial Connection started")
return self.port return self.port
@@ -99,6 +104,7 @@ class SerialConnection:
if self.transport: if self.transport:
self.transport.close() self.transport.close()
self.transport = None self.transport = None
self._connected_event.clear()
logger.debug("Serial Connection closed") logger.debug("Serial Connection closed")
def set_disconnect_callback(self, callback): def set_disconnect_callback(self, callback):

View File

@@ -44,6 +44,7 @@ class TestBLEConnection(unittest.TestCase):
asyncio.run(ble_conn.send(data_to_send)) asyncio.run(ble_conn.send(data_to_send))
# Assert # Assert
assert(isinstance(ble_conn.rx_char, MagicMock))
ble_conn.rx_char.write_gatt_char.assert_called_once_with(ble_conn.rx_char, data_to_send, response=True) ble_conn.rx_char.write_gatt_char.assert_called_once_with(ble_conn.rx_char, data_to_send, response=True)
def _get_mock_bleak_client(self): def _get_mock_bleak_client(self):