diff --git a/src/meshcore/ble_cx.py b/src/meshcore/ble_cx.py index 0ce06d9..0a7380f 100644 --- a/src/meshcore/ble_cx.py +++ b/src/meshcore/ble_cx.py @@ -51,6 +51,14 @@ class BLEConnection: self.pin = pin self.rx_char = None self._disconnect_callback = None + self._background_tasks: set[asyncio.Task] = set() + + def _spawn_background(self, coro) -> asyncio.Task: + """Create a tracked background task (prevents GC of fire-and-forget tasks).""" + task = asyncio.create_task(coro) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + return task async def connect(self): """ @@ -116,9 +124,12 @@ class BLEConnection: await self.client.pair() logger.info("BLE pairing successful") except Exception as e: - logger.warning(f"BLE pairing failed: {e}") - # Don't fail the connection if pairing fails, as the device - # might already be paired or not require pairing + logger.error(f"BLE pairing failed: {e}") + # A failed pairing leaves the transport in a half-usable + # state — re-raise so the caller gets a clean failure + # instead of a silently degraded connection. + await self.client.disconnect() + raise except BleakDeviceNotFoundError: return None @@ -154,8 +165,19 @@ class BLEConnection: self.client = self._user_provided_client self.device = self._user_provided_device + # Re-register disconnect callback on the reset client so subsequent + # disconnects after a reconnect cycle are still detected. + if self.client is not None and hasattr(self.client, 'set_disconnected_callback'): + try: + self.client.set_disconnected_callback(self.handle_disconnect) + except Exception: + # set_disconnected_callback may not be available on all bleak + # versions; the next connect() call will re-create the client + # with the callback anyway. + pass + if self._disconnect_callback: - asyncio.create_task(self._disconnect_callback("ble_disconnect")) + self._spawn_background(self._disconnect_callback("ble_disconnect")) def set_disconnect_callback(self, callback): """Set callback to handle disconnections.""" @@ -166,16 +188,24 @@ class BLEConnection: def handle_rx(self, _: BleakGATTCharacteristic, data: bytearray): if self.reader is not None: - asyncio.create_task(self.reader.handle_rx(data)) + self._spawn_background(self.reader.handle_rx(data)) async def send(self, data): if not self.client: logger.error("Client is not connected") + if self._disconnect_callback: + await self._disconnect_callback("ble_transport_lost") return False if not self.rx_char: logger.error("RX characteristic not found") return False - await self.client.write_gatt_char(self.rx_char, bytes(data), response=True) + try: + await self.client.write_gatt_char(self.rx_char, bytes(data), response=True) + except Exception as exc: + logger.warning(f"BLE write failed: {exc}") + if self._disconnect_callback: + await self._disconnect_callback(f"ble_write_failed: {exc}") + return False async def disconnect(self): """Disconnect from the BLE device.""" diff --git a/src/meshcore/commands/base.py b/src/meshcore/commands/base.py index dcd4673..543460a 100644 --- a/src/meshcore/commands/base.py +++ b/src/meshcore/commands/base.py @@ -58,17 +58,32 @@ def _validate_destination(dst: DestinationType, prefix_length: int = 6) -> bytes class CommandHandlerBase: + """Base class for command handlers. + + .. note:: + The internal ``asyncio.Lock`` is created lazily on first access + so that it binds to the correct running event loop (required for + Python 3.9/3.10 compatibility). + """ + DEFAULT_TIMEOUT = 15.0 def __init__(self, default_timeout: Optional[float] = None): self._sender_func: Optional[Callable[[bytes], Coroutine[Any, Any, None]]] = None self._reader: Optional[MessageReader] = None self.dispatcher: Optional[EventDispatcher] = None - self._mesh_request_lock = asyncio.Lock() + self.__mesh_request_lock: Optional[asyncio.Lock] = None self.default_timeout = ( default_timeout if default_timeout is not None else self.DEFAULT_TIMEOUT ) + @property + def _mesh_request_lock(self) -> asyncio.Lock: + """Lazy-init lock so it binds to the running loop, not import-time.""" + if self.__mesh_request_lock is None: + self.__mesh_request_lock = asyncio.Lock() + return self.__mesh_request_lock + def set_connection(self, connection: Any) -> None: async def sender(data: bytes) -> None: await connection.send(data) @@ -90,6 +105,14 @@ class CommandHandlerBase: expected_events: Optional[Union[EventType, List[EventType]]] = None, timeout: Optional[float] = None, ) -> Event: + """Wait for the first of *expected_events* to arrive. + + Returns the first matched ``Event``. When ``EventType.ERROR`` is + among the expected types, the caller **must** check + ``result.is_error()`` before accessing command-specific payload + keys — an ERROR payload is ``{"reason": "..."}`` and will + ``KeyError`` on any other key. + """ try: # Convert single event to list if needed if not isinstance(expected_events, list): @@ -129,9 +152,6 @@ class CommandHandlerBase: logger.debug(f"Command error: {e}") return Event(EventType.ERROR, {"error": str(e)}) - return Event(EventType.ERROR, {}) - - async def send( self, data: bytes, @@ -151,7 +171,14 @@ class CommandHandlerBase: timeout: Timeout in seconds, or None to use default_timeout Returns: - Event: The full event object that was received in response to the command + Event: The full event object that was received in response to + the command. + + Important: + When ``EventType.ERROR`` is included in *expected_events*, the + returned event may be an error response. Callers **must** + check ``result.is_error()`` before accessing command-specific + payload keys to avoid ``KeyError``. """ if not self.dispatcher: raise RuntimeError("Dispatcher not set, cannot send commands") @@ -170,7 +197,7 @@ class CommandHandlerBase: futures: List[asyncio.Future] = [] subscriptions = [] - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() for event_type in expected_events: future = loop.create_future() @@ -279,6 +306,7 @@ class CommandHandlerBase: contact = self._get_contact_by_prefix(dst_bytes.hex()) # need a contact for return path if contact is None: logger.error("No contact found") + return Event(EventType.ERROR, {"reason": "contact_not_found"}) zero_hop = False if contact["out_path_len"] == -1: diff --git a/src/meshcore/commands/contact.py b/src/meshcore/commands/contact.py index 436a066..fae3b26 100644 --- a/src/meshcore/commands/contact.py +++ b/src/meshcore/commands/contact.py @@ -191,6 +191,24 @@ class ContactCommands(CommandHandlerBase): data = b"\x3B" return await self.send(data, [EventType.AUTOADD_CONFIG, EventType.ERROR]) + async def get_contact_by_key(self, pubkey: bytes) -> Event: + """N09: Retrieve a single contact by its public key (CMD 30). + + Args: + pubkey: 32-byte public key of the contact. + + Returns: + Event with the contact data (same format as CONTACT/NEXT_CONTACT), + or ERROR if not found. + """ + if not isinstance(pubkey, (bytes, bytearray)): + raise TypeError("pubkey must be bytes-like") + # Truncate or pad to 32 bytes + key_bytes = bytes(pubkey[:32]) + logger.debug(f"Getting contact by key: {key_bytes.hex()}") + data = b"\x1e" + key_bytes + return await self.send(data, [EventType.NEXT_CONTACT, EventType.ERROR]) + async def get_advert_path(self, key: DestinationType) -> Event: key_bytes = _validate_destination(key, prefix_length=32) logger.debug(f"getting advert path for: {key} {key_bytes.hex()}") diff --git a/src/meshcore/commands/device.py b/src/meshcore/commands/device.py index af986db..6175dc8 100644 --- a/src/meshcore/commands/device.py +++ b/src/meshcore/commands/device.py @@ -4,6 +4,7 @@ from hashlib import sha256 from typing import Optional from ..events import Event, EventType +from ..packets import CommandType from .base import CommandHandlerBase, DestinationType, _validate_destination logger = logging.getLogger("meshcore") @@ -13,7 +14,7 @@ class DeviceCommands(CommandHandlerBase): async def send_appstart(self) -> Event: logger.debug("Sending appstart command") b1 = bytearray(b"\x01\x03 mccli") - return await self.send(b1, [EventType.SELF_INFO]) + return await self.send(b1, [EventType.SELF_INFO, EventType.ERROR]) async def send_device_query(self) -> Event: logger.debug("Sending device query command") @@ -129,32 +130,50 @@ class DeviceCommands(CommandHandlerBase): return await self.send(data, [EventType.OK, EventType.ERROR]) async def set_telemetry_mode_base(self, telemetry_mode_base: int) -> Event: - infos = (await self.send_appstart()).payload + result = await self.send_appstart() + if result.is_error(): + return result + infos = result.payload infos["telemetry_mode_base"] = telemetry_mode_base return await self.set_other_params_from_infos(infos) async def set_telemetry_mode_loc(self, telemetry_mode_loc: int) -> Event: - infos = (await self.send_appstart()).payload + result = await self.send_appstart() + if result.is_error(): + return result + infos = result.payload infos["telemetry_mode_loc"] = telemetry_mode_loc return await self.set_other_params_from_infos(infos) async def set_telemetry_mode_env(self, telemetry_mode_env: int) -> Event: - infos = (await self.send_appstart()).payload + result = await self.send_appstart() + if result.is_error(): + return result + infos = result.payload infos["telemetry_mode_env"] = telemetry_mode_env return await self.set_other_params_from_infos(infos) async def set_manual_add_contacts(self, manual_add_contacts: bool) -> Event: - infos = (await self.send_appstart()).payload + result = await self.send_appstart() + if result.is_error(): + return result + infos = result.payload infos["manual_add_contacts"] = manual_add_contacts return await self.set_other_params_from_infos(infos) async def set_advert_loc_policy(self, advert_loc_policy: int) -> Event: - infos = (await self.send_appstart()).payload + result = await self.send_appstart() + if result.is_error(): + return result + infos = result.payload infos["adv_loc_policy"] = advert_loc_policy return await self.set_other_params_from_infos(infos) async def set_multi_acks(self, multi_acks: int) -> Event: - infos = (await self.send_appstart()).payload + result = await self.send_appstart() + if result.is_error(): + return result + infos = result.payload infos["multi_acks"] = multi_acks return await self.set_other_params_from_infos(infos) @@ -273,20 +292,89 @@ class DeviceCommands(CommandHandlerBase): return await self.sign_finish(timeout=timeout, data_size=len(data)) + async def has_connection(self) -> Event: + """N09: Check if the device has an active connection (CMD 28). + + Returns: + Event with a 1-byte response indicating connection status, + or ERROR. + """ + logger.debug("Checking device connection status") + return await self.send(b"\x1c", [EventType.OK, EventType.ERROR]) + + async def get_tuning(self) -> Event: + """N03/N09: Request current tuning parameters (CMD_GET_TUNING_PARAMS = 43). + + Firmware responds with RESP_CODE_TUNING_PARAMS (23): 9 bytes containing + rx_delay (4 bytes LE) and airtime_factor (4 bytes LE). + + Returns: + Event of type TUNING_PARAMS with rx_delay and airtime_factor, + or ERROR. + """ + logger.debug("Getting tuning parameters") + return await self.send(b"\x2b", [EventType.TUNING_PARAMS, EventType.ERROR]) + + async def request_factory_reset(self) -> str: + """N09: Request a factory reset token (step 1 of 2). + + This method returns a confirmation token string. Pass it to + ``confirm_factory_reset(token)`` to actually execute the reset. + The two-step pattern is a Python-side safety measure; the firmware + itself has no token verification. + + Returns: + A confirmation token string to pass to confirm_factory_reset(). + """ + import secrets + token = secrets.token_hex(8) + logger.warning( + "Factory reset requested. Call confirm_factory_reset('%s') to proceed. " + "This will ERASE ALL DATA on the device.", token + ) + # Store the token on the instance for validation + self._factory_reset_token = token + return token + + async def confirm_factory_reset(self, token: str) -> Event: + """N09: Execute factory reset after token confirmation (step 2 of 2). + + Args: + token: The token returned by request_factory_reset(). + + Returns: + Event with OK or ERROR. + + Raises: + ValueError: If the token does not match. + """ + expected = getattr(self, "_factory_reset_token", None) + if expected is None or token != expected: + raise ValueError( + "Invalid or expired factory reset token. " + "Call request_factory_reset() first." + ) + self._factory_reset_token = None # Consume the token + logger.warning("Executing factory reset — all device data will be erased") + return await self.send(b"\x33", [EventType.OK, EventType.ERROR]) + async def get_stats_core(self) -> Event: logger.debug("Getting core statistics") - # CMD_GET_STATS (56) + STATS_TYPE_CORE (0) - return await self.send(b"\x38\x00", [EventType.STATS_CORE, EventType.ERROR]) + # R04: Use CommandType enum instead of literal bytes + cmd = bytes([CommandType.GET_STATS.value, 0x00]) # GET_STATS + STATS_TYPE_CORE + return await self.send(cmd, [EventType.STATS_CORE, EventType.ERROR]) async def get_stats_radio(self) -> Event: logger.debug("Getting radio statistics") - # CMD_GET_STATS (56) + STATS_TYPE_RADIO (1) - return await self.send(b"\x38\x01", [EventType.STATS_RADIO, EventType.ERROR]) + # R04: Use CommandType enum instead of literal bytes + cmd = bytes([CommandType.GET_STATS.value, 0x01]) # GET_STATS + STATS_TYPE_RADIO + return await self.send(cmd, [EventType.STATS_RADIO, EventType.ERROR]) async def get_stats_packets(self) -> Event: logger.debug("Getting packet statistics") - # CMD_GET_STATS (56) + STATS_TYPE_PACKETS (2) - return await self.send(b"\x38\x02", [EventType.STATS_PACKETS, EventType.ERROR]) + # R04: Use CommandType enum instead of literal bytes + cmd = bytes([CommandType.GET_STATS.value, 0x02]) # GET_STATS + STATS_TYPE_PACKETS + return await self.send(cmd, [EventType.STATS_PACKETS, EventType.ERROR]) async def get_allowed_repeat_freq(self) -> Event: logger.debug("Getting allowed repeat freqs") diff --git a/src/meshcore/commands/messaging.py b/src/meshcore/commands/messaging.py index 0c15479..176635c 100644 --- a/src/meshcore/commands/messaging.py +++ b/src/meshcore/commands/messaging.py @@ -144,8 +144,12 @@ class MessagingCommands(CommandHandlerBase): logger.info(f"Retry sending msg: {attempts + 1}") result = await self.send_msg(dst, msg, timestamp, attempt=attempts) - if result.type == EventType.ERROR: - logger.error(f"⚠️ Failed to send message: {result.payload}") + if result.is_error(): + logger.error(f"Failed to send message: {result.payload}") + attempts += 1 + if flood: + flood_attempts += 1 + continue exp_ack = result.payload["expected_ack"].hex() timeout = result.payload["suggested_timeout"] / 1000 * 1.2 if timeout==0 else timeout @@ -255,7 +259,7 @@ class MessagingCommands(CommandHandlerBase): elif path_hash_len == 8 : flags = 3 else : - logger.error(f"Invalid path format: {e}") + logger.error(f"Invalid path format: unknown path_hash_len {path_hash_len}") return Event(EventType.ERROR, {"reason": "invalid_path_format"}) else: flags = 0 @@ -291,12 +295,34 @@ class MessagingCommands(CommandHandlerBase): cmd_data.append(flags) cmd_data.extend(path_bytes) + # N05: Firmware requires strict len > 10 (MyMesh.cpp:1620). + # When path is empty, cmd(1)+tag(4)+auth(4)+flags(1) = 10 bytes exactly, + # which is silently rejected. Pad with one zero byte to reach 11. + if len(cmd_data) <= 10: + cmd_data.append(0x00) + logger.debug( f"Sending trace: tag={tag}, auth={auth_code}, flags={flags}, path={path_bytes.hex()}" ) return await self.send(cmd_data, [EventType.MSG_SENT, EventType.ERROR]) + async def send_raw_data(self, payload: bytes) -> Event: + """N09: Send raw data via CMD_SEND_RAW_DATA (25). + + Sends an arbitrary payload through the mesh network. + + Args: + payload: Raw bytes to send. + + Returns: + Event with MSG_SENT or ERROR. + """ + if not isinstance(payload, (bytes, bytearray)): + raise TypeError("payload must be bytes-like") + data = b"\x19" + bytes(payload) + return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) + async def set_flood_scope(self, scope): if scope is None: logger.debug(f"Resetting scope") diff --git a/src/meshcore/connection_manager.py b/src/meshcore/connection_manager.py index c95ec37..bcd09ff 100644 --- a/src/meshcore/connection_manager.py +++ b/src/meshcore/connection_manager.py @@ -4,14 +4,23 @@ Connection manager that orchestrates reconnection logic for any connection type. import asyncio import logging -from typing import Optional, Any, Callable, Protocol +from typing import Optional, Any, Awaitable, Callable, Protocol from .events import Event, EventType logger = logging.getLogger("meshcore") class ConnectionProtocol(Protocol): - """Protocol defining the interface that connection classes must implement.""" + """Protocol defining the interface that connection classes must implement. + + Return contract for connect(): + - On success: return a truthy value (typically an address string) + that identifies the connection. This value is included in the + CONNECTED event payload as ``connection_info``. + - On failure: return ``None`` (soft failure — triggers a retry in + ``_attempt_reconnect``) **or** raise an exception (hard failure — + also triggers a retry, logged as an error). + """ async def connect(self) -> Optional[Any]: """Connect and return connection info, or None if failed.""" @@ -39,11 +48,13 @@ class ConnectionManager: event_dispatcher=None, auto_reconnect: bool = False, max_reconnect_attempts: int = 3, + reconnect_callback: Optional[Callable[[], Awaitable[None]]] = None, ): self.connection = connection self.event_dispatcher = event_dispatcher self.auto_reconnect = auto_reconnect self.max_reconnect_attempts = max_reconnect_attempts + self._reconnect_callback = reconnect_callback self._reconnect_attempts = 0 self._is_connected = False @@ -109,45 +120,51 @@ class ConnectionManager: ) async def _attempt_reconnect(self): - """Attempt to reconnect with flat delay.""" - logger.debug( - f"Attempting reconnection ({self._reconnect_attempts + 1}/{self.max_reconnect_attempts})" - ) - self._reconnect_attempts += 1 + """Attempt to reconnect using an iterative loop. - # Flat 1 second delay for all attempts - await asyncio.sleep(1) + Runs as a single persistent task for the entire reconnect session. + Previous implementation used tail-recursion via create_task which + orphaned the running task reference — disconnect() could only cancel + the newest pointer, leaving earlier attempts in flight (F03). + """ + while self._reconnect_attempts < self.max_reconnect_attempts: + self._reconnect_attempts += 1 + logger.debug( + f"Attempting reconnection ({self._reconnect_attempts}/{self.max_reconnect_attempts})" + ) + + # Flat 1 second delay for all attempts + await asyncio.sleep(1) + + try: + result = await self.connection.connect() + if result is not None: + self._is_connected = True + self._reconnect_attempts = 0 + + # Invoke reconnect callback (e.g. send_appstart) if provided + if self._reconnect_callback is not None: + try: + await self._reconnect_callback() + except Exception as cb_err: + logger.warning( + f"Reconnect callback failed: {cb_err}" + ) - try: - result = await self.connection.connect() - if result is not None: - self._is_connected = True - self._reconnect_attempts = 0 - await self._emit_event( - EventType.CONNECTED, - {"connection_info": result, "reconnected": True}, - ) - logger.debug("Reconnected successfully") - else: - # Reconnection failed, try again if we haven't exceeded max attempts - if self._reconnect_attempts < self.max_reconnect_attempts: - self._reconnect_task = asyncio.create_task( - self._attempt_reconnect() - ) - else: await self._emit_event( - EventType.DISCONNECTED, - {"reason": "reconnect_failed", "max_attempts_exceeded": True}, + EventType.CONNECTED, + {"connection_info": result, "reconnected": True}, ) - except Exception as e: - logger.debug(f"Reconnection attempt failed: {e}") - if self._reconnect_attempts < self.max_reconnect_attempts: - self._reconnect_task = asyncio.create_task(self._attempt_reconnect()) - else: - await self._emit_event( - EventType.DISCONNECTED, - {"reason": f"reconnect_error: {e}", "max_attempts_exceeded": True}, - ) + logger.debug("Reconnected successfully") + return + except Exception as e: + logger.debug(f"Reconnection attempt failed: {e}") + + # All attempts exhausted + await self._emit_event( + EventType.DISCONNECTED, + {"reason": "reconnect_failed", "max_attempts_exceeded": True}, + ) async def _emit_event(self, event_type: EventType, payload: dict): """Emit connection events if dispatcher is available.""" diff --git a/src/meshcore/events.py b/src/meshcore/events.py index f8a7521..940ee18 100644 --- a/src/meshcore/events.py +++ b/src/meshcore/events.py @@ -49,6 +49,9 @@ class EventType(Enum): PATH_RESPONSE = "path_response" PRIVATE_KEY = "private_key" DISABLED = "disabled" + CONTACT_DELETED = "contact_deleted" + CONTACTS_FULL = "contacts_full" + TUNING_PARAMS = "tuning_params" CONTROL_DATA = "control_data" DISCOVER_RESPONSE = "discover_response" NEIGHBOURS_RESPONSE = "neighbours_response" @@ -104,6 +107,17 @@ class Event: if kwargs: self.attributes.update(kwargs) + def is_error(self) -> bool: + """Return True if this event represents an error response. + + Callers that include ``EventType.ERROR`` in their expected-events + list **must** check ``result.is_error()`` (or ``result.type == + EventType.ERROR``) before accessing keyed payload fields, because + an ERROR payload contains ``{"reason": "..."}`` — not the + command-specific keys the caller expects on the happy path. + """ + return self.type == EventType.ERROR + def clone(self): """ Create a copy of the event. @@ -129,11 +143,28 @@ class Subscription: class EventDispatcher: + """Event dispatch engine. + + .. note:: + ``start()`` must be called before dispatching or processing events. + The internal ``asyncio.Queue`` is created lazily inside ``start()`` + so that it binds to the correct running event loop (required for + Python 3.9/3.10 compatibility). + """ + def __init__(self): - self.queue: asyncio.Queue[Event] = asyncio.Queue() + self.queue: Optional[asyncio.Queue[Event]] = None self.subscriptions: List[Subscription] = [] self.running = False self._task = None + self._background_tasks: set[asyncio.Task] = set() + + def _spawn_background(self, coro) -> asyncio.Task: + """Create a tracked background task (prevents GC of fire-and-forget tasks).""" + task = asyncio.create_task(coro) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + return task def subscribe( self, @@ -166,6 +197,10 @@ class EventDispatcher: self.subscriptions.remove(subscription) async def dispatch(self, event: Event): + if self.queue is None: + raise RuntimeError( + "EventDispatcher.start() must be called before dispatching events" + ) await self.queue.put(event) async def _process_events(self): @@ -197,7 +232,7 @@ class EventDispatcher: # returns - avoids the race where create_task schedules the callback after # the waiter has already timed out with done=set(). if asyncio.iscoroutinefunction(subscription.callback): - asyncio.create_task(self._execute_callback(subscription.callback, event.clone())) + self._spawn_background(self._execute_callback(subscription.callback, event.clone())) else: try: subscription.callback(event.clone()) @@ -220,6 +255,8 @@ class EventDispatcher: async def start(self): if not self.running: + if self.queue is None: + self.queue = asyncio.Queue() self.running = True self._task = asyncio.create_task(self._process_events()) @@ -227,7 +264,12 @@ class EventDispatcher: if self.running: self.running = False if self._task: - await self.queue.join() + if self.queue is not None: + await self.queue.join() + # Wait for any in-flight async callbacks to complete before + # tearing down (F07: task_done fires before callbacks finish). + if self._background_tasks: + await asyncio.gather(*self._background_tasks, return_exceptions=True) self._task.cancel() try: await self._task diff --git a/src/meshcore/meshcore.py b/src/meshcore/meshcore.py index 24a1550..a7e06c1 100644 --- a/src/meshcore/meshcore.py +++ b/src/meshcore/meshcore.py @@ -28,10 +28,17 @@ class MeshCore: auto_reconnect: bool = False, max_reconnect_attempts: int = 3, ): - # Wrap connection with ConnectionManager + # Wrap connection with ConnectionManager. + # The reconnect callback ensures send_appstart() runs after every + # transport-level reconnect, which is required by firmware to + # initialize the session (F02). self.dispatcher = EventDispatcher() self.connection_manager = ConnectionManager( - cx, self.dispatcher, auto_reconnect, max_reconnect_attempts + cx, + self.dispatcher, + auto_reconnect, + max_reconnect_attempts, + reconnect_callback=self._on_reconnect, ) self.cx = self.connection_manager # For backward compatibility @@ -174,6 +181,15 @@ class MeshCore: return None return mc + async def _on_reconnect(self): + """Callback invoked by ConnectionManager after a successful reconnect. + + Firmware requires CMD_APP_START after every transport-level connection + to initialize the session. MeshCore.connect() does this on the initial + connection; this callback ensures it also happens on reconnects (F02). + """ + await self.commands.send_appstart() + async def connect(self): await self.dispatcher.start() result = await self.connection_manager.connect() diff --git a/src/meshcore/meshcore_parser.py b/src/meshcore/meshcore_parser.py index 2139663..046709a 100644 --- a/src/meshcore/meshcore_parser.py +++ b/src/meshcore/meshcore_parser.py @@ -42,8 +42,30 @@ class MeshcorePacketParser: Returns : completed log_data """ + # Minimum viable payload is 2 bytes (1 header + 1 path_byte) for a + # direct route. Anything shorter is provably broken — for example, + # the LOG_DATA branch in reader.py only requires `len(data) > 3`, + # which means a 4-byte LOG_DATA frame produces a 1-byte payload + # here, and `path_byte = pbuf.read(1)[0]` further down would raise + # IndexError on the empty buffer. Populate sentinel values so the + # caller's downstream `log_data['route_type']` etc. lookups don't + # KeyError, then return early. + if len(payload) < 2: + logger.debug(f"parsePacketPayload: payload too short ({len(payload)} bytes < 2), returning sentinel log_data") + log_data["route_type"] = -1 + log_data["route_typename"] = "UNK" + log_data["payload_type"] = -1 + log_data["payload_typename"] = "UNK" + log_data["payload_ver"] = 0 + log_data["path_len"] = 0 + log_data["path_hash_size"] = 1 + log_data["path"] = "" + log_data["pkt_payload"] = b"" + log_data["pkt_hash"] = 0 + return log_data + pbuf = io.BytesIO(payload) - + header = pbuf.read(1)[0] route_type = header & 0x03 payload_type = (header & 0x3c) >> 2 @@ -128,7 +150,7 @@ class MeshcorePacketParser: uncrypted = cipher.decrypt(msg) timestamp = int.from_bytes(uncrypted[0:4], "little", signed=False) attempt = uncrypted[4] & 3 - txt_type = int.from_bytes(uncrypted[4:4], "little", signed=False) >> 2 + txt_type = int.from_bytes(uncrypted[4:5], "little", signed=False) >> 2 message = uncrypted[5:].strip(b"\0") msg_hash = int.from_bytes(SHA256.new(timestamp.to_bytes(4, "little", signed=False) + message).digest()[0:4], "little", signed=False) log_data["message"] = message.decode("utf-8", "ignore") @@ -149,39 +171,42 @@ class MeshcorePacketParser: del self.channels_log[:25] elif not payload is None and payload_type == 0x04: # Advert - pk_buf = io.BytesIO(pkt_payload) - adv_key = pk_buf.read(32).hex() - adv_timestamp = int.from_bytes(pk_buf.read(4), "little", signed=False) - signature = pk_buf.read(64).hex() - flags = pk_buf.read(1)[0] - adv_type = flags & 0x0F - adv_lat = None - adv_lon = None - adv_feat1 = None - adv_feat2 = None - if flags & 0x10 > 0: #has location - adv_lat = int.from_bytes(pk_buf.read(4), "little", signed=True)/1000000.0 - adv_lon = int.from_bytes(pk_buf.read(4), "little", signed=True)/1000000.0 - if flags & 0x20 > 0: #has feature1 - adv_feat1 = pk_buf.read(2).hex() - if flags & 0x40 > 0: #has feature2 - adv_feat2 = pk_buf.read(2).hex() - if flags & 0x80 > 0: #has name - adv_name = pk_buf.read().decode("utf-8", "ignore").strip("\x00") - log_data["adv_name"] = adv_name + try: + pk_buf = io.BytesIO(pkt_payload) + adv_key = pk_buf.read(32).hex() + adv_timestamp = int.from_bytes(pk_buf.read(4), "little", signed=False) + signature = pk_buf.read(64).hex() + flags = pk_buf.read(1)[0] + adv_type = flags & 0x0F + adv_lat = None + adv_lon = None + adv_feat1 = None + adv_feat2 = None + if flags & 0x10 > 0: #has location + adv_lat = int.from_bytes(pk_buf.read(4), "little", signed=True)/1000000.0 + adv_lon = int.from_bytes(pk_buf.read(4), "little", signed=True)/1000000.0 + if flags & 0x20 > 0: #has feature1 + adv_feat1 = pk_buf.read(2).hex() + if flags & 0x40 > 0: #has feature2 + adv_feat2 = pk_buf.read(2).hex() + if flags & 0x80 > 0: #has name + adv_name = pk_buf.read().decode("utf-8", "ignore").strip("\x00") + log_data["adv_name"] = adv_name - log_data["adv_key"] = adv_key - log_data["adv_timestamp"] = adv_timestamp - log_data["signature"] = signature - log_data["adv_flags"] = flags - log_data["adv_type"] = adv_type - if not adv_lat is None : - log_data["adv_lat"] = adv_lat - if not adv_lon is None : - log_data["adv_lon"] = adv_lon - if not adv_feat1 is None: - log_data["adv_feat1"] = adv_feat1 - if not adv_feat2 is None: - log_data["adv_feat2"] = adv_feat2 + log_data["adv_key"] = adv_key + log_data["adv_timestamp"] = adv_timestamp + log_data["signature"] = signature + log_data["adv_flags"] = flags + log_data["adv_type"] = adv_type + if not adv_lat is None : + log_data["adv_lat"] = adv_lat + if not adv_lon is None : + log_data["adv_lon"] = adv_lon + if not adv_feat1 is None: + log_data["adv_feat1"] = adv_feat1 + if not adv_feat2 is None: + log_data["adv_feat2"] = adv_feat2 + except (IndexError, ValueError) as e: + logger.debug(f"parsePacketPayload: malformed ADVERT payload ({type(e).__name__}: {e}), len={len(pkt_payload)}") return log_data diff --git a/src/meshcore/packets.py b/src/meshcore/packets.py index b5cef23..ad27fce 100644 --- a/src/meshcore/packets.py +++ b/src/meshcore/packets.py @@ -71,6 +71,7 @@ class CommandType(Enum): SET_AUTOADD_CONFIG = 58 GET_AUTOADD_CONFIG = 59 GET_ALLOWED_REPEAT_FREQ = 60 + GET_STATS = 56 # R04: CMD_GET_STATS — used by get_stats_core/radio/packets SET_PATH_HASH_MODE = 61 # Packet prefixes for the protocol @@ -120,3 +121,6 @@ class PacketType(Enum): PATH_DISCOVERY_RESPONSE = 0x8D CONTROL_DATA = 0x8E CONTACT_DELETED = 0x8F + CONTACTS_FULL = 0x90 # N02: MyMesh::onContactsFull() — 1-byte push, no payload + # Note: 0x90 == ControlType.NODE_DISCOVER_RESP in a different namespace. + # Not a literal conflict (PacketType vs ControlType), but a maintenance hazard. diff --git a/src/meshcore/reader.py b/src/meshcore/reader.py index 802004a..5ae2bce 100644 --- a/src/meshcore/reader.py +++ b/src/meshcore/reader.py @@ -3,6 +3,7 @@ import json import struct import time import io +import traceback from typing import Any, Dict from .events import Event, EventType, EventDispatcher, ErrorMessages from .meshcore_parser import MeshcorePacketParser @@ -69,853 +70,917 @@ class MessageReader: except IndexError as e: logger.warning(f"Received empty packet: {e}") return - logger.debug(f"Received data: {data.hex()}") + try: + logger.debug(f"Received data: {data.hex()}") - # Handle command responses - if packet_type_value == PacketType.OK.value: - result: Dict[str, Any] = {} - if len(data) == 5: - result["value"] = int.from_bytes(data[1:5], byteorder="little") + # Handle command responses + if packet_type_value == PacketType.OK.value: + result: Dict[str, Any] = {} + if len(data) == 5: + result["value"] = int.from_bytes(data[1:5], byteorder="little") - # Dispatch event for the OK response - await self.dispatcher.dispatch(Event(EventType.OK, result)) + # Dispatch event for the OK response + await self.dispatcher.dispatch(Event(EventType.OK, result)) - elif packet_type_value == PacketType.ERROR.value: - if len(data) > 1: - result = { "error_code": data[1], } - if data[1] in ErrorMessages: - result["code_string"] = ErrorMessages[data[1]] - else: - result = {} - - # Dispatch event for the ERROR response - await self.dispatcher.dispatch(Event(EventType.ERROR, result)) - - elif packet_type_value == PacketType.CONTACT_START.value: - self.contact_nb = int.from_bytes(data[1:5], byteorder="little") - self.contacts = {} - - elif ( - packet_type_value == PacketType.CONTACT.value - or packet_type_value == PacketType.PUSH_CODE_NEW_ADVERT.value - ): - c = {} - c["public_key"] = dbuf.read(32).hex() - c["type"] = dbuf.read(1)[0] - c["flags"] = dbuf.read(1)[0] - plen = int.from_bytes(dbuf.read(1), signed=False, byteorder="little") - if plen == 255: # flood - c["out_path_hash_mode"] = -1 - c["out_path_len"] = -1 # 6 LSB - else: - c["out_path_hash_mode"] = plen >> 6 - c["out_path_len"] = plen & 0x3F # 6 LSB - c["out_path"] = dbuf.read(64).replace(b"\0", b"").hex() - c["adv_name"] = dbuf.read(32).decode("utf-8", "ignore").replace("\0", "") - c["last_advert"] = int.from_bytes(dbuf.read(4), byteorder="little") - c["adv_lat"] = ( - int.from_bytes(dbuf.read(4), byteorder="little", signed=True) / 1e6 - ) - c["adv_lon"] = ( - int.from_bytes(dbuf.read(4), byteorder="little", signed=True) / 1e6 - ) - c["lastmod"] = int.from_bytes(dbuf.read(4), byteorder="little") - - if packet_type_value == PacketType.PUSH_CODE_NEW_ADVERT.value: - await self.dispatcher.dispatch(Event(EventType.NEW_CONTACT, c)) - else: - await self.dispatcher.dispatch(Event(EventType.NEXT_CONTACT, c)) - self.contacts[c["public_key"]] = c - - elif packet_type_value == PacketType.ADVERT_PATH.value : - r = {} - r["timestamp"] = int.from_bytes(dbuf.read(4), "little", signed=False) - plen = int.from_bytes(dbuf.read(1), "little", signed=False) - if plen == 255: # flood, should not happen - r["path_hash_mode"] = -1 - r["path_len"] = -1 - else: - r["path_hash_mode"] = plen >> 6 # 2 upper bytes - r["path_len"] = plen & 0x3F - r["path"] = dbuf.read().replace(b"\0", b"").hex() - - await self.dispatcher.dispatch(Event(EventType.ADVERT_PATH, r)) - - elif packet_type_value == PacketType.CONTACT_END.value: - lastmod = int.from_bytes(dbuf.read(4), byteorder="little") - attributes = { - "lastmod": lastmod, - } - await self.dispatcher.dispatch( - Event(EventType.CONTACTS, self.contacts, attributes) - ) - - elif packet_type_value == PacketType.SELF_INFO.value: - self_info = {} - self_info["adv_type"] = dbuf.read(1)[0] - self_info["tx_power"] = dbuf.read(1)[0] - self_info["max_tx_power"] = dbuf.read(1)[0] - self_info["public_key"] = dbuf.read(32).hex() - self_info["adv_lat"] = ( - int.from_bytes(dbuf.read(4), byteorder="little", signed=True) / 1e6 - ) - self_info["adv_lon"] = ( - int.from_bytes(dbuf.read(4), byteorder="little", signed=True) / 1e6 - ) - self_info["multi_acks"] = dbuf.read(1)[0] - self_info["adv_loc_policy"] = dbuf.read(1)[0] - telemetry_mode = dbuf.read(1)[0] - self_info["telemetry_mode_env"] = (telemetry_mode >> 4) & 0b11 - self_info["telemetry_mode_loc"] = (telemetry_mode >> 2) & 0b11 - self_info["telemetry_mode_base"] = (telemetry_mode) & 0b11 - self_info["manual_add_contacts"] = dbuf.read(1)[0] > 0 - self_info["radio_freq"] = ( - int.from_bytes(dbuf.read(4), byteorder="little") / 1000 - ) - self_info["radio_bw"] = ( - int.from_bytes(dbuf.read(4), byteorder="little") / 1000 - ) - self_info["radio_sf"] = dbuf.read(1)[0] - self_info["radio_cr"] = dbuf.read(1)[0] - self_info["name"] = dbuf.read().decode("utf-8", "ignore") - await self.dispatcher.dispatch(Event(EventType.SELF_INFO, self_info)) - - elif packet_type_value == PacketType.MSG_SENT.value: - res = {} - res["type"] = dbuf.read(1)[0] - res["expected_ack"] = dbuf.read(4) - res["suggested_timeout"] = int.from_bytes(dbuf.read(4), byteorder="little") - - attributes = { - "type": res["type"], - "expected_ack": res["expected_ack"].hex(), - } - - await self.dispatcher.dispatch(Event(EventType.MSG_SENT, res, attributes)) - - elif packet_type_value == PacketType.CONTACT_MSG_RECV.value: - res = {} - res["type"] = "PRIV" - res["pubkey_prefix"] = dbuf.read(6).hex() - plen = dbuf.read(1)[0] - if plen == 255 : # direct message - res["path_hash_mode"] = -1 - res["path_len"] = plen - else: - res["path_hash_mode"] = plen >> 6 - res["path_len"] = plen & 0x3F - txt_type = dbuf.read(1)[0] - res["txt_type"] = txt_type - res["sender_timestamp"] = int.from_bytes(dbuf.read(4), byteorder="little") - if txt_type == 2: - res["signature"] = dbuf.read(4).hex() - res["text"] = dbuf.read().decode("utf-8", "ignore") - - attributes = { - "pubkey_prefix": res["pubkey_prefix"], - "txt_type": res["txt_type"], - } - - evt_type = EventType.CONTACT_MSG_RECV - - await self.dispatcher.dispatch(Event(evt_type, res, attributes)) - - elif packet_type_value == 16: # A reply to CMD_SYNC_NEXT_MESSAGE (ver >= 3) - res = {} - res["type"] = "PRIV" - res["SNR"] = int.from_bytes(dbuf.read(1), byteorder="little", signed=True) / 4 - dbuf.read(2) # reserved - res["pubkey_prefix"] = dbuf.read(6).hex() - plen = dbuf.read(1)[0] - if plen == 255 : # direct message - res["path_hash_mode"] = -1 - res["path_len"] = plen - else: - res["path_hash_mode"] = plen >> 6 - res["path_len"] = plen & 0x3F - txt_type = dbuf.read(1)[0] - res["txt_type"] = txt_type - res["sender_timestamp"] = int.from_bytes(dbuf.read(4), byteorder="little") - if txt_type == 2: - res["signature"] = dbuf.read(4).hex() - res["text"] = dbuf.read().decode("utf-8", "ignore") - - attributes = { - "pubkey_prefix": res["pubkey_prefix"], - "txt_type": res["txt_type"], - } - - await self.dispatcher.dispatch( - Event(EventType.CONTACT_MSG_RECV, res, attributes) - ) - - elif packet_type_value == PacketType.CHANNEL_MSG_RECV.value: - res = {} - res["type"] = "CHAN" - res["channel_idx"] = dbuf.read(1)[0] - plen = dbuf.read(1)[0] - if plen == 255 : # direct message - res["path_hash_mode"] = -1 - res["path_len"] = plen - else: - res["path_hash_mode"] = plen >> 6 - res["path_len"] = plen & 0x3F - res["txt_type"] = dbuf.read(1)[0] - res["sender_timestamp"] = int.from_bytes(dbuf.read(4), byteorder="little", signed=False) - text = dbuf.read().strip(b"\0") - res["text"] = text.decode("utf-8", "ignore") - - # search for text in log_channels - txt_hash = int.from_bytes(SHA256.new(res["sender_timestamp"].to_bytes(4, "little", signed=False)+text).digest()[0:4], "little", signed=False) - if self.decrypt_channels: - logged = await self.packet_parser.findLogChannelMsg(txt_hash) - if not logged is None: - res["path"] = logged["path"] - res["RSSI"] = logged["rssi"] - res["SNR"] = logged["snr"] - res["recv_time"] = logged["recv_time"] - res["attempt"] = logged["attempt"] - - attributes = { - "channel_idx": res["channel_idx"], - "txt_type": res["txt_type"], - } - - await self.dispatcher.dispatch( - Event(EventType.CHANNEL_MSG_RECV, res, attributes) - ) - - elif packet_type_value == 17: # A reply to CMD_SYNC_NEXT_MESSAGE (ver >= 3) - res = {} - res["type"] = "CHAN" - res["SNR"] = int.from_bytes(dbuf.read(1), byteorder="little", signed=True) / 4 - dbuf.read(2) # reserved - res["channel_idx"] = dbuf.read(1)[0] - plen = dbuf.read(1)[0] - if plen == 255 : # direct message - res["path_hash_mode"] = -1 - res["path_len"] = plen - else: - res["path_hash_mode"] = plen >> 6 - res["path_len"] = plen & 0x3F - res["txt_type"] = dbuf.read(1)[0] - res["sender_timestamp"] = int.from_bytes(dbuf.read(4), byteorder="little", signed=False) - text = dbuf.read() - res["text"] = text.decode("utf-8", "ignore") - - # search for text in log_channels - if self.decrypt_channels: - txt_hash = int.from_bytes(SHA256.new(res["sender_timestamp"].to_bytes(4, "little", signed=False)+text).digest()[0:4], "little", signed=False) - res["txt_hash"] = txt_hash - logged = await self.packet_parser.findLogChannelMsg(txt_hash) - - if not logged is None: - res["path"] = logged["path"] - res["RSSI"] = logged["rssi"] - res["recv_time"] = logged["recv_time"] - res["attempt"] = logged["attempt"] - - attributes = { - "channel_idx": res["channel_idx"], - "txt_type": res["txt_type"], - } - - await self.dispatcher.dispatch( - Event(EventType.CHANNEL_MSG_RECV, res, attributes) - ) - - elif packet_type_value == PacketType.CURRENT_TIME.value: - time_value = int.from_bytes(dbuf.read(4), byteorder="little") - result = {"time": time_value} - await self.dispatcher.dispatch(Event(EventType.CURRENT_TIME, result)) - - elif packet_type_value == PacketType.NO_MORE_MSGS.value: - result = {"messages_available": False} - await self.dispatcher.dispatch(Event(EventType.NO_MORE_MSGS, result)) - - elif packet_type_value == PacketType.CONTACT_URI.value: - contact_uri = "meshcore://" + dbuf.read().hex() - result = {"uri": contact_uri} - await self.dispatcher.dispatch(Event(EventType.CONTACT_URI, result)) - - elif packet_type_value == PacketType.BATTERY.value: - battery_level = int.from_bytes(dbuf.read(2), byteorder="little") - result = {"level": battery_level} - if len(data) > 3: # has storage info as well - result["used_kb"] = int.from_bytes(dbuf.read(4), byteorder="little") - result["total_kb"] = int.from_bytes(dbuf.read(4), byteorder="little") - await self.dispatcher.dispatch(Event(EventType.BATTERY, result)) - - elif packet_type_value == PacketType.DEVICE_INFO.value: - res = {} - fw_ver = dbuf.read(1)[0] - res["fw ver"] = fw_ver - if fw_ver >= 3: - res["max_contacts"] = dbuf.read(1)[0] * 2 - res["max_channels"] = dbuf.read(1)[0] - res["ble_pin"] = int.from_bytes(dbuf.read(4), byteorder="little") - res["fw_build"] = dbuf.read(12).decode("utf-8", "ignore").replace("\0", "") - res["model"] = dbuf.read(40).decode("utf-8", "ignore").replace("\0", "") - res["ver"] = dbuf.read(20).decode("utf-8", "ignore").replace("\0", "") - if fw_ver >= 9: # has repeater mode - rpt = dbuf.read(1) - if len(rpt) > 0: - res["repeat"] = (rpt[0] != 0) - if fw_ver >= 10: # has path_hash_mode - path_hash_mode = dbuf.read(1)[0] - res["path_hash_mode"] = path_hash_mode - await self.dispatcher.dispatch(Event(EventType.DEVICE_INFO, res)) - - elif packet_type_value == PacketType.CUSTOM_VARS.value: - logger.debug(f"received custom vars response: {data.hex()}") - res = {} - rawdata = dbuf.read().decode("utf-8", "ignore") - if not rawdata == "": - pairs = rawdata.split(",") - for p in pairs: - psplit = p.split(":") - res[psplit[0]] = psplit[1] - logger.debug(f"got custom vars : {res}") - await self.dispatcher.dispatch(Event(EventType.CUSTOM_VARS, res)) - - elif packet_type_value == PacketType.STATS.value: # RESP_CODE_STATS (24) - logger.debug(f"received stats response: {data.hex()}") - # RESP_CODE_STATS: All stats responses use code 24 with sub-type byte - # Byte 0: response_code (24), Byte 1: stats_type (0=core, 1=radio, 2=packets) - if len(data) < 2: - logger.error(f"Stats response too short: {len(data)} bytes, need at least 2 for header") - await self.dispatcher.dispatch(Event(EventType.ERROR, {"reason": "invalid_frame_length"})) - return - - stats_type = data[1] - - if stats_type == 0: # STATS_TYPE_CORE - # RESP_CODE_STATS + STATS_TYPE_CORE: 11 bytes total - # Format: 1: + result = { "error_code": data[1], } + if data[1] in ErrorMessages: + result["code_string"] = ErrorMessages[data[1]] else: - try: - battery_mv, uptime_secs, errors, queue_len = struct.unpack('> 6 + c["out_path_len"] = plen & 0x3F # 6 LSB + c["out_path"] = dbuf.read(64).replace(b"\0", b"").hex() + c["adv_name"] = dbuf.read(32).decode("utf-8", "ignore").replace("\0", "") + c["last_advert"] = int.from_bytes(dbuf.read(4), byteorder="little") + c["adv_lat"] = ( + int.from_bytes(dbuf.read(4), byteorder="little", signed=True) / 1e6 + ) + c["adv_lon"] = ( + int.from_bytes(dbuf.read(4), byteorder="little", signed=True) / 1e6 + ) + c["lastmod"] = int.from_bytes(dbuf.read(4), byteorder="little") + + if packet_type_value == PacketType.PUSH_CODE_NEW_ADVERT.value: + await self.dispatcher.dispatch(Event(EventType.NEW_CONTACT, c)) else: - try: - recv, sent, flood_tx, direct_tx, flood_rx, direct_rx = struct.unpack('= 30: - (recv_errors,) = struct.unpack('= 0: - res["channel_name"] = name_bytes[:null_pos].decode("utf-8", "ignore") - else: - res["channel_name"] = name_bytes.decode("utf-8", "ignore") - - res["channel_secret"] = dbuf.read(16) - res["channel_hash"] = SHA256.new(res["channel_secret"]).hexdigest()[0:2] - - await self.packet_parser.newChannel(res) - - await self.dispatcher.dispatch(Event(EventType.CHANNEL_INFO, res, res)) - - # Push notifications - elif packet_type_value == PacketType.ADVERTISEMENT.value: - logger.debug("Advertisement received") - res = {} - res["public_key"] = dbuf.read(32).hex() - await self.dispatcher.dispatch(Event(EventType.ADVERTISEMENT, res, res)) - - elif packet_type_value == PacketType.PATH_UPDATE.value: - logger.debug("Code path update") - res = {} - res["public_key"] = dbuf.read(32).hex() - await self.dispatcher.dispatch(Event(EventType.PATH_UPDATE, res, res)) - - elif packet_type_value == PacketType.ACK.value: - logger.debug("Received ACK") - ack_data = {} - - if len(data) >= 5: - ack_data["code"] = dbuf.read(4).hex() - - attributes = {"code": ack_data.get("code", "")} - - await self.dispatcher.dispatch(Event(EventType.ACK, ack_data, attributes)) - - elif packet_type_value == PacketType.MESSAGES_WAITING.value: - logger.debug("Msgs are waiting") - await self.dispatcher.dispatch(Event(EventType.MESSAGES_WAITING, {})) - - elif packet_type_value == PacketType.RAW_DATA.value: - res = {} - res["SNR"] = int.from_bytes(dbuf.read(1), byteorder="little", signed=True) / 4 - res["RSSI"] = int.from_bytes(dbuf.read(1), byteorder="little", signed=True) - res["payload"] = dbuf.read(4).hex() - logger.debug("Received raw data") - print(res) - await self.dispatcher.dispatch(Event(EventType.RAW_DATA, res)) - - elif packet_type_value == PacketType.LOGIN_SUCCESS.value: - res = {} - attributes = {} - if len(data) > 1: - perms = dbuf.read(1)[0] - res["permissions"] = perms - res["is_admin"] = (perms & 1) == 1 # Check if admin bit is set - - res["pubkey_prefix"] = dbuf.read(6).hex() - - attributes = {"pubkey_prefix": res.get("pubkey_prefix")} - - await self.dispatcher.dispatch( - Event(EventType.LOGIN_SUCCESS, res, attributes) - ) - - elif packet_type_value == PacketType.LOGIN_FAILED.value: - res = {} - attributes = {} - - dbuf.read(1) - - if len(data) > 7: - res["pubkey_prefix"] = pbuf.read(6).hex() - - attributes = {"pubkey_prefix": res.get("pubkey_prefix")} - - await self.dispatcher.dispatch( - Event(EventType.LOGIN_FAILED, res, attributes) - ) - - elif packet_type_value == PacketType.STATUS_RESPONSE.value: - res = parse_status(data, offset=8) - data_hex = data[8:].hex() - logger.debug(f"Status response: {data_hex}") - - attributes = { - "pubkey_prefix": res["pubkey_pre"], - } - - await self.dispatcher.dispatch( - Event(EventType.STATUS_RESPONSE, res, attributes) - ) - - elif packet_type_value == PacketType.LOG_DATA.value: - logger.debug(f"Received RF log data: {data.hex()}") - - # Parse as raw RX data - log_data: Dict[str, Any] = {"raw_hex": data[1:].hex()} - attributes = {} - - recv_time = int(time.time()) - log_data["recv_time"] = recv_time - attributes["recv_time"] = recv_time - - # First byte is SNR (signed byte, multiplied by 4) - if len(data) > 1: - snr_byte = dbuf.read(1)[0] - # Convert to signed value - snr = (snr_byte if snr_byte < 128 else snr_byte - 256) / 4.0 - log_data["snr"] = snr - - # Second byte is RSSI (signed byte) - if len(data) > 2: - rssi_byte = dbuf.read(1)[0] - # Convert to signed value - rssi = rssi_byte if rssi_byte < 128 else rssi_byte - 256 - log_data["rssi"] = rssi - - # Remaining bytes are the raw data payload - payload = None - if len(data) > 3: - payload=dbuf.read() - log_data["payload"] = payload.hex() - log_data["payload_length"] = len(payload) - - # Parse payload and get some info from it - log_data = await self.packet_parser.parsePacketPayload(payload, log_data) - attributes['route_type'] = log_data['route_type'] - attributes['payload_type'] = log_data['payload_type'] - attributes['path_len'] = log_data['path_len'] - attributes['path'] = log_data['path'] - - # Dispatch as RF log data - await self.dispatcher.dispatch( - Event(EventType.RX_LOG_DATA, log_data, attributes) - ) - - elif packet_type_value == PacketType.TRACE_DATA.value: - logger.debug(f"Received trace data: {data.hex()}") - res = {} - - # According to the source, format is: - # 0x89, reserved(0), path_len, flags, tag(4), auth(4), path_hashes[], path_snrs[], final_snr - - reserved = dbuf.read(1)[0] - path_len = dbuf.read(1)[0] - flags = dbuf.read(1)[0] - tag = int.from_bytes(dbuf.read(4), byteorder="little") - auth_code = int.from_bytes(dbuf.read(4), byteorder="little") - - path_hash_len = 1 << (flags&3) - path_len = int(path_len / path_hash_len) - - # Initialize result - res["tag"] = tag - res["auth"] = auth_code - res["flags"] = flags - res["path_len"] = path_len - - # Process path as array of objects with hash and SNR - path_nodes = [] - - if path_len > 0 and len(data) >= 12 + path_len + (path_len * path_hash_len) + 1: - # Extract path with hash and SNR pairs - for i in range(path_len): - node = { - "hash": dbuf.read(path_hash_len).hex(), - } - path_nodes.append(node) - - for n in path_nodes: - node_snr = int.from_bytes(dbuf.read(1), byteorder="little", signed=True) - n["snr"] = node_snr / 4.0 - - # Add the final node (our device) with its SNR - final_snr = int.from_bytes(dbuf.read(1), byteorder="little", signed=True) / 4.0 - path_nodes.append({"snr": final_snr}) - - res["path"] = path_nodes - - logger.debug(f"Parsed trace data: {res}") - - attributes = { - "tag": res["tag"], - "auth_code": res["auth"], - } - - await self.dispatcher.dispatch(Event(EventType.TRACE_DATA, res, attributes)) - - elif packet_type_value == PacketType.TELEMETRY_RESPONSE.value: - logger.debug(f"Received telemetry data: {data.hex()}") - res = {} - - dbuf.read(1) - - res["pubkey_pre"] = dbuf.read(6).hex() - buf = dbuf.read() - - """Parse a given byte string and return as a LppFrame object.""" - i = 0 - lpp_data_list = [] - while i < len(buf) and buf[i] != 0: - lppdata = LppData.from_bytes(buf[i:]) - lpp_data_list.append(lppdata) - i = i + len(lppdata) - - lpp = json.loads( - json.dumps(LppFrame(lpp_data_list), default=lpp_json_encoder) - ) - - res["lpp"] = lpp - - attributes = { - "raw": buf.hex(), - "pubkey_prefix": res["pubkey_pre"], - } - - await self.dispatcher.dispatch( - Event(EventType.TELEMETRY_RESPONSE, res, attributes) - ) - - elif packet_type_value == PacketType.ALLOWED_REPEAT_FREQ.value: - res = {} - freqs = [] - - cont = True - try: - while cont: - min = int.from_bytes(dbuf.read(4), "little", signed=False) - max = int.from_bytes(dbuf.read(4), "little", signed=False) - if min == 0 or max == 0: - cont = False - else: - freqs.append({"min" : min, "max": max}) - except e: - print(e) - - res["freqs"] = freqs - - await self.dispatcher.dispatch( - Event(EventType.ALLOWED_REPEAT_FREQ, res) - ) - - elif packet_type_value == PacketType.BINARY_RESPONSE.value: - dbuf.read(1) - tag = dbuf.read(4).hex() - response_data = dbuf.read() - logger.debug(f"Received binary data: {data.hex()}, tag {tag}, data {response_data.hex()}") - - # Always dispatch generic BINARY_RESPONSE - binary_res = {"tag": tag, "data": response_data.hex()} - await self.dispatcher.dispatch( - Event(EventType.BINARY_RESPONSE, binary_res, {"tag": tag}) - ) - - # Check for tracked request type and dispatch specific response - if tag in self.pending_binary_requests: - request_type = self.pending_binary_requests[tag]["request_type"] - is_anon = self.pending_binary_requests[tag]["is_anon"] - pubkey_prefix = self.pending_binary_requests[tag]["pubkey_prefix"] - context = self.pending_binary_requests[tag]["context"] - del self.pending_binary_requests[tag] - logger.debug(f"Processing binary response for tag {tag}, type {request_type}, pubkey_prefix {pubkey_prefix}") - - if not is_anon: - - if request_type == BinaryReqType.STATUS and len(response_data) >= 52: - res = {} - res = parse_status(response_data, pubkey_prefix=pubkey_prefix) - await self.dispatcher.dispatch( - Event(EventType.STATUS_RESPONSE, res, {"pubkey_prefix": res["pubkey_pre"], "tag": tag}) - ) - - elif request_type == BinaryReqType.TELEMETRY : - try: - lpp = lpp_parse(response_data) - telem_res = {"tag": tag, "lpp": lpp, "pubkey_prefix": pubkey_prefix} - await self.dispatcher.dispatch( - Event(EventType.TELEMETRY_RESPONSE, telem_res, telem_res) - ) - except Exception as e: - logger.error(f"Error parsing binary telemetry response: {e}") - - elif request_type == BinaryReqType.MMA: - try: - mma_result = lpp_parse_mma(response_data[4:]) # Skip 4-byte header - mma_res = {"tag": tag, "mma_data": mma_result, "pubkey_prefix": pubkey_prefix} - await self.dispatcher.dispatch( - Event(EventType.MMA_RESPONSE, mma_res, mma_res) - ) - except Exception as e: - logger.error(f"Error parsing binary MMA response: {e}") - - elif request_type == BinaryReqType.ACL: - try: - acl_result = parse_acl(response_data) - acl_res = {"tag": tag, "acl_data": acl_result, "pubkey_prefix": pubkey_prefix} - await self.dispatcher.dispatch( - Event(EventType.ACL_RESPONSE, acl_res, {"tag": tag, "pubkey_prefix": pubkey_prefix}) - ) - except Exception as e: - logger.error(f"Error parsing binary ACL response: {e}") - - elif request_type == BinaryReqType.NEIGHBOURS: - try: - pk_plen = context["pubkey_prefix_length"] - bbuf = io.BytesIO(response_data) - - res = { - "pubkey_prefix": pubkey_prefix, - "tag": tag - } - res.update(context) # add context in result - - res["neighbours_count"] = int.from_bytes(bbuf.read(2), "little", signed=True) - results_count = int.from_bytes(bbuf.read(2), "little", signed=True) - res["results_count"] = results_count - - neighbours_list = [] - - for _ in range (results_count): - neighb = {} - neighb["pubkey"] = bbuf.read(pk_plen).hex() - neighb["secs_ago"] = int.from_bytes(bbuf.read(4), "little", signed=True) - neighb["snr"] = int.from_bytes(bbuf.read(1), "little", signed=True) / 4 - neighbours_list.append(neighb) - - res["neighbours"] = neighbours_list - - await self.dispatcher.dispatch( - Event(EventType.NEIGHBOURS_RESPONSE, res, {"tag": tag, "pubkey_prefix": pubkey_prefix}) - ) - - except Exception as e: - logger.error(f"Error parsing binary NEIGHBOURS response: {e}") - - else: - logger.debug(f"No tracked request found for binary response tag {tag}") - - elif packet_type_value == PacketType.PATH_DISCOVERY_RESPONSE.value: - logger.debug(f"Received path discovery response: {data.hex()}") - res = {} - dbuf.read(1) - res["pubkey_pre"] = dbuf.read(6).hex() - opl = dbuf.read(1)[0] - opl_hlen = ((opl & 0xc0) >> 6) + 1 - opl = opl & 0x3f - res["out_path_len"] = opl - res["out_path_hash_len"] = opl_hlen - res["out_path"] = dbuf.read(opl*opl_hlen).hex() - ipl = dbuf.read(1)[0] - ipl_hlen = ((ipl & 0xc0) >> 6) + 1 - ipl = ipl & 0x3f - res["in_path_len"] = ipl - res["in_path_hash_len"] = ipl_hlen - res["in_path"] = dbuf.read(ipl*ipl_hlen).hex() - - attributes = {"pubkey_pre": res["pubkey_pre"]} - - await self.dispatcher.dispatch( - Event(EventType.PATH_RESPONSE, res, attributes) - ) - - elif packet_type_value == PacketType.PRIVATE_KEY.value: - logger.debug(f"Received private key response: {data.hex()}") - if len(data) >= 65: # 1 byte response code + 64 bytes private key - private_key = dbuf.read(64) # Extract 64-byte private key - res = {"private_key": private_key} - await self.dispatcher.dispatch(Event(EventType.PRIVATE_KEY, res)) - else: - logger.error(f"Invalid private key response length: {len(data)}") - - elif packet_type_value == PacketType.SIGN_START.value: - logger.debug(f"Received sign start response: {data.hex()}") - # Payload: 1 reserved byte, 4-byte max length - dbuf.read(1) - max_len = int.from_bytes(dbuf.read(4), "little") - res = {"max_length": max_len} - await self.dispatcher.dispatch(Event(EventType.SIGN_START, res)) - - elif packet_type_value == PacketType.SIGNATURE.value: - logger.debug(f"Received signature: {data.hex()}") - signature = dbuf.read() - res = {"signature": signature} - await self.dispatcher.dispatch(Event(EventType.SIGNATURE, res)) - - elif packet_type_value == PacketType.DISABLED.value: - logger.debug("Received disabled response") - res = {"reason": "private_key_export_disabled"} - await self.dispatcher.dispatch(Event(EventType.DISABLED, res)) - - elif packet_type_value == PacketType.CONTROL_DATA.value: - logger.debug("Received control data packet") - res={} - res["SNR"] = int.from_bytes(dbuf.read(1), byteorder="little", signed=True) / 4 - res["RSSI"] = int.from_bytes(dbuf.read(1), byteorder="little", signed=True) - res["path_len"] = dbuf.read(1)[0] - payload = dbuf.read() - payload_type = payload[0] - res["payload_type"] = payload_type - res["payload"] = payload - - attributes = {"payload_type": payload_type} - await self.dispatcher.dispatch( - Event(EventType.CONTROL_DATA, res, attributes) - ) - - # decode NODE_DISCOVER_RESP - if payload_type & 0xF0 == ControlType.NODE_DISCOVER_RESP.value: - pbuf = io.BytesIO(payload[1:]) - ndr = dict(res) - del ndr["payload_type"] - del ndr["payload"] - ndr["node_type"] = payload_type & 0x0F - ndr["SNR_in"] = int.from_bytes(pbuf.read(1), byteorder="little", signed=True)/4 - ndr["tag"] = pbuf.read(4).hex() - - pubkey = pbuf.read() - if len(pubkey) < 32: - pubkey = pubkey[0:8] + await self.dispatcher.dispatch(Event(EventType.NEXT_CONTACT, c)) + self.contacts[c["public_key"]] = c + + elif packet_type_value == PacketType.ADVERT_PATH.value : + r = {} + r["timestamp"] = int.from_bytes(dbuf.read(4), "little", signed=False) + plen = int.from_bytes(dbuf.read(1), "little", signed=False) + if plen == 255: # flood, should not happen + r["path_hash_mode"] = -1 + r["path_len"] = -1 else: - pubkey = pubkey[0:32] + r["path_hash_mode"] = plen >> 6 # 2 upper bytes + r["path_len"] = plen & 0x3F + r["path"] = dbuf.read().replace(b"\0", b"").hex() - ndr["pubkey"] = pubkey.hex() + await self.dispatcher.dispatch(Event(EventType.ADVERT_PATH, r)) + + elif packet_type_value == PacketType.CONTACT_END.value: + lastmod = int.from_bytes(dbuf.read(4), byteorder="little") + attributes = { + "lastmod": lastmod, + } + await self.dispatcher.dispatch( + Event(EventType.CONTACTS, self.contacts, attributes) + ) + + elif packet_type_value == PacketType.SELF_INFO.value: + self_info = {} + self_info["adv_type"] = dbuf.read(1)[0] + self_info["tx_power"] = dbuf.read(1)[0] + self_info["max_tx_power"] = dbuf.read(1)[0] + self_info["public_key"] = dbuf.read(32).hex() + self_info["adv_lat"] = ( + int.from_bytes(dbuf.read(4), byteorder="little", signed=True) / 1e6 + ) + self_info["adv_lon"] = ( + int.from_bytes(dbuf.read(4), byteorder="little", signed=True) / 1e6 + ) + self_info["multi_acks"] = dbuf.read(1)[0] + self_info["adv_loc_policy"] = dbuf.read(1)[0] + telemetry_mode = dbuf.read(1)[0] + self_info["telemetry_mode_env"] = (telemetry_mode >> 4) & 0b11 + self_info["telemetry_mode_loc"] = (telemetry_mode >> 2) & 0b11 + self_info["telemetry_mode_base"] = (telemetry_mode) & 0b11 + self_info["manual_add_contacts"] = dbuf.read(1)[0] > 0 + self_info["radio_freq"] = ( + int.from_bytes(dbuf.read(4), byteorder="little") / 1000 + ) + self_info["radio_bw"] = ( + int.from_bytes(dbuf.read(4), byteorder="little") / 1000 + ) + self_info["radio_sf"] = dbuf.read(1)[0] + self_info["radio_cr"] = dbuf.read(1)[0] + self_info["name"] = dbuf.read().decode("utf-8", "ignore") + await self.dispatcher.dispatch(Event(EventType.SELF_INFO, self_info)) + + elif packet_type_value == PacketType.MSG_SENT.value: + res = {} + res["type"] = dbuf.read(1)[0] + res["expected_ack"] = dbuf.read(4) + res["suggested_timeout"] = int.from_bytes(dbuf.read(4), byteorder="little") attributes = { - "node_type" : ndr["node_type"], - "tag" : ndr["tag"], - "pubkey" : ndr["pubkey"], + "type": res["type"], + "expected_ack": res["expected_ack"].hex(), + } + + await self.dispatcher.dispatch(Event(EventType.MSG_SENT, res, attributes)) + + elif packet_type_value == PacketType.CONTACT_MSG_RECV.value: + res = {} + res["type"] = "PRIV" + res["pubkey_prefix"] = dbuf.read(6).hex() + plen = dbuf.read(1)[0] + if plen == 255 : # direct message + res["path_hash_mode"] = -1 + res["path_len"] = plen + else: + res["path_hash_mode"] = plen >> 6 + res["path_len"] = plen & 0x3F + txt_type = dbuf.read(1)[0] + res["txt_type"] = txt_type + res["sender_timestamp"] = int.from_bytes(dbuf.read(4), byteorder="little") + if txt_type == 2: + res["signature"] = dbuf.read(4).hex() + res["text"] = dbuf.read().decode("utf-8", "ignore") + + attributes = { + "pubkey_prefix": res["pubkey_prefix"], + "txt_type": res["txt_type"], + } + + evt_type = EventType.CONTACT_MSG_RECV + + await self.dispatcher.dispatch(Event(evt_type, res, attributes)) + + elif packet_type_value == PacketType.CONTACT_MSG_RECV_V3.value: # A reply to CMD_SYNC_NEXT_MESSAGE (ver >= 3) + res = {} + res["type"] = "PRIV" + res["SNR"] = int.from_bytes(dbuf.read(1), byteorder="little", signed=True) / 4 + dbuf.read(2) # reserved + res["pubkey_prefix"] = dbuf.read(6).hex() + plen = dbuf.read(1)[0] + if plen == 255 : # direct message + res["path_hash_mode"] = -1 + res["path_len"] = plen + else: + res["path_hash_mode"] = plen >> 6 + res["path_len"] = plen & 0x3F + txt_type = dbuf.read(1)[0] + res["txt_type"] = txt_type + res["sender_timestamp"] = int.from_bytes(dbuf.read(4), byteorder="little") + if txt_type == 2: + res["signature"] = dbuf.read(4).hex() + res["text"] = dbuf.read().decode("utf-8", "ignore") + + attributes = { + "pubkey_prefix": res["pubkey_prefix"], + "txt_type": res["txt_type"], } await self.dispatcher.dispatch( - Event(EventType.DISCOVER_RESPONSE, ndr, attributes) + Event(EventType.CONTACT_MSG_RECV, res, attributes) ) - else: - logger.debug(f"Unhandled data received {data}") - logger.debug(f"Unhandled packet type: {packet_type_value}") + elif packet_type_value == PacketType.CHANNEL_MSG_RECV.value: + res = {} + res["type"] = "CHAN" + res["channel_idx"] = dbuf.read(1)[0] + plen = dbuf.read(1)[0] + if plen == 255 : # direct message + res["path_hash_mode"] = -1 + res["path_len"] = plen + else: + res["path_hash_mode"] = plen >> 6 + res["path_len"] = plen & 0x3F + res["txt_type"] = dbuf.read(1)[0] + res["sender_timestamp"] = int.from_bytes(dbuf.read(4), byteorder="little", signed=False) + text = dbuf.read().strip(b"\0") + res["text"] = text.decode("utf-8", "ignore") + + # search for text in log_channels + txt_hash = int.from_bytes(SHA256.new(res["sender_timestamp"].to_bytes(4, "little", signed=False)+text).digest()[0:4], "little", signed=False) + if self.decrypt_channels: + logged = await self.packet_parser.findLogChannelMsg(txt_hash) + if not logged is None: + res["path"] = logged["path"] + res["RSSI"] = logged["rssi"] + res["SNR"] = logged["snr"] + res["recv_time"] = logged["recv_time"] + res["attempt"] = logged["attempt"] + + attributes = { + "channel_idx": res["channel_idx"], + "txt_type": res["txt_type"], + } + + await self.dispatcher.dispatch( + Event(EventType.CHANNEL_MSG_RECV, res, attributes) + ) + + elif packet_type_value == PacketType.CHANNEL_MSG_RECV_V3.value: # A reply to CMD_SYNC_NEXT_MESSAGE (ver >= 3) + res = {} + res["type"] = "CHAN" + res["SNR"] = int.from_bytes(dbuf.read(1), byteorder="little", signed=True) / 4 + dbuf.read(2) # reserved + res["channel_idx"] = dbuf.read(1)[0] + plen = dbuf.read(1)[0] + if plen == 255 : # direct message + res["path_hash_mode"] = -1 + res["path_len"] = plen + else: + res["path_hash_mode"] = plen >> 6 + res["path_len"] = plen & 0x3F + res["txt_type"] = dbuf.read(1)[0] + res["sender_timestamp"] = int.from_bytes(dbuf.read(4), byteorder="little", signed=False) + text = dbuf.read() + res["text"] = text.decode("utf-8", "ignore") + + # search for text in log_channels + if self.decrypt_channels: + txt_hash = int.from_bytes(SHA256.new(res["sender_timestamp"].to_bytes(4, "little", signed=False)+text).digest()[0:4], "little", signed=False) + res["txt_hash"] = txt_hash + logged = await self.packet_parser.findLogChannelMsg(txt_hash) + + if not logged is None: + res["path"] = logged["path"] + res["RSSI"] = logged["rssi"] + res["recv_time"] = logged["recv_time"] + res["attempt"] = logged["attempt"] + + attributes = { + "channel_idx": res["channel_idx"], + "txt_type": res["txt_type"], + } + + await self.dispatcher.dispatch( + Event(EventType.CHANNEL_MSG_RECV, res, attributes) + ) + + elif packet_type_value == PacketType.CURRENT_TIME.value: + time_value = int.from_bytes(dbuf.read(4), byteorder="little") + result = {"time": time_value} + await self.dispatcher.dispatch(Event(EventType.CURRENT_TIME, result)) + + elif packet_type_value == PacketType.NO_MORE_MSGS.value: + result = {"messages_available": False} + await self.dispatcher.dispatch(Event(EventType.NO_MORE_MSGS, result)) + + elif packet_type_value == PacketType.CONTACT_URI.value: + contact_uri = "meshcore://" + dbuf.read().hex() + result = {"uri": contact_uri} + await self.dispatcher.dispatch(Event(EventType.CONTACT_URI, result)) + + elif packet_type_value == PacketType.BATTERY.value: + # Full RESP_CODE_BATT_AND_STORAGE: 1 type + 2 level + 4 used_kb + 4 total_kb = 11 bytes. + # Minimum viable frame is 3 bytes (type + level). Shorter frames are + # malformed — dbuf.read(2) would return short bytes and + # int.from_bytes(b"", ...) silently yields 0 (same class as N07). + if len(data) < 3: + logger.debug( + "BATTERY frame too short for level field " + f"({len(data)} bytes < 3), skipping" + ) + return + battery_level = int.from_bytes(dbuf.read(2), byteorder="little") + result = {"level": battery_level} + # The previous `len(data) > 3` guard let 4-10 byte truncated frames + # through, producing silent zero values for used_kb/total_kb because + # io.BytesIO.read() returns short data without raising. + if len(data) >= 11: # has storage info as well + result["used_kb"] = int.from_bytes(dbuf.read(4), byteorder="little") + result["total_kb"] = int.from_bytes(dbuf.read(4), byteorder="little") + await self.dispatcher.dispatch(Event(EventType.BATTERY, result)) + + elif packet_type_value == PacketType.DEVICE_INFO.value: + res = {} + fw_ver = dbuf.read(1)[0] + res["fw ver"] = fw_ver + if fw_ver >= 3: + res["max_contacts"] = dbuf.read(1)[0] * 2 + res["max_channels"] = dbuf.read(1)[0] + res["ble_pin"] = int.from_bytes(dbuf.read(4), byteorder="little") + res["fw_build"] = dbuf.read(12).decode("utf-8", "ignore").replace("\0", "") + res["model"] = dbuf.read(40).decode("utf-8", "ignore").replace("\0", "") + res["ver"] = dbuf.read(20).decode("utf-8", "ignore").replace("\0", "") + if fw_ver >= 9: # has repeater mode + rpt = dbuf.read(1) + if len(rpt) > 0: + res["repeat"] = (rpt[0] != 0) + if fw_ver >= 10: # has path_hash_mode + path_hash_mode = dbuf.read(1)[0] + res["path_hash_mode"] = path_hash_mode + await self.dispatcher.dispatch(Event(EventType.DEVICE_INFO, res)) + + elif packet_type_value == PacketType.CUSTOM_VARS.value: + logger.debug(f"received custom vars response: {data.hex()}") + res = {} + rawdata = dbuf.read().decode("utf-8", "ignore") + if not rawdata == "": + pairs = rawdata.split(",") + for p in pairs: + psplit = p.split(":") + res[psplit[0]] = psplit[1] + logger.debug(f"got custom vars : {res}") + await self.dispatcher.dispatch(Event(EventType.CUSTOM_VARS, res)) + + elif packet_type_value == PacketType.STATS.value: # RESP_CODE_STATS (24) + logger.debug(f"received stats response: {data.hex()}") + # RESP_CODE_STATS: All stats responses use code 24 with sub-type byte + # Byte 0: response_code (24), Byte 1: stats_type (0=core, 1=radio, 2=packets) + if len(data) < 2: + logger.error(f"Stats response too short: {len(data)} bytes, need at least 2 for header") + await self.dispatcher.dispatch(Event(EventType.ERROR, {"reason": "invalid_frame_length"})) + return + + stats_type = data[1] + + if stats_type == 0: # STATS_TYPE_CORE + # RESP_CODE_STATS + STATS_TYPE_CORE: 11 bytes total + # Format: = 30: + (recv_errors,) = struct.unpack('= 0: + res["channel_name"] = name_bytes[:null_pos].decode("utf-8", "ignore") + else: + res["channel_name"] = name_bytes.decode("utf-8", "ignore") + + res["channel_secret"] = dbuf.read(16) + res["channel_hash"] = SHA256.new(res["channel_secret"]).hexdigest()[0:2] + + await self.packet_parser.newChannel(res) + + await self.dispatcher.dispatch(Event(EventType.CHANNEL_INFO, res, res)) + + # Push notifications + elif packet_type_value == PacketType.ADVERTISEMENT.value: + logger.debug("Advertisement received") + res = {} + res["public_key"] = dbuf.read(32).hex() + await self.dispatcher.dispatch(Event(EventType.ADVERTISEMENT, res, res)) + + elif packet_type_value == PacketType.PATH_UPDATE.value: + logger.debug("Code path update") + res = {} + res["public_key"] = dbuf.read(32).hex() + await self.dispatcher.dispatch(Event(EventType.PATH_UPDATE, res, res)) + + elif packet_type_value == PacketType.ACK.value: + logger.debug("Received ACK") + ack_data = {} + + if len(data) >= 5: + ack_data["code"] = dbuf.read(4).hex() + + attributes = {"code": ack_data.get("code", "")} + + await self.dispatcher.dispatch(Event(EventType.ACK, ack_data, attributes)) + + elif packet_type_value == PacketType.MESSAGES_WAITING.value: + logger.debug("Msgs are waiting") + await self.dispatcher.dispatch(Event(EventType.MESSAGES_WAITING, {})) + + elif packet_type_value == PacketType.RAW_DATA.value: + res = {} + res["SNR"] = int.from_bytes(dbuf.read(1), byteorder="little", signed=True) / 4 + res["RSSI"] = int.from_bytes(dbuf.read(1), byteorder="little", signed=True) + res["payload"] = dbuf.read(4).hex() + logger.debug("Received raw data") + logger.debug(res) + await self.dispatcher.dispatch(Event(EventType.RAW_DATA, res)) + + elif packet_type_value == PacketType.LOGIN_SUCCESS.value: + res = {} + attributes = {} + if len(data) > 1: + perms = dbuf.read(1)[0] + res["permissions"] = perms + res["is_admin"] = (perms & 1) == 1 # Check if admin bit is set + if len(data) > 7: + res["pubkey_prefix"] = dbuf.read(6).hex() + + attributes = {"pubkey_prefix": res.get("pubkey_prefix")} + + await self.dispatcher.dispatch( + Event(EventType.LOGIN_SUCCESS, res, attributes) + ) + + elif packet_type_value == PacketType.LOGIN_FAILED.value: + res = {} + attributes = {} + + dbuf.read(1) + + if len(data) > 7: + res["pubkey_prefix"] = dbuf.read(6).hex() + + attributes = {"pubkey_prefix": res.get("pubkey_prefix")} + + await self.dispatcher.dispatch( + Event(EventType.LOGIN_FAILED, res, attributes) + ) + + elif packet_type_value == PacketType.STATUS_RESPONSE.value: + # parse_status with offset=8 reads up through data[56:60] + # (rx_airtime field), so the full payload is 60 bytes: + # 1 type + 1 reserved + 6 pubkey + 52 status fields. The + # BINARY_RESPONSE STATUS path below gates with `>= 52` on + # the offset-stripped buffer; this gate is the equivalent + # for the push path with the 8-byte header included. + if len(data) < 60: + logger.debug(f"STATUS_RESPONSE push frame too short ({len(data)} bytes < 60), skipping parse") + return + res = parse_status(data, offset=8) + data_hex = data[8:].hex() + logger.debug(f"Status response: {data_hex}") + + attributes = { + "pubkey_prefix": res["pubkey_pre"], + } + + await self.dispatcher.dispatch( + Event(EventType.STATUS_RESPONSE, res, attributes) + ) + + elif packet_type_value == PacketType.LOG_DATA.value: + logger.debug(f"Received RF log data: {data.hex()}") + + # Parse as raw RX data + log_data: Dict[str, Any] = {"raw_hex": data[1:].hex()} + attributes = {} + + recv_time = int(time.time()) + log_data["recv_time"] = recv_time + attributes["recv_time"] = recv_time + + # First byte is SNR (signed byte, multiplied by 4) + if len(data) > 1: + snr_byte = dbuf.read(1)[0] + # Convert to signed value + snr = (snr_byte if snr_byte < 128 else snr_byte - 256) / 4.0 + log_data["snr"] = snr + + # Second byte is RSSI (signed byte) + if len(data) > 2: + rssi_byte = dbuf.read(1)[0] + # Convert to signed value + rssi = rssi_byte if rssi_byte < 128 else rssi_byte - 256 + log_data["rssi"] = rssi + + # Remaining bytes are the raw data payload + payload = None + if len(data) > 3: + payload=dbuf.read() + log_data["payload"] = payload.hex() + log_data["payload_length"] = len(payload) + + # Parse payload and get some info from it + log_data = await self.packet_parser.parsePacketPayload(payload, log_data) + attributes['route_type'] = log_data['route_type'] + attributes['payload_type'] = log_data['payload_type'] + attributes['path_len'] = log_data['path_len'] + attributes['path'] = log_data['path'] + + # Dispatch as RF log data + await self.dispatcher.dispatch( + Event(EventType.RX_LOG_DATA, log_data, attributes) + ) + + elif packet_type_value == PacketType.TRACE_DATA.value: + logger.debug(f"Received trace data: {data.hex()}") + res = {} + + # According to the source, format is: + # 0x89, reserved(0), path_len, flags, tag(4), auth(4), path_hashes[], path_snrs[], final_snr + + reserved = dbuf.read(1)[0] + path_len = dbuf.read(1)[0] + flags = dbuf.read(1)[0] + tag = int.from_bytes(dbuf.read(4), byteorder="little") + auth_code = int.from_bytes(dbuf.read(4), byteorder="little") + + path_hash_len = 1 << (flags&3) + path_len = int(path_len / path_hash_len) + + # Initialize result + res["tag"] = tag + res["auth"] = auth_code + res["flags"] = flags + res["path_len"] = path_len + + # Process path as array of objects with hash and SNR + path_nodes = [] + + if path_len > 0 and len(data) >= 12 + path_len + (path_len * path_hash_len) + 1: + # Extract path with hash and SNR pairs + for i in range(path_len): + node = { + "hash": dbuf.read(path_hash_len).hex(), + } + path_nodes.append(node) + + for n in path_nodes: + node_snr = int.from_bytes(dbuf.read(1), byteorder="little", signed=True) + n["snr"] = node_snr / 4.0 + + # Add the final node (our device) with its SNR + final_snr = int.from_bytes(dbuf.read(1), byteorder="little", signed=True) / 4.0 + path_nodes.append({"snr": final_snr}) + + res["path"] = path_nodes + + logger.debug(f"Parsed trace data: {res}") + + attributes = { + "tag": res["tag"], + "auth_code": res["auth"], + } + + await self.dispatcher.dispatch(Event(EventType.TRACE_DATA, res, attributes)) + + elif packet_type_value == PacketType.TELEMETRY_RESPONSE.value: + logger.debug(f"Received telemetry data: {data.hex()}") + res = {} + + dbuf.read(1) + + res["pubkey_pre"] = dbuf.read(6).hex() + buf = dbuf.read() + + """Parse a given byte string and return as a LppFrame object.""" + i = 0 + lpp_data_list = [] + while i < len(buf) and buf[i] != 0: + lppdata = LppData.from_bytes(buf[i:]) + lpp_data_list.append(lppdata) + i = i + len(lppdata) + + lpp = json.loads( + json.dumps(LppFrame(lpp_data_list), default=lpp_json_encoder) + ) + + res["lpp"] = lpp + + attributes = { + "raw": buf.hex(), + "pubkey_prefix": res["pubkey_pre"], + } + + await self.dispatcher.dispatch( + Event(EventType.TELEMETRY_RESPONSE, res, attributes) + ) + + elif packet_type_value == PacketType.ALLOWED_REPEAT_FREQ.value: + res = {} + freqs = [] + + cont = True + try: + while cont: + min = int.from_bytes(dbuf.read(4), "little", signed=False) + max = int.from_bytes(dbuf.read(4), "little", signed=False) + if min == 0 or max == 0: + cont = False + else: + freqs.append({"min" : min, "max": max}) + except Exception as e: + logger.warning(f"Error parsing ALLOWED_REPEAT_FREQ payload: {e}") + + res["freqs"] = freqs + + await self.dispatcher.dispatch( + Event(EventType.ALLOWED_REPEAT_FREQ, res) + ) + + elif packet_type_value == PacketType.BINARY_RESPONSE.value: + dbuf.read(1) + tag = dbuf.read(4).hex() + response_data = dbuf.read() + logger.debug(f"Received binary data: {data.hex()}, tag {tag}, data {response_data.hex()}") + + # Always dispatch generic BINARY_RESPONSE + binary_res = {"tag": tag, "data": response_data.hex()} + await self.dispatcher.dispatch( + Event(EventType.BINARY_RESPONSE, binary_res, {"tag": tag}) + ) + + # Check for tracked request type and dispatch specific response + if tag in self.pending_binary_requests: + request_type = self.pending_binary_requests[tag]["request_type"] + is_anon = self.pending_binary_requests[tag]["is_anon"] + pubkey_prefix = self.pending_binary_requests[tag]["pubkey_prefix"] + context = self.pending_binary_requests[tag]["context"] + del self.pending_binary_requests[tag] + logger.debug(f"Processing binary response for tag {tag}, type {request_type}, pubkey_prefix {pubkey_prefix}") + + if not is_anon: + + if request_type == BinaryReqType.STATUS and len(response_data) >= 52: + res = {} + res = parse_status(response_data, pubkey_prefix=pubkey_prefix) + await self.dispatcher.dispatch( + Event(EventType.STATUS_RESPONSE, res, {"pubkey_prefix": res["pubkey_pre"], "tag": tag}) + ) + + elif request_type == BinaryReqType.TELEMETRY : + try: + lpp = lpp_parse(response_data) + telem_res = {"tag": tag, "lpp": lpp, "pubkey_prefix": pubkey_prefix} + await self.dispatcher.dispatch( + Event(EventType.TELEMETRY_RESPONSE, telem_res, telem_res) + ) + except Exception as e: + logger.error(f"Error parsing binary telemetry response: {e}") + + elif request_type == BinaryReqType.MMA: + try: + mma_result = lpp_parse_mma(response_data[4:]) # Skip 4-byte header + mma_res = {"tag": tag, "mma_data": mma_result, "pubkey_prefix": pubkey_prefix} + await self.dispatcher.dispatch( + Event(EventType.MMA_RESPONSE, mma_res, mma_res) + ) + except Exception as e: + logger.error(f"Error parsing binary MMA response: {e}") + + elif request_type == BinaryReqType.ACL: + try: + acl_result = parse_acl(response_data) + acl_res = {"tag": tag, "acl_data": acl_result, "pubkey_prefix": pubkey_prefix} + await self.dispatcher.dispatch( + Event(EventType.ACL_RESPONSE, acl_res, {"tag": tag, "pubkey_prefix": pubkey_prefix}) + ) + except Exception as e: + logger.error(f"Error parsing binary ACL response: {e}") + + elif request_type == BinaryReqType.NEIGHBOURS: + try: + pk_plen = context["pubkey_prefix_length"] + bbuf = io.BytesIO(response_data) + + res = { + "pubkey_prefix": pubkey_prefix, + "tag": tag + } + res.update(context) # add context in result + + res["neighbours_count"] = int.from_bytes(bbuf.read(2), "little", signed=True) + results_count = int.from_bytes(bbuf.read(2), "little", signed=True) + res["results_count"] = results_count + + neighbours_list = [] + + for _ in range (results_count): + neighb = {} + neighb["pubkey"] = bbuf.read(pk_plen).hex() + neighb["secs_ago"] = int.from_bytes(bbuf.read(4), "little", signed=True) + neighb["snr"] = int.from_bytes(bbuf.read(1), "little", signed=True) / 4 + neighbours_list.append(neighb) + + res["neighbours"] = neighbours_list + + await self.dispatcher.dispatch( + Event(EventType.NEIGHBOURS_RESPONSE, res, {"tag": tag, "pubkey_prefix": pubkey_prefix}) + ) + + except Exception as e: + logger.error(f"Error parsing binary NEIGHBOURS response: {e}") + + else: + logger.debug(f"No tracked request found for binary response tag {tag}") + + elif packet_type_value == PacketType.PATH_DISCOVERY_RESPONSE.value: + logger.debug(f"Received path discovery response: {data.hex()}") + res = {} + dbuf.read(1) + res["pubkey_pre"] = dbuf.read(6).hex() + opl = dbuf.read(1)[0] + opl_hlen = ((opl & 0xc0) >> 6) + 1 + opl = opl & 0x3f + res["out_path_len"] = opl + res["out_path_hash_len"] = opl_hlen + res["out_path"] = dbuf.read(opl*opl_hlen).hex() + ipl = dbuf.read(1)[0] + ipl_hlen = ((ipl & 0xc0) >> 6) + 1 + ipl = ipl & 0x3f + res["in_path_len"] = ipl + res["in_path_hash_len"] = ipl_hlen + res["in_path"] = dbuf.read(ipl*ipl_hlen).hex() + + attributes = {"pubkey_pre": res["pubkey_pre"]} + + await self.dispatcher.dispatch( + Event(EventType.PATH_RESPONSE, res, attributes) + ) + + elif packet_type_value == PacketType.PRIVATE_KEY.value: + logger.debug(f"Received private key response: {data.hex()}") + if len(data) >= 65: # 1 byte response code + 64 bytes private key + private_key = dbuf.read(64) # Extract 64-byte private key + res = {"private_key": private_key} + await self.dispatcher.dispatch(Event(EventType.PRIVATE_KEY, res)) + else: + logger.error(f"Invalid private key response length: {len(data)}") + + elif packet_type_value == PacketType.SIGN_START.value: + logger.debug(f"Received sign start response: {data.hex()}") + # Payload: 1 reserved byte, 4-byte max length + dbuf.read(1) + max_len = int.from_bytes(dbuf.read(4), "little") + res = {"max_length": max_len} + await self.dispatcher.dispatch(Event(EventType.SIGN_START, res)) + + elif packet_type_value == PacketType.SIGNATURE.value: + logger.debug(f"Received signature: {data.hex()}") + signature = dbuf.read() + res = {"signature": signature} + await self.dispatcher.dispatch(Event(EventType.SIGNATURE, res)) + + elif packet_type_value == PacketType.DISABLED.value: + logger.debug("Received disabled response") + res = {"reason": "private_key_export_disabled"} + await self.dispatcher.dispatch(Event(EventType.DISABLED, res)) + + elif packet_type_value == PacketType.CONTROL_DATA.value: + logger.debug("Received control data packet") + res={} + res["SNR"] = int.from_bytes(dbuf.read(1), byteorder="little", signed=True) / 4 + res["RSSI"] = int.from_bytes(dbuf.read(1), byteorder="little", signed=True) + res["path_len"] = dbuf.read(1)[0] + payload = dbuf.read() + if len(payload) == 0: + logger.debug("CONTROL_DATA frame has empty payload, skipping") + return + payload_type = payload[0] + res["payload_type"] = payload_type + res["payload"] = payload + + attributes = {"payload_type": payload_type} + await self.dispatcher.dispatch( + Event(EventType.CONTROL_DATA, res, attributes) + ) + + # decode NODE_DISCOVER_RESP + if payload_type & 0xF0 == ControlType.NODE_DISCOVER_RESP.value: + pbuf = io.BytesIO(payload[1:]) + ndr = dict(res) + del ndr["payload_type"] + del ndr["payload"] + ndr["node_type"] = payload_type & 0x0F + ndr["SNR_in"] = int.from_bytes(pbuf.read(1), byteorder="little", signed=True)/4 + ndr["tag"] = pbuf.read(4).hex() + + pubkey = pbuf.read() + if len(pubkey) < 32: + pubkey = pubkey[0:8] + else: + pubkey = pubkey[0:32] + + ndr["pubkey"] = pubkey.hex() + + attributes = { + "node_type" : ndr["node_type"], + "tag" : ndr["tag"], + "pubkey" : ndr["pubkey"], + } + + await self.dispatcher.dispatch( + Event(EventType.DISCOVER_RESPONSE, ndr, attributes) + ) + elif packet_type_value == PacketType.CONTACT_DELETED.value: + # N01: PUSH_CODE_CONTACT_DELETED (0x8F) — 1-byte code + 32-byte pubkey + # Emitted by MyMesh::onContactOverwrite() (MyMesh.cpp:325-334) + if len(data) < 33: + logger.debug("CONTACT_DELETED frame too short (%d bytes, need 33)", len(data)) + return + pubkey = data[1:33].hex() + await self.dispatcher.dispatch( + Event(EventType.CONTACT_DELETED, {"pubkey": pubkey}, {"pubkey": pubkey}) + ) + + elif packet_type_value == PacketType.CONTACTS_FULL.value: + # N02: PUSH_CODE_CONTACTS_FULL (0x90) — 1-byte push, no payload + # Emitted by MyMesh::onContactsFull() (MyMesh.cpp:336) + await self.dispatcher.dispatch(Event(EventType.CONTACTS_FULL, {})) + + elif packet_type_value == PacketType.TUNING_PARAMS.value: + # N03: RESP_CODE_TUNING_PARAMS (23) — response to CMD_GET_TUNING_PARAMS (43) + # Format: 1-byte code + 4-byte rx_delay (LE) + 4-byte airtime_factor (LE) = 9 bytes + # Emitted by MyMesh.cpp:1307-1313 + if len(data) < 9: + logger.debug("TUNING_PARAMS frame too short (%d bytes, need 9)", len(data)) + await self.dispatcher.dispatch( + Event(EventType.ERROR, {"reason": "invalid_frame_length"}) + ) + return + rx_delay = int.from_bytes(data[1:5], byteorder="little") + airtime_factor = int.from_bytes(data[5:9], byteorder="little") + res = {"rx_delay": rx_delay, "airtime_factor": airtime_factor} + await self.dispatcher.dispatch(Event(EventType.TUNING_PARAMS, res)) + + else: + logger.debug(f"Unhandled data received {data}") + logger.debug(f"Unhandled packet type: {packet_type_value}") + except Exception as e: + logger.error( + "handle_rx parse error: %s: %s | raw=%s\n%s", + type(e).__name__, + e, + data.hex(), + traceback.format_exc(), + ) \ No newline at end of file diff --git a/src/meshcore/serial_cx.py b/src/meshcore/serial_cx.py index 61163bd..088142b 100644 --- a/src/meshcore/serial_cx.py +++ b/src/meshcore/serial_cx.py @@ -20,11 +20,19 @@ class SerialConnection: self._disconnect_callback = None self.cx_dly = cx_dly self._connected_event = asyncio.Event() + self._background_tasks: set[asyncio.Task] = set() self.frame_expected_size = 0 self.inframe = b"" self.header = b"" + def _spawn_background(self, coro) -> asyncio.Task: + """Create a tracked background task (prevents GC of fire-and-forget tasks).""" + task = asyncio.create_task(coro) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + return task + class MCSerialClientProtocol(asyncio.Protocol): def __init__(self, cx): self.cx = cx @@ -44,7 +52,7 @@ class SerialConnection: self.cx._connected_event.clear() if self.cx._disconnect_callback: - asyncio.create_task(self.cx._disconnect_callback("serial_disconnect")) + self.cx._spawn_background(self.cx._disconnect_callback("serial_disconnect")) def pause_writing(self): logger.debug("pause writing") @@ -52,12 +60,16 @@ class SerialConnection: def resume_writing(self): logger.debug("resume writing") - async def connect(self): + async def connect(self, timeout: float = 10.0): """ - Connects to the device + Connects to the device. + + Args: + timeout: Maximum seconds to wait for connection_made callback. + Defaults to 10.0. Raises asyncio.TimeoutError on expiry. """ self._connected_event.clear() - + loop = asyncio.get_running_loop() await serial_asyncio.create_serial_connection( loop, @@ -66,7 +78,7 @@ class SerialConnection: baudrate=self.baudrate, ) - await self._connected_event.wait() + await asyncio.wait_for(self._connected_event.wait(), timeout=timeout) logger.info("Serial Connection started") return self.port @@ -102,7 +114,7 @@ class SerialConnection: self.frame_expected_size = 0 if len(data) > 0: # rerun handle_rx on remaining data self.handle_rx(data) - return + return # nothing left to process after reset upbound = self.frame_expected_size - len(self.inframe) if len(data) < upbound: @@ -114,7 +126,7 @@ class SerialConnection: data = data[upbound:] if self.reader is not None: # feed meshcore reader - asyncio.create_task(self.reader.handle_rx(self.inframe)) + self._spawn_background(self.reader.handle_rx(self.inframe)) # reset inframe self.inframe = b"" self.header = b"" @@ -125,11 +137,18 @@ class SerialConnection: async def send(self, data): if not self.transport: logger.error("Transport not connected, cannot send data") + if self._disconnect_callback: + await self._disconnect_callback("serial_transport_lost") return size = len(data) pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data logger.debug(f"sending pkt : {pkt}") - self.transport.write(pkt) + try: + self.transport.write(pkt) + except OSError as exc: + logger.warning(f"Serial write failed: {exc}") + if self._disconnect_callback: + await self._disconnect_callback(f"serial_write_failed: {exc}") async def disconnect(self): """Close the serial connection.""" diff --git a/src/meshcore/tcp_cx.py b/src/meshcore/tcp_cx.py index 497c3b2..80c8af3 100644 --- a/src/meshcore/tcp_cx.py +++ b/src/meshcore/tcp_cx.py @@ -24,6 +24,14 @@ class TCPConnection: self.frame_expected_size = 0 self.header = b"" self.inframe = b"" + self._background_tasks: set[asyncio.Task] = set() + + def _spawn_background(self, coro) -> asyncio.Task: + """Create a tracked background task (prevents GC of fire-and-forget tasks).""" + task = asyncio.create_task(coro) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + return task class MCClientProtocol(asyncio.Protocol): def __init__(self, cx): @@ -38,7 +46,6 @@ class TCPConnection: def data_received(self, data): logger.debug("data received") - self.cx._receive_count += 1 self.cx.handle_rx(data) def error_received(self, exc): @@ -47,7 +54,7 @@ class TCPConnection: def connection_lost(self, exc): logger.debug("TCP server closed the connection") if self.cx._disconnect_callback: - asyncio.create_task(self.cx._disconnect_callback("tcp_disconnect")) + self.cx._spawn_background(self.cx._disconnect_callback("tcp_disconnect")) async def connect(self): """ @@ -59,10 +66,7 @@ class TCPConnection: ) logger.info("TCP Connection started") - future = asyncio.Future() - future.set_result(self.host) - - return future + return self.host def set_reader(self, reader): self.reader = reader @@ -96,7 +100,7 @@ class TCPConnection: self.frame_expected_size = 0 if len(data) > 0: # rerun handle_rx on remaining data self.handle_rx(data) - return + return # nothing left to process after reset upbound = self.frame_expected_size - len(self.inframe) if len(data) < upbound : @@ -106,9 +110,13 @@ class TCPConnection: self.inframe = self.inframe + data[0:upbound] data = data[upbound:] + # Increment per completed MeshCore frame, not per TCP segment (N04). + # The threshold heuristic in send() compares _send_count vs + # _receive_count — counting per-segment skews it under fragmentation. + self._receive_count += 1 if self.reader is not None: # feed meshcore reader - asyncio.create_task(self.reader.handle_rx(self.inframe)) + self._spawn_background(self.reader.handle_rx(self.inframe)) # reset inframe self.inframe = b"" self.header = b"" @@ -137,7 +145,12 @@ class TCPConnection: size = len(data) pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data logger.debug(f"sending pkt : {pkt}") - self.transport.write(pkt) + try: + self.transport.write(pkt) + except (OSError, ConnectionResetError) as exc: + logger.warning(f"TCP write failed: {exc}") + if self._disconnect_callback: + await self._disconnect_callback(f"tcp_write_failed: {exc}") async def disconnect(self): """Close the TCP connection.""" diff --git a/tests/test_ble_pin_pairing.py b/tests/test_ble_pin_pairing.py index aaa3acc..8b49184 100644 --- a/tests/test_ble_pin_pairing.py +++ b/tests/test_ble_pin_pairing.py @@ -37,7 +37,7 @@ class TestBLEPinPairing(unittest.TestCase): @patch("meshcore.ble_cx.BleakClient") def test_ble_connection_with_pin_failed_pairing(self, mock_bleak_client): - """Test BLE connection with PIN when pairing fails but connection continues""" + """Test BLE connection with PIN when pairing fails — re-raises (F17).""" # Arrange mock_client_instance = self._get_mock_bleak_client() mock_client_instance.pair = AsyncMock(side_effect=Exception("Pairing failed")) @@ -47,17 +47,16 @@ class TestBLEPinPairing(unittest.TestCase): pin = "123456" ble_conn = BLEConnection(address=address, pin=pin) - # Act - result = asyncio.run(ble_conn.connect()) - - # Assert + # Act & Assert — pairing failure now re-raises instead of being + # swallowed, because a half-usable transport is worse than a clean + # failure (forensics finding F17). + with self.assertRaises(Exception) as ctx: + asyncio.run(ble_conn.connect()) + self.assertIn("Pairing failed", str(ctx.exception)) mock_client_instance.connect.assert_called_once() mock_client_instance.pair.assert_called_once() - mock_client_instance.start_notify.assert_called_once_with( - UART_TX_CHAR_UUID, ble_conn.handle_rx - ) - # Connection should still succeed even if pairing fails - self.assertEqual(result, address) + # disconnect should be called to clean up the failed connection + mock_client_instance.disconnect.assert_called_once() @patch("meshcore.ble_cx.BleakClient") def test_ble_connection_without_pin_no_pairing(self, mock_bleak_client): diff --git a/tests/unit/test_asyncio_lifecycle.py b/tests/unit/test_asyncio_lifecycle.py new file mode 100644 index 0000000..f70f2f0 --- /dev/null +++ b/tests/unit/test_asyncio_lifecycle.py @@ -0,0 +1,235 @@ +""" +Verification tests for asyncio lifecycle fixes. +""" + +import asyncio +import gc +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from meshcore.events import Event, EventDispatcher, EventType +from meshcore.tcp_cx import TCPConnection +from meshcore.serial_cx import SerialConnection +from meshcore.commands.base import CommandHandlerBase + + +class TestBackgroundTaskTracking(unittest.TestCase): + """Fire-and-forget create_task calls must be tracked to prevent GC.""" + + def test_tcp_spawn_background_retains_task(self): + """TCP _spawn_background adds the task to _background_tasks.""" + async def _run(): + cx = TCPConnection("127.0.0.1", 5555) + completed = asyncio.Event() + + async def dummy(): + completed.set() + + task = cx._spawn_background(dummy()) + assert task in cx._background_tasks + await completed.wait() + # After completion, done_callback should have discarded it + await asyncio.sleep(0) # let done callback fire + assert task not in cx._background_tasks + + asyncio.run(_run()) + + def test_serial_spawn_background_retains_task(self): + """Serial _spawn_background adds the task to _background_tasks.""" + async def _run(): + with patch("meshcore.serial_cx.asyncio.Event") as mock_event: + mock_event.return_value = MagicMock() + cx = SerialConnection("/dev/null", 115200) + completed = asyncio.Event() + + async def dummy(): + completed.set() + + task = cx._spawn_background(dummy()) + assert task in cx._background_tasks + await completed.wait() + await asyncio.sleep(0) + assert task not in cx._background_tasks + + asyncio.run(_run()) + + def test_event_dispatcher_spawn_background_retains_task(self): + """EventDispatcher _spawn_background adds task to _background_tasks.""" + async def _run(): + dispatcher = EventDispatcher() + completed = asyncio.Event() + + async def dummy(): + completed.set() + + task = dispatcher._spawn_background(dummy()) + assert task in dispatcher._background_tasks + await completed.wait() + await asyncio.sleep(0) + assert task not in dispatcher._background_tasks + + asyncio.run(_run()) + + def test_tcp_handle_rx_uses_tracked_task(self): + """TCP handle_rx dispatches reader.handle_rx via _spawn_background.""" + async def _run(): + cx = TCPConnection("127.0.0.1", 5555) + reader = AsyncMock() + reader.handle_rx = AsyncMock() + cx.set_reader(reader) + + # Build a minimal valid frame: 0x3e + 2-byte LE size + payload + payload = b"\x01\x02\x03" + size = len(payload).to_bytes(2, "little") + frame = b"\x3e" + size + payload + + cx.handle_rx(frame) + # Task should be tracked + assert len(cx._background_tasks) == 1 + # Let task complete + await asyncio.sleep(0.05) + reader.handle_rx.assert_awaited_once_with(payload) + + asyncio.run(_run()) + + def test_tcp_connection_lost_uses_tracked_task(self): + """TCP connection_lost dispatches disconnect callback via _spawn_background.""" + async def _run(): + cx = TCPConnection("127.0.0.1", 5555) + callback = AsyncMock() + cx.set_disconnect_callback(callback) + + protocol = cx.MCClientProtocol(cx) + protocol.connection_lost(None) + + assert len(cx._background_tasks) == 1 + await asyncio.sleep(0.05) + callback.assert_awaited_once_with("tcp_disconnect") + + asyncio.run(_run()) + + def test_gc_does_not_cancel_tracked_tasks(self): + """Tracked tasks survive GC pressure (the whole point of tracking).""" + async def _run(): + cx = TCPConnection("127.0.0.1", 5555) + result = [] + + async def slow_task(): + await asyncio.sleep(0.05) + result.append("done") + + cx._spawn_background(slow_task()) + # Force GC — untracked tasks could be collected here + gc.collect() + await asyncio.sleep(0.1) + assert result == ["done"] + + asyncio.run(_run()) + + +class TestTaskDoneCorrectness(unittest.TestCase): + """EventDispatcher.stop() must wait for in-flight async callbacks.""" + + def test_stop_waits_for_async_callbacks(self): + """stop() should not return until async callbacks have completed.""" + async def _run(): + dispatcher = EventDispatcher() + await dispatcher.start() + + callback_completed = False + + async def slow_callback(event): + nonlocal callback_completed + await asyncio.sleep(0.1) + callback_completed = True + + dispatcher.subscribe(EventType.OK, slow_callback) + await dispatcher.dispatch(Event(EventType.OK, {})) + + # Give the dispatch loop a moment to pick up the event + await asyncio.sleep(0.02) + + # stop() should wait for slow_callback to finish + await dispatcher.stop() + assert callback_completed, "stop() returned before async callback completed" + + asyncio.run(_run()) + + +class TestDeferredPrimitiveConstruction(unittest.TestCase): + """Queue and Lock must not bind to import-time loop.""" + + def test_event_dispatcher_queue_is_none_before_start(self): + """EventDispatcher.queue should be None until start() is called.""" + dispatcher = EventDispatcher() + assert dispatcher.queue is None + + def test_event_dispatcher_queue_created_on_start(self): + """start() creates the queue.""" + async def _run(): + dispatcher = EventDispatcher() + assert dispatcher.queue is None + await dispatcher.start() + assert dispatcher.queue is not None + assert isinstance(dispatcher.queue, asyncio.Queue) + await dispatcher.stop() + + asyncio.run(_run()) + + def test_event_dispatcher_dispatch_before_start_raises(self): + """dispatch() before start() should raise RuntimeError.""" + async def _run(): + dispatcher = EventDispatcher() + with self.assertRaises(RuntimeError): + await dispatcher.dispatch(Event(EventType.OK, {})) + + asyncio.run(_run()) + + def test_command_handler_lock_is_none_before_use(self): + """CommandHandlerBase lock should be None until first access.""" + handler = CommandHandlerBase() + assert handler._CommandHandlerBase__mesh_request_lock is None + + def test_command_handler_lock_created_on_access(self): + """Accessing _mesh_request_lock creates it lazily.""" + async def _run(): + handler = CommandHandlerBase() + lock = handler._mesh_request_lock + assert isinstance(lock, asyncio.Lock) + # Second access returns same instance + assert handler._mesh_request_lock is lock + + asyncio.run(_run()) + + +class TestGetRunningLoop(unittest.TestCase): + """get_event_loop() replaced with get_running_loop() in send().""" + + def test_send_uses_get_running_loop(self): + """send() should call get_running_loop, not get_event_loop.""" + async def _run(): + handler = CommandHandlerBase() + dispatcher = EventDispatcher() + await dispatcher.start() + handler.set_dispatcher(dispatcher) + + mock_sender = AsyncMock() + handler._sender_func = mock_sender + + # Patch get_running_loop to verify it's called + with patch("meshcore.commands.base.asyncio.get_running_loop", wraps=asyncio.get_running_loop) as mock_grl: + # send with expected_events triggers the loop = asyncio.get_running_loop() path + result = await handler.send( + b"\x01", + expected_events=[EventType.OK], + timeout=0.05, + ) + mock_grl.assert_called() + + await dispatcher.stop() + + asyncio.run(_run()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_commands.py b/tests/unit/test_commands.py index 44dff6f..02516e1 100644 --- a/tests/unit/test_commands.py +++ b/tests/unit/test_commands.py @@ -28,6 +28,10 @@ def mock_dispatcher(): sub.unsubscribe = MagicMock() dispatcher._last_subscribe_handler = handler dispatcher._last_subscribe_event_type = event_type + # Immediately resolve the future so send() doesn't block + asyncio.get_event_loop().call_soon( + handler, Event(event_type, {}) + ) return sub dispatcher.subscribe = MagicMock(side_effect=fake_subscribe) @@ -80,6 +84,13 @@ async def test_send_with_event(command_handler, mock_connection, mock_dispatcher async def test_send_timeout(command_handler, mock_connection, mock_dispatcher): + # Override to NOT resolve events, so we can test the timeout path + def non_resolving_subscribe(event_type, handler, attribute_filters=None): + sub = MagicMock(spec=Subscription) + sub.unsubscribe = MagicMock() + return sub + mock_dispatcher.subscribe = MagicMock(side_effect=non_resolving_subscribe) + result = await command_handler.send(b"test_command", [EventType.OK], timeout=0.1) assert result.type == EventType.ERROR assert result.payload == {"reason": "no_event_received"} diff --git a/tests/unit/test_connection_manager.py b/tests/unit/test_connection_manager.py new file mode 100644 index 0000000..d92cf06 --- /dev/null +++ b/tests/unit/test_connection_manager.py @@ -0,0 +1,294 @@ +"""Tests for reconnect-path fixes.""" + +import asyncio + +import pytest + +from meshcore.connection_manager import ConnectionManager +from meshcore.events import Event, EventDispatcher, EventType + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class FakeConnection: + """Minimal stub that satisfies ConnectionProtocol.""" + + def __init__(self, connect_results=None): + """ + Args: + connect_results: iterator of return values for successive + connect() calls. ``None`` means soft failure; a string + means success; raising is also supported via sentinel. + """ + self._connect_results = list(connect_results or ["ok"]) + self._call_index = 0 + self.reader = None + + async def connect(self): + if self._call_index < len(self._connect_results): + result = self._connect_results[self._call_index] + self._call_index += 1 + else: + result = self._connect_results[-1] + if isinstance(result, Exception): + raise result + return result + + async def disconnect(self): + pass + + async def send(self, data): + pass + + def set_reader(self, reader): + self.reader = reader + + +class RaisingConnection(FakeConnection): + """Connection that raises on every connect() attempt.""" + + def __init__(self, exc=None): + super().__init__() + self._exc = exc or ConnectionError("boom") + + async def connect(self): + raise self._exc + + +class _EventCollector: + """Subscribes to all events and records them.""" + + def __init__(self, dispatcher: EventDispatcher): + self.events: list[Event] = [] + dispatcher.subscribe(None, self._on_event) + + async def _on_event(self, event: Event): + self.events.append(event) + + +# --------------------------------------------------------------------------- +# TCP connect() should return a plain value, not an asyncio.Future +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_tcp_connect_returns_plain_string(): + """TCPConnection.connect() returns self.host (a plain string), not an + asyncio.Future. We test indirectly via ConnectionManager — the + CONNECTED event payload should contain a plain string, not a Future + object.""" + conn = FakeConnection(connect_results=["10.0.0.1"]) + dispatcher = EventDispatcher() + await dispatcher.start() + try: + collector = _EventCollector(dispatcher) + mgr = ConnectionManager(conn, dispatcher) + + result = await mgr.connect() + + assert result == "10.0.0.1" + # Give the dispatcher a moment to deliver the event + await asyncio.sleep(0.05) + connected_events = [e for e in collector.events if e.type == EventType.CONNECTED] + assert len(connected_events) == 1 + payload = connected_events[0].payload + assert payload["connection_info"] == "10.0.0.1" + # The payload value must NOT be an asyncio.Future + assert not isinstance(payload["connection_info"], asyncio.Future) + finally: + await dispatcher.stop() + + +# --------------------------------------------------------------------------- +# Reconnect attempts must not compound (no tail-recursive create_task) +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_reconnect_loop_does_not_compound(): + """_attempt_reconnect must use a single iterative loop. After + max_reconnect_attempts failures, exactly that many connect() calls + should have been made — no exponential fan-out from orphaned tasks.""" + # All attempts fail (return None) + conn = FakeConnection(connect_results=[None, None, None, None]) + dispatcher = EventDispatcher() + await dispatcher.start() + try: + collector = _EventCollector(dispatcher) + mgr = ConnectionManager( + conn, dispatcher, auto_reconnect=True, max_reconnect_attempts=3, + ) + mgr._is_connected = True # simulate a live connection + + await mgr.handle_disconnect("test_disconnect") + # Wait for the reconnect loop to exhaust all attempts + # (3 attempts × 1s sleep each, but we can just await the task) + if mgr._reconnect_task: + await mgr._reconnect_task + + # Exactly 3 connect() calls should have been made + assert conn._call_index == 3 + + # A DISCONNECTED event with max_attempts_exceeded should have fired + await asyncio.sleep(0.05) + disconnected = [e for e in collector.events if e.type == EventType.DISCONNECTED] + assert len(disconnected) == 1 + assert disconnected[0].payload.get("max_attempts_exceeded") is True + finally: + await dispatcher.stop() + + +@pytest.mark.asyncio +async def test_disconnect_cancels_reconnect_loop(): + """disconnect() during an active reconnect loop must cancel the + single task cleanly — no orphaned tasks left running.""" + # Simulate a connection that always fails (returns None), giving us + # time to call disconnect() mid-loop. + conn = FakeConnection(connect_results=[None, None, None, None, None]) + dispatcher = EventDispatcher() + await dispatcher.start() + try: + mgr = ConnectionManager( + conn, dispatcher, auto_reconnect=True, max_reconnect_attempts=5, + ) + mgr._is_connected = True + + await mgr.handle_disconnect("test_disconnect") + + # Let the first attempt start (wait just past the 1s sleep) + await asyncio.sleep(1.2) + assert conn._call_index >= 1 # at least one attempt made + + # Now disconnect — should cancel the loop + await mgr.disconnect() + + assert mgr._reconnect_task is None + calls_at_cancel = conn._call_index + + # Wait a bit and confirm no more attempts happened + await asyncio.sleep(2) + assert conn._call_index == calls_at_cancel + finally: + await dispatcher.stop() + + +# --------------------------------------------------------------------------- +# reconnect_callback (send_appstart) is called after reconnect +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_reconnect_callback_called_after_reconnect(): + """When ConnectionManager reconnects successfully, the + reconnect_callback (e.g. send_appstart) must be invoked.""" + callback_called = [] + + async def fake_appstart(): + callback_called.append(True) + + # First connect() fails (None), second succeeds + conn = FakeConnection(connect_results=[None, "10.0.0.1"]) + dispatcher = EventDispatcher() + await dispatcher.start() + try: + mgr = ConnectionManager( + conn, dispatcher, + auto_reconnect=True, + max_reconnect_attempts=3, + reconnect_callback=fake_appstart, + ) + mgr._is_connected = True + + await mgr.handle_disconnect("test_disconnect") + if mgr._reconnect_task: + await mgr._reconnect_task + + assert len(callback_called) == 1 + finally: + await dispatcher.stop() + + +@pytest.mark.asyncio +async def test_reconnect_callback_failure_does_not_crash_loop(): + """If the reconnect_callback raises, the reconnect still counts as + successful (transport is up) — the callback failure is logged but + does not crash the loop or leave the manager in a broken state.""" + async def failing_callback(): + raise RuntimeError("appstart failed") + + # connect() succeeds on first attempt + conn = FakeConnection(connect_results=["10.0.0.1"]) + dispatcher = EventDispatcher() + await dispatcher.start() + try: + collector = _EventCollector(dispatcher) + mgr = ConnectionManager( + conn, dispatcher, + auto_reconnect=True, + max_reconnect_attempts=3, + reconnect_callback=failing_callback, + ) + mgr._is_connected = True + + await mgr.handle_disconnect("test_disconnect") + if mgr._reconnect_task: + await mgr._reconnect_task + + # Despite callback failure, CONNECTED event should have fired + await asyncio.sleep(0.05) + connected = [e for e in collector.events if e.type == EventType.CONNECTED] + assert len(connected) == 1 + assert mgr._is_connected is True + finally: + await dispatcher.stop() + + +# --------------------------------------------------------------------------- +# connect() returning None is a soft failure (BLE scan miss) +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_connect_none_is_soft_failure(): + """When connect() returns None (e.g. BLE scan found no device), + ConnectionManager.connect() should NOT set _is_connected and should + NOT emit a CONNECTED event.""" + conn = FakeConnection(connect_results=[None]) + dispatcher = EventDispatcher() + await dispatcher.start() + try: + collector = _EventCollector(dispatcher) + mgr = ConnectionManager(conn, dispatcher) + + result = await mgr.connect() + + assert result is None + assert mgr._is_connected is False + await asyncio.sleep(0.05) + connected = [e for e in collector.events if e.type == EventType.CONNECTED] + assert len(connected) == 0 + finally: + await dispatcher.stop() + + +@pytest.mark.asyncio +async def test_no_reconnect_callback_is_noop(): + """When no reconnect_callback is provided (backwards compat for + direct ConnectionManager users), reconnect should still work.""" + conn = FakeConnection(connect_results=["10.0.0.1"]) + dispatcher = EventDispatcher() + await dispatcher.start() + try: + mgr = ConnectionManager( + conn, dispatcher, + auto_reconnect=True, + max_reconnect_attempts=3, + # No reconnect_callback — default None + ) + mgr._is_connected = True + + await mgr.handle_disconnect("test_disconnect") + if mgr._reconnect_task: + await mgr._reconnect_task + + assert mgr._is_connected is True + finally: + await dispatcher.stop() diff --git a/tests/unit/test_error_handling.py b/tests/unit/test_error_handling.py new file mode 100644 index 0000000..7f88e28 --- /dev/null +++ b/tests/unit/test_error_handling.py @@ -0,0 +1,236 @@ +"""Verification tests for error response handling fixes. + +The tests confirm that error responses are surfaced cleanly instead +of causing KeyError, TypeError, NameError, or silent fallthrough. +""" +import asyncio +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from meshcore.commands import CommandHandler +from meshcore.events import EventType, Event, Subscription + +pytestmark = pytest.mark.asyncio + +VALID_PUBKEY_HEX = "0123456789abcdef" * 4 # 64 hex chars = 32 bytes + + +# ── Fixtures ─────────────────────────────────────────────────────── + +@pytest.fixture +def mock_connection(): + connection = MagicMock() + connection.send = AsyncMock() + return connection + + +@pytest.fixture +def mock_dispatcher(): + dispatcher = MagicMock() + dispatcher.wait_for_event = AsyncMock() + dispatcher.dispatch = AsyncMock() + + def fake_subscribe(event_type, handler, attribute_filters=None): + sub = MagicMock(spec=Subscription) + sub.unsubscribe = MagicMock() + dispatcher._last_subscribe_handler = handler + dispatcher._last_subscribe_event_type = event_type + return sub + + dispatcher.subscribe = MagicMock(side_effect=fake_subscribe) + return dispatcher + + +@pytest.fixture +def command_handler(mock_connection, mock_dispatcher): + handler = CommandHandler() + + async def sender(data): + await mock_connection.send(data) + + handler._sender_func = sender + handler.dispatcher = mock_dispatcher + return handler + + +def setup_error_response(mock_dispatcher): + """Configure dispatcher to return an ERROR event for any subscribe.""" + def fake_subscribe(evt_type, handler, attr_filters=None): + sub = MagicMock(spec=Subscription) + sub.unsubscribe = MagicMock() + # Always fire ERROR regardless of which event type was subscribed + if evt_type == EventType.ERROR: + asyncio.get_event_loop().call_soon( + handler, Event(EventType.ERROR, {"reason": "test_error"}) + ) + return sub + + mock_dispatcher.subscribe = MagicMock(side_effect=fake_subscribe) + + +def setup_event_response(mock_dispatcher, event_type, payload): + """Configure dispatcher to return a specific event.""" + def fake_subscribe(evt_type, handler, attr_filters=None): + sub = MagicMock(spec=Subscription) + sub.unsubscribe = MagicMock() + if evt_type == event_type: + asyncio.get_event_loop().call_soon( + handler, Event(event_type, payload) + ) + return sub + + mock_dispatcher.subscribe = MagicMock(side_effect=fake_subscribe) + + +# ── Event.is_error() helper ────────────────────────────────── + +async def test_event_is_error_true(): + """is_error() returns True for ERROR events.""" + event = Event(EventType.ERROR, {"reason": "test"}) + assert event.is_error() is True + + +async def test_event_is_error_false(): + """is_error() returns False for non-ERROR events.""" + event = Event(EventType.OK, {}) + assert event.is_error() is False + event2 = Event(EventType.SELF_INFO, {"name": "test"}) + assert event2.is_error() is False + + +# ── send_msg_with_retry continues on ERROR ────────────── + +async def test_send_msg_with_retry_error_no_keyerror( + command_handler, mock_dispatcher +): + """send_msg_with_retry returns None (exhausted retries) on + persistent ERROR instead of raising KeyError on missing 'expected_ack'.""" + setup_error_response(mock_dispatcher) + + # Provide a mock contact so the path logic doesn't interfere + command_handler._get_contact_by_prefix = MagicMock(return_value=None) + + # max_attempts=2 so it retries once then gives up + result = await command_handler.send_msg_with_retry( + VALID_PUBKEY_HEX, "hello", max_attempts=2, timeout=0.1 + ) + + # Should return None (no ACK received) rather than raising KeyError + assert result is None + + +# ── send_appstart includes ERROR in expected events ────────── + +async def test_send_appstart_returns_error( + command_handler, mock_dispatcher +): + """send_appstart returns ERROR event instead of hanging on timeout.""" + setup_error_response(mock_dispatcher) + + result = await command_handler.send_appstart() + + assert result.type == EventType.ERROR + assert result.is_error() is True + assert result.payload["reason"] == "test_error" + + +# ── device setters return ERROR from send_appstart ─────────── + +async def test_set_telemetry_mode_base_error( + command_handler, mock_dispatcher +): + """set_telemetry_mode_base returns ERROR instead of KeyError.""" + setup_error_response(mock_dispatcher) + + result = await command_handler.set_telemetry_mode_base(1) + + assert result.is_error() + assert result.payload["reason"] == "test_error" + + +async def test_set_telemetry_mode_loc_error( + command_handler, mock_dispatcher +): + """set_telemetry_mode_loc returns ERROR instead of KeyError.""" + setup_error_response(mock_dispatcher) + + result = await command_handler.set_telemetry_mode_loc(1) + + assert result.is_error() + + +async def test_set_telemetry_mode_env_error( + command_handler, mock_dispatcher +): + """set_telemetry_mode_env returns ERROR instead of KeyError.""" + setup_error_response(mock_dispatcher) + + result = await command_handler.set_telemetry_mode_env(1) + + assert result.is_error() + + +async def test_set_manual_add_contacts_error( + command_handler, mock_dispatcher +): + """set_manual_add_contacts returns ERROR instead of KeyError.""" + setup_error_response(mock_dispatcher) + + result = await command_handler.set_manual_add_contacts(True) + + assert result.is_error() + + +async def test_set_advert_loc_policy_error( + command_handler, mock_dispatcher +): + """set_advert_loc_policy returns ERROR instead of KeyError.""" + setup_error_response(mock_dispatcher) + + result = await command_handler.set_advert_loc_policy(1) + + assert result.is_error() + + +async def test_set_multi_acks_error( + command_handler, mock_dispatcher +): + """set_multi_acks returns ERROR instead of KeyError.""" + setup_error_response(mock_dispatcher) + + result = await command_handler.set_multi_acks(1) + + assert result.is_error() + + +# ── send_anon_req returns ERROR on contact not found ───────── + +async def test_send_anon_req_contact_not_found( + command_handler, mock_dispatcher +): + """send_anon_req returns ERROR event when contact prefix not found, + instead of raising TypeError on NoneType subscript.""" + command_handler._get_contact_by_prefix = MagicMock(return_value=None) + + result = await command_handler.send_anon_req( + VALID_PUBKEY_HEX, MagicMock(value=1) + ) + + assert result.is_error() + assert result.payload["reason"] == "contact_not_found" + + +# ── send_trace handles unknown path_hash_len without NameError ── + +async def test_send_trace_unknown_path_hash_len( + command_handler, mock_connection, mock_dispatcher +): + """send_trace with a path whose segments don't match any known + path_hash_len returns ERROR cleanly instead of NameError on 'e'.""" + # 5-char hex segments → path_hash_len = 2.5 → doesn't match 1,2,4,8 + result = await command_handler.send_trace( + auth_code=0, tag=1, flags=None, path="abcde" + ) + + assert result.is_error() + assert result.payload["reason"] == "invalid_path_format" diff --git a/tests/unit/test_protocol_surface_gaps.py b/tests/unit/test_protocol_surface_gaps.py new file mode 100644 index 0000000..a490a81 --- /dev/null +++ b/tests/unit/test_protocol_surface_gaps.py @@ -0,0 +1,364 @@ +"""Verification tests for protocol surface gaps. + +Each test constructs a mock firmware frame and verifies the SDK dispatches +the correct EventType with the expected payload fields. +""" + +import asyncio +import struct +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from meshcore.events import Event, EventType, EventDispatcher +from meshcore.reader import MessageReader +from meshcore.packets import PacketType, CommandType + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_reader(): + """Create a MessageReader with a mock dispatcher that records dispatched events.""" + dispatcher = MagicMock(spec=EventDispatcher) + dispatched = [] + + async def _capture(event): + dispatched.append(event) + + dispatcher.dispatch = AsyncMock(side_effect=_capture) + reader = MessageReader(dispatcher) + return reader, dispatched + + +# --------------------------------------------------------------------------- +# CONTACT_DELETED handler +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_contact_deleted_dispatches_event(): + """A 33-byte CONTACT_DELETED frame dispatches EventType.CONTACT_DELETED.""" + reader, dispatched = _make_reader() + pubkey = bytes(range(32)) + frame = bytes([PacketType.CONTACT_DELETED.value]) + pubkey + assert len(frame) == 33 + + await reader.handle_rx(bytearray(frame)) + + assert len(dispatched) == 1 + evt = dispatched[0] + assert evt.type == EventType.CONTACT_DELETED + assert evt.payload["pubkey"] == pubkey.hex() + assert evt.attributes["pubkey"] == pubkey.hex() + + +@pytest.mark.asyncio +async def test_contact_deleted_short_frame_ignored(): + """A CONTACT_DELETED frame shorter than 33 bytes is silently dropped.""" + reader, dispatched = _make_reader() + # Only 10 bytes — too short + frame = bytes([PacketType.CONTACT_DELETED.value]) + b"\x00" * 9 + + await reader.handle_rx(bytearray(frame)) + + assert len(dispatched) == 0 + + +# --------------------------------------------------------------------------- +# CONTACTS_FULL handler + enum entry +# --------------------------------------------------------------------------- + +def test_contacts_full_enum_exists(): + """PacketType.CONTACTS_FULL == 0x90.""" + assert PacketType.CONTACTS_FULL.value == 0x90 + + +@pytest.mark.asyncio +async def test_contacts_full_dispatches_event(): + """A 1-byte CONTACTS_FULL push dispatches EventType.CONTACTS_FULL.""" + reader, dispatched = _make_reader() + frame = bytes([PacketType.CONTACTS_FULL.value]) + + await reader.handle_rx(bytearray(frame)) + + assert len(dispatched) == 1 + evt = dispatched[0] + assert evt.type == EventType.CONTACTS_FULL + assert evt.payload == {} + + +# --------------------------------------------------------------------------- +# TUNING_PARAMS handler +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_tuning_params_dispatches_event(): + """A 9-byte TUNING_PARAMS frame dispatches with rx_delay and airtime_factor.""" + reader, dispatched = _make_reader() + rx_delay = 500 + airtime_factor = 200 + frame = ( + bytes([PacketType.TUNING_PARAMS.value]) + + rx_delay.to_bytes(4, "little") + + airtime_factor.to_bytes(4, "little") + ) + assert len(frame) == 9 + + await reader.handle_rx(bytearray(frame)) + + assert len(dispatched) == 1 + evt = dispatched[0] + assert evt.type == EventType.TUNING_PARAMS + assert evt.payload["rx_delay"] == 500 + assert evt.payload["airtime_factor"] == 200 + + +@pytest.mark.asyncio +async def test_tuning_params_short_frame_dispatches_error(): + """A TUNING_PARAMS frame shorter than 9 bytes dispatches ERROR.""" + reader, dispatched = _make_reader() + # Only 5 bytes — too short + frame = bytes([PacketType.TUNING_PARAMS.value]) + b"\x01\x00\x00\x00" + + await reader.handle_rx(bytearray(frame)) + + assert len(dispatched) == 1 + evt = dispatched[0] + assert evt.type == EventType.ERROR + assert evt.payload["reason"] == "invalid_frame_length" + + +# --------------------------------------------------------------------------- +# send_trace() one-byte pad +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_send_trace_empty_path_pads_to_11_bytes(): + """send_trace() with no path produces an 11-byte packet (not 10).""" + from meshcore.commands.messaging import MessagingCommands + + cmd = MessagingCommands.__new__(MessagingCommands) + + captured_data = None + + async def mock_send(data, expected_events, timeout=None): + nonlocal captured_data + captured_data = bytes(data) + return Event(EventType.MSG_SENT, {"type": 0, "expected_ack": b"\x00" * 4, "suggested_timeout": 1000}) + + cmd.send = mock_send + + await cmd.send_trace(auth_code=0, tag=1, flags=0, path=None) + + assert captured_data is not None + # cmd(1) + tag(4) + auth(4) + flags(1) + pad(1) = 11 + assert len(captured_data) == 11 + assert captured_data[-1] == 0x00 # The pad byte + + +@pytest.mark.asyncio +async def test_send_trace_with_path_no_padding(): + """send_trace() with a non-empty path does NOT add padding.""" + from meshcore.commands.messaging import MessagingCommands + + cmd = MessagingCommands.__new__(MessagingCommands) + + captured_data = None + + async def mock_send(data, expected_events, timeout=None): + nonlocal captured_data + captured_data = bytes(data) + return Event(EventType.MSG_SENT, {"type": 0, "expected_ack": b"\x00" * 4, "suggested_timeout": 1000}) + + cmd.send = mock_send + + # 2-byte path hash (flags=1 means hash_len=2) + await cmd.send_trace(auth_code=0, tag=1, flags=1, path=b"\xAA\xBB") + + assert captured_data is not None + # cmd(1) + tag(4) + auth(4) + flags(1) + path(2) = 12 — no pad needed + assert len(captured_data) == 12 + + +# --------------------------------------------------------------------------- +# Command wrapper: send_raw_data +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_send_raw_data_wrapper(): + """send_raw_data sends CMD 0x19 + payload.""" + from meshcore.commands.messaging import MessagingCommands + + cmd = MessagingCommands.__new__(MessagingCommands) + + captured_data = None + + async def mock_send(data, expected_events, timeout=None): + nonlocal captured_data + captured_data = bytes(data) + return Event(EventType.MSG_SENT, {"type": 0, "expected_ack": b"\x00" * 4, "suggested_timeout": 1000}) + + cmd.send = mock_send + + await cmd.send_raw_data(b"\xDE\xAD") + + assert captured_data is not None + assert captured_data[0] == 0x19 # CMD_SEND_RAW_DATA + assert captured_data[1:] == b"\xDE\xAD" + + +# --------------------------------------------------------------------------- +# Command wrapper: has_connection +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_has_connection_wrapper(): + """has_connection sends CMD 0x1c.""" + from meshcore.commands.device import DeviceCommands + + cmd = DeviceCommands.__new__(DeviceCommands) + + captured_data = None + + async def mock_send(data, expected_events, timeout=None): + nonlocal captured_data + captured_data = bytes(data) + return Event(EventType.OK, {"value": 1}) + + cmd.send = mock_send + + await cmd.has_connection() + + assert captured_data is not None + assert captured_data == b"\x1c" + + +# --------------------------------------------------------------------------- +# Command wrapper: get_tuning +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_get_tuning_wrapper(): + """get_tuning sends CMD 0x2b (GET_TUNING_PARAMS = 43).""" + from meshcore.commands.device import DeviceCommands + + cmd = DeviceCommands.__new__(DeviceCommands) + + captured_data = None + + async def mock_send(data, expected_events, timeout=None): + nonlocal captured_data + captured_data = bytes(data) + return Event(EventType.TUNING_PARAMS, {"rx_delay": 500, "airtime_factor": 200}) + + cmd.send = mock_send + + result = await cmd.get_tuning() + + assert captured_data == b"\x2b" + assert result.type == EventType.TUNING_PARAMS + + +# --------------------------------------------------------------------------- +# Command wrapper: get_contact_by_key +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_get_contact_by_key_wrapper(): + """get_contact_by_key sends CMD 0x1e + 32-byte pubkey.""" + from meshcore.commands.contact import ContactCommands + + cmd = ContactCommands.__new__(ContactCommands) + + captured_data = None + + async def mock_send(data, expected_events, timeout=None): + nonlocal captured_data + captured_data = bytes(data) + return Event(EventType.NEXT_CONTACT, {"public_key": "ab" * 32}) + + cmd.send = mock_send + + pubkey = bytes(range(32)) + await cmd.get_contact_by_key(pubkey) + + assert captured_data is not None + assert captured_data[0] == 0x1E # CMD_GET_CONTACT_BY_KEY + assert captured_data[1:] == pubkey + + +# --------------------------------------------------------------------------- +# Command wrapper: factory_reset (two-step) +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_factory_reset_two_step(): + """factory_reset requires a token from request_factory_reset.""" + from meshcore.commands.device import DeviceCommands + + cmd = DeviceCommands.__new__(DeviceCommands) + + captured_data = None + + async def mock_send(data, expected_events, timeout=None): + nonlocal captured_data + captured_data = bytes(data) + return Event(EventType.OK, {}) + + cmd.send = mock_send + + # Step 1: request token + token = await cmd.request_factory_reset() + assert isinstance(token, str) + assert len(token) == 16 # hex-encoded 8 bytes + + # Step 2: confirm with wrong token fails + with pytest.raises(ValueError, match="Invalid or expired"): + await cmd.confirm_factory_reset("wrong_token") + + # Step 2: confirm with correct token succeeds + await cmd.confirm_factory_reset(token) + assert captured_data == b"\x33" # CMD_FACTORY_RESET + + +@pytest.mark.asyncio +async def test_factory_reset_without_request_fails(): + """confirm_factory_reset without request_factory_reset raises ValueError.""" + from meshcore.commands.device import DeviceCommands + + cmd = DeviceCommands.__new__(DeviceCommands) + + with pytest.raises(ValueError, match="Invalid or expired"): + await cmd.confirm_factory_reset("any_token") + + +# --------------------------------------------------------------------------- +# GET_STATS enum entry +# --------------------------------------------------------------------------- + +def test_get_stats_enum_exists(): + """CommandType.GET_STATS == 56.""" + assert CommandType.GET_STATS.value == 56 + + +@pytest.mark.asyncio +async def test_get_stats_core_uses_enum(): + """get_stats_core sends CommandType.GET_STATS.value (0x38) + 0x00.""" + from meshcore.commands.device import DeviceCommands + + cmd = DeviceCommands.__new__(DeviceCommands) + + captured_data = None + + async def mock_send(data, expected_events, timeout=None): + nonlocal captured_data + captured_data = bytes(data) + return Event(EventType.STATS_CORE, {}) + + cmd.send = mock_send + + await cmd.get_stats_core() + + assert captured_data is not None + assert captured_data[0] == CommandType.GET_STATS.value # 0x38 = 56 + assert captured_data[1] == 0x00 diff --git a/tests/unit/test_reader.py b/tests/unit/test_reader.py index 39bb8ac..758968e 100644 --- a/tests/unit/test_reader.py +++ b/tests/unit/test_reader.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import asyncio +import logging from unittest.mock import AsyncMock from meshcore.events import EventType from meshcore.reader import MessageReader @@ -87,4 +88,193 @@ async def test_binary_response(): print(f"⚠️ Unknown request type {request_type}, no specific event expected") if __name__ == "__main__": - asyncio.run(test_binary_response()) \ No newline at end of file + asyncio.run(test_binary_response()) + + +# --------------------------------------------------------------------------- +# Reader/parser crash-safety verification tests +# --------------------------------------------------------------------------- + +class _CapturingDispatcher: + """Quiet dispatcher that records every dispatched event.""" + def __init__(self): + self.events = [] + + async def dispatch(self, event): + self.events.append(event) + + +@pytest.mark.asyncio +async def test_handle_rx_malformed_frame_logged_and_swallowed(caplog): + """Malformed frame must not propagate, must be logged with traceback.""" + dispatcher = _CapturingDispatcher() + reader = MessageReader(dispatcher) + + # 4-byte CHANNEL_MSG_RECV_V3 frame: type byte (0x11) + 1 SNR byte + + # 2 reserved bytes, but no channel_idx byte. The handler will raise + # IndexError on the next dbuf.read(1)[0] when the buffer is empty. + # The umbrella try/except must catch it, log the parse error, and + # return cleanly. + malformed = bytearray.fromhex("11100000") + + with caplog.at_level(logging.ERROR, logger="meshcore"): + await reader.handle_rx(malformed) # must not raise + + error_records = [r for r in caplog.records if "handle_rx parse error" in r.message] + assert error_records, ( + f"Expected an error log containing 'handle_rx parse error'; " + f"got: {[r.message for r in caplog.records]}" + ) + # Traceback should be present in the log message + assert "Traceback" in error_records[0].message, ( + "Umbrella log message must include a traceback" + ) + # No CHANNEL_MSG_RECV event should have been dispatched + assert not any(e.type == EventType.CHANNEL_MSG_RECV for e in dispatcher.events) + + +@pytest.mark.asyncio +async def test_battery_short_frame_omits_storage_fields(): + """Short BATTERY frame must not silently yield zero used_kb/total_kb.""" + dispatcher = _CapturingDispatcher() + reader = MessageReader(dispatcher) + + # 3-byte BATTERY frame: type 0x0c + 2 level bytes (no storage tail). + # Pre-fix the `len(data) > 3` gate would have let any frame >= 4 bytes + # through, producing a BATTERY event with bogus zero used_kb/total_kb + # because io.BytesIO.read() returns short data without raising. + # Post-fix (`len(data) >= 11`) the storage fields are skipped entirely. + short_battery = bytearray.fromhex("0c8000") + + await reader.handle_rx(short_battery) + + battery_events = [e for e in dispatcher.events if e.type == EventType.BATTERY] + assert len(battery_events) == 1, ( + f"Expected exactly one BATTERY event, got {len(battery_events)}" + ) + payload = battery_events[0].payload + assert payload["level"] == 0x0080, f"Unexpected level: {payload['level']}" + assert "used_kb" not in payload, ( + "Short BATTERY frame must not include used_kb (would be a silent zero)" + ) + assert "total_kb" not in payload, ( + "Short BATTERY frame must not include total_kb (would be a silent zero)" + ) + + +@pytest.mark.asyncio +async def test_battery_too_short_for_level(caplog): + """BATTERY frame shorter than 3 bytes must be dropped entirely (Option B). + + A 1-byte frame (just the packet-type byte 0x0c, no level bytes) would cause + dbuf.read(2) to return b"" and int.from_bytes(b"", ...) to silently yield 0. + The fix adds an early return with a debug log. + """ + dispatcher = _CapturingDispatcher() + reader = MessageReader(dispatcher) + + # 1-byte BATTERY frame: only the type byte, no level payload. + too_short = bytearray.fromhex("0c") + + with caplog.at_level(logging.DEBUG, logger="meshcore"): + await reader.handle_rx(too_short) + + battery_events = [e for e in dispatcher.events if e.type == EventType.BATTERY] + assert len(battery_events) == 0, ( + "BATTERY frame shorter than 3 bytes must not dispatch an event" + ) + debug_records = [ + r for r in caplog.records if "BATTERY frame too short" in r.message + ] + assert debug_records, "Expected a debug log about the short BATTERY frame" + + +@pytest.mark.asyncio +async def test_status_response_short_frame_skipped(caplog): + """Short STATUS_RESPONSE push frame must be skipped, not parsed with bogus zeros.""" + dispatcher = _CapturingDispatcher() + reader = MessageReader(dispatcher) + + # 30-byte STATUS_RESPONSE push frame, well below the 60-byte minimum. + # First byte is the type (0x87 = PacketType.STATUS_RESPONSE), the rest + # is arbitrary filler. parse_status with offset=8 reads up through + # data[56:60], so anything < 60 bytes would yield short reads and + # silent zero values pre-fix. + short_status = bytearray([0x87] + [0xAA] * 29) + assert len(short_status) == 30 + + with caplog.at_level(logging.DEBUG, logger="meshcore"): + await reader.handle_rx(short_status) + + status_events = [e for e in dispatcher.events if e.type == EventType.STATUS_RESPONSE] + assert len(status_events) == 0, ( + "Short STATUS_RESPONSE push frame must not dispatch a parsed event" + ) + assert any( + "STATUS_RESPONSE push frame too short" in r.message for r in caplog.records + ), "Expected a debug log line for short STATUS_RESPONSE frames" + + +@pytest.mark.asyncio +async def test_parse_packet_payload_txt_type_decodes_high_bits(): + """txt_type must decode the high 6 bits of byte 4, not always be 0.""" + from Crypto.Cipher import AES + from Crypto.Hash import HMAC, SHA256 + from meshcore.meshcore_parser import MeshcorePacketParser + + parser = MeshcorePacketParser() + parser.decrypt_channels = True + + # Set up a synthetic channel with a known 16-byte AES key. Direct dict + # assignment matches how the parser stores channels (newChannel is async + # and serves the same purpose). + channel_secret = b"\x01" * 16 + channel_hash_byte = 0xAB + parser.channels[0] = { + "channel_idx": 0, + "channel_name": "test", + "channel_hash": "ab", + "channel_secret": channel_secret, + } + + # 16-byte plaintext (one AES block): + # bytes 0-3 = sender_timestamp (little-endian) + # byte 4 = (txt_type << 2) | attempt + # bytes 5-15 = message + null padding + # Pick txt_type=5, attempt=1 → byte 4 = (5 << 2) | 1 = 0x15. + # Pre-fix uncrypted[4:4] is empty so txt_type would be 0; + # post-fix uncrypted[4:5] yields 0x15 >> 2 = 5. + plaintext = b"\x00\x00\x00\x00\x15hello\x00\x00\x00\x00\x00\x00" + assert len(plaintext) == 16 + + encrypted = AES.new(channel_secret, AES.MODE_ECB).encrypt(plaintext) + + # cipher_mac = first 2 bytes of HMAC-SHA256(channel_secret, encrypted) + h = HMAC.new(channel_secret, digestmod=SHA256) + h.update(encrypted) + cipher_mac = h.digest()[:2] + + # pkt_payload layout: 1-byte chan_hash + 2-byte cipher_mac + ciphertext + pkt_payload = bytes([channel_hash_byte]) + cipher_mac + encrypted + + # parsePacketPayload expects the full payload buffer: + # header byte (route_type=1 DIRECT, payload_type=5 channel, ver=0) + # path_byte (path_len=0, path_hash_size=1) → 0x00 + # pkt_payload + header = 0x15 # route_type=1, payload_type=5, payload_ver=0 + path_byte = 0x00 + payload = bytes([header, path_byte]) + pkt_payload + + log_data = await parser.parsePacketPayload(payload, log_data={}) + + assert log_data["payload_type"] == 0x05 + assert "txt_type" in log_data, ( + f"txt_type missing from log_data — channel decrypt path was not reached. " + f"log_data keys: {list(log_data.keys())}" + ) + assert log_data["txt_type"] == 5, ( + f"Expected txt_type=5, got {log_data['txt_type']}" + ) + assert log_data["attempt"] == 1, ( + f"Expected attempt=1, got {log_data['attempt']}" + ) \ No newline at end of file diff --git a/tests/unit/test_transport_symmetry.py b/tests/unit/test_transport_symmetry.py new file mode 100644 index 0000000..1882f3d --- /dev/null +++ b/tests/unit/test_transport_symmetry.py @@ -0,0 +1,238 @@ +""" +Verification tests for transport symmetry fixes. + +Covers: send symmetry across transports, serial disconnect callback on +transport-lost, serial connect timeout, oversize-frame return, BLE +disconnect-callback re-registration, BLE pairing failure re-raise, +TCP counter per frame not per segment. +""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from meshcore.tcp_cx import TCPConnection +from meshcore.serial_cx import SerialConnection + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class RecordingReader: + """Minimal reader mock that records dispatched frames.""" + def __init__(self): + self.frames = [] + + async def handle_rx(self, data): + self.frames.append(bytes(data)) + + +# --------------------------------------------------------------------------- +# TCP send() wraps transport.write in try/except +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_tcp_send_write_error_fires_disconnect(): + """TCP: OSError during transport.write fires _disconnect_callback.""" + cx = TCPConnection("127.0.0.1", 5000) + cb = AsyncMock() + cx.set_disconnect_callback(cb) + + mock_transport = MagicMock() + mock_transport.write.side_effect = OSError("Broken pipe") + cx.transport = mock_transport + cx._send_count = 0 + cx._receive_count = 0 + + await cx.send(b"\x01\x02\x03") + + cb.assert_awaited_once() + reason = cb.call_args[0][0] + assert "tcp_write_failed" in reason + + +# --------------------------------------------------------------------------- +# Serial send() fires disconnect on transport-lost and write error +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_serial_send_no_transport_fires_disconnect(): + """Serial: send() on None transport fires _disconnect_callback .""" + cx = SerialConnection("/dev/null", 115200) + cb = AsyncMock() + cx.set_disconnect_callback(cb) + cx.transport = None + + await cx.send(b"\x01") + + cb.assert_awaited_once() + reason = cb.call_args[0][0] + assert reason == "serial_transport_lost" + + +@pytest.mark.asyncio +async def test_serial_send_write_error_fires_disconnect(): + """Serial: OSError during transport.write fires _disconnect_callback.""" + cx = SerialConnection("/dev/null", 115200) + cb = AsyncMock() + cx.set_disconnect_callback(cb) + + mock_transport = MagicMock() + mock_transport.write.side_effect = OSError("Device not configured") + cx.transport = mock_transport + + await cx.send(b"\x01") + + cb.assert_awaited_once() + reason = cb.call_args[0][0] + assert "serial_write_failed" in reason + + +# --------------------------------------------------------------------------- +# BLE send() fires disconnect on transport-lost and write error +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_ble_send_no_client_fires_disconnect(): + """BLE: send() with no client fires _disconnect_callback.""" + # Can't import BLEConnection directly if bleak isn't installed, + # so test via dynamic import with a guard. + try: + from meshcore.ble_cx import BLEConnection + except ImportError: + pytest.skip("bleak not installed") + + # BLEConnection.__init__ checks BLEAK_AVAILABLE; patch it + with patch("meshcore.ble_cx.BLEAK_AVAILABLE", True), \ + patch("meshcore.ble_cx.BleakClient", MagicMock()): + cx = BLEConnection.__new__(BLEConnection) + cx.client = None + cx._user_provided_client = None + cx._user_provided_address = None + cx._user_provided_device = None + cx.address = None + cx.device = None + cx.pin = None + cx.rx_char = None + cb = AsyncMock() + cx._disconnect_callback = cb + + result = await cx.send(b"\x01") + + assert result is False + cb.assert_awaited_once() + reason = cb.call_args[0][0] + assert reason == "ble_transport_lost" + + +# --------------------------------------------------------------------------- +# Serial connect() times out if connection_made never fires +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_serial_connect_timeout(): + """Serial: connect() raises TimeoutError if connection_made never fires.""" + cx = SerialConnection("/dev/null", 115200) + + # Mock create_serial_connection to do nothing (never fires connection_made) + async def mock_create(*args, **kwargs): + return (MagicMock(), MagicMock()) + + with patch("meshcore.serial_cx.serial_asyncio.create_serial_connection", + side_effect=mock_create): + with pytest.raises(asyncio.TimeoutError): + await cx.connect(timeout=0.1) + + +# --------------------------------------------------------------------------- +# Oversize frame resets state and returns without dispatch +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_tcp_oversize_frame_empty_data_returns(): + """TCP: oversize header with no trailing data returns without dispatch.""" + cx = TCPConnection("127.0.0.1", 5000) + reader = RecordingReader() + cx.set_reader(reader) + + # Build a frame header with size > 300 and no payload data after header + # Header: 0x3e + 2-byte LE size (e.g. 500 = 0x01F4) + header = b"\x3e" + (500).to_bytes(2, "little") + cx.handle_rx(header) + await asyncio.sleep(0) + + # No frames should be dispatched, and state should be reset + assert reader.frames == [] + assert cx.header == b"" + assert cx.inframe == b"" + assert cx.frame_expected_size == 0 + + +@pytest.mark.asyncio +async def test_serial_oversize_frame_empty_data_returns(): + """Serial: oversize header with no trailing data returns without dispatch.""" + cx = SerialConnection("/dev/null", 115200) + reader = RecordingReader() + cx.set_reader(reader) + + header = b"\x3e" + (500).to_bytes(2, "little") + cx.handle_rx(header) + await asyncio.sleep(0) + + assert reader.frames == [] + assert cx.header == b"" + assert cx.inframe == b"" + assert cx.frame_expected_size == 0 + + +# --------------------------------------------------------------------------- +# TCP receive counter increments per MeshCore frame, not per TCP segment +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_tcp_receive_count_per_frame_not_per_segment(): + """TCP: _receive_count increments per completed frame, not per data_received call.""" + cx = TCPConnection("127.0.0.1", 5000) + reader = RecordingReader() + cx.set_reader(reader) + cx._receive_count = 0 + + # Build a 4-byte payload frame + payload = b"\xAA\xBB\xCC\xDD" + frame = b"\x3e" + len(payload).to_bytes(2, "little") + payload + + # Split the frame into 3 TCP segments (simulating fragmentation) + protocol = TCPConnection.MCClientProtocol(cx) + protocol.data_received(frame[:2]) # partial header + protocol.data_received(frame[2:5]) # rest of header + 2 bytes payload + protocol.data_received(frame[5:]) # remaining payload + + await asyncio.sleep(0) + + # 3 data_received calls but only 1 completed frame + assert cx._receive_count == 1 + assert reader.frames == [payload] + + +@pytest.mark.asyncio +async def test_tcp_multiple_frames_count_correctly(): + """TCP: two complete frames in separate segments → _receive_count == 2.""" + cx = TCPConnection("127.0.0.1", 5000) + reader = RecordingReader() + cx.set_reader(reader) + cx._receive_count = 0 + + payload1 = b"\x01\x02" + frame1 = b"\x3e" + len(payload1).to_bytes(2, "little") + payload1 + payload2 = b"\x03\x04\x05" + frame2 = b"\x3e" + len(payload2).to_bytes(2, "little") + payload2 + + protocol = TCPConnection.MCClientProtocol(cx) + protocol.data_received(frame1) + protocol.data_received(frame2) + await asyncio.sleep(0) + + assert cx._receive_count == 2 + assert reader.frames == [payload1, payload2]