Refactor to event system

This commit is contained in:
Alex Wolden
2025-04-08 22:56:16 -07:00
parent 8f0ecd7d75
commit a5f1ec5c26
7 changed files with 66 additions and 271 deletions

View File

@@ -62,6 +62,9 @@ class BLEConnection:
await self.client.start_notify(UART_TX_CHAR_UUID, self.handle_rx) await self.client.start_notify(UART_TX_CHAR_UUID, self.handle_rx)
nus = self.client.services.get_service(UART_SERVICE_UUID) nus = self.client.services.get_service(UART_SERVICE_UUID)
if nus is None:
logger.error("Could not find UART service")
return None
self.rx_char = nus.get_characteristic(UART_RX_CHAR_UUID) self.rx_char = nus.get_characteristic(UART_RX_CHAR_UUID)
logger.info("BLE Connection started") logger.info("BLE Connection started")
@@ -82,4 +85,10 @@ class BLEConnection:
asyncio.create_task(self.reader.handle_rx(data)) asyncio.create_task(self.reader.handle_rx(data))
async def send(self, data): async def send(self, data):
if not self.client:
logger.error("Client is not connected")
return False
if not self.rx_char:
logger.error("RX characteristic not found")
return False
await self.client.write_gatt_char(self.rx_char, bytes(data), response=False) await self.client.write_gatt_char(self.rx_char, bytes(data), response=False)

View File

@@ -26,10 +26,13 @@ def deprecated(func):
class CommandHandler: class CommandHandler:
def __init__(self): DEFAULT_TIMEOUT = 5.0
def __init__(self, default_timeout=None):
self._sender_func = None self._sender_func = None
self._reader = None self._reader = None
self.dispatcher = None self.dispatcher = None
self.default_timeout = default_timeout if default_timeout is not None else self.DEFAULT_TIMEOUT
def set_connection(self, connection): def set_connection(self, connection):
async def sender(data): async def sender(data):
@@ -42,10 +45,13 @@ class CommandHandler:
def set_dispatcher(self, dispatcher): def set_dispatcher(self, dispatcher):
self.dispatcher = dispatcher self.dispatcher = dispatcher
async def send(self, data, expected_events=None, timeout=5.0): async def send(self, data, expected_events=None, timeout=None):
if not self.dispatcher: if not self.dispatcher:
raise RuntimeError("Dispatcher not set, cannot send commands") raise RuntimeError("Dispatcher not set, cannot send commands")
# Use the provided timeout or fall back to default_timeout
timeout = timeout if timeout is not None else self.default_timeout
if self._sender_func: if self._sender_func:
logger.debug(f"Sending raw data: {data.hex() if isinstance(data, bytes) else data}") logger.debug(f"Sending raw data: {data.hex() if isinstance(data, bytes) else data}")
await self._sender_func(data) await self._sender_func(data)
@@ -163,15 +169,20 @@ class CommandHandler:
data = b"\x0f" + key data = b"\x0f" + key
return await self.send(data, [EventType.OK, EventType.ERROR]) return await self.send(data, [EventType.OK, EventType.ERROR])
async def get_msg(self): async def get_msg(self, timeout=1):
logger.debug("Requesting pending messages") logger.debug("Requesting pending messages")
return await self.send(b"\x0A", [EventType.CONTACT_MSG_RECV, EventType.CHANNEL_MSG_RECV, EventType.ERROR], 1) return await self.send(b"\x0A", [EventType.CONTACT_MSG_RECV, EventType.CHANNEL_MSG_RECV, EventType.ERROR], timeout)
async def send_login(self, dst, pwd): async def send_login(self, dst, pwd):
logger.debug(f"Sending login request to: {dst.hex() if isinstance(dst, bytes) else dst}") logger.debug(f"Sending login request to: {dst.hex() if isinstance(dst, bytes) else dst}")
data = b"\x1a" + dst + pwd.encode("ascii") data = b"\x1a" + dst + pwd.encode("ascii")
return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) return await self.send(data, [EventType.MSG_SENT, EventType.ERROR])
async def send_logout(self, dst):
self.login_resp = asyncio.Future()
data = b"\x1d" + dst
return await self.send(data, [EventType.MSG_SENT, EventType.ERROR])
async def send_statusreq(self, dst): async def send_statusreq(self, dst):
logger.debug(f"Sending status request to: {dst.hex() if isinstance(dst, bytes) else dst}") logger.debug(f"Sending status request to: {dst.hex() if isinstance(dst, bytes) else dst}")
data = b"\x1b" + dst data = b"\x1b" + dst

View File

@@ -40,7 +40,7 @@ class EventType(Enum):
class Event: class Event:
type: EventType type: EventType
payload: Any payload: Any
attributes: Dict[str, Any] = None attributes: Dict[str, Any] = {}
def __post_init__(self): def __post_init__(self):
if self.attributes is None: if self.attributes is None:
@@ -64,7 +64,7 @@ class EventDispatcher:
self.running = False self.running = False
self._task = None self._task = None
def subscribe(self, event_type: Union[EventType, None], callback: Callable[[Event], None]) -> Subscription: def subscribe(self, event_type: Union[EventType, None], callback: Callable[[Event], Union[None, asyncio.Future]]) -> Subscription:
subscription = Subscription(self, event_type, callback) subscription = Subscription(self, event_type, callback)
self.subscriptions.append(subscription) self.subscriptions.append(subscription)
return subscription return subscription
@@ -83,7 +83,9 @@ class EventDispatcher:
for subscription in self.subscriptions.copy(): for subscription in self.subscriptions.copy():
if subscription.event_type is None or subscription.event_type == event.type: if subscription.event_type is None or subscription.event_type == event.type:
try: try:
await subscription.callback(event) result = subscription.callback(event)
if asyncio.iscoroutine(result):
await result
except Exception as e: except Exception as e:
print(f"Error in event handler: {e}") print(f"Error in event handler: {e}")
@@ -106,10 +108,10 @@ class EventDispatcher:
pass pass
self._task = None self._task = None
async def wait_for_event(self, event_type: EventType, timeout: float = None) -> Optional[Event]: async def wait_for_event(self, event_type: EventType, timeout: float | None = None) -> Optional[Event]:
future = asyncio.Future() future = asyncio.Future()
async def event_handler(event: Event): def event_handler(event: Event):
if not future.done(): if not future.done():
future.set_result(event) future.set_result(event)

View File

@@ -28,11 +28,11 @@ class MeshCore:
""" """
Interface to a MeshCore device Interface to a MeshCore device
""" """
def __init__(self, cx, debug=False): def __init__(self, cx, debug=False, default_timeout=None):
self.cx = cx self.cx = cx
self.dispatcher = EventDispatcher() self.dispatcher = EventDispatcher()
self._reader = MessageReader(self.dispatcher) self._reader = MessageReader(self.dispatcher)
self.commands = CommandHandler() self.commands = CommandHandler(default_timeout=default_timeout)
# Set up logger # Set up logger
if debug: if debug:
@@ -58,19 +58,19 @@ class MeshCore:
cx.set_reader(self._reader) cx.set_reader(self._reader)
@classmethod @classmethod
async def create_tcp(cls, host: str, port: int, debug: bool = False) -> 'MeshCore': async def create_tcp(cls, host: str, port: int, debug: bool = False, default_timeout=None) -> 'MeshCore':
"""Create and connect a MeshCore instance using TCP connection""" """Create and connect a MeshCore instance using TCP connection"""
from .tcp_cx import TCPConnection from .tcp_cx import TCPConnection
connection = TCPConnection(host, port) connection = TCPConnection(host, port)
await connection.connect() await connection.connect()
mc = cls(connection, debug=debug) mc = cls(connection, debug=debug, default_timeout=default_timeout)
await mc.connect() await mc.connect()
return mc return mc
@classmethod @classmethod
async def create_serial(cls, port: str, baudrate: int = 115200, debug: bool = False) -> 'MeshCore': async def create_serial(cls, port: str, baudrate: int = 115200, debug: bool = False, default_timeout=None) -> 'MeshCore':
"""Create and connect a MeshCore instance using serial connection""" """Create and connect a MeshCore instance using serial connection"""
from .serial_cx import SerialConnection from .serial_cx import SerialConnection
import asyncio import asyncio
@@ -79,12 +79,12 @@ class MeshCore:
await connection.connect() await connection.connect()
await asyncio.sleep(0.1) # Time for transport to establish await asyncio.sleep(0.1) # Time for transport to establish
mc = cls(connection, debug=debug) mc = cls(connection, debug=debug, default_timeout=default_timeout)
await mc.connect() await mc.connect()
return mc return mc
@classmethod @classmethod
async def create_ble(cls, address: Optional[str] = None, debug: bool = False) -> 'MeshCore': async def create_ble(cls, address: Optional[str] = None, debug: bool = False, default_timeout=None) -> 'MeshCore':
"""Create and connect a MeshCore instance using BLE connection """Create and connect a MeshCore instance using BLE connection
If address is None, it will scan for and connect to the first available MeshCore device. If address is None, it will scan for and connect to the first available MeshCore device.
@@ -96,7 +96,7 @@ class MeshCore:
if result is None: if result is None:
raise ConnectionError("Failed to connect to BLE device") raise ConnectionError("Failed to connect to BLE device")
mc = cls(connection, debug=debug) mc = cls(connection, debug=debug, default_timeout=default_timeout)
await mc.connect() await mc.connect()
return mc return mc
@@ -142,11 +142,15 @@ class MeshCore:
Args: Args:
event_type: Type of event to wait for, from EventType enum event_type: Type of event to wait for, from EventType enum
timeout: Maximum time to wait in seconds, or None for no timeout timeout: Maximum time to wait in seconds, or None to use default_timeout
Returns: Returns:
Event object or None if timeout Event object or None if timeout
""" """
# Use the provided timeout or fall back to default_timeout
if timeout is None:
timeout = self.default_timeout
return await self.dispatcher.wait_for_event(event_type, timeout) return await self.dispatcher.wait_for_event(event_type, timeout)
def _setup_data_tracking(self): def _setup_data_tracking(self):
@@ -181,6 +185,16 @@ class MeshCore:
"""Get the current device time""" """Get the current device time"""
return self._time return self._time
@property
def default_timeout(self):
"""Get the default timeout for commands"""
return self.commands.default_timeout
@default_timeout.setter
def default_timeout(self, value):
"""Set the default timeout for commands"""
self.commands.default_timeout = value
def get_contact_by_name(self, name): def get_contact_by_name(self, name):
""" """
Find a contact by its name (adv_name field) Find a contact by its name (adv_name field)
@@ -275,7 +289,7 @@ class MeshCore:
if hasattr(self, '_auto_fetch_task') and self._auto_fetch_task and not self._auto_fetch_task.done(): if hasattr(self, '_auto_fetch_task') and self._auto_fetch_task and not self._auto_fetch_task.done():
self._auto_fetch_task.cancel() self._auto_fetch_task.cancel()
try: try:
await self._auto_fetch_task await self._auto_fetch_task # type: ignore
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
self._auto_fetch_task = None self._auto_fetch_task = None

View File

@@ -1,249 +0,0 @@
import asyncio
from typing import Dict, Any, Optional, Callable
from .events import EventDispatcher, MessageType, Event
from .reader import MessageReader
from .commands import CommandHandler, deprecated
class MeshCore:
def __init__(self, cx):
self.cx = cx
self.dispatcher = EventDispatcher()
self._reader = MessageReader(self.dispatcher)
self.commands = CommandHandler()
# Set up connections
self.commands.set_connection(cx)
# Initialize state
self.contacts = {}
self.self_info = {}
self.time = 0
# Set the message handler in the connection
cx.set_mc(self)
async def connect(self):
# Start the event dispatcher
await self.dispatcher.start()
# Start the command handler
await self.commands.start()
# Send the initial app start
return await self.commands.send_appstart()
async def disconnect(self):
# Stop the event dispatcher
await self.dispatcher.stop()
# Stop the command handler
await self.commands.stop()
# Internal method - called by the connection
def handle_rx(self, data: bytearray):
asyncio.create_task(self._reader.handle_rx(data))
# Expose subscribe/wait capabilities from the event system
def subscribe(self, message_type, callback):
return self.dispatcher.subscribe(message_type, callback)
async def wait_for_event(self, message_type, timeout=None):
return await self.dispatcher.wait_for_event(message_type, timeout)
# Legacy method implementations that delegate to the command handler
# using the deprecated decorator from commands.py
@deprecated
async def send(self, data, timeout=5):
return await self.commands.send(data, timeout)
@deprecated
async def send_only(self, data):
await self.commands.send_only(data)
@deprecated
async def send_appstart(self):
return await self.commands.send_appstart()
@deprecated
async def send_device_query(self):
return await self.commands.send_device_query()
@deprecated
async def send_advert(self, flood=False):
return await self.commands.send_advert(flood)
@deprecated
async def set_name(self, name):
return await self.commands.set_name(name)
@deprecated
async def set_coords(self, lat, lon):
return await self.commands.set_coords(lat, lon)
@deprecated
async def reboot(self):
return await self.commands.reboot()
@deprecated
async def get_bat(self):
return await self.commands.get_bat()
@deprecated
async def get_time(self):
time_result = await self.commands.get_time()
if isinstance(time_result, int):
self.time = time_result
return self.time
@deprecated
async def set_time(self, val):
return await self.commands.set_time(val)
@deprecated
async def set_tx_power(self, val):
return await self.commands.set_tx_power(val)
@deprecated
async def set_radio(self, freq, bw, sf, cr):
return await self.commands.set_radio(freq, bw, sf, cr)
@deprecated
async def set_tuning(self, rx_dly, af):
return await self.commands.set_tuning(rx_dly, af)
@deprecated
async def set_devicepin(self, pin):
return await self.commands.set_devicepin(pin)
@deprecated
async def get_contacts(self):
await self.commands.get_contacts()
contact_end = await self.dispatcher.wait_for_event(MessageType.CONTACT_END)
if contact_end:
self.contacts = contact_end.payload
return self.contacts
@deprecated
async def ensure_contacts(self):
if not self.contacts:
await self.get_contacts()
@deprecated
async def reset_path(self, key):
return await self.commands.reset_path(key)
@deprecated
async def share_contact(self, key):
return await self.commands.share_contact(key)
@deprecated
async def export_contact(self, key=b""):
return await self.commands.export_contact(key)
@deprecated
async def remove_contact(self, key):
return await self.commands.remove_contact(key)
@deprecated
async def set_out_path(self, contact, path):
contact["out_path"] = path
contact["out_path_len"] = -1
contact["out_path_len"] = int(len(path) / 2)
@deprecated
async def update_contact(self, contact):
out_path_hex = contact["out_path"]
out_path_hex = out_path_hex + (128-len(out_path_hex)) * "0"
adv_name_hex = contact["adv_name"].encode().hex()
adv_name_hex = adv_name_hex + (64-len(adv_name_hex)) * "0"
data = b"\x09" \
+ bytes.fromhex(contact["public_key"])\
+ contact["type"].to_bytes(1)\
+ contact["flags"].to_bytes(1)\
+ contact["out_path_len"].to_bytes(1, 'little', signed=True)\
+ bytes.fromhex(out_path_hex)\
+ bytes.fromhex(adv_name_hex)\
+ contact["last_advert"].to_bytes(4, 'little')\
+ int(contact["adv_lat"]*1e6).to_bytes(4, 'little', signed=True)\
+ int(contact["adv_lon"]*1e6).to_bytes(4, 'little', signed=True)
return await self.send(data)
@deprecated
async def send_login(self, dst, pwd):
await self.commands.send_login(dst, pwd)
login_event = await self.dispatcher.wait_for_event(MessageType.LOGIN_SUCCESS, 0.1)
if login_event:
return True
return await self.commands.send_login(dst, pwd)
@deprecated
async def wait_login(self, timeout=5):
login_event = await self.dispatcher.wait_for_event(MessageType.LOGIN_SUCCESS, timeout)
if login_event:
return True
login_failed = await self.dispatcher.wait_for_event(MessageType.LOGIN_FAILED, 0)
if login_failed:
return False
return False
@deprecated
async def send_statusreq(self, dst):
await self.commands.send_statusreq(dst)
@deprecated
async def wait_status(self, timeout=5):
status_event = await self.dispatcher.wait_for_event(MessageType.STATUS_RESPONSE, timeout)
if status_event:
return status_event.payload
return False
@deprecated
async def send_cmd(self, dst, cmd):
timestamp = await self.get_time()
return await self.commands.send_cmd(dst, cmd, timestamp.to_bytes(4, 'little'))
@deprecated
async def send_msg(self, dst, msg):
timestamp = await self.get_time()
result = await self.commands.send_msg(dst, msg, timestamp.to_bytes(4, 'little'))
return result
@deprecated
async def send_chan_msg(self, chan, msg):
timestamp = await self.get_time()
return await self.commands.send_chan_msg(chan, msg, timestamp.to_bytes(4, 'little'))
@deprecated
async def get_msg(self):
await self.commands.get_msg()
# Wait for any message type that could be received
message_types = [
MessageType.CONTACT_MSG_RECV,
MessageType.CHANNEL_MSG_RECV,
MessageType.NO_MORE_MSGS
]
for msg_type in message_types:
event = await self.dispatcher.wait_for_event(msg_type, 0)
if event:
return event.payload
return False
@deprecated
async def wait_msg(self, timeout=-1):
msg_event = await self.dispatcher.wait_for_event(MessageType.MESSAGES_WAITING, timeout)
return msg_event is not None
@deprecated
async def wait_ack(self, timeout=6):
ack_event = await self.dispatcher.wait_for_event(MessageType.ACK, timeout)
return ack_event is not None
@deprecated
async def send_cli(self, cmd):
return await self.commands.send_cli(cmd)

View File

@@ -15,6 +15,7 @@ class SerialConnection:
self.baudrate = baudrate self.baudrate = baudrate
self.frame_started = False self.frame_started = False
self.frame_size = 0 self.frame_size = 0
self.transport = None
self.header = b"" self.header = b""
self.inframe = b"" self.inframe = b""
@@ -25,6 +26,7 @@ class SerialConnection:
def connection_made(self, transport): def connection_made(self, transport):
self.cx.transport = transport self.cx.transport = transport
logger.debug('port opened') logger.debug('port opened')
if isinstance(transport, serial_asyncio.SerialTransport) and transport.serial:
transport.serial.rts = False # You can manipulate Serial object via transport transport.serial.rts = False # You can manipulate Serial object via transport
def data_received(self, data): def data_received(self, data):
@@ -79,6 +81,9 @@ class SerialConnection:
self.handle_rx(data[self.frame_size-framelen:]) self.handle_rx(data[self.frame_size-framelen:])
async def send(self, data): async def send(self, data):
if not self.transport:
logger.error("Transport not connected, cannot send data")
return
size = len(data) size = len(data)
pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data
logger.debug(f"sending pkt : {pkt}") logger.debug(f"sending pkt : {pkt}")

View File

@@ -18,7 +18,7 @@ class TCPConnection:
self.header = b"" self.header = b""
self.inframe = b"" self.inframe = b""
class MCClientProtocol: class MCClientProtocol(asyncio.Protocol):
def __init__(self, cx): def __init__(self, cx):
self.cx = cx self.cx = cx
@@ -76,6 +76,9 @@ class TCPConnection:
self.handle_rx(data[self.frame_size-framelen:]) self.handle_rx(data[self.frame_size-framelen:])
async def send(self, data): async def send(self, data):
if not self.transport:
logger.error("Transport not connected, cannot send data")
return
size = len(data) size = len(data)
pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data
logger.debug(f"sending pkt : {pkt}") logger.debug(f"sending pkt : {pkt}")